Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ classifiers = [
requires-python = ">=3.12"
dependencies = [
"h5py>=3.12.1",
"nvalchemi-toolkit-ops",
"nvalchemi-toolkit-ops>=0.2.0",
"numpy>=1.26,<3",
"tables>=3.10.2",
"torch>=2",
Expand Down
46 changes: 29 additions & 17 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,11 @@ def test_neighbor_list_implementations(
neighbors.standard_nl,
]
+ ([neighbors.vesin_nl, neighbors.vesin_nl_ts] if neighbors.VESIN_AVAILABLE else [])
+ ([neighbors.alchemiops_nl_n2] if neighbors.ALCHEMIOPS_AVAILABLE else []),
+ (
[neighbors.alchemiops_nl_n2, neighbors.alchemiops_nl_cell_list]
if neighbors.ALCHEMIOPS_AVAILABLE
else []
),
)
def test_torch_nl_implementations(
*,
Expand Down Expand Up @@ -477,16 +481,18 @@ def test_torchsim_nl_availability() -> None:

if neighbors.ALCHEMIOPS_AVAILABLE:
assert neighbors.alchemiops_nl_n2 is not None
assert neighbors.alchemiops_nl_cell_list is not None
else:
assert neighbors.alchemiops_nl_n2 is None
assert neighbors.alchemiops_nl_cell_list is None


@pytest.mark.skipif(
not neighbors.ALCHEMIOPS_AVAILABLE or not torch.cuda.is_available(),
reason="Alchemiops requires CUDA",
)
def test_alchemiops_nl_edge_cases() -> None:
"""Test edge cases for alchemiops_nl_n2 implementation (CUDA only)."""
"""Test edge cases for alchemiops implementations (CUDA only)."""
device = torch.device("cuda")
dtype = torch.float32

Expand All @@ -495,20 +501,24 @@ def test_alchemiops_nl_edge_cases() -> None:
cutoff = torch.tensor(1.5, device=device, dtype=dtype)
system_idx = torch.zeros(2, dtype=torch.long, device=device)

# Test alchemiops_nl_n2
for pbc in (
torch.tensor([True, True, True], device=device),
torch.tensor([False, False, False], device=device),
):
mapping, sys_map, _shifts = neighbors.alchemiops_nl_n2(
positions=pos,
cell=cell,
pbc=pbc,
cutoff=cutoff,
system_idx=system_idx,
)
assert len(mapping[0]) > 0 # Should find neighbors
assert (sys_map == 0).all() # All in system 0
# Test both implementations
for nl_impl, impl_name in [
(neighbors.alchemiops_nl_n2, "alchemiops_nl_n2"),
(neighbors.alchemiops_nl_cell_list, "alchemiops_nl_cell_list"),
]:
for pbc in (
torch.tensor([True, True, True], device=device),
torch.tensor([False, False, False], device=device),
):
mapping, sys_map, _shifts = nl_impl(
positions=pos,
cell=cell,
pbc=pbc,
cutoff=cutoff,
system_idx=system_idx,
)
assert len(mapping[0]) > 0, f"{impl_name} should find neighbors"
assert (sys_map == 0).all(), f"{impl_name}: All pairs should be in system 0"


def test_fallback_when_alchemiops_unavailable(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down Expand Up @@ -721,7 +731,9 @@ def test_neighbor_lists_time_and_memory() -> None:
]
)
if neighbors.ALCHEMIOPS_AVAILABLE and DEVICE.type == "cuda":
nl_implementations.append(neighbors.alchemiops_nl_n2)
nl_implementations.extend(
[neighbors.alchemiops_nl_n2, neighbors.alchemiops_nl_cell_list]
)

for nl_fn in nl_implementations:
# Get initial memory usage
Expand Down
8 changes: 7 additions & 1 deletion torch_sim/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,15 @@ def _normalize_inputs(

# Try to import Alchemiops implementations (NVIDIA CUDA acceleration)
try:
from torch_sim.neighbors.alchemiops import ALCHEMIOPS_AVAILABLE, alchemiops_nl_n2
from torch_sim.neighbors.alchemiops import (
ALCHEMIOPS_AVAILABLE,
alchemiops_nl_cell_list,
alchemiops_nl_n2,
)
except ImportError:
ALCHEMIOPS_AVAILABLE = False
alchemiops_nl_n2 = None # type: ignore[assignment]
alchemiops_nl_cell_list = None # type: ignore[assignment]

# Try to import Vesin implementations
try:
Expand Down Expand Up @@ -145,6 +150,7 @@ def torchsim_nl(
"VESIN_AVAILABLE",
"VesinNeighborList",
"VesinNeighborListTorch",
"alchemiops_nl_cell_list",
"alchemiops_nl_n2",
"default_batched_nl",
"primitive_neighbor_list",
Expand Down
90 changes: 86 additions & 4 deletions torch_sim/neighbors/alchemiops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Alchemiops-based neighbor list implementations.

This module provides high-performance CUDA-accelerated neighbor list calculations
using the nvalchemiops library. Uses the naive N^2 implementation for reliability.
using the nvalchemiops library. Supports both naive N^2 and cell list algorithms.

nvalchemiops is available at: https://github.com/NVIDIA/nvalchemiops
"""
Expand All @@ -10,17 +10,19 @@


try:
from nvalchemiops.neighborlist import batch_naive_neighbor_list
from nvalchemiops.neighborlist import batch_cell_list, batch_naive_neighbor_list
from nvalchemiops.neighborlist.neighbor_utils import estimate_max_neighbors

ALCHEMIOPS_AVAILABLE = True
except ImportError:
ALCHEMIOPS_AVAILABLE = False
batch_naive_neighbor_list = None # type: ignore[assignment]
estimate_max_neighbors = None # type: ignore[assignment]
batch_cell_list = None # type: ignore[assignment]
estimate_max_neighbors = None # type: ignore[assignment, name-defined]

__all__ = [
"ALCHEMIOPS_AVAILABLE",
"alchemiops_nl_cell_list",
"alchemiops_nl_n2",
]

Expand Down Expand Up @@ -98,8 +100,79 @@ def alchemiops_nl_n2(

return mapping, system_mapping, shifts_idx

def alchemiops_nl_cell_list(
positions: torch.Tensor,
cell: torch.Tensor,
pbc: torch.Tensor,
cutoff: torch.Tensor,
system_idx: torch.Tensor,
self_interaction: bool = False, # noqa: FBT001, FBT002
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute neighbor lists using Alchemiops cell list algorithm.

Args:
positions: Atomic positions tensor [n_atoms, 3]
cell: Unit cell vectors [n_systems, 3, 3] or [3, 3]
pbc: Boolean tensor [n_systems, 3] or [3]
cutoff: Maximum distance (scalar tensor)
system_idx: Tensor [n_atoms] indicating system assignment
self_interaction: If True, include self-pairs

Returns:
(mapping, system_mapping, shifts_idx)
"""
from torch_sim.neighbors import _normalize_inputs

r_max = cutoff.item() if isinstance(cutoff, torch.Tensor) else cutoff
n_systems = system_idx.max().item() + 1
cell, pbc = _normalize_inputs(cell, pbc, n_systems)

# Call alchemiops cell list
res = batch_cell_list(
positions=positions,
cutoff=r_max,
batch_idx=system_idx.to(torch.int32),
cell=cell,
pbc=pbc.to(torch.bool),
return_neighbor_list=True,
)

# Parse results: (neighbor_list, neighbor_ptr[, neighbor_list_shifts])
if len(res) == 3: # type: ignore[arg-type]
mapping, _, shifts_idx = res # type: ignore[misc]
else:
mapping, _ = res # type: ignore[misc]
shifts_idx = torch.zeros(
(mapping.shape[1], 3), dtype=positions.dtype, device=positions.device
)

# Convert dtypes
mapping = mapping.to(dtype=torch.long)
# Convert shifts_idx to floating point to match cell dtype (for einsum)
shifts_idx = shifts_idx.to(dtype=cell.dtype)

# Create system_mapping
system_mapping = system_idx[mapping[0]]

# Alchemiops does NOT include self-interactions by default
# Add them only if requested
if self_interaction:
n_atoms = positions.shape[0]
self_pairs = torch.arange(n_atoms, device=positions.device, dtype=torch.long)
self_mapping = torch.stack([self_pairs, self_pairs], dim=0)
# Self-shifts should match shifts_idx dtype
self_shifts = torch.zeros(
(n_atoms, 3), dtype=cell.dtype, device=positions.device
)

mapping = torch.cat([mapping, self_mapping], dim=1)
shifts_idx = torch.cat([shifts_idx, self_shifts], dim=0)
system_mapping = torch.cat([system_mapping, system_idx], dim=0)

return mapping, system_mapping, shifts_idx

else:
# Provide stub function that raises informative error
# Provide stub functions that raise informative errors
def alchemiops_nl_n2( # type: ignore[misc]
*args, # noqa: ARG001
**kwargs, # noqa: ARG001
Expand All @@ -108,3 +181,12 @@ def alchemiops_nl_n2( # type: ignore[misc]
raise ImportError(
"nvalchemiops is not installed. Install it with: pip install nvalchemiops"
)

def alchemiops_nl_cell_list( # type: ignore[misc]
*args, # noqa: ARG001
**kwargs, # noqa: ARG001
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Stub function when nvalchemiops is not available."""
raise ImportError(
"nvalchemiops is not installed. Install it with: pip install nvalchemiops"
)
Loading