Skip to content

Commit 80cd460

Browse files
authored
[ML][Pipelines] Singularity: bypass operations to Singularity compute (Azure#28576)
* test: add test for Singularity compute in pipeline job * skip get asset id/compute operation when Singularity * test: add recording for newly added test * test: refine test
1 parent 5225120 commit 80cd460

File tree

7 files changed

+902
-1
lines changed

7 files changed

+902
-1
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_utils/_arm_id_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
PROVIDER_RESOURCE_ID_WITH_VERSION,
2020
REGISTRY_URI_REGEX_FORMAT,
2121
REGISTRY_VERSION_PATTERN,
22+
SINGULARITY_ID_FORMAT,
2223
)
2324
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
2425

@@ -325,6 +326,12 @@ def is_registry_id_for_resource(name: Any) -> bool:
325326
return False
326327

327328

329+
def is_singularity_id_for_resource(name: Any) -> bool:
330+
if isinstance(name, str) and re.match(SINGULARITY_ID_FORMAT, name, re.IGNORECASE):
331+
return True
332+
return False
333+
334+
328335
def get_arm_id_with_version(
329336
operation_scope: OperationScope,
330337
provider_name: str,

sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
PROVIDER_RESOURCE_ID_WITH_VERSION = (
3636
"/subscriptions/{}/resourceGroups/{}/providers/Microsoft.MachineLearningServices/workspaces/{}/{}/{}/versions/{}"
3737
)
38+
SINGULARITY_ID_FORMAT = (
39+
"/subscriptions/.*/resourceGroups/.*/providers/Microsoft.MachineLearningServices/virtualclusters/.*"
40+
)
3841
ASSET_ID_FORMAT = "azureml://locations/{}/workspaces/{}/{}/{}/versions/{}"
3942
VERSIONED_RESOURCE_NAME = "{}:{}"
4043
LABELLED_RESOURCE_NAME = "{}@{}"

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_job_operations.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,12 @@
105105
from ._job_ops_helper import get_git_properties, get_job_output_uris_from_dataplane, stream_logs_until_completion
106106
from ._local_job_invoker import is_local_run, start_run_if_local
107107
from ._model_dataplane_operations import ModelDataplaneOperations
108-
from ._operation_orchestrator import OperationOrchestrator, is_ARM_id_for_resource, is_registry_id_for_resource
108+
from ._operation_orchestrator import (
109+
OperationOrchestrator,
110+
is_ARM_id_for_resource,
111+
is_registry_id_for_resource,
112+
is_singularity_id_for_resource,
113+
)
109114
from ._run_operations import RunOperations
110115

111116
try:
@@ -331,6 +336,9 @@ def _try_get_compute_arm_id(self, compute: Union[Compute, str]):
331336
return compute
332337

333338
if compute is not None:
339+
if is_singularity_id_for_resource(compute):
340+
# Singularity compute, skip try to get operation
341+
return compute
334342
if is_ARM_id_for_resource(compute, resource_type=AzureMLResourceType.COMPUTE):
335343
# compute is not a sub-workspace resource
336344
compute_name = compute.split("/")[-1]

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_operation_orchestrator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
get_arm_id_with_version,
1919
is_ARM_id_for_resource,
2020
is_registry_id_for_resource,
21+
is_singularity_id_for_resource,
2122
parse_name_label,
2223
parse_prefixed_name_version,
2324
)
@@ -124,6 +125,7 @@ def get_asset_arm_id(
124125
asset is None
125126
or is_ARM_id_for_resource(asset, azureml_type, sub_workspace_resource)
126127
or is_registry_id_for_resource(asset)
128+
or is_singularity_id_for_resource(asset)
127129
):
128130
return asset
129131
if isinstance(asset, str):

sdk/ml/azure-ai-ml/tests/pipeline_job/e2etests/test_pipeline_job.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,6 +1344,26 @@ def test_pipeline_node_with_default_component(self, client: MLClient, randstr: C
13441344
== "microsoftsamples_command_component_basic@default"
13451345
)
13461346

1347+
def test_pipeline_job_with_singularity_compute(self, client: MLClient, randstr: Callable[[str], str]):
1348+
params_override = [{"name": randstr("job_name")}]
1349+
pipeline_job: PipelineJob = load_job(
1350+
"./tests/test_configs/pipeline_jobs/helloworld_pipeline_job_with_singularity_compute.yml",
1351+
params_override=params_override,
1352+
)
1353+
1354+
singularity_compute_id = (
1355+
f"/subscriptions/{client.subscription_id}/resourceGroups/{client.resource_group_name}/"
1356+
f"providers/Microsoft.MachineLearningServices/virtualclusters/SingularityTestVC"
1357+
)
1358+
pipeline_job.settings.default_compute = singularity_compute_id
1359+
pipeline_job.jobs["hello_job"].compute = singularity_compute_id
1360+
1361+
assert pipeline_job._customized_validate().passed is True
1362+
1363+
created_pipeline_job: PipelineJob = assert_job_cancel(pipeline_job, client)
1364+
assert created_pipeline_job.settings.default_compute == singularity_compute_id
1365+
assert created_pipeline_job.jobs["hello_job"].compute == singularity_compute_id
1366+
13471367

13481368
@pytest.mark.usefixtures("enable_pipeline_private_preview_features")
13491369
@pytest.mark.e2etest

0 commit comments

Comments
 (0)