Skip to content

Commit 72401c0

Browse files
committed
RHOAIENG-8098 - ClusterConfiguration should support tolerations
1 parent 6b0a3cc commit 72401c0

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/codeflare_sdk/ray/cluster/build_ray_cluster.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
This sub-module exists primarily to be used internally by the Cluster object
1717
(in the cluster sub-module) for RayCluster/AppWrapper generation.
1818
"""
19-
from typing import Union, Tuple, Dict
19+
from typing import List, Union, Tuple, Dict
2020
from ...common import _kube_api_error_handling
2121
from ...common.kubernetes_cluster import get_api_client, config_check
2222
from kubernetes.client.exceptions import ApiException
@@ -40,6 +40,7 @@
4040
V1PodTemplateSpec,
4141
V1PodSpec,
4242
V1LocalObjectReference,
43+
V1Toleration
4344
)
4445

4546
import yaml
@@ -139,7 +140,8 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
139140
"resources": head_resources,
140141
},
141142
"template": {
142-
"spec": get_pod_spec(cluster, [get_head_container_spec(cluster)])
143+
"spec": get_pod_spec(cluster, [get_head_container_spec(cluster)],
144+
cluster.config.head_tolerations)
143145
},
144146
},
145147
"workerGroupSpecs": [
@@ -154,7 +156,8 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
154156
"resources": worker_resources,
155157
},
156158
"template": V1PodTemplateSpec(
157-
spec=get_pod_spec(cluster, [get_worker_container_spec(cluster)])
159+
spec=get_pod_spec(cluster, [get_worker_container_spec(cluster)],
160+
cluster.config.tolerations)
158161
),
159162
}
160163
],
@@ -243,13 +246,14 @@ def update_image(image) -> str:
243246
return image
244247

245248

246-
def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers):
249+
def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers, tolerations):
247250
"""
248251
The get_pod_spec() function generates a V1PodSpec for the head/worker containers
249252
"""
250253
pod_spec = V1PodSpec(
251254
containers=containers,
252255
volumes=VOLUMES,
256+
tolerations=tolerations
253257
)
254258
if cluster.config.image_pull_secrets != []:
255259
pod_spec.image_pull_secrets = generate_image_pull_secrets(cluster)

src/codeflare_sdk/ray/cluster/config.py

+7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import warnings
2323
from dataclasses import dataclass, field, fields
2424
from typing import Dict, List, Optional, Union, get_args, get_origin
25+
from kubernetes.client import V1Toleration
2526

2627
dir = pathlib.Path(__file__).parent.parent.resolve()
2728

@@ -57,6 +58,8 @@ class ClusterConfiguration:
5758
The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
5859
head_extended_resource_requests:
5960
A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
61+
head_tolerations:
62+
List of tolerations for head nodes.
6063
min_cpus:
6164
The minimum number of CPUs to allocate to each worker.
6265
max_cpus:
@@ -69,6 +72,8 @@ class ClusterConfiguration:
6972
The maximum amount of memory to allocate to each worker.
7073
num_gpus:
7174
The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
75+
tolerations:
76+
List of tolerations for worker nodes.
7277
appwrapper:
7378
A boolean indicating whether to use an AppWrapper.
7479
envs:
@@ -105,6 +110,7 @@ class ClusterConfiguration:
105110
head_extended_resource_requests: Dict[str, Union[str, int]] = field(
106111
default_factory=dict
107112
)
113+
head_tolerations: Optional[List[V1Toleration]]
108114
worker_cpu_requests: Union[int, str] = 1
109115
worker_cpu_limits: Union[int, str] = 1
110116
min_cpus: Optional[Union[int, str]] = None # Deprecating
@@ -115,6 +121,7 @@ class ClusterConfiguration:
115121
min_memory: Optional[Union[int, str]] = None # Deprecating
116122
max_memory: Optional[Union[int, str]] = None # Deprecating
117123
num_gpus: Optional[int] = None # Deprecating
124+
tolerations: Optional[List[V1Toleration]]
118125
appwrapper: bool = False
119126
envs: Dict[str, str] = field(default_factory=dict)
120127
image: str = ""

0 commit comments

Comments
 (0)