Skip to content

Commit dbab892

Browse files
authored
feat: pass user_group for job submissions and retries (#92)
1 parent 00646aa commit dbab892

18 files changed

+141
-36
lines changed

examples/data/add_one.wasm

-1.62 MB
Binary file not shown.

examples/data/rus.wasm

-1.72 MB
Binary file not shown.

examples/wasm_examples.ipynb

+15-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,16 @@
3030
},
3131
{
3232
"cell_type": "code",
33-
"execution_count": 2,
33+
"execution_count": null,
34+
"metadata": {},
35+
"outputs": [],
36+
"source": [
37+
"qnx.login()"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": 3,
3443
"metadata": {},
3544
"outputs": [],
3645
"source": [
@@ -70,9 +79,10 @@
7079
"metadata": {},
7180
"outputs": [],
7281
"source": [
73-
"circuit = Circuit(0, 2)\n",
82+
"circuit = Circuit(1)\n",
7483
"# Very minimal WASM example\n",
75-
"circuit.add_wasm(\"add_one\", wfh, [1], [1], [0, 1])"
84+
"a = circuit.add_c_register(\"a\", 8)\n",
85+
"circuit.add_wasm_to_reg(\"add_one\", wfh, [a], [a])"
7686
]
7787
},
7888
{
@@ -84,7 +94,7 @@
8494
},
8595
{
8696
"cell_type": "code",
87-
"execution_count": 7,
97+
"execution_count": 6,
8898
"metadata": {},
8999
"outputs": [],
90100
"source": [
@@ -96,7 +106,7 @@
96106
},
97107
{
98108
"cell_type": "code",
99-
"execution_count": 8,
109+
"execution_count": 7,
100110
"metadata": {},
101111
"outputs": [],
102112
"source": [

integration/test_jobs.py

+59
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,62 @@ def test_results_not_available_error(
277277
assert isinstance(execute_results[0].download_result(), BackendResult)
278278

279279
assert isinstance(execute_results[0].download_backend_info(), BackendInfo)
280+
281+
282+
def test_submit_under_user_group(
283+
_authenticated_nexus_circuit_ref: CircuitRef,
284+
qa_project_name: str,
285+
qa_circuit_name: str,
286+
) -> None:
287+
"""Test that a user can submit jobs under a user_group that
288+
they belong to.
289+
290+
Requires that the test user is a member of a group called:
291+
'QA_IntegrationTestGroup',
292+
and not a member of a group called:
293+
'made_up_group'.
294+
"""
295+
296+
fake_group = "made_up_group"
297+
298+
my_proj = qnx.projects.get(name_like=qa_project_name)
299+
300+
with pytest.raises(qnx_exc.ResourceCreateFailed) as exc:
301+
qnx.start_compile_job(
302+
circuits=[_authenticated_nexus_circuit_ref],
303+
name=f"qnexus_integration_test_compile_job_{datetime.now()}",
304+
project=my_proj,
305+
backend_config=qnx.AerConfig(),
306+
user_group=fake_group,
307+
)
308+
assert exc.value == f"Not a member of any group with name: {fake_group}"
309+
310+
qnx.start_compile_job(
311+
circuits=[_authenticated_nexus_circuit_ref],
312+
name=f"qnexus_integration_test_compile_job_{datetime.now()}",
313+
project=my_proj,
314+
backend_config=qnx.AerConfig(),
315+
user_group="QA_IntegrationTestGroup",
316+
)
317+
318+
my_circ = qnx.circuits.get(name_like=qa_circuit_name, project=my_proj)
319+
320+
with pytest.raises(qnx_exc.ResourceCreateFailed):
321+
qnx.start_execute_job(
322+
circuits=[my_circ],
323+
name=f"qnexus_integration_test_execute_job_{datetime.now()}",
324+
project=my_proj,
325+
backend_config=qnx.AerConfig(),
326+
n_shots=[10],
327+
user_group="made_up_group",
328+
)
329+
assert exc.value == f"Not a member of any group with name: {fake_group}"
330+
331+
qnx.start_execute_job(
332+
circuits=[my_circ],
333+
name=f"qnexus_integration_test_execute_job_{datetime.now()}",
334+
project=my_proj,
335+
backend_config=qnx.AerConfig(),
336+
n_shots=[10],
337+
user_group="QA_IntegrationTestGroup",
338+
)

integration/test_user.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_user_get(_authenticated_nexus: None) -> None:
1010
my_user = qnx.users.get_self()
1111
assert isinstance(my_user, UserRef)
1212

13-
my_user_again = qnx.users._fetch( # pylint: disable=protected-access
13+
my_user_again = qnx.users._fetch_by_id( # pylint: disable=protected-access
1414
user_id=my_user.id
1515
)
1616
assert isinstance(my_user_again, UserRef)

integration/test_wasm_modules.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ def test_wasm_flow(
3838
wasm_ref_2 = qnx.wasm_modules.get(id=wasm_ref.id)
3939
assert wasm_ref == wasm_ref_2
4040

41-
circuit = Circuit(0, 2)
42-
circuit.add_wasm("add_one", wfh, [1], [1], [0, 1])
41+
circuit = Circuit(1)
42+
a = circuit.add_c_register("a", 8)
43+
circuit.add_wasm_to_reg("add_one", wfh, [a], [a])
4344
qa_wasm_circuit_name_fixture = (
4445
f"qnexus_integration_test_wasm_circuit_{datetime.now()}"
4546
)

qnexus/client/backend_snapshots.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# pass
1313

1414

15-
# def _fetch():
15+
# def _fetch_by_id():
1616
# pass
1717

1818

qnexus/client/circuits.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def get(
130130
not match exactly one object.
131131
"""
132132
if id:
133-
return _fetch(circuit_id=id)
133+
return _fetch_by_id(circuit_id=id)
134134

135135
return get_all(
136136
name_like=name_like,
@@ -237,7 +237,7 @@ def update(
237237
)
238238

239239

240-
def _fetch(circuit_id: UUID | str) -> CircuitRef:
240+
def _fetch_by_id(circuit_id: UUID | str) -> CircuitRef:
241241
"""Utility method for fetching directly by a unique identifier."""
242242

243243
res = get_nexus_client().get(f"/api/circuits/v1beta/{circuit_id}")

qnexus/client/jobs/__init__.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def get( # pylint: disable=too-many-positional-arguments
206206
not match exactly one object.
207207
"""
208208
if id:
209-
return _fetch(job_id=id)
209+
return _fetch_by_id(job_id=id)
210210

211211
return get_all(
212212
name_like=name_like,
@@ -225,7 +225,7 @@ def get( # pylint: disable=too-many-positional-arguments
225225
).try_unique_match()
226226

227227

228-
def _fetch(job_id: UUID | str) -> JobRef:
228+
def _fetch_by_id(job_id: UUID | str) -> JobRef:
229229
"""Utility method for fetching directly by a unique identifier."""
230230
res = get_nexus_client().get(f"/api/jobs/v1beta/{job_id}")
231231

@@ -376,13 +376,17 @@ def retry_submission(
376376
job: JobRef,
377377
retry_status: list[StatusEnum] | None = None,
378378
remote_retry_strategy: RemoteRetryStrategy = RemoteRetryStrategy.DEFAULT,
379+
user_group: str | None = None,
379380
):
380381
"""Retry a job in Nexus according to status(es) or retry strategy.
381382
382383
By default, jobs with the ERROR status will be retried.
383384
"""
384385
body: dict[str, str | list[str]] = {"remote_retry_strategy": remote_retry_strategy}
385386

387+
if user_group is not None:
388+
body["user_group"] = user_group
389+
386390
if retry_status is not None:
387391
body["retry_status"] = [status.name for status in retry_status]
388392

@@ -418,6 +422,7 @@ def compile( # pylint: disable=redefined-builtin, too-many-positional-arguments
418422
properties: PropertiesDict | None = None,
419423
optimisation_level: int = 2,
420424
credential_name: str | None = None,
425+
user_group: str | None = None,
421426
hypertket_config: HyperTketConfig | None = None,
422427
timeout: float | None = 300.0,
423428
) -> DataframableList[CircuitRef]:
@@ -437,6 +442,7 @@ def compile( # pylint: disable=redefined-builtin, too-many-positional-arguments
437442
properties=properties,
438443
optimisation_level=optimisation_level,
439444
credential_name=credential_name,
445+
user_group=user_group,
440446
hypertket_config=hypertket_config,
441447
)
442448

@@ -465,6 +471,7 @@ def execute( # pylint: disable=too-many-locals, too-many-positional-arguments
465471
language: Language = Language.AUTO,
466472
seed: int | None = None,
467473
credential_name: str | None = None,
474+
user_group: str | None = None,
468475
timeout: float | None = 300.0,
469476
) -> list[BackendResult]:
470477
"""
@@ -489,6 +496,7 @@ def execute( # pylint: disable=too-many-locals, too-many-positional-arguments
489496
language=language,
490497
seed=seed,
491498
credential_name=credential_name,
499+
user_group=user_group,
492500
)
493501

494502
wait_for(job=execute_job_ref, timeout=timeout)

qnexus/client/jobs/_compile.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
@merge_properties_from_context
26-
def start_compile_job( # pylint: disable=too-many-arguments, too-many-positional-arguments
26+
def start_compile_job( # pylint: disable=too-many-arguments, too-many-locals, too-many-positional-arguments
2727
circuits: Union[CircuitRef, list[CircuitRef]],
2828
backend_config: BackendConfig,
2929
name: str,
@@ -32,6 +32,7 @@ def start_compile_job( # pylint: disable=too-many-arguments, too-many-positiona
3232
properties: PropertiesDict | None = None,
3333
optimisation_level: int = 2,
3434
credential_name: str | None = None,
35+
user_group: str | None = None,
3536
hypertket_config: HyperTketConfig | None = None,
3637
) -> CompileJobRef:
3738
"""Submit a compile job to be run in Nexus."""
@@ -55,11 +56,10 @@ def start_compile_job( # pylint: disable=too-many-arguments, too-many-positiona
5556
"definition": {
5657
"job_definition_type": "compile_job_definition",
5758
"backend_config": backend_config.model_dump(),
58-
"hypertket_config": (
59-
hypertket_config.model_dump()
60-
if hypertket_config is not None
61-
else None
62-
),
59+
"user_group": user_group,
60+
"hypertket_config": hypertket_config.model_dump()
61+
if hypertket_config is not None
62+
else None,
6363
"optimisation_level": optimisation_level,
6464
"credential_name": credential_name,
6565
"items": [
@@ -245,14 +245,18 @@ def _fetch_compilation_passes(
245245
pass_input_circuit_id = pass_info["relationships"]["original_circuit"]["data"][
246246
"id"
247247
]
248-
pass_input_circuit = circuit_api._fetch( # pylint: disable=protected-access
249-
pass_input_circuit_id
248+
pass_input_circuit = (
249+
circuit_api._fetch_by_id( # pylint: disable=protected-access
250+
pass_input_circuit_id
251+
)
250252
)
251253
pass_output_circuit_id = pass_info["relationships"]["compiled_circuit"]["data"][
252254
"id"
253255
]
254-
pass_output_circuit = circuit_api._fetch( # pylint: disable=protected-access
255-
pass_output_circuit_id
256+
pass_output_circuit = (
257+
circuit_api._fetch_by_id( # pylint: disable=protected-access
258+
pass_output_circuit_id
259+
)
256260
)
257261

258262
pass_list.append(

qnexus/client/jobs/_execute.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def start_execute_job( # pylint: disable=too-many-arguments, too-many-locals, t
4040
seed: int | None = None,
4141
credential_name: str | None = None,
4242
wasm_module: WasmModuleRef | None = None,
43+
user_group: str | None = None,
4344
) -> ExecuteJobRef:
4445
"""
4546
Submit an execute job to be run in Nexus. Returns an ``ExecuteJobRef``
@@ -69,6 +70,7 @@ def start_execute_job( # pylint: disable=too-many-arguments, too-many-locals, t
6970
"definition": {
7071
"job_definition_type": "execute_job_definition",
7172
"backend_config": backend_config.model_dump(),
73+
"user_group": user_group,
7274
"valid_check": valid_check,
7375
"postprocess": postprocess,
7476
"noisy_simulator": noisy_simulator,
@@ -164,7 +166,7 @@ def _fetch_execution_result(
164166

165167
input_circuit_id = res_dict["data"]["relationships"]["circuit"]["data"]["id"]
166168

167-
input_circuit = circuit_api._fetch( # pylint: disable=protected-access
169+
input_circuit = circuit_api._fetch_by_id( # pylint: disable=protected-access
168170
input_circuit_id
169171
)
170172

qnexus/client/projects.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def get(
116116
not match exactly one object.
117117
"""
118118
if id:
119-
return _fetch(id)
119+
return _fetch_by_id(id)
120120

121121
return get_all(
122122
name_like=name_like,
@@ -154,7 +154,7 @@ def get_or_create(
154154
)
155155

156156

157-
def _fetch(project_id: UUID | str) -> ProjectRef:
157+
def _fetch_by_id(project_id: UUID | str) -> ProjectRef:
158158
"""Utility method for fetching directly by a unique identifier."""
159159
res = get_nexus_client().get(f"/api/projects/v1beta/{project_id}")
160160

qnexus/client/results.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# pass
1212

1313

14-
# def _fetch():
14+
# def _fetch_by_id():
1515
# pass
1616

1717

qnexus/client/roles.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,19 @@ def assignments(resource_ref: BaseRef) -> DataframableList[RoleInfo]:
6767
role_infos.append(
6868
RoleInfo(
6969
assignment_type="user",
70-
assignee=user_client._fetch( # pylint: disable=protected-access
70+
assignee=user_client._fetch_by_id( # pylint: disable=protected-access
7171
user_id=user_role_assignment["user_id"]
7272
),
7373
role=roles_dict[user_role_assignment["role_id"]],
7474
)
7575
)
76-
7776
for team_role_assignment in res_assignments["team_role_assignments"]:
7877
role_infos.append(
7978
RoleInfo(
8079
assignment_type="team",
81-
assignee=team_client.get(name=team_role_assignment["team_id"]),
80+
assignee=team_client._fetch_by_id( # pylint: disable=protected-access
81+
team_id=team_role_assignment["team_id"]
82+
),
8283
role=roles_dict[team_role_assignment["role_id"]],
8384
)
8485
)

qnexus/client/teams.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def get(name: str) -> TeamRef:
3333
"""
3434
res = get_nexus_client().get("/api/v5/user/teams", params={"name": name})
3535

36-
if res.status_code == 404:
36+
if res.status_code == 404 or res.json() == []:
3737
raise qnx_exc.ZeroMatches
3838

3939
if res.status_code != 200:
@@ -54,6 +54,27 @@ def get(name: str) -> TeamRef:
5454
return teams_list[0]
5555

5656

57+
def _fetch_by_id(team_id: str) -> TeamRef: # pylint: disable=redefined-builtin
58+
"""
59+
Get a single team by id.
60+
"""
61+
res = get_nexus_client().get(f"/api/v5/user/teams/{team_id}")
62+
63+
if res.status_code == 404:
64+
raise qnx_exc.ZeroMatches
65+
66+
if res.status_code != 200:
67+
raise qnx_exc.ResourceFetchFailed(message=res.text, status_code=res.status_code)
68+
69+
team_dict = res.json()
70+
71+
return TeamRef(
72+
id=team_dict["id"],
73+
name=team_dict["team_name"],
74+
description=team_dict["description"],
75+
)
76+
77+
5778
def create(name: str, description: str | None = None) -> TeamRef:
5879
"""Create a team in Nexus."""
5980

0 commit comments

Comments
 (0)