Skip to content

Commit aa33d91

Browse files
Implement warmpool
1 parent ba38bbf commit aa33d91

File tree

2 files changed

+138
-70
lines changed

2 files changed

+138
-70
lines changed

Diff for: scripts/tabflow/evaluate.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,22 @@
1313
parser = argparse.ArgumentParser()
1414

1515
parser.add_argument('--experiment_name', type=str, required=True, help="Name of the experiment")
16+
parser.add_argument('--context_name', type=str, required=True, help="Name of the context")
1617
parser.add_argument('--datasets', nargs='+', type=str, required=True, help="List of datasets to evaluate")
1718
parser.add_argument('--folds', nargs='+', type=int, required=True, help="List of folds to evaluate")
1819
parser.add_argument('--methods', type=str, required=True, help="Path to the YAML file containing methods")
1920
args = parser.parse_args()
2021

2122
# Load Context
22-
context_name = "D244_F3_C1530_30" # 30 Datasets. To run larger, set to "D244_F3_C1530_200"
23+
context_name = args.context_name #"D244_F3_C1530_30" # 30 Datasets. To run larger, set to "D244_F3_C1530_200"
2324
expname = args.experiment_name # folder location of all experiment artifacts
2425
ignore_cache = False # set to True to overwrite existing caches and re-run experiments from scratch
2526

2627
#TODO: Download the repo without pred-proba
2728
repo_og: EvaluationRepository = EvaluationRepository.from_context(context_name, cache=True)
2829

29-
if args.datasets == "run_all":
30-
datasets = repo_og.datasets() # run on all datasets
31-
else:
32-
datasets = args.datasets
33-
34-
if -1 in args.folds:
35-
folds = repo_og.folds # run on all folds
36-
else:
37-
folds = args.folds
30+
datasets = args.datasets
31+
folds = args.folds
3832

3933
# Parse fit_kwargs from JSON string
4034
fit_kwargs = json.loads(args.fit_kwargs)

Diff for: scripts/tabflow/launch_jobs.py

+134-60
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,41 @@
11
import boto3
22
import sagemaker
3-
from pathlib import Path
4-
import fire
53
import re
4+
import yaml
5+
import argparse
6+
import json
7+
import time
8+
69
from datetime import datetime
10+
from pathlib import Path
11+
from tabrepo import EvaluationRepository
712

813

914
DOCKER_IMAGE_ALIASES = {
10-
"mlflow-image": "097403188315.dkr.ecr.us-west-2.amazonaws.com/pmdesai:mlflow-tabrepo",
15+
"mlflow-image": "ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com/pmdesai:mlflow-tabrepo",
1116
}
1217

1318

19+
class TrainingJobResourceManager:
20+
def __init__(self, sagemaker_client, max_concurrent_jobs):
21+
self.sagemaker_client = sagemaker_client
22+
self.max_concurrent_jobs = max_concurrent_jobs
23+
24+
def get_running_jobs_count(self):
25+
response = self.sagemaker_client.list_training_jobs(StatusEquals='InProgress', MaxResults=100)
26+
return len(response['TrainingJobSummaries'])
27+
28+
def wait_for_available_slot(self, poll_interval=30):
29+
while True:
30+
print("\nTraining Jobs: ", self.get_running_jobs_count())
31+
current_jobs = self.get_running_jobs_count()
32+
if current_jobs < (self.max_concurrent_jobs):
33+
return current_jobs
34+
print(f"Currently running {current_jobs}/{self.max_concurrent_jobs} jobs. Waiting...")
35+
time.sleep(poll_interval)
36+
37+
38+
1439
def sanitize_job_name(name: str) -> str:
1540
"""
1641
Sanitize the job name to meet SageMaker requirements:
@@ -34,16 +59,21 @@ def sanitize_job_name(name: str) -> str:
3459

3560
def launch_jobs(
3661
experiment_name: str = "tabflow",
62+
context_name: str = "D244_F3_C1530_30", # 30 datasets. To run larger, set to "D244_F3_C1530_200"
3763
entry_point: str = "evaluate.py",
3864
source_dir: str = ".",
3965
instance_type: str = "ml.m6i.4xlarge",
4066
docker_image_uri: str = "mlflow-image",
41-
sagemaker_role: str = "arn:aws:iam::097403188315:role/service-role/AmazonSageMaker-ExecutionRole-20250128T153145",
67+
sagemaker_role: str = "arn:aws:iam::ACCOUNT_ID:role/service-role/AmazonSageMakerRole",
4268
aws_profile: str | None = None,
4369
hyperparameters: dict = None,
4470
job_name: str = None,
45-
keep_alive_period_in_seconds: int = 300,
71+
keep_alive_period_in_seconds: int = 3600,
4672
limit_runtime: int = 24 * 60 * 60,
73+
datasets: list = None,
74+
folds: list = None,
75+
methods_file: str = "methods.yaml",
76+
max_concurrent_jobs: int = 30,
4777
) -> None:
4878
"""
4979
Launch multiple SageMaker training jobs.
@@ -58,73 +88,117 @@ def launch_jobs(
5888
aws_profile: AWS profile name
5989
hyperparameters: Dictionary of hyperparameters to pass to the training script
6090
job_name: Name for the training job
61-
keep_alive_period_in_seconds: Idle time before terminating the instance
91+
keep_alive_period_in_seconds: Idle time before terminating the instance
6292
limit_runtime: Maximum running time in seconds
93+
datasets: List of datasets to evaluate
94+
folds: List of folds to evaluate
95+
methods_file: Path to the YAML file containing methods
96+
max_concurrent_jobs: Maximum number of concurrent jobs, based on account limit
6397
"""
6498
timestamp = datetime.now().strftime("%d-%b-%Y-%H:%M:%S.%f")[:-3]
6599
experiment_name = f"{experiment_name}-{timestamp}"
66100

101+
# Create a SageMaker client session
102+
boto_session = boto3.Session(profile_name=aws_profile) if aws_profile else boto3.Session()
103+
sagemaker_client = boto_session.client('sagemaker')
104+
sagemaker_session = sagemaker.Session(boto_session=boto_session)
105+
106+
# Initialize the resource manager
107+
resource_manager = TrainingJobResourceManager(sagemaker_client=sagemaker_client, max_concurrent_jobs=max_concurrent_jobs)
108+
67109
# Load methods from YAML file
68-
with open(args.methods, 'r') as file:
110+
with open(methods_file, 'r') as file:
69111
methods_data = yaml.safe_load(file)
70112

71-
methods = [(method["name"], eval(method["wrapper_class"]), method["fit_kwargs"]) for method in methods_data["methods"]]
72-
73-
74-
for dataset in datasets:
75-
for fold in folds:
76-
for method in methods:
77-
method_name, wrapper_class, fit_kwargs = method
78-
# Create a unique job name
79-
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
80-
base_name = f"{dataset[:4]}-f{fold}-{method_name[:4]}-{timestamp}"
81-
job_name = sanitize_job_name(base_name)
82-
83-
84-
if docker_image_uri in DOCKER_IMAGE_ALIASES:
85-
print(f"Expanding docker_image_uri alias '{docker_image_uri}' -> '{DOCKER_IMAGE_ALIASES[docker_image_uri]}'")
86-
docker_image_uri = DOCKER_IMAGE_ALIASES[docker_image_uri]
87-
88-
# Create SageMaker session
89-
sagemaker_session = (
90-
sagemaker.Session(boto_session=boto3.Session(profile_name=aws_profile))
91-
if aws_profile is not None
92-
else sagemaker.Session()
93-
)
94-
95-
# Update hyperparameters for this job
96-
job_hyperparameters = hyperparameters.copy() if hyperparameters else {}
97-
job_hyperparameters.update({
98-
"experiment_name": experiment_name,
99-
"dataset": dataset,
100-
"fold": fold, # NOTE: Can be a 'str' as well, refer to Estimators in SM docs
101-
"method_name": method_name,
102-
"wrapper_class": wrapper_class,
103-
"fit_kwargs": f"'{json.dumps(fit_kwargs)}'",
104-
})
105-
106-
# Create the estimator
107-
estimator = sagemaker.estimator.Estimator(
108-
entry_point=entry_point,
109-
source_dir=source_dir,
110-
image_uri=docker_image_uri,
111-
role=sagemaker_role,
112-
instance_count=1,
113-
instance_type=instance_type,
114-
sagemaker_session=sagemaker_session,
115-
hyperparameters=job_hyperparameters,
116-
keep_alive_period_in_seconds=keep_alive_period_in_seconds,
117-
max_run=limit_runtime,
118-
)
119-
120-
# Launch the training job
121-
estimator.fit(wait=False, job_name=job_name)
122-
print(f"Launched training job: {estimator.latest_training_job.name}")
113+
methods = [(method["name"], method["wrapper_class"], method["fit_kwargs"]) for method in methods_data["methods"]]
114+
115+
repo_og: EvaluationRepository = EvaluationRepository.from_context(context_name, cache=True)
116+
117+
if "run_all" in datasets:
118+
datasets = repo_og.datasets()
119+
else:
120+
datasets = datasets
121+
122+
if -1 in folds:
123+
folds = repo_og.folds
124+
else:
125+
folds = folds
126+
127+
total_jobs = len(datasets) * len(folds) * len(methods)
128+
total_launched_jobs = 0
129+
130+
print(f"Preparing to launch {total_jobs} jobs with max concurrency of {max_concurrent_jobs}")
131+
print(f"Instance keep-alive period set to {keep_alive_period_in_seconds} seconds to enable warm-starts")
132+
133+
try:
134+
for dataset in datasets:
135+
for fold in folds:
136+
for method in methods:
137+
138+
current_jobs = resource_manager.wait_for_available_slot()
139+
print(f"\nSlot available. Currently running {current_jobs}/{max_concurrent_jobs} jobs")
140+
141+
method_name, wrapper_class, fit_kwargs = method
142+
# Create a unique job name
143+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
144+
base_name = f"{dataset[:4]}-f{fold}-{method_name[:4]}-{timestamp}"
145+
job_name = sanitize_job_name(base_name)
146+
147+
148+
if docker_image_uri in DOCKER_IMAGE_ALIASES:
149+
print(f"Expanding docker_image_uri alias '{docker_image_uri}' -> '{DOCKER_IMAGE_ALIASES[docker_image_uri]}'")
150+
docker_image_uri = DOCKER_IMAGE_ALIASES[docker_image_uri]
151+
152+
# Update hyperparameters for this job
153+
job_hyperparameters = hyperparameters.copy() if hyperparameters else {}
154+
job_hyperparameters.update({
155+
"experiment_name": experiment_name,
156+
"context_name": context_name,
157+
"dataset": dataset,
158+
"fold": fold, # NOTE: Can be a 'str' as well, refer to Estimators in SM docs
159+
"method_name": method_name,
160+
"wrapper_class": wrapper_class,
161+
"fit_kwargs": f"'{json.dumps(fit_kwargs)}'",
162+
})
163+
164+
# Create the estimator
165+
estimator = sagemaker.estimator.Estimator(
166+
entry_point=entry_point,
167+
source_dir=source_dir,
168+
image_uri=docker_image_uri,
169+
role=sagemaker_role,
170+
instance_count=1,
171+
instance_type=instance_type,
172+
sagemaker_session=sagemaker_session,
173+
hyperparameters=job_hyperparameters,
174+
keep_alive_period_in_seconds=keep_alive_period_in_seconds,
175+
max_run=limit_runtime,
176+
)
177+
178+
# Launch the training job
179+
estimator.fit(wait=False, job_name=job_name)
180+
total_launched_jobs += 1
181+
print(f"Launched job {total_launched_jobs} out of a total of {total_jobs}: {job_name}")
182+
# print(f"Launched training job: {estimator.latest_training_job.name}")
183+
except Exception as e:
184+
print(f"Error launching jobs: {e}")
185+
raise
123186

124187

125188
def main():
126189
"""Entrypoint for CLI"""
127-
fire.Fire(launch_jobs)
190+
parser = argparse.ArgumentParser()
191+
parser.add_argument('--datasets', nargs='+', type=str, required=True, help="List of datasets to evaluate")
192+
parser.add_argument('--folds', nargs='+', type=int, required=True, help="List of folds to evaluate")
193+
parser.add_argument('--methods_file', type=str, required=True, help="Path to the YAML file containing methods")
194+
195+
args = parser.parse_args()
196+
197+
launch_jobs(
198+
datasets=args.datasets,
199+
folds=args.folds,
200+
methods_file=args.methods_file,
201+
)
128202

129203

130204
if __name__ == "__main__":

0 commit comments

Comments
 (0)