Skip to content

Commit bcad32c

Browse files
committed
fix: align datacenter enum with valid real-world targets
1 parent 838018d commit bcad32c

12 files changed

Lines changed: 534 additions & 1194 deletions

File tree

src/runpod_flash/core/resources/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
)
1919
from .serverless_cpu import CpuServerlessEndpoint
2020
from .template import PodTemplate
21-
from .network_volume import NetworkVolume, DataCenter, CPU_DATACENTERS
21+
from .network_volume import NetworkVolume
22+
from .datacenter import DataCenter, CPU_DATACENTERS
2223
from .load_balancer_sls_resource import (
2324
CpuLoadBalancerSlsResource,
2425
LoadBalancerSlsResource,
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from enum import Enum
2+
3+
4+
class DataCenter(str, Enum):
5+
"""Enum representing available RunPod data centers.
6+
7+
NOTE: these are only datacenters with storage support, and s3 API support.
8+
see https://linear.app/runpod/issue/AE-3084/prevent-no-primary-mount-point-error-when-workers-are-initializing
9+
"""
10+
11+
# north america
12+
US_CA_2 = "US-CA-2"
13+
US_IL_1 = "US-IL-1"
14+
US_KS_2 = "US-KS-2"
15+
US_MO_1 = "US-MO-1"
16+
US_MO_2 = "US-MO-2"
17+
US_NC_2 = "US-NC-2"
18+
US_NE_1 = "US-NE-1"
19+
US_WA_1 = "US-WA-1"
20+
21+
# europe
22+
EU_CZ_1 = "EU-CZ-1"
23+
EU_RO_1 = "EU-RO-1"
24+
EUR_NO_1 = "EUR-NO-1"
25+
26+
@classmethod
27+
def from_string(cls, value: str) -> "DataCenter":
28+
"""Parse a datacenter ID string into a DataCenter enum.
29+
30+
Accepts the canonical form (e.g. "EU-RO-1") as well as common
31+
variations like lowercase or underscore-separated.
32+
"""
33+
normalized = value.strip().upper().replace("_", "-")
34+
try:
35+
return cls(normalized)
36+
except ValueError:
37+
valid = ", ".join(dc.value for dc in cls)
38+
raise ValueError(
39+
f"Unknown datacenter '{value}'. Valid datacenters: {valid}"
40+
)
41+
42+
@classmethod
43+
def all(cls) -> list["DataCenter"]:
44+
"""Return all datacenters."""
45+
return list(cls)
46+
47+
48+
# data centers that support CPU serverless endpoints
49+
CPU_DATACENTERS: frozenset[DataCenter] = frozenset(
50+
{
51+
DataCenter.EU_RO_1,
52+
}
53+
)

src/runpod_flash/core/resources/network_volume.py

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import hashlib
22
import logging
3-
from enum import Enum
43
from typing import Optional, Dict, Any
54

65
from pydantic import (
@@ -10,6 +9,7 @@
109
field_serializer,
1110
model_validator,
1211
)
12+
from .datacenter import DataCenter
1313

1414
from ..api.runpod import RunpodRestClient
1515
from ..urls import RUNPOD_CONSOLE_URL
@@ -19,53 +19,6 @@
1919
log = logging.getLogger(__name__)
2020

2121

22-
class DataCenter(str, Enum):
23-
"""Enum representing available RunPod data centers."""
24-
25-
# north america
26-
US_CA_2 = "US-CA-2"
27-
US_GA_2 = "US-GA-2"
28-
US_IL_1 = "US-IL-1"
29-
US_KS_2 = "US-KS-2"
30-
US_MD_1 = "US-MD-1"
31-
US_MO_1 = "US-MO-1"
32-
US_MO_2 = "US-MO-2"
33-
US_NC_1 = "US-NC-1"
34-
US_NC_2 = "US-NC-2"
35-
US_NE_1 = "US-NE-1"
36-
US_WA_1 = "US-WA-1"
37-
38-
# europe
39-
EU_CZ_1 = "EU-CZ-1"
40-
EU_RO_1 = "EU-RO-1"
41-
EUR_IS_1 = "EUR-IS-1"
42-
EUR_NO_1 = "EUR-NO-1"
43-
44-
@classmethod
45-
def from_string(cls, value: str) -> "DataCenter":
46-
"""Parse a datacenter ID string into a DataCenter enum.
47-
48-
Accepts the canonical form (e.g. "EU-RO-1") as well as common
49-
variations like lowercase or underscore-separated.
50-
"""
51-
normalized = value.strip().upper().replace("_", "-")
52-
try:
53-
return cls(normalized)
54-
except ValueError:
55-
valid = ", ".join(dc.value for dc in cls)
56-
raise ValueError(
57-
f"Unknown datacenter '{value}'. Valid datacenters: {valid}"
58-
)
59-
60-
61-
# data centers that support CPU serverless endpoints
62-
CPU_DATACENTERS: frozenset[DataCenter] = frozenset(
63-
{
64-
DataCenter.EU_RO_1,
65-
}
66-
)
67-
68-
6922
class NetworkVolume(DeployableResource):
7023
"""
7124
NetworkVolume resource for creating and managing Runpod network volumes.

src/runpod_flash/core/resources/serverless.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
)
3333
from .cpu import CpuInstanceType
3434
from .gpu import GpuGroup, GpuType
35-
from .network_volume import NetworkVolume, DataCenter, CPU_DATACENTERS
35+
from .network_volume import NetworkVolume
36+
from .datacenter import DataCenter, CPU_DATACENTERS
3637
from .request_logs import QBRequestLogBatch, QBRequestLogFetcher, QBRequestLogPhase
3738
from .worker_availability_diagnostic import WorkerAvailabilityDiagnostic
3839
from .template import KeyValuePair, PodTemplate

src/runpod_flash/endpoint.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from .core.resources.constants import DEFAULT_WORKERS_MAX, DEFAULT_WORKERS_MIN
88
from .core.resources.cpu import CpuInstanceType
99
from .core.resources.gpu import GpuGroup, GpuType
10-
from .core.resources.network_volume import DataCenter, NetworkVolume
10+
from .core.resources.network_volume import NetworkVolume
11+
from .core.resources.datacenter import DataCenter
1112
from .core.resources.serverless import CudaVersion, ServerlessScalerType
1213
from .core.resources.template import PodTemplate
1314

@@ -413,6 +414,11 @@ def __init__(
413414
if not self._is_cpu and self._gpu is None and not self.is_client:
414415
self._gpu = [GpuGroup.ANY]
415416

417+
# if not in pure client mode, make sure default datacenters are set
418+
# not CPU though, that gets pinned to specific datacenters
419+
if not self._is_cpu and not self.is_client and not self.datacenter:
420+
self.datacenter = DataCenter.all()
421+
416422
# lb routes registered via .get()/.post()/etc (decorator mode only)
417423
self._routes: List[Dict[str, Any]] = []
418424

tests/unit/cli/commands/build_utils/test_manifest.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,7 @@ def test_extract_deployment_config_includes_network_volume():
513513

514514
resource_py = project_dir / "resource.py"
515515
resource_py.write_text(
516-
"from runpod_flash import NetworkVolume\n"
517-
"from runpod_flash.core.resources.network_volume import DataCenter\n"
516+
"from runpod_flash import NetworkVolume, DataCenter\n"
518517
"\n"
519518
"class gpu_config:\n"
520519
' imageName = "test-image"\n'
@@ -631,8 +630,7 @@ def test_extract_deployment_config_includes_network_volumes():
631630

632631
resource_py = project_dir / "resource.py"
633632
resource_py.write_text(
634-
"from runpod_flash import NetworkVolume\n"
635-
"from runpod_flash.core.resources.network_volume import DataCenter\n"
633+
"from runpod_flash import NetworkVolume, DataCenter\n"
636634
"\n"
637635
"class gpu_config:\n"
638636
' imageName = "test-image"\n'
@@ -645,7 +643,7 @@ def test_extract_deployment_config_includes_network_volumes():
645643
" NetworkVolume(\n"
646644
' name="vol-us",\n'
647645
" size=200,\n"
648-
" dataCenterId=DataCenter.US_GA_2,\n"
646+
" dataCenterId=DataCenter.US_CA_2,\n"
649647
" ),\n"
650648
" ]\n"
651649
)
@@ -675,7 +673,7 @@ def test_extract_deployment_config_includes_network_volumes():
675673
assert config["networkVolumes"][0]["dataCenterId"] == "EU-RO-1"
676674
assert config["networkVolumes"][1]["name"] == "vol-us"
677675
assert config["networkVolumes"][1]["size"] == 200
678-
assert config["networkVolumes"][1]["dataCenterId"] == "US-GA-2"
676+
assert config["networkVolumes"][1]["dataCenterId"] == "US-CA-2"
679677
assert "networkVolume" not in config
680678

681679

tests/unit/resources/test_network_volume.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import pytest
77
from pydantic import ValidationError
88

9-
from runpod_flash.core.resources.network_volume import NetworkVolume, DataCenter
9+
from runpod_flash.core.resources.network_volume import NetworkVolume
10+
from runpod_flash.core.resources.datacenter import DataCenter
1011

1112

1213
class TestNetworkVolumeIdempotent:

tests/unit/resources/test_serverless.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from runpod_flash.core.resources.serverless_cpu import CpuServerlessEndpoint
2323
from runpod_flash.core.resources.gpu import GpuGroup
2424
from runpod_flash.core.resources.cpu import CpuInstanceType
25-
from runpod_flash.core.resources.network_volume import NetworkVolume, DataCenter
25+
from runpod_flash.core.resources.network_volume import NetworkVolume
26+
from runpod_flash.core.resources.datacenter import DataCenter
2627
from runpod_flash.core.resources.request_logs import (
2728
QBRequestLogBatch,
2829
QBRequestLogPhase,
@@ -220,12 +221,12 @@ class TestMultiVolumeDeployPath:
220221
@pytest.mark.asyncio
221222
async def test_multi_volume_deploys_all_and_collects_ids(self):
222223
vol_a = NetworkVolume(name="vol-a", size=50, dataCenterId=DataCenter.EU_RO_1)
223-
vol_b = NetworkVolume(name="vol-b", size=50, dataCenterId=DataCenter.US_GA_2)
224+
vol_b = NetworkVolume(name="vol-b", size=50, dataCenterId=DataCenter.US_CA_2)
224225

225226
serverless = ServerlessResource(
226227
name="test",
227228
networkVolumes=[vol_a, vol_b],
228-
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
229+
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
229230
)
230231

231232
async def fake_deploy(self_vol):
@@ -242,12 +243,12 @@ async def fake_deploy(self_vol):
242243
async def test_multi_volume_skips_already_created(self):
243244
vol_a = NetworkVolume(name="vol-a", size=50, dataCenterId=DataCenter.EU_RO_1)
244245
vol_a.id = "vol-aaa"
245-
vol_b = NetworkVolume(name="vol-b", size=50, dataCenterId=DataCenter.US_GA_2)
246+
vol_b = NetworkVolume(name="vol-b", size=50, dataCenterId=DataCenter.US_CA_2)
246247

247248
serverless = ServerlessResource(
248249
name="test",
249250
networkVolumes=[vol_a, vol_b],
250-
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
251+
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
251252
)
252253

253254
deploy_calls = []
@@ -330,7 +331,7 @@ def test_single_volume_payload_uses_singular_field(self):
330331
def test_multi_volume_drift_detection(self):
331332
"""Changing networkVolumes changes the config hash."""
332333
vol_a = NetworkVolume(name="vol-a", size=50, dataCenterId=DataCenter.EU_RO_1)
333-
vol_b = NetworkVolume(name="vol-b", size=50, dataCenterId=DataCenter.US_GA_2)
334+
vol_b = NetworkVolume(name="vol-b", size=50, dataCenterId=DataCenter.US_CA_2)
334335

335336
s1 = ServerlessResource(
336337
name="test",
@@ -340,7 +341,7 @@ def test_multi_volume_drift_detection(self):
340341
s2 = ServerlessResource(
341342
name="test",
342343
networkVolumes=[vol_a, vol_b],
343-
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
344+
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
344345
)
345346

346347
assert s1.config_hash != s2.config_hash
@@ -470,9 +471,9 @@ def test_datacenter_multiple_values(self):
470471
"""Test datacenter accepts a list of DataCenter values."""
471472
serverless = ServerlessResource(
472473
name="test",
473-
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
474+
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
474475
)
475-
assert serverless.datacenter == [DataCenter.EU_RO_1, DataCenter.US_GA_2]
476+
assert serverless.datacenter == [DataCenter.EU_RO_1, DataCenter.US_CA_2]
476477

477478
def test_datacenter_string_value(self):
478479
"""Test datacenter accepts string values."""
@@ -481,8 +482,8 @@ def test_datacenter_string_value(self):
481482

482483
def test_datacenter_string_list(self):
483484
"""Test datacenter accepts list of strings."""
484-
serverless = ServerlessResource(name="test", datacenter=["EU-RO-1", "US-GA-2"])
485-
assert serverless.datacenter == [DataCenter.EU_RO_1, DataCenter.US_GA_2]
485+
serverless = ServerlessResource(name="test", datacenter=["EU-RO-1", "US-CA-2"])
486+
assert serverless.datacenter == [DataCenter.EU_RO_1, DataCenter.US_CA_2]
486487

487488
def test_datacenter_invalid_string_raises(self):
488489
"""Test that an invalid datacenter string raises ValueError."""
@@ -498,9 +499,9 @@ def test_locations_synced_from_multi_datacenter(self):
498499
"""Test locations field gets synced from multiple datacenters."""
499500
serverless = ServerlessResource(
500501
name="test",
501-
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
502+
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
502503
)
503-
assert serverless.locations == "EU-RO-1,US-GA-2"
504+
assert serverless.locations == "EU-RO-1,US-CA-2"
504505

505506
def test_no_datacenter_no_locations(self):
506507
"""Test that no datacenter means no locations restriction."""
@@ -509,9 +510,9 @@ def test_no_datacenter_no_locations(self):
509510

510511
def test_explicit_locations_not_overridden(self):
511512
"""Test explicit locations field is not overridden."""
512-
serverless = ServerlessResource(name="test", locations="US-GA-2")
513+
serverless = ServerlessResource(name="test", locations="US-CA-2")
513514

514-
assert serverless.locations == "US-GA-2"
515+
assert serverless.locations == "US-CA-2"
515516

516517
def test_datacenter_validation_matching_datacenters(self):
517518
"""Test that matching datacenters between endpoint and volume work."""
@@ -525,7 +526,7 @@ def test_datacenter_validation_matching_datacenters(self):
525526

526527
def test_datacenter_validation_volume_not_in_dc_list(self):
527528
"""Test that a volume DC not in endpoint's DC list raises an error."""
528-
volume = NetworkVolume(name="test-volume", dataCenterId=DataCenter.US_GA_2)
529+
volume = NetworkVolume(name="test-volume", dataCenterId=DataCenter.US_CA_2)
529530
with pytest.raises(
530531
ValueError,
531532
match="Network volume datacenter.*is not in the endpoint's datacenter list",
@@ -536,9 +537,9 @@ def test_datacenter_validation_volume_not_in_dc_list(self):
536537

537538
def test_volume_dc_allowed_when_no_datacenter_set(self):
538539
"""Test that any volume DC is allowed when no datacenter restriction is set."""
539-
volume = NetworkVolume(name="test-volume", dataCenterId=DataCenter.US_GA_2)
540+
volume = NetworkVolume(name="test-volume", dataCenterId=DataCenter.US_CA_2)
540541
serverless = ServerlessResource(name="test", networkVolume=volume)
541-
assert serverless.networkVolume.dataCenterId == DataCenter.US_GA_2
542+
assert serverless.networkVolume.dataCenterId == DataCenter.US_CA_2
542543

543544
def test_no_flashboot_keeps_name(self):
544545
"""Test flashboot=False keeps original name."""
@@ -629,7 +630,7 @@ def test_single_volume_compat(self):
629630
def test_multiple_volumes_via_list(self):
630631
"""Test networkVolumes accepts multiple volumes."""
631632
v1 = NetworkVolume(name="v1", dataCenterId=DataCenter.EU_RO_1)
632-
v2 = NetworkVolume(name="v2", dataCenterId=DataCenter.US_GA_2)
633+
v2 = NetworkVolume(name="v2", dataCenterId=DataCenter.US_CA_2)
633634
s = ServerlessResource(name="test", networkVolumes=[v1, v2])
634635
assert len(s.networkVolumes) == 2
635636
assert s.networkVolume is v1
@@ -643,7 +644,7 @@ def test_duplicate_dc_raises(self):
643644

644645
def test_volumes_dc_outside_endpoint_dc_raises(self):
645646
"""Test volume DC not in endpoint's DC list raises."""
646-
vol = NetworkVolume(name="v1", dataCenterId=DataCenter.US_GA_2)
647+
vol = NetworkVolume(name="v1", dataCenterId=DataCenter.US_CA_2)
647648
with pytest.raises(
648649
ValueError,
649650
match="is not in the endpoint's datacenter list",
@@ -657,10 +658,10 @@ def test_volumes_dc_outside_endpoint_dc_raises(self):
657658
def test_volumes_dc_within_endpoint_dc_list(self):
658659
"""Test volume DCs all within endpoint DC list works."""
659660
v1 = NetworkVolume(name="v1", dataCenterId=DataCenter.EU_RO_1)
660-
v2 = NetworkVolume(name="v2", dataCenterId=DataCenter.US_GA_2)
661+
v2 = NetworkVolume(name="v2", dataCenterId=DataCenter.US_CA_2)
661662
s = ServerlessResource(
662663
name="test",
663-
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
664+
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
664665
networkVolumes=[v1, v2],
665666
)
666667
assert len(s.networkVolumes) == 2
@@ -684,7 +685,7 @@ def test_cpu_endpoint_in_unsupported_dc_raises(self):
684685
CpuServerlessEndpoint(
685686
name="test-cpu",
686687
imageName="test/cpu:latest",
687-
datacenter=DataCenter.US_GA_2,
688+
datacenter=DataCenter.US_CA_2,
688689
)
689690

690691
def test_cpu_endpoint_mixed_dcs_raises(self):
@@ -693,7 +694,7 @@ def test_cpu_endpoint_mixed_dcs_raises(self):
693694
CpuServerlessEndpoint(
694695
name="test-cpu",
695696
imageName="test/cpu:latest",
696-
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
697+
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
697698
)
698699

699700
def test_cpu_endpoint_no_datacenter_ok(self):
@@ -708,9 +709,9 @@ def test_gpu_endpoint_any_dc_ok(self):
708709
"""Test GPU endpoint in any datacenter is allowed."""
709710
serverless = ServerlessResource(
710711
name="test-gpu",
711-
datacenter=DataCenter.US_GA_2,
712+
datacenter=DataCenter.US_CA_2,
712713
)
713-
assert serverless.datacenter == [DataCenter.US_GA_2]
714+
assert serverless.datacenter == [DataCenter.US_CA_2]
714715

715716

716717
class TestMinCudaVersion:

0 commit comments

Comments
 (0)