diff --git a/src/_nebari/provider/cloud/google_cloud.py b/src/_nebari/provider/cloud/google_cloud.py index 81f86d8b13..7c87fed70e 100644 --- a/src/_nebari/provider/cloud/google_cloud.py +++ b/src/_nebari/provider/cloud/google_cloud.py @@ -51,6 +51,22 @@ def regions() -> Set[str]: return {region.name for region in response} +@functools.lru_cache() +def instances(region: str) -> Set[str]: + """Return a set of available compute instances in a region.""" + credentials, project_id = load_credentials() + zones_client = compute_v1.services.region_zones.RegionZonesClient( + credentials=credentials + ) + instances_client = compute_v1.InstancesClient(credentials=credentials) + + return { + instance.machine_type.split("/")[-1] + for zone in zones_client.list(project=project_id, region=region) + for instance in instances_client.list(project=project_id, zone=zone.name) + } + + @functools.lru_cache() def kubernetes_versions(region: str) -> List[str]: """Return list of available kubernetes supported by cloud provider. Sorted from oldest to latest.""" diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index 7b4c1aa237..243abd1608 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -325,6 +325,21 @@ def _check_input(cls, data: Any) -> Any: raise ValueError( f"\nInvalid `kubernetes-version` provided: {data['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." ) + + # check if instances are valid + available_instances = google_cloud.instances(data["region"]) + if "node_groups" in data: + for _, node_group in data["node_groups"].items(): + instance = ( + node_group["instance"] + if hasattr(node_group, "__getitem__") + else node_group.instance + ) + if instance not in available_instances: + raise ValueError( + f"Google Cloud Platform instance {instance} not one of available instance types={available_instances}" + ) + return data diff --git a/tests/tests_unit/conftest.py b/tests/tests_unit/conftest.py index 19ab7702a5..54528cbd23 100644 --- a/tests/tests_unit/conftest.py +++ b/tests/tests_unit/conftest.py @@ -85,6 +85,11 @@ def _mock_return_value(return_value): "us-central1", "us-east1", ], + "_nebari.provider.cloud.google_cloud.instances": [ + "e2-standard-4", + "e2-standard-8", + "e2-highmem-4", + ], } for attribute_path, return_value in MOCK_VALUES.items():