Skip to content

Remove last_permutation from SwapMCState initialization #379

@curtischong

Description

@curtischong

When doing a MC swap, we need to pass in last_permutation. But I think the SwapMCState can just figure out an initial last_permutation automatically (no need to pass it into the state at first?). Is there a strong reason why we need to pass it in?

@dataclass(kw_only=True)
class SwapMCState(SimState):
    """State for Monte Carlo simulations with swap moves.

    This class extends the SimState to include properties specific to Monte Carlo
    simulations, such as the system energy and records of permutations applied
    during the simulation.

    Attributes:
        energy (torch.Tensor): Energy of the system with shape [batch_size]
        last_permutation (torch.Tensor): Last permutation applied to the system,
            with shape [n_atoms], tracking the moves made for analysis or reversal
    """

    energy: torch.Tensor
    last_permutation: torch.Tensor

    _atom_attributes = SimState._atom_attributes | {"last_permutation"}  # noqa: SLF001
    _system_attributes = SimState._system_attributes | {"energy"}  # noqa: SLF001

Maybe a better approach is to make it | None so people don't need to pass it in when they create this swap state?

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions