Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions azext_iot/iothub/providers/message_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def __init__(
rg: Optional[str] = None,
):
super(MessageEndpoint, self).__init__(cmd, hub_name, rg, dataplane=False)
self.support_cosmos = hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_collections")
# Temporary flag to check for which cosmos property to look for.
self.support_cosmos = 0
if hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_collections"):
self.support_cosmos = 1
if hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_containers"):
self.support_cosmos = 2
self.cli = EmbeddedCLI(cli_ctx=self.cmd.cli_ctx)

def create(
Expand Down Expand Up @@ -179,16 +184,22 @@ def create(
del new_endpoint["connectionString"]
new_endpoint.update({
"databaseName": database_name,
"collectionName": container_name,
"primaryKey": primary_key,
"secondaryKey": secondary_key,
"partitionKeyName": partition_key_name,
"partitionKeyTemplate": partition_key_template,
})
# TODO @vilit - why is this None if empty
if endpoints.cosmos_db_sql_collections is None:
endpoints.cosmos_db_sql_collections = []
endpoints.cosmos_db_sql_collections.append(new_endpoint)
# TODO @vilit - None checks for when the service breaks things
if self.support_cosmos == 2:
new_endpoint["containerName"] = container_name
if endpoints.cosmos_db_sql_containers is None:
endpoints.cosmos_db_sql_containers = []
endpoints.cosmos_db_sql_containers.append(new_endpoint)
if self.support_cosmos == 1:
new_endpoint["collectionName"] = container_name
if endpoints.cosmos_db_sql_collections is None:
endpoints.cosmos_db_sql_collections = []
endpoints.cosmos_db_sql_collections.append(new_endpoint)
elif endpoint_type.lower() == EndpointType.AzureStorageContainer.value:
if fetch_connection_string:
# try to get connection string
Expand Down Expand Up @@ -369,7 +380,9 @@ def _show_by_type(self, endpoint_name: str, endpoint_type: Optional[str] = None)
endpoint_list.extend(endpoints.service_bus_topics)
if endpoint_type is None or endpoint_type.lower() == EndpointType.AzureStorageContainer.value:
endpoint_list.extend(endpoints.storage_containers)
if self.support_cosmos and (endpoint_type is None or endpoint_type.lower() == EndpointType.CosmosDBContainer.value):
if self.support_cosmos == 2 and (endpoint_type is None or endpoint_type.lower() == EndpointType.CosmosDBContainer.value):
endpoint_list.extend(endpoints.cosmos_db_sql_containers)
if self.support_cosmos == 1 and (endpoint_type is None or endpoint_type.lower() == EndpointType.CosmosDBContainer.value):
endpoint_list.extend(endpoints.cosmos_db_sql_collections)

for endpoint in endpoint_list:
Expand Down Expand Up @@ -397,7 +410,9 @@ def list(self, endpoint_type: Optional[str] = None):
return endpoints.service_bus_queues
elif EndpointType.ServiceBusTopic.value == endpoint_type:
return endpoints.service_bus_topics
elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos:
elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 2:
return endpoints.cosmos_db_sql_containers
elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 1:
return endpoints.cosmos_db_sql_collections
elif EndpointType.CosmosDBContainer.value == endpoint_type:
raise InvalidArgumentValueError(INVALID_CLI_CORE_FOR_COSMOS)
Expand All @@ -413,7 +428,7 @@ def delete(
endpoints = self.hub_resource.properties.routing.endpoints
if endpoint_type:
endpoint_type = endpoint_type.lower()
if EndpointType.CosmosDBContainer.value == endpoint_type and not self.support_cosmos:
if EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 0:
raise InvalidArgumentValueError(INVALID_CLI_CORE_FOR_COSMOS)

if self.hub_resource.properties.routing.enrichments or self.hub_resource.properties.routing.routes:
Expand All @@ -433,7 +448,9 @@ def delete(
endpoint_names.extend([e.name for e in endpoints.service_bus_queues])
if not endpoint_type or endpoint_type == EndpointType.ServiceBusTopic.value:
endpoint_names.extend([e.name for e in endpoints.service_bus_topics])
if self.support_cosmos and not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value:
if self.support_cosmos == 2 and not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value:
endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_containers])
if self.support_cosmos == 1 and not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value:
endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_collections])
if not endpoint_type or endpoint_type == EndpointType.AzureStorageContainer.value:
endpoint_names.extend([e.name for e in endpoints.storage_containers])
Expand Down Expand Up @@ -481,7 +498,12 @@ def delete(
endpoints.service_bus_queues = [e for e in endpoints.service_bus_queues if e.name.lower() != endpoint_name]
if not endpoint_type or EndpointType.ServiceBusTopic.value == endpoint_type:
endpoints.service_bus_topics = [e for e in endpoints.service_bus_topics if e.name.lower() != endpoint_name]
if self.support_cosmos and not endpoint_type or EndpointType.CosmosDBContainer.value == endpoint_type:
if self.support_cosmos == 2 and not endpoint_type or EndpointType.CosmosDBContainer.value == endpoint_type:
cosmos_db_endpoints = endpoints.cosmos_db_sql_containers if endpoints.cosmos_db_sql_containers else []
endpoints.cosmos_db_sql_containers = [
e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name
]
if self.support_cosmos == 1 and not endpoint_type or EndpointType.CosmosDBContainer.value == endpoint_type:
cosmos_db_endpoints = endpoints.cosmos_db_sql_collections if endpoints.cosmos_db_sql_collections else []
endpoints.cosmos_db_sql_collections = [
e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name
Expand All @@ -496,7 +518,9 @@ def delete(
endpoints.service_bus_queues = []
elif EndpointType.ServiceBusTopic.value == endpoint_type:
endpoints.service_bus_topics = []
elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos:
elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 2:
endpoints.cosmos_db_sql_containers = []
elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == 1:
endpoints.cosmos_db_sql_collections = []
elif EndpointType.AzureStorageContainer.value == endpoint_type:
endpoints.storage_containers = []
Expand All @@ -505,7 +529,9 @@ def delete(
endpoints.event_hubs = []
endpoints.service_bus_queues = []
endpoints.service_bus_topics = []
if self.support_cosmos:
if self.support_cosmos == 2:
endpoints.cosmos_db_sql_containers = []
if self.support_cosmos == 1:
endpoints.cosmos_db_sql_collections = []
endpoints.storage_containers = []

Expand Down
1 change: 1 addition & 0 deletions azext_iot/tests/iothub/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def _cosmos_db_provisioner():
collection_name = generate_hub_depenency_id()
partition_key_path = "/test"
location = "eastus"
print(f"--locations regionName={location}")
cosmos_obj = cli.invoke(
"cosmosdb create --name {} --resource-group {} --locations regionName={} failoverPriority=0".format(
account_name, RG, location
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,36 @@ def create_mock_endpoint():
hub_mock.properties.routing.endpoints.service_bus_queues = [create_mock_endpoint()]
hub_mock.properties.routing.endpoints.service_bus_topics = [create_mock_endpoint()]
hub_mock.properties.routing.endpoints.storage_containers = [create_mock_endpoint()]
hub_mock.properties.routing.endpoints.cosmos_db_sql_containers = [create_mock_endpoint()]

def initialize_mock_client(self, *args):
self.client = mocker.MagicMock()
self.client.begin_create_or_update.return_value = generic_response
return hub_mock

find_resource.side_effect = initialize_mock_client

yield find_resource


@pytest.fixture()
def fixture_update_endpoint_backwards_comp_ops(mocker):
# Parse connection string
mocker.patch(parse_cosmos_db_cstring_path, return_value={
"AccountKey": "get_cosmos_db_account_key",
"AccountEndpoint": "get_cosmos_db_account_endpoint"
})

# Hub Resource
find_resource = mocker.patch(path_find_resource, autospec=True)

def create_mock_endpoint():
endpoint = mocker.Mock()
endpoint.name = endpoint_name
return endpoint

hub_mock = mocker.MagicMock()
del hub_mock.properties.routing.endpoints.cosmos_db_sql_containers
hub_mock.properties.routing.endpoints.cosmos_db_sql_collections = [create_mock_endpoint()]

def initialize_mock_client(self, *args):
Expand Down Expand Up @@ -721,7 +751,7 @@ def test_message_endpoint_update_cosmos_db_sql_container(self, mocker, fixture_c
assert req.get("resource_group_name") == resource_group
hub_resource = fixture_find_resource.call_args[0][0].client.begin_create_or_update.call_args[0][2]
# TODO: @vilit fix once service fixes their naming
endpoints = hub_resource.properties.routing.endpoints.cosmos_db_sql_collections
endpoints = hub_resource.properties.routing.endpoints.cosmos_db_sql_containers
assert len(endpoints) == 1
endpoint = endpoints[0]

Expand Down Expand Up @@ -800,7 +830,7 @@ def test_message_endpoint_update_cosmos_db_sql_container(self, mocker, fixture_c
else:
assert isinstance(endpoint.authentication_type, mock)

def test_message_endpoint_update_cosmos_db_sql_collections_error(self, fixture_cmd, fixture_update_endpoint_ops):
def test_message_endpoint_update_cosmos_db_sql_container_error(self, fixture_cmd, fixture_update_endpoint_ops):
# Cannot do both types of Authentication
with pytest.raises(MutuallyExclusiveArgumentError) as e:
subject.message_endpoint_update_cosmos_db_container(
Expand Down Expand Up @@ -848,3 +878,182 @@ def test_message_endpoint_update_cosmos_db_sql_collections_error(self, fixture_c
hub_name=hub_name,
endpoint_name=generate_names(),
)

@pytest.mark.parametrize(
"req",
[
{},
{
"endpoint_resource_group": generate_names(),
"endpoint_subscription_id": generate_names(),
"database_name": generate_names(),
"connection_string": generate_names(),
"primary_key": None,
"secondary_key": None,
"endpoint_uri": generate_names(),
"partition_key_name": None,
"partition_key_template": None,
"identity": None,
"resource_group_name": None,
},
{
"endpoint_resource_group": None,
"endpoint_subscription_id": None,
"database_name": None,
"connection_string": None,
"primary_key": None,
"secondary_key": None,
"endpoint_uri": generate_names(),
"partition_key_name": generate_names(),
"partition_key_template": generate_names(),
"identity": generate_names(),
"resource_group_name": generate_names(),
},
{
"endpoint_resource_group": None,
"endpoint_subscription_id": None,
"database_name": None,
"connection_string": None,
"primary_key": None,
"secondary_key": None,
"endpoint_uri": None,
"partition_key_name": None,
"partition_key_template": None,
"identity": "[system]",
"resource_group_name": None,
},
{
"endpoint_resource_group": None,
"endpoint_subscription_id": None,
"database_name": None,
"connection_string": generate_names(),
"primary_key": None,
"secondary_key": generate_names(),
"endpoint_uri": None,
"partition_key_name": None,
"partition_key_template": generate_names(),
"identity": None,
"resource_group_name": None,
},
{
"endpoint_resource_group": generate_names(),
"endpoint_subscription_id": None,
"database_name": None,
"connection_string": generate_names(),
"primary_key": generate_names(),
"secondary_key": generate_names(),
"endpoint_uri": None,
"partition_key_name": generate_names(),
"partition_key_template": None,
"identity": None,
"resource_group_name": None,
},
{
"endpoint_resource_group": None,
"endpoint_subscription_id": None,
"database_name": generate_names(),
"connection_string": None,
"primary_key": None,
"secondary_key": None,
"endpoint_uri": None,
"partition_key_name": None,
"partition_key_template": None,
"identity": None,
"resource_group_name": None,
},
]
)
def test_message_endpoint_update_cosmos_db_sql_collections(
self, mocker, fixture_cmd, fixture_update_endpoint_backwards_comp_ops, req
):
result = subject.message_endpoint_update_cosmos_db_container(
cmd=fixture_cmd,
hub_name=hub_name,
endpoint_name=endpoint_name,
**req
)
fixture_find_resource = fixture_update_endpoint_backwards_comp_ops

assert result == generic_response
resource_group = fixture_find_resource.call_args[0][2]
assert req.get("resource_group_name") == resource_group
hub_resource = fixture_find_resource.call_args[0][0].client.begin_create_or_update.call_args[0][2]
# TODO: @vilit fix once service fixes their naming
endpoints = hub_resource.properties.routing.endpoints.cosmos_db_sql_collections
assert len(endpoints) == 1
endpoint = endpoints[0]

assert endpoint.name == endpoint_name
mock = mocker.Mock

# if a prop is not set, it will be a Mock object
# Props that will always be set if present
if req.get("endpoint_resource_group"):
assert endpoint.resource_group == req.get("endpoint_resource_group")
else:
assert isinstance(endpoint.resource_group, mock)

if req.get("endpoint_subscription_id"):
assert endpoint.subscription_id == req.get("endpoint_subscription_id")
else:
assert isinstance(endpoint.subscription_id, mock)

if req.get("database_name"):
assert endpoint.database_name == req.get("database_name").lower()
else:
assert isinstance(endpoint.database_name, mock)

if req.get("partition_key_name"):
partition_key_name = req.get("partition_key_name")
if partition_key_name == "":
assert endpoint.partition_key_name is None
else:
endpoint.partition_key_name == partition_key_name
else:
assert isinstance(endpoint.partition_key_name, mock)

if req.get("partition_key_template"):
partition_key_template = req.get("partition_key_template")
if partition_key_template == "":
assert endpoint.partition_key_template is None
else:
endpoint.partition_key_template == partition_key_template
else:
assert isinstance(endpoint.partition_key_template, mock)

# Connection strings dont exist
assert isinstance(endpoint.connection_string, mock)

# Authentication props
if req.get("identity"):
assert endpoint.authentication_type == AuthenticationType.IdentityBased.value
assert endpoint.primary_key is None
assert endpoint.secondary_key is None
identity = req.get("identity")
if identity == "[system]":
assert endpoint.identity is None
else:
assert isinstance(endpoint.identity, ManagedIdentity)
assert endpoint.identity.user_assigned_identity == identity
elif any([req.get("connection_string"), req.get("primary_key"), req.get("secondary_key")]):
assert endpoint.authentication_type == AuthenticationType.KeyBased.value
assert endpoint.identity is None
assert endpoint.entity_path is None
connection_string = req.get("connection_string")
primary_key = req.get("primary_key")
secondary_key = req.get("secondary_key")
endpoint_uri = req.get("endpoint_uri")

if primary_key:
assert endpoint.primary_key == primary_key
if secondary_key:
assert endpoint.secondary_key == secondary_key
if not primary_key and not secondary_key and connection_string:
assert endpoint.primary_key == endpoint.secondary_key == "get_cosmos_db_account_key"

if endpoint_uri:
assert endpoint.endpoint_uri == endpoint_uri
elif connection_string:
assert endpoint.endpoint_uri == "get_cosmos_db_account_endpoint"
else:
assert isinstance(endpoint.authentication_type, mock)