Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,21 @@ def test_integrate_double_nvt(
assert not torch.isnan(final_state.energy).any()


def test_integrate_double_nvt_multiple_temperatures(
ar_double_sim_state: SimState, lj_model: LennardJonesModel
) -> None:
"""Test NVT integration with LJ potential."""
_ = ts.integrate(
system=ar_double_sim_state,
model=lj_model,
integrator=ts.Integrator.nvt_langevin,
n_steps=10,
temperature=[100.0, 200.0], # K
timestep=0.001, # ps
init_kwargs=dict(seed=481516),
)


def test_integrate_double_nvt_with_reporter(
ar_double_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path
) -> None:
Expand Down
89 changes: 77 additions & 12 deletions torch_sim/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,67 @@ def _configure_batches_iterator(
return batches


def _normalize_temperature_tensor(
temperature: float | list | torch.Tensor, n_steps: int, initial_state: SimState
) -> torch.Tensor:
"""Turn the temperature into a tensor of shape (n_steps,) or (n_steps, n_systems).

Args:
temperature (float | int | list | torch.Tensor): Temperature input
n_steps (int): Number of integration steps
initial_state (SimState): Initial simulation state for dtype and device
Returns:
torch.Tensor: Normalized temperature tensor
"""
# ---- Step 1: Convert to tensor ----
if isinstance(temperature, (float, int)):
return torch.full(
(n_steps,),
float(temperature),
dtype=initial_state.dtype,
device=initial_state.device,
)

# Convert list or tensor input to tensor
if isinstance(temperature, list):
temps = torch.tensor(
temperature, dtype=initial_state.dtype, device=initial_state.device
)
elif isinstance(temperature, torch.Tensor):
temps = temperature.to(dtype=initial_state.dtype, device=initial_state.device)
else:
raise TypeError(
f"Invalid temperature type: {type(temperature).__name__}. "
"Must be float, int, list, or torch.Tensor."
)

# ---- Step 2: Determine how to broadcast ----
temps = torch.atleast_1d(temps)
if temps.ndim > 2:
raise ValueError(f"Temperature tensor must be 1D or 2D, got shape {temps.shape}.")

if temps.shape[0] == 1:
# A single value in a 1-element list/tensor
return temps.repeat(n_steps)

# This assumes that in case n_systems == n_steps, the user wants to apply
# different temperatures per system, not per step.
if temps.shape[0] == initial_state.n_systems:
# Interpret as single-step multi-system temperatures → broadcast over steps
return temps.unsqueeze(0).expand(n_steps, -1) # (n_steps, n_systems)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly might just throw a well-documented error here and demand the user provide a 2D tensor if n_systems == n_steps. It's a rare edge case and the consequences of incorrectly guessing the default is potentially many hours of wasted debugging.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a warning, let me know if you want to be more strict and raise an error

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me


if temps.shape[0] == n_steps:
return temps # already good: (n_steps,) or (n_steps, n_systems)

raise ValueError(
f"Temperature length ({temps.shape[0]}) must be either:\n"
f" - n_steps ({n_steps}), or\n"
f" - n_systems ({initial_state.n_systems}), or\n"
f" - 1 (scalar),\n"
f"but got {temps.shape[0]}."
)


def integrate[T: SimState]( # noqa: C901
system: StateLike,
model: ModelInterface,
Expand All @@ -123,7 +184,11 @@ def integrate[T: SimState]( # noqa: C901
(init_func, step_func) functions.
n_steps (int): Number of integration steps
temperature (float | ArrayLike): Temperature or array of temperatures for each
step
step or system:
Float: used for all steps and systems
1D array of length n_steps: used for each step
1D array of length n_systems: used for each system
2D array of shape (n_steps, n_systems): used for each step and system.
timestep (float): Integration time step
trajectory_reporter (TrajectoryReporter | dict | None): Optional reporter for
tracking trajectory. If a dict, will be passed to the TrajectoryReporter
Expand All @@ -140,18 +205,11 @@ def integrate[T: SimState]( # noqa: C901
T: Final state after integration
"""
unit_system = UnitSystem.metal
# create a list of temperatures
temps = (
[temperature] * n_steps
if isinstance(temperature, (float, int))
else list(temperature)
)
if len(temps) != n_steps:
raise ValueError(f"{len(temps)=:,}. It must equal n_steps = {n_steps=:,}")

initial_state: SimState = ts.initialize_state(system, model.device, model.dtype)
dtype, device = initial_state.dtype, initial_state.device
kTs = torch.tensor(temps, dtype=dtype, device=device) * unit_system.temperature
kTs = _normalize_temperature_tensor(temperature, n_steps, initial_state)
kTs = kTs * unit_system.temperature
dt = torch.tensor(timestep * unit_system.time, dtype=dtype, device=device)

# Handle both string names and direct function tuples
Expand Down Expand Up @@ -192,7 +250,10 @@ def integrate[T: SimState]( # noqa: C901
# Handle both BinningAutoBatcher and list of tuples
for state, system_indices in batch_iterator:
# Pass correct parameters based on integrator type
state = init_func(state=state, model=model, kT=kTs[0], dt=dt, **init_kwargs or {})
batch_kT = kTs[:, system_indices] if (system_indices and kTs.shape == 2) else kTs
state = init_func(
state=state, model=model, kT=batch_kT[0], dt=dt, **init_kwargs or {}
)

# set up trajectory reporters
if autobatcher and trajectory_reporter is not None and og_filenames is not None:
Expand All @@ -204,7 +265,11 @@ def integrate[T: SimState]( # noqa: C901
# run the simulation
for step in range(1, n_steps + 1):
state = step_func(
state=state, model=model, dt=dt, kT=kTs[step - 1], **integrator_kwargs
state=state,
model=model,
dt=dt,
kT=batch_kT[step - 1],
**integrator_kwargs,
)

if trajectory_reporter:
Expand Down