|
11 | 11 |
|
12 | 12 | import pygsti |
13 | 13 | from pygsti.modelpacks import smq1Q_XYI as std |
| 14 | +from pygsti.protocols.protocol import SlurmSettings |
14 | 15 |
|
15 | 16 |
|
16 | 17 | @pytest.mark.skipif(MPI is None, reason="mpi4py could not be imported") |
@@ -98,3 +99,52 @@ def test_run_mpi_repeated_call_inmemory_data(self): |
98 | 99 | results2 = proto.run_mpi(data, num_ranks=2, mpiexec='auto', extra_mpi_args=extra) |
99 | 100 | assert "GateSetTomography" in results1.estimates |
100 | 101 | assert "GateSetTomography" in results2.estimates |
| 102 | + |
| 103 | + def test_run_mpi_with_persistent_dir(self, tmp_path): |
| 104 | + """run_mpi with persistent_dir leaves results on disk and returns them.""" |
| 105 | + exp_design = std.create_gst_experiment_design(4) |
| 106 | + mdl_datagen = std.target_model().depolarize(op_noise=0.1, spam_noise=0.01) |
| 107 | + ds = pygsti.data.simulate_data(mdl_datagen, exp_design, 1000, seed=1234) |
| 108 | + data = pygsti.protocols.ProtocolData(exp_design, ds) |
| 109 | + |
| 110 | + initial_model = std.target_model("full TP") |
| 111 | + proto = pygsti.protocols.GateSetTomography( |
| 112 | + initial_model, verbosity=0, |
| 113 | + optimizer={'maxiter': 10}, |
| 114 | + ) |
| 115 | + |
| 116 | + results = proto.run_mpi( |
| 117 | + data, num_ranks=2, mpiexec='auto', |
| 118 | + extra_mpi_args=self._extra_mpi_args(), |
| 119 | + persistent_dir=str(tmp_path), |
| 120 | + ) |
| 121 | + assert "GateSetTomography" in results.estimates |
| 122 | + # Directory persists and contains written data. |
| 123 | + assert tmp_path.exists() |
| 124 | + assert (tmp_path / 'edesign').is_dir() |
| 125 | + |
| 126 | + def test_stage_slurm_writes_valid_script(self, tmp_path): |
| 127 | + """stage_slurm produces a script with correct content (no subprocess launched).""" |
| 128 | + exp_design = std.create_gst_experiment_design(4) |
| 129 | + mdl_datagen = std.target_model().depolarize(op_noise=0.1, spam_noise=0.01) |
| 130 | + ds = pygsti.data.simulate_data(mdl_datagen, exp_design, 1000, seed=1234) |
| 131 | + data = pygsti.protocols.ProtocolData(exp_design, ds) |
| 132 | + |
| 133 | + initial_model = std.target_model("full TP") |
| 134 | + proto = pygsti.protocols.GateSetTomography(initial_model, verbosity=0) |
| 135 | + |
| 136 | + script_path = str(tmp_path / 'submit.sh') |
| 137 | + slurm = SlurmSettings(script_path, partition='debug', time='1:00:00', nodes=2) |
| 138 | + proto.stage_slurm( |
| 139 | + data, num_ranks=4, slurm=slurm, work_dir=str(tmp_path), |
| 140 | + blas_threads_per_rank=2, |
| 141 | + ) |
| 142 | + |
| 143 | + import ast |
| 144 | + content = (tmp_path / 'submit.sh').read_text() |
| 145 | + assert 'srun python' in content |
| 146 | + assert '#SBATCH --nodes=2' in content |
| 147 | + assert '#SBATCH --partition=debug' in content |
| 148 | + runner_path = tmp_path / 'mpi_runner.py' |
| 149 | + assert runner_path.exists() |
| 150 | + ast.parse(runner_path.read_text()) # syntactically valid Python |
0 commit comments