From 62d11b25349b3369a3750cceec239681209355a9 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Tue, 16 Dec 2025 14:23:42 -0800 Subject: [PATCH 1/2] add batch cell list --- tests/test_neighbors.py | 46 ++++++++++------ torch_sim/neighbors/__init__.py | 8 ++- torch_sim/neighbors/alchemiops.py | 90 +++++++++++++++++++++++++++++-- 3 files changed, 122 insertions(+), 22 deletions(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index ce0adcb1..6d5deb5d 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -316,7 +316,11 @@ def test_neighbor_list_implementations( neighbors.standard_nl, ] + ([neighbors.vesin_nl, neighbors.vesin_nl_ts] if neighbors.VESIN_AVAILABLE else []) - + ([neighbors.alchemiops_nl_n2] if neighbors.ALCHEMIOPS_AVAILABLE else []), + + ( + [neighbors.alchemiops_nl_n2, neighbors.alchemiops_nl_cell_list] + if neighbors.ALCHEMIOPS_AVAILABLE + else [] + ), ) def test_torch_nl_implementations( *, @@ -477,8 +481,10 @@ def test_torchsim_nl_availability() -> None: if neighbors.ALCHEMIOPS_AVAILABLE: assert neighbors.alchemiops_nl_n2 is not None + assert neighbors.alchemiops_nl_cell_list is not None else: assert neighbors.alchemiops_nl_n2 is None + assert neighbors.alchemiops_nl_cell_list is None @pytest.mark.skipif( @@ -486,7 +492,7 @@ def test_torchsim_nl_availability() -> None: reason="Alchemiops requires CUDA", ) def test_alchemiops_nl_edge_cases() -> None: - """Test edge cases for alchemiops_nl_n2 implementation (CUDA only).""" + """Test edge cases for alchemiops implementations (CUDA only).""" device = torch.device("cuda") dtype = torch.float32 @@ -495,20 +501,24 @@ def test_alchemiops_nl_edge_cases() -> None: cutoff = torch.tensor(1.5, device=device, dtype=dtype) system_idx = torch.zeros(2, dtype=torch.long, device=device) - # Test alchemiops_nl_n2 - for pbc in ( - torch.tensor([True, True, True], device=device), - torch.tensor([False, False, False], device=device), - ): - mapping, sys_map, _shifts = neighbors.alchemiops_nl_n2( - positions=pos, - cell=cell, - pbc=pbc, - cutoff=cutoff, - system_idx=system_idx, - ) - assert len(mapping[0]) > 0 # Should find neighbors - assert (sys_map == 0).all() # All in system 0 + # Test both implementations + for nl_impl, impl_name in [ + (neighbors.alchemiops_nl_n2, "alchemiops_nl_n2"), + (neighbors.alchemiops_nl_cell_list, "alchemiops_nl_cell_list"), + ]: + for pbc in ( + torch.tensor([True, True, True], device=device), + torch.tensor([False, False, False], device=device), + ): + mapping, sys_map, _shifts = nl_impl( + positions=pos, + cell=cell, + pbc=pbc, + cutoff=cutoff, + system_idx=system_idx, + ) + assert len(mapping[0]) > 0, f"{impl_name} should find neighbors" + assert (sys_map == 0).all(), f"{impl_name}: All pairs should be in system 0" def test_fallback_when_alchemiops_unavailable(monkeypatch: pytest.MonkeyPatch) -> None: @@ -721,7 +731,9 @@ def test_neighbor_lists_time_and_memory() -> None: ] ) if neighbors.ALCHEMIOPS_AVAILABLE and DEVICE.type == "cuda": - nl_implementations.append(neighbors.alchemiops_nl_n2) + nl_implementations.extend( + [neighbors.alchemiops_nl_n2, neighbors.alchemiops_nl_cell_list] + ) for nl_fn in nl_implementations: # Get initial memory usage diff --git a/torch_sim/neighbors/__init__.py b/torch_sim/neighbors/__init__.py index 915beaee..c22e570c 100644 --- a/torch_sim/neighbors/__init__.py +++ b/torch_sim/neighbors/__init__.py @@ -57,10 +57,15 @@ def _normalize_inputs( # Try to import Alchemiops implementations (NVIDIA CUDA acceleration) try: - from torch_sim.neighbors.alchemiops import ALCHEMIOPS_AVAILABLE, alchemiops_nl_n2 + from torch_sim.neighbors.alchemiops import ( + ALCHEMIOPS_AVAILABLE, + alchemiops_nl_cell_list, + alchemiops_nl_n2, + ) except ImportError: ALCHEMIOPS_AVAILABLE = False alchemiops_nl_n2 = None # type: ignore[assignment] + alchemiops_nl_cell_list = None # type: ignore[assignment] # Try to import Vesin implementations try: @@ -145,6 +150,7 @@ def torchsim_nl( "VESIN_AVAILABLE", "VesinNeighborList", "VesinNeighborListTorch", + "alchemiops_nl_cell_list", "alchemiops_nl_n2", "default_batched_nl", "primitive_neighbor_list", diff --git a/torch_sim/neighbors/alchemiops.py b/torch_sim/neighbors/alchemiops.py index 59d1ba8d..6cb211dd 100644 --- a/torch_sim/neighbors/alchemiops.py +++ b/torch_sim/neighbors/alchemiops.py @@ -1,7 +1,7 @@ """Alchemiops-based neighbor list implementations. This module provides high-performance CUDA-accelerated neighbor list calculations -using the nvalchemiops library. Uses the naive N^2 implementation for reliability. +using the nvalchemiops library. Supports both naive N^2 and cell list algorithms. nvalchemiops is available at: https://github.com/NVIDIA/nvalchemiops """ @@ -10,17 +10,19 @@ try: - from nvalchemiops.neighborlist import batch_naive_neighbor_list + from nvalchemiops.neighborlist import batch_cell_list, batch_naive_neighbor_list from nvalchemiops.neighborlist.neighbor_utils import estimate_max_neighbors ALCHEMIOPS_AVAILABLE = True except ImportError: ALCHEMIOPS_AVAILABLE = False batch_naive_neighbor_list = None # type: ignore[assignment] - estimate_max_neighbors = None # type: ignore[assignment] + batch_cell_list = None # type: ignore[assignment] + estimate_max_neighbors = None # type: ignore[assignment, name-defined] __all__ = [ "ALCHEMIOPS_AVAILABLE", + "alchemiops_nl_cell_list", "alchemiops_nl_n2", ] @@ -98,8 +100,79 @@ def alchemiops_nl_n2( return mapping, system_mapping, shifts_idx + def alchemiops_nl_cell_list( + positions: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, + cutoff: torch.Tensor, + system_idx: torch.Tensor, + self_interaction: bool = False, # noqa: FBT001, FBT002 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute neighbor lists using Alchemiops cell list algorithm. + + Args: + positions: Atomic positions tensor [n_atoms, 3] + cell: Unit cell vectors [n_systems, 3, 3] or [3, 3] + pbc: Boolean tensor [n_systems, 3] or [3] + cutoff: Maximum distance (scalar tensor) + system_idx: Tensor [n_atoms] indicating system assignment + self_interaction: If True, include self-pairs + + Returns: + (mapping, system_mapping, shifts_idx) + """ + from torch_sim.neighbors import _normalize_inputs + + r_max = cutoff.item() if isinstance(cutoff, torch.Tensor) else cutoff + n_systems = system_idx.max().item() + 1 + cell, pbc = _normalize_inputs(cell, pbc, n_systems) + + # Call alchemiops cell list + res = batch_cell_list( + positions=positions, + cutoff=r_max, + batch_idx=system_idx.to(torch.int32), + cell=cell, + pbc=pbc.to(torch.bool), + return_neighbor_list=True, + ) + + # Parse results: (neighbor_list, neighbor_ptr[, neighbor_list_shifts]) + if len(res) == 3: # type: ignore[arg-type] + mapping, _, shifts_idx = res # type: ignore[misc] + else: + mapping, _ = res # type: ignore[misc] + shifts_idx = torch.zeros( + (mapping.shape[1], 3), dtype=positions.dtype, device=positions.device + ) + + # Convert dtypes + mapping = mapping.to(dtype=torch.long) + # Convert shifts_idx to floating point to match cell dtype (for einsum) + shifts_idx = shifts_idx.to(dtype=cell.dtype) + + # Create system_mapping + system_mapping = system_idx[mapping[0]] + + # Alchemiops does NOT include self-interactions by default + # Add them only if requested + if self_interaction: + n_atoms = positions.shape[0] + self_pairs = torch.arange(n_atoms, device=positions.device, dtype=torch.long) + self_mapping = torch.stack([self_pairs, self_pairs], dim=0) + # Self-shifts should match shifts_idx dtype + self_shifts = torch.zeros( + (n_atoms, 3), dtype=cell.dtype, device=positions.device + ) + + mapping = torch.cat([mapping, self_mapping], dim=1) + shifts_idx = torch.cat([shifts_idx, self_shifts], dim=0) + system_mapping = torch.cat([system_mapping, system_idx], dim=0) + + return mapping, system_mapping, shifts_idx + else: - # Provide stub function that raises informative error + # Provide stub functions that raise informative errors def alchemiops_nl_n2( # type: ignore[misc] *args, # noqa: ARG001 **kwargs, # noqa: ARG001 @@ -108,3 +181,12 @@ def alchemiops_nl_n2( # type: ignore[misc] raise ImportError( "nvalchemiops is not installed. Install it with: pip install nvalchemiops" ) + + def alchemiops_nl_cell_list( # type: ignore[misc] + *args, # noqa: ARG001 + **kwargs, # noqa: ARG001 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Stub function when nvalchemiops is not available.""" + raise ImportError( + "nvalchemiops is not installed. Install it with: pip install nvalchemiops" + ) From 1009bce3adae4cbe9371476defe4225ca8de73b7 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Fri, 19 Dec 2025 09:24:37 -0800 Subject: [PATCH 2/2] pin minimum nvalchemi-toolkit-ops version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 340ae664..fd0e67c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ requires-python = ">=3.12" dependencies = [ "h5py>=3.12.1", - "nvalchemi-toolkit-ops", + "nvalchemi-toolkit-ops>=0.2.0", "numpy>=1.26,<3", "tables>=3.10.2", "torch>=2",