Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions lib/python/picongpu/picmi/diagnostics/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
License: GPLv3+
"""

from typing import Dict, Optional
from typing import Dict, Optional, Union

import typeguard

Expand All @@ -28,7 +28,7 @@ class Checkpoint:

Parameters
----------
period: TimeStepSpec, optional
period: int or TimeStepSpec, optional
Specify on which time steps to create checkpoints.
Unit: steps (simulation time steps). Required if timePeriod is not provided.

Expand Down Expand Up @@ -68,9 +68,7 @@ class Checkpoint:
"""

def check(self, *args, **kwargs):
if self.period is None and self.timePeriod is None:
raise ValueError("At least one of period or timePeriod must be provided")
if self.timePeriod is not None and self.timePeriod < 0:
if self.timePeriod is not None and (not isinstance(self.timePeriod, int) or self.timePeriod < 0):
raise ValueError("timePeriod must be a non-negative integer")
if self.restartStep is not None and self.restartStep < 0:
raise ValueError("restartStep must be non-negative")
Expand All @@ -81,9 +79,9 @@ def check(self, *args, **kwargs):

def __init__(
self,
period: Optional[TimeStepSpec] = None,
period: Optional[Union[int, TimeStepSpec]] = None,
timePeriod: Optional[int] = None,
directory: Optional[str] = None,
directory: Optional[str] = "checkpoints",
file: Optional[str] = None,
restart: Optional[bool] = None,
tryRestart: Optional[bool] = None,
Expand All @@ -94,8 +92,23 @@ def __init__(
restartLoop: Optional[int] = None,
openPMD: Optional[Dict] = None,
):
self.period = period
if period is not None and not isinstance(period, (int, TimeStepSpec)):
raise TypeError("period must be an integer or TimeStepSpec")
if isinstance(period, int):
if period < 0:
raise ValueError("period must be non-negative")
if period == 0:
self.period = TimeStepSpec()("steps")
else:
self.period = TimeStepSpec(slice(None, None, period))("steps")
else:
self.period = period if period is not None else TimeStepSpec()("steps")

self.timePeriod = timePeriod
if (self.timePeriod is None or self.timePeriod <= 0) and (self.period is None or not self.period.specs):
raise ValueError(
"At least one of period or timePeriod must be provided and active (period with steps or timePeriod > 0)"
)
self.directory = directory
self.file = file
self.restart = restart
Expand All @@ -106,3 +119,4 @@ def __init__(
self.restartChunkSize = restartChunkSize
self.restartLoop = restartLoop
self.openPMD = openPMD
self.check()
54 changes: 39 additions & 15 deletions lib/python/picongpu/pypicongpu/output/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,18 @@ class Checkpoint(Plugin):
_name = "checkpoint"

def __init__(self):
"do nothing"
self.period = None
self.timePeriod = None
self.directory = None
self.file = None
self.restart = None
self.tryRestart = None
self.restartStep = None
self.restartDirectory = None
self.restartFile = None
self.restartChunkSize = None
self.restartLoop = None
self.openPMD = None

def check(self):
if self.period is None and self.timePeriod is None:
Expand All @@ -49,18 +60,31 @@ def check(self):
def _get_serialized(self) -> typing.Dict:
"""Return the serialized representation of the object."""
self.check()
serialized = {
"period": self.period.get_rendering_context() if self.period is not None else None,
"timePeriod": self.timePeriod,
"directory": self.directory,
"file": self.file,
"restart": self.restart,
"tryRestart": self.tryRestart,
"restartStep": self.restartStep,
"restartDirectory": self.restartDirectory,
"restartFile": self.restartFile,
"restartChunkSize": self.restartChunkSize,
"restartLoop": self.restartLoop,
"openPMD": self.openPMD,
}
serialized = {}

if self.period is not None:
serialized["period"] = self.period.get_rendering_context()
if self.timePeriod is not None:
serialized["timePeriod"] = self.timePeriod
if self.directory is not None:
serialized["directory"] = self.directory
if self.file is not None:
serialized["file"] = self.file
if self.restart is not None:
serialized["restart"] = self.restart
if self.tryRestart is not None:
serialized["tryRestart"] = self.tryRestart
if self.restartStep is not None:
serialized["restartStep"] = self.restartStep
if self.restartDirectory is not None:
serialized["restartDirectory"] = self.restartDirectory
if self.restartFile is not None:
serialized["restartFile"] = self.restartFile
if self.restartChunkSize is not None:
serialized["restartChunkSize"] = self.restartChunkSize
if self.restartLoop is not None:
serialized["restartLoop"] = self.restartLoop
if self.openPMD is not None:
serialized["openPMD"] = self.openPMD

return serialized
3 changes: 2 additions & 1 deletion test/python/picongpu/quick/picmi/diagnostics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""
This file is part of PIConGPU.
Copyright 2025 PIConGPU contributors
Authors: Julian Lenz
Authors: Julian Lenz, Masoud Afshari
License: GPLv3+
"""

# flake8: noqa
from .timestepspec import * # pyflakes.ignore
from .checkpoint import * # pyflakes.ignore
118 changes: 118 additions & 0 deletions test/python/picongpu/quick/picmi/diagnostics/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
This file is part of PIConGPU.
Copyright 2025 PIConGPU contributors
Authors: Masoud Afshari
License: GPLv3+
"""

from picongpu.picmi.diagnostics import Checkpoint, TimeStepSpec
from picongpu.pypicongpu.output.checkpoint import Checkpoint as PyPIConGPUCheckpoint
from picongpu.pypicongpu.output.timestepspec import TimeStepSpec as PyPIConGPUTimeStepSpec
import unittest
import typeguard

TESTCASES_VALID = [
(
{"period": 10, "timePeriod": None, "directory": "checkpoints"},
{"period": {"specs": [{"start": 0, "stop": -1, "step": 10}]}, "timePeriod": None, "directory": "checkpoints"},
),
(
{
"period": TimeStepSpec(5, 10)("steps"),
"timePeriod": 10,
"restartStep": 100,
"restartDirectory": "backups",
"restartFile": "backup",
"restartChunkSize": 1000,
"restartLoop": 2,
"openPMD": {"ext": "h5"},
},
{
"period": {"specs": [{"start": 5, "stop": 5, "step": 1}, {"start": 10, "stop": 10, "step": 1}]},
"timePeriod": 10,
"restartStep": 100,
"restartDirectory": "backups",
"restartFile": "backup",
"restartChunkSize": 1000,
"restartLoop": 2,
"openPMD": {"ext": "h5"},
},
),
]

logic_invalid_cases = [
({"period": None, "timePeriod": None}, "At least one of period or timePeriod must be provided and active"),
({"period": None, "timePeriod": 0}, "At least one of period or timePeriod must be provided and active"),
({"period": 0, "timePeriod": 0}, "At least one of period or timePeriod must be provided and active"),
(
{"period": TimeStepSpec()("steps"), "timePeriod": 0},
"At least one of period or timePeriod must be provided and active",
),
({"period": 10, "timePeriod": -5}, "timePeriod must be a non-negative"),
({"period": 10, "restartStep": -1}, "restartStep must be non-negative"),
({"period": 10, "restartChunkSize": 0}, "restartChunkSize must be positive"),
({"period": 10, "restartLoop": -1}, "restartLoop must be non-negative"),
]

type_invalid_cases = [
({"period": "invalid", "timePeriod": None}, 'argument "period".*did not match any element'),
]

TESTCASES_INVALID_GET_AS = [
({"period": TimeStepSpec([slice(None, None, -10)]), "timePeriod": None}, "Step size must be >= 1"),
({"period": 10, "timePeriod": None}, -0.5, 200, "time_step_size must be positive", True),
({"period": 10, "timePeriod": None}, 0.5, 0, "num_steps must be positive", True),
]


class PICMI_TestCheckpoint(unittest.TestCase):
def test_checkpoint(self):
"""Test Checkpoint instantiation, validation, and serialization."""
for params, expected_serialized in TESTCASES_VALID:
with self.subTest(params=params):
checkpoint = Checkpoint(**params)
for key, value in params.items():
if key == "period" and isinstance(value, int):
expected = TimeStepSpec(slice(None, None, value))("steps")
self.assertEqual(checkpoint.period.specs, expected.specs)
else:
self.assertEqual(getattr(checkpoint, key), value)
pypicongpu_checkpoint = checkpoint.get_as_pypicongpu(0.5, 200)
self.assertIsInstance(pypicongpu_checkpoint, PyPIConGPUCheckpoint)
self.assertIsInstance(pypicongpu_checkpoint.period, PyPIConGPUTimeStepSpec)
serialized_data = pypicongpu_checkpoint._get_serialized()
serialized = {"typeID": {"checkpoint": True}, "data": serialized_data}
self.assertEqual(serialized["typeID"], {"checkpoint": True})
for key, value in expected_serialized.items():
if key == "period":
self.assertEqual(serialized_data["period"]["specs"], value["specs"])
elif key in serialized_data:
self.assertEqual(serialized_data[key], value)
else:
self.assertIsNone(value)

for params, expected_error in logic_invalid_cases:
with self.subTest(params=params, expected_error=expected_error):
with self.assertRaisesRegex(ValueError, expected_error):
Checkpoint(**params)

for params, expected_error in type_invalid_cases:
with self.subTest(params=params, expected_error=expected_error):
with self.assertRaisesRegex(typeguard.TypeCheckError, expected_error):
Checkpoint(**params)

def test_checkpoint_invalid_cases(self):
"""Test invalid TimeStepSpec and simulation parameters."""
for params, *args in TESTCASES_INVALID_GET_AS:
with self.subTest(params=params, args=args):
checkpoint = Checkpoint(**params)
time_step_size, num_steps = args if len(args) == 2 else (0.5, 200)
expected_error, *skip = args[-1] if len(args) == 2 else "Step size must be >= 1"
if skip and skip[0]:
continue
with self.assertRaisesRegex(ValueError, expected_error):
checkpoint.get_as_pypicongpu({}, time_step_size, num_steps)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions test/python/picongpu/quick/pypicongpu/output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .auto import * # pyflakes.ignore
from .phase_space import * # pyflakes.ignore
from .timestepspec import * # pyflakes.ignore
from .checkpoint import * # pyflakes.ignore
Loading