1
1
import boto3
2
2
import sagemaker
3
- from pathlib import Path
4
- import fire
5
3
import re
4
+ import yaml
5
+ import argparse
6
+ import json
7
+ import time
8
+
6
9
from datetime import datetime
10
+ from pathlib import Path
11
+ from tabrepo import EvaluationRepository
7
12
8
13
9
14
DOCKER_IMAGE_ALIASES = {
10
- "mlflow-image" : "097403188315 .dkr.ecr.us-west-2.amazonaws.com/pmdesai:mlflow-tabrepo" ,
15
+ "mlflow-image" : "ACCOUNT_ID .dkr.ecr.us-west-2.amazonaws.com/pmdesai:mlflow-tabrepo" ,
11
16
}
12
17
13
18
19
+ class TrainingJobResourceManager :
20
+ def __init__ (self , sagemaker_client , max_concurrent_jobs ):
21
+ self .sagemaker_client = sagemaker_client
22
+ self .max_concurrent_jobs = max_concurrent_jobs
23
+
24
+ def get_running_jobs_count (self ):
25
+ response = self .sagemaker_client .list_training_jobs (StatusEquals = 'InProgress' , MaxResults = 100 )
26
+ return len (response ['TrainingJobSummaries' ])
27
+
28
+ def wait_for_available_slot (self , poll_interval = 30 ):
29
+ while True :
30
+ print ("\n Training Jobs: " , self .get_running_jobs_count ())
31
+ current_jobs = self .get_running_jobs_count ()
32
+ if current_jobs < (self .max_concurrent_jobs ):
33
+ return current_jobs
34
+ print (f"Currently running { current_jobs } /{ self .max_concurrent_jobs } jobs. Waiting..." )
35
+ time .sleep (poll_interval )
36
+
37
+
38
+
14
39
def sanitize_job_name (name : str ) -> str :
15
40
"""
16
41
Sanitize the job name to meet SageMaker requirements:
@@ -34,16 +59,21 @@ def sanitize_job_name(name: str) -> str:
34
59
35
60
def launch_jobs (
36
61
experiment_name : str = "tabflow" ,
62
+ context_name : str = "D244_F3_C1530_30" , # 30 datasets. To run larger, set to "D244_F3_C1530_200"
37
63
entry_point : str = "evaluate.py" ,
38
64
source_dir : str = "." ,
39
65
instance_type : str = "ml.m6i.4xlarge" ,
40
66
docker_image_uri : str = "mlflow-image" ,
41
- sagemaker_role : str = "arn:aws:iam::097403188315 :role/service-role/AmazonSageMaker-ExecutionRole-20250128T153145 " ,
67
+ sagemaker_role : str = "arn:aws:iam::ACCOUNT_ID :role/service-role/AmazonSageMakerRole " ,
42
68
aws_profile : str | None = None ,
43
69
hyperparameters : dict = None ,
44
70
job_name : str = None ,
45
- keep_alive_period_in_seconds : int = 300 ,
71
+ keep_alive_period_in_seconds : int = 3600 ,
46
72
limit_runtime : int = 24 * 60 * 60 ,
73
+ datasets : list = None ,
74
+ folds : list = None ,
75
+ methods_file : str = "methods.yaml" ,
76
+ max_concurrent_jobs : int = 30 ,
47
77
) -> None :
48
78
"""
49
79
Launch multiple SageMaker training jobs.
@@ -58,73 +88,117 @@ def launch_jobs(
58
88
aws_profile: AWS profile name
59
89
hyperparameters: Dictionary of hyperparameters to pass to the training script
60
90
job_name: Name for the training job
61
- keep_alive_period_in_seconds: Idle time before terminating the instance
91
+ keep_alive_period_in_seconds: Idle time before terminating the instance
62
92
limit_runtime: Maximum running time in seconds
93
+ datasets: List of datasets to evaluate
94
+ folds: List of folds to evaluate
95
+ methods_file: Path to the YAML file containing methods
96
+ max_concurrent_jobs: Maximum number of concurrent jobs, based on account limit
63
97
"""
64
98
timestamp = datetime .now ().strftime ("%d-%b-%Y-%H:%M:%S.%f" )[:- 3 ]
65
99
experiment_name = f"{ experiment_name } -{ timestamp } "
66
100
101
+ # Create a SageMaker client session
102
+ boto_session = boto3 .Session (profile_name = aws_profile ) if aws_profile else boto3 .Session ()
103
+ sagemaker_client = boto_session .client ('sagemaker' )
104
+ sagemaker_session = sagemaker .Session (boto_session = boto_session )
105
+
106
+ # Initialize the resource manager
107
+ resource_manager = TrainingJobResourceManager (sagemaker_client = sagemaker_client , max_concurrent_jobs = max_concurrent_jobs )
108
+
67
109
# Load methods from YAML file
68
- with open (args . methods , 'r' ) as file :
110
+ with open (methods_file , 'r' ) as file :
69
111
methods_data = yaml .safe_load (file )
70
112
71
- methods = [(method ["name" ], eval (method ["wrapper_class" ]), method ["fit_kwargs" ]) for method in methods_data ["methods" ]]
72
-
73
-
74
- for dataset in datasets :
75
- for fold in folds :
76
- for method in methods :
77
- method_name , wrapper_class , fit_kwargs = method
78
- # Create a unique job name
79
- timestamp = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
80
- base_name = f"{ dataset [:4 ]} -f{ fold } -{ method_name [:4 ]} -{ timestamp } "
81
- job_name = sanitize_job_name (base_name )
82
-
83
-
84
- if docker_image_uri in DOCKER_IMAGE_ALIASES :
85
- print (f"Expanding docker_image_uri alias '{ docker_image_uri } ' -> '{ DOCKER_IMAGE_ALIASES [docker_image_uri ]} '" )
86
- docker_image_uri = DOCKER_IMAGE_ALIASES [docker_image_uri ]
87
-
88
- # Create SageMaker session
89
- sagemaker_session = (
90
- sagemaker .Session (boto_session = boto3 .Session (profile_name = aws_profile ))
91
- if aws_profile is not None
92
- else sagemaker .Session ()
93
- )
94
-
95
- # Update hyperparameters for this job
96
- job_hyperparameters = hyperparameters .copy () if hyperparameters else {}
97
- job_hyperparameters .update ({
98
- "experiment_name" : experiment_name ,
99
- "dataset" : dataset ,
100
- "fold" : fold , # NOTE: Can be a 'str' as well, refer to Estimators in SM docs
101
- "method_name" : method_name ,
102
- "wrapper_class" : wrapper_class ,
103
- "fit_kwargs" : f"'{ json .dumps (fit_kwargs )} '" ,
104
- })
105
-
106
- # Create the estimator
107
- estimator = sagemaker .estimator .Estimator (
108
- entry_point = entry_point ,
109
- source_dir = source_dir ,
110
- image_uri = docker_image_uri ,
111
- role = sagemaker_role ,
112
- instance_count = 1 ,
113
- instance_type = instance_type ,
114
- sagemaker_session = sagemaker_session ,
115
- hyperparameters = job_hyperparameters ,
116
- keep_alive_period_in_seconds = keep_alive_period_in_seconds ,
117
- max_run = limit_runtime ,
118
- )
119
-
120
- # Launch the training job
121
- estimator .fit (wait = False , job_name = job_name )
122
- print (f"Launched training job: { estimator .latest_training_job .name } " )
113
+ methods = [(method ["name" ], method ["wrapper_class" ], method ["fit_kwargs" ]) for method in methods_data ["methods" ]]
114
+
115
+ repo_og : EvaluationRepository = EvaluationRepository .from_context (context_name , cache = True )
116
+
117
+ if "run_all" in datasets :
118
+ datasets = repo_og .datasets ()
119
+ else :
120
+ datasets = datasets
121
+
122
+ if - 1 in folds :
123
+ folds = repo_og .folds
124
+ else :
125
+ folds = folds
126
+
127
+ total_jobs = len (datasets ) * len (folds ) * len (methods )
128
+ total_launched_jobs = 0
129
+
130
+ print (f"Preparing to launch { total_jobs } jobs with max concurrency of { max_concurrent_jobs } " )
131
+ print (f"Instance keep-alive period set to { keep_alive_period_in_seconds } seconds to enable warm-starts" )
132
+
133
+ try :
134
+ for dataset in datasets :
135
+ for fold in folds :
136
+ for method in methods :
137
+
138
+ current_jobs = resource_manager .wait_for_available_slot ()
139
+ print (f"\n Slot available. Currently running { current_jobs } /{ max_concurrent_jobs } jobs" )
140
+
141
+ method_name , wrapper_class , fit_kwargs = method
142
+ # Create a unique job name
143
+ timestamp = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
144
+ base_name = f"{ dataset [:4 ]} -f{ fold } -{ method_name [:4 ]} -{ timestamp } "
145
+ job_name = sanitize_job_name (base_name )
146
+
147
+
148
+ if docker_image_uri in DOCKER_IMAGE_ALIASES :
149
+ print (f"Expanding docker_image_uri alias '{ docker_image_uri } ' -> '{ DOCKER_IMAGE_ALIASES [docker_image_uri ]} '" )
150
+ docker_image_uri = DOCKER_IMAGE_ALIASES [docker_image_uri ]
151
+
152
+ # Update hyperparameters for this job
153
+ job_hyperparameters = hyperparameters .copy () if hyperparameters else {}
154
+ job_hyperparameters .update ({
155
+ "experiment_name" : experiment_name ,
156
+ "context_name" : context_name ,
157
+ "dataset" : dataset ,
158
+ "fold" : fold , # NOTE: Can be a 'str' as well, refer to Estimators in SM docs
159
+ "method_name" : method_name ,
160
+ "wrapper_class" : wrapper_class ,
161
+ "fit_kwargs" : f"'{ json .dumps (fit_kwargs )} '" ,
162
+ })
163
+
164
+ # Create the estimator
165
+ estimator = sagemaker .estimator .Estimator (
166
+ entry_point = entry_point ,
167
+ source_dir = source_dir ,
168
+ image_uri = docker_image_uri ,
169
+ role = sagemaker_role ,
170
+ instance_count = 1 ,
171
+ instance_type = instance_type ,
172
+ sagemaker_session = sagemaker_session ,
173
+ hyperparameters = job_hyperparameters ,
174
+ keep_alive_period_in_seconds = keep_alive_period_in_seconds ,
175
+ max_run = limit_runtime ,
176
+ )
177
+
178
+ # Launch the training job
179
+ estimator .fit (wait = False , job_name = job_name )
180
+ total_launched_jobs += 1
181
+ print (f"Launched job { total_launched_jobs } out of a total of { total_jobs } : { job_name } " )
182
+ # print(f"Launched training job: {estimator.latest_training_job.name}")
183
+ except Exception as e :
184
+ print (f"Error launching jobs: { e } " )
185
+ raise
123
186
124
187
125
188
def main ():
126
189
"""Entrypoint for CLI"""
127
- fire .Fire (launch_jobs )
190
+ parser = argparse .ArgumentParser ()
191
+ parser .add_argument ('--datasets' , nargs = '+' , type = str , required = True , help = "List of datasets to evaluate" )
192
+ parser .add_argument ('--folds' , nargs = '+' , type = int , required = True , help = "List of folds to evaluate" )
193
+ parser .add_argument ('--methods_file' , type = str , required = True , help = "Path to the YAML file containing methods" )
194
+
195
+ args = parser .parse_args ()
196
+
197
+ launch_jobs (
198
+ datasets = args .datasets ,
199
+ folds = args .folds ,
200
+ methods_file = args .methods_file ,
201
+ )
128
202
129
203
130
204
if __name__ == "__main__" :
0 commit comments