-
Notifications
You must be signed in to change notification settings - Fork 275
Expand file tree
/
Copy pathtest_create_kubernetes_endpoint.py
More file actions
94 lines (78 loc) · 3.41 KB
/
test_create_kubernetes_endpoint.py
File metadata and controls
94 lines (78 loc) · 3.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Tests for create_kubernetes_endpoint."""
from unittest.mock import Mock, patch
import pytest
from pathlib import Path
from llmops.common.deployment.kubernetes_endpoint import (
create_kubernetes_endpoint,
)
THIS_PATH = Path(__file__).parent
RESOURCE_PATH = THIS_PATH / "resources"
@pytest.fixture(scope="module", autouse=True)
def _set_required_env_vars():
"""Set required environment variables."""
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setenv("SUBSCRIPTION_ID", "TEST_SUBSCRIPTION_ID")
monkeypatch.setenv("RESOURCE_GROUP_NAME", "TEST_RESOURCE_GROUP_NAME")
monkeypatch.setenv("WORKSPACE_NAME", "TEST_WORKSPACE_NAME")
def test_create_kubernetes_endpoint_when_not_exists():
"""Test create_kubernetes_endpoint."""
env_name = "dev"
endpoint_name = "k8s-test-endpoint"
endpoint_description = "k8s-test-endpoint-description"
compute_name = "k8s-compute"
with patch(
"llmops.common.deployment.kubernetes_endpoint.MLClient"
) as mock_ml_client:
# Mock the MLClient
ml_client_instance = Mock()
mock_ml_client.return_value = ml_client_instance
ml_client_instance.online_endpoints.list.return_value = []
# Create the endpoint
create_kubernetes_endpoint(env_name, str(RESOURCE_PATH))
# Assert that ml_client.online_endpoints.begin_create_or_update
# is called once
create_endpoint_calls = (
ml_client_instance.online_endpoints.begin_create_or_update
)
assert create_endpoint_calls.call_count == 1
# Assert that ml_client.online_endpoints.begin_create_or_update
# is called with the correct argument
# create_endpoint_calls.call_args_list is triple nested,
# first index: select the call of
# ml_client.online_endpoints.begin_create_or_update [0]
# second index: select the argument of
# ml_client.online_endpoints.begin_create_or_update
# [1 (named_argument)]
# third index: select the named argument ["endpoint"]
created_endpoint = (
create_endpoint_calls.call_args_list[0][1]["endpoint"]
)
assert created_endpoint.name == endpoint_name
assert created_endpoint.description == endpoint_description
assert created_endpoint.compute == compute_name
assert created_endpoint.auth_mode == "key"
def test_create_kubernetes_endpoint_when_exists():
"""Test create_kubernetes_endpoint."""
env_name = "dev"
endpoint_name = "k8s-test-endpoint"
with patch(
"llmops.common.deployment.kubernetes_endpoint.MLClient"
) as mock_ml_client:
# Mock the MLClient
ml_client_instance = Mock()
mock_ml_client.return_value = ml_client_instance
mock_endpoint = Mock()
mock_endpoint.name = endpoint_name
ml_client_instance.online_endpoints.list.return_value = [
mock_endpoint
]
# Create the endpoint
create_kubernetes_endpoint(env_name, str(RESOURCE_PATH))
# Assert online_endpoints.begin_create_or_update is called once
create_endpoint_calls = (
ml_client_instance.online_endpoints.begin_create_or_update
)
# Endpoint should not be created if it already exists as it would
# set the traffic to zero for existing deployments and there are
# no properties we need to update on the endpoint
assert create_endpoint_calls.call_count == 0