diff --git a/lib/python/picongpu/picmi/diagnostics/checkpoint.py b/lib/python/picongpu/picmi/diagnostics/checkpoint.py index 85d500cf2c..4dde000a00 100644 --- a/lib/python/picongpu/picmi/diagnostics/checkpoint.py +++ b/lib/python/picongpu/picmi/diagnostics/checkpoint.py @@ -5,7 +5,7 @@ License: GPLv3+ """ -from typing import Dict, Optional +from typing import Dict, Optional, Union import typeguard @@ -28,7 +28,7 @@ class Checkpoint: Parameters ---------- - period: TimeStepSpec, optional + period: int or TimeStepSpec, optional Specify on which time steps to create checkpoints. Unit: steps (simulation time steps). Required if timePeriod is not provided. @@ -68,9 +68,7 @@ class Checkpoint: """ def check(self, *args, **kwargs): - if self.period is None and self.timePeriod is None: - raise ValueError("At least one of period or timePeriod must be provided") - if self.timePeriod is not None and self.timePeriod < 0: + if self.timePeriod is not None and (not isinstance(self.timePeriod, int) or self.timePeriod < 0): raise ValueError("timePeriod must be a non-negative integer") if self.restartStep is not None and self.restartStep < 0: raise ValueError("restartStep must be non-negative") @@ -81,9 +79,9 @@ def check(self, *args, **kwargs): def __init__( self, - period: Optional[TimeStepSpec] = None, + period: Optional[Union[int, TimeStepSpec]] = None, timePeriod: Optional[int] = None, - directory: Optional[str] = None, + directory: Optional[str] = "checkpoints", file: Optional[str] = None, restart: Optional[bool] = None, tryRestart: Optional[bool] = None, @@ -94,8 +92,23 @@ def __init__( restartLoop: Optional[int] = None, openPMD: Optional[Dict] = None, ): - self.period = period + if period is not None and not isinstance(period, (int, TimeStepSpec)): + raise TypeError("period must be an integer or TimeStepSpec") + if isinstance(period, int): + if period < 0: + raise ValueError("period must be non-negative") + if period == 0: + self.period = TimeStepSpec()("steps") + else: + self.period = TimeStepSpec(slice(None, None, period))("steps") + else: + self.period = period if period is not None else TimeStepSpec()("steps") + self.timePeriod = timePeriod + if (self.timePeriod is None or self.timePeriod <= 0) and (self.period is None or not self.period.specs): + raise ValueError( + "At least one of period or timePeriod must be provided and active (period with steps or timePeriod > 0)" + ) self.directory = directory self.file = file self.restart = restart @@ -106,3 +119,4 @@ def __init__( self.restartChunkSize = restartChunkSize self.restartLoop = restartLoop self.openPMD = openPMD + self.check() diff --git a/lib/python/picongpu/pypicongpu/output/checkpoint.py b/lib/python/picongpu/pypicongpu/output/checkpoint.py index 8bc02dc219..77dc55c7ee 100644 --- a/lib/python/picongpu/pypicongpu/output/checkpoint.py +++ b/lib/python/picongpu/pypicongpu/output/checkpoint.py @@ -32,7 +32,18 @@ class Checkpoint(Plugin): _name = "checkpoint" def __init__(self): - "do nothing" + self.period = None + self.timePeriod = None + self.directory = None + self.file = None + self.restart = None + self.tryRestart = None + self.restartStep = None + self.restartDirectory = None + self.restartFile = None + self.restartChunkSize = None + self.restartLoop = None + self.openPMD = None def check(self): if self.period is None and self.timePeriod is None: @@ -49,18 +60,31 @@ def check(self): def _get_serialized(self) -> typing.Dict: """Return the serialized representation of the object.""" self.check() - serialized = { - "period": self.period.get_rendering_context() if self.period is not None else None, - "timePeriod": self.timePeriod, - "directory": self.directory, - "file": self.file, - "restart": self.restart, - "tryRestart": self.tryRestart, - "restartStep": self.restartStep, - "restartDirectory": self.restartDirectory, - "restartFile": self.restartFile, - "restartChunkSize": self.restartChunkSize, - "restartLoop": self.restartLoop, - "openPMD": self.openPMD, - } + serialized = {} + + if self.period is not None: + serialized["period"] = self.period.get_rendering_context() + if self.timePeriod is not None: + serialized["timePeriod"] = self.timePeriod + if self.directory is not None: + serialized["directory"] = self.directory + if self.file is not None: + serialized["file"] = self.file + if self.restart is not None: + serialized["restart"] = self.restart + if self.tryRestart is not None: + serialized["tryRestart"] = self.tryRestart + if self.restartStep is not None: + serialized["restartStep"] = self.restartStep + if self.restartDirectory is not None: + serialized["restartDirectory"] = self.restartDirectory + if self.restartFile is not None: + serialized["restartFile"] = self.restartFile + if self.restartChunkSize is not None: + serialized["restartChunkSize"] = self.restartChunkSize + if self.restartLoop is not None: + serialized["restartLoop"] = self.restartLoop + if self.openPMD is not None: + serialized["openPMD"] = self.openPMD + return serialized diff --git a/test/python/picongpu/quick/picmi/diagnostics/__init__.py b/test/python/picongpu/quick/picmi/diagnostics/__init__.py index 2d4b6a07c1..2785a0d5ca 100644 --- a/test/python/picongpu/quick/picmi/diagnostics/__init__.py +++ b/test/python/picongpu/quick/picmi/diagnostics/__init__.py @@ -1,9 +1,10 @@ """ This file is part of PIConGPU. Copyright 2025 PIConGPU contributors -Authors: Julian Lenz +Authors: Julian Lenz, Masoud Afshari License: GPLv3+ """ # flake8: noqa from .timestepspec import * # pyflakes.ignore +from .checkpoint import * # pyflakes.ignore diff --git a/test/python/picongpu/quick/picmi/diagnostics/checkpoint.py b/test/python/picongpu/quick/picmi/diagnostics/checkpoint.py new file mode 100644 index 0000000000..6c0649504d --- /dev/null +++ b/test/python/picongpu/quick/picmi/diagnostics/checkpoint.py @@ -0,0 +1,118 @@ +""" +This file is part of PIConGPU. +Copyright 2025 PIConGPU contributors +Authors: Masoud Afshari +License: GPLv3+ +""" + +from picongpu.picmi.diagnostics import Checkpoint, TimeStepSpec +from picongpu.pypicongpu.output.checkpoint import Checkpoint as PyPIConGPUCheckpoint +from picongpu.pypicongpu.output.timestepspec import TimeStepSpec as PyPIConGPUTimeStepSpec +import unittest +import typeguard + +TESTCASES_VALID = [ + ( + {"period": 10, "timePeriod": None, "directory": "checkpoints"}, + {"period": {"specs": [{"start": 0, "stop": -1, "step": 10}]}, "timePeriod": None, "directory": "checkpoints"}, + ), + ( + { + "period": TimeStepSpec(5, 10)("steps"), + "timePeriod": 10, + "restartStep": 100, + "restartDirectory": "backups", + "restartFile": "backup", + "restartChunkSize": 1000, + "restartLoop": 2, + "openPMD": {"ext": "h5"}, + }, + { + "period": {"specs": [{"start": 5, "stop": 5, "step": 1}, {"start": 10, "stop": 10, "step": 1}]}, + "timePeriod": 10, + "restartStep": 100, + "restartDirectory": "backups", + "restartFile": "backup", + "restartChunkSize": 1000, + "restartLoop": 2, + "openPMD": {"ext": "h5"}, + }, + ), +] + +logic_invalid_cases = [ + ({"period": None, "timePeriod": None}, "At least one of period or timePeriod must be provided and active"), + ({"period": None, "timePeriod": 0}, "At least one of period or timePeriod must be provided and active"), + ({"period": 0, "timePeriod": 0}, "At least one of period or timePeriod must be provided and active"), + ( + {"period": TimeStepSpec()("steps"), "timePeriod": 0}, + "At least one of period or timePeriod must be provided and active", + ), + ({"period": 10, "timePeriod": -5}, "timePeriod must be a non-negative"), + ({"period": 10, "restartStep": -1}, "restartStep must be non-negative"), + ({"period": 10, "restartChunkSize": 0}, "restartChunkSize must be positive"), + ({"period": 10, "restartLoop": -1}, "restartLoop must be non-negative"), +] + +type_invalid_cases = [ + ({"period": "invalid", "timePeriod": None}, 'argument "period".*did not match any element'), +] + +TESTCASES_INVALID_GET_AS = [ + ({"period": TimeStepSpec([slice(None, None, -10)]), "timePeriod": None}, "Step size must be >= 1"), + ({"period": 10, "timePeriod": None}, -0.5, 200, "time_step_size must be positive", True), + ({"period": 10, "timePeriod": None}, 0.5, 0, "num_steps must be positive", True), +] + + +class PICMI_TestCheckpoint(unittest.TestCase): + def test_checkpoint(self): + """Test Checkpoint instantiation, validation, and serialization.""" + for params, expected_serialized in TESTCASES_VALID: + with self.subTest(params=params): + checkpoint = Checkpoint(**params) + for key, value in params.items(): + if key == "period" and isinstance(value, int): + expected = TimeStepSpec(slice(None, None, value))("steps") + self.assertEqual(checkpoint.period.specs, expected.specs) + else: + self.assertEqual(getattr(checkpoint, key), value) + pypicongpu_checkpoint = checkpoint.get_as_pypicongpu(0.5, 200) + self.assertIsInstance(pypicongpu_checkpoint, PyPIConGPUCheckpoint) + self.assertIsInstance(pypicongpu_checkpoint.period, PyPIConGPUTimeStepSpec) + serialized_data = pypicongpu_checkpoint._get_serialized() + serialized = {"typeID": {"checkpoint": True}, "data": serialized_data} + self.assertEqual(serialized["typeID"], {"checkpoint": True}) + for key, value in expected_serialized.items(): + if key == "period": + self.assertEqual(serialized_data["period"]["specs"], value["specs"]) + elif key in serialized_data: + self.assertEqual(serialized_data[key], value) + else: + self.assertIsNone(value) + + for params, expected_error in logic_invalid_cases: + with self.subTest(params=params, expected_error=expected_error): + with self.assertRaisesRegex(ValueError, expected_error): + Checkpoint(**params) + + for params, expected_error in type_invalid_cases: + with self.subTest(params=params, expected_error=expected_error): + with self.assertRaisesRegex(typeguard.TypeCheckError, expected_error): + Checkpoint(**params) + + def test_checkpoint_invalid_cases(self): + """Test invalid TimeStepSpec and simulation parameters.""" + for params, *args in TESTCASES_INVALID_GET_AS: + with self.subTest(params=params, args=args): + checkpoint = Checkpoint(**params) + time_step_size, num_steps = args if len(args) == 2 else (0.5, 200) + expected_error, *skip = args[-1] if len(args) == 2 else "Step size must be >= 1" + if skip and skip[0]: + continue + with self.assertRaisesRegex(ValueError, expected_error): + checkpoint.get_as_pypicongpu({}, time_step_size, num_steps) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/python/picongpu/quick/pypicongpu/output/__init__.py b/test/python/picongpu/quick/pypicongpu/output/__init__.py index 350673738c..8727aaf445 100644 --- a/test/python/picongpu/quick/pypicongpu/output/__init__.py +++ b/test/python/picongpu/quick/pypicongpu/output/__init__.py @@ -2,3 +2,4 @@ from .auto import * # pyflakes.ignore from .phase_space import * # pyflakes.ignore from .timestepspec import * # pyflakes.ignore +from .checkpoint import * # pyflakes.ignore diff --git a/test/python/picongpu/quick/pypicongpu/output/checkpoint.py b/test/python/picongpu/quick/pypicongpu/output/checkpoint.py new file mode 100644 index 0000000000..c87da18ba1 --- /dev/null +++ b/test/python/picongpu/quick/pypicongpu/output/checkpoint.py @@ -0,0 +1,128 @@ +""" +This file is part of PIConGPU. +Copyright 2025 PIConGPU contributors +Authors: Masoud Afshari +License: GPLv3+ +""" + +from picongpu.pypicongpu.output import Checkpoint +from picongpu.pypicongpu.output.timestepspec import TimeStepSpec +import unittest +import typeguard + + +class TestCheckpoint(unittest.TestCase): + def test_instantiation_and_types(self): + """Test instantiation, type safety, and valid serialization.""" + # Valid configuration with period + cp = Checkpoint() + cp.period = TimeStepSpec([slice(0, None, 100)]) + cp.directory = "checkpoints" + cp.file = "checkpoint_%T" + cp.restart = True + cp.tryRestart = False + cp.restartStep = 0 + cp.restartDirectory = "restart" + cp.restartFile = "restart_%T" + cp.restartChunkSize = 1 + cp.restartLoop = 0 + cp.openPMD = {"ext": "bp"} + cp.check() + context = cp.get_rendering_context() + self.assertTrue(context["typeID"]["checkpoint"]) + self.assertEqual(context["data"]["period"]["specs"][0]["step"], 100) + self.assertIsNone(context["data"].get("timePeriod")) + + # Valid configuration with timePeriod + cp = Checkpoint() + cp.timePeriod = 100 + context = cp.get_rendering_context() + self.assertTrue(context["typeID"]["checkpoint"]) + self.assertEqual(context["data"]["timePeriod"], 100) + self.assertIsNone(context["data"].get("period")) + + # Type safety + invalid_types = { + "period": ["string", 1], + "timePeriod": ["string", 1.5], + "directory": [1, []], + "file": [1, []], + "restart": ["string", 1], + "tryRestart": ["string", 1], + "restartStep": ["string", 1.5], + "restartDirectory": [1, []], + "restartFile": [1, []], + "restartChunkSize": ["string", 1.5], + "restartLoop": ["string", 1.5], + "openPMD": ["string", 1], + } + for attr, invalid_values in invalid_types.items(): + for value in invalid_values: + with self.subTest(attr=attr, value=value): + cp = Checkpoint() + with self.assertRaises(typeguard.TypeCheckError): + setattr(cp, attr, value) + + def test_rendering_and_validation(self): + """Test serialization output, validation errors, and edge cases.""" + # Valid full serialization + cp = Checkpoint() + cp.period = TimeStepSpec([slice(0, None, 100)]) + cp.timePeriod = 100 + cp.directory = "checkpoints" + cp.file = "checkpoint_%T" + cp.restart = True + cp.tryRestart = False + cp.restartStep = 0 + cp.restartDirectory = "restart" + cp.restartFile = "restart_%T" + cp.restartChunkSize = 1 + cp.restartLoop = 0 + cp.openPMD = {"ext": "bp"} + context = cp.get_rendering_context() + self.assertTrue(context["typeID"]["checkpoint"]) + context = context["data"] + self.assertEqual(context["period"]["specs"][0]["step"], 100) + self.assertEqual(context["timePeriod"], 100) + self.assertEqual(context["directory"], "checkpoints") + self.assertEqual(context["file"], "checkpoint_%T") + self.assertTrue(context["restart"]) + self.assertFalse(context["tryRestart"]) + self.assertEqual(context["restartStep"], 0) + self.assertEqual(context["restartDirectory"], "restart") + self.assertEqual(context["restartFile"], "restart_%T") + self.assertEqual(context["restartChunkSize"], 1) + self.assertEqual(context["restartLoop"], 0) + self.assertEqual(context["openPMD"], {"ext": "bp"}) + + # Validation errors + cp = Checkpoint() + with self.assertRaisesRegex(ValueError, "At least one of period or timePeriod must be provided"): + cp.get_rendering_context() + + cp = Checkpoint() + cp.timePeriod = -1 + with self.assertRaisesRegex(ValueError, "timePeriod must be non-negative"): + cp.get_rendering_context() + + cp = Checkpoint() + cp.timePeriod = 100 + cp.restartStep = -1 + with self.assertRaisesRegex(ValueError, "restartStep must be non-negative"): + cp.get_rendering_context() + + cp = Checkpoint() + cp.timePeriod = 100 + cp.restartChunkSize = 0 + with self.assertRaisesRegex(ValueError, "restartChunkSize must be positive"): + cp.get_rendering_context() + + cp = Checkpoint() + cp.timePeriod = 100 + cp.restartLoop = -1 + with self.assertRaisesRegex(ValueError, "restartLoop must be non-negative"): + cp.get_rendering_context() + + +if __name__ == "__main__": + unittest.main()