Skip to content

Commit 171f19e

Browse files
adding tabrepo benchmarking
1 parent 7e7844c commit 171f19e

File tree

6 files changed

+534
-5
lines changed

6 files changed

+534
-5
lines changed

examples/run_scripts_v5/run_simple_benchmark_w_simulator_realmlp.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# If the artifact is present, it will be used and the models will not be re-run.
1818
if __name__ == '__main__':
1919
# Load Context
20-
context_name = "D244_F3_C1530_200" # 200 smallest datasets. To run larger, set to "D244_F3_C1530_200"
20+
context_name = "D244_F3_C1530_30" # 30 Datasets. To run larger, set to "D244_F3_C1530_200"
2121
expname = "./initial_experiment_simple_simulator" # folder location of all experiment artifacts
2222
ignore_cache = False # set to True to overwrite existing caches and re-run experiments from scratch
2323

@@ -26,12 +26,12 @@
2626
repo_og: EvaluationRepository = EvaluationRepository.from_context(context_name, cache=True)
2727

2828
# Sample for a quick demo
29-
# datasets = repo_og.datasets()[:3]
30-
# folds = [0]
29+
datasets = ["Australian", "blood-transfusion-service-center"]
30+
folds = [0, 1]
3131

3232
# To run everything:
33-
datasets = repo_og.datasets()
34-
folds = repo_og.folds
33+
# datasets = repo_og.datasets()
34+
# folds = repo_og.folds
3535

3636
# TODO: Why is RealMLP slow when running sequentially / not in a bag? Way slower than it should be. Torch threads?
3737
methods = [

scripts/tabflow/Dockerfile_SM

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
FROM 763104351884.dkr.ecr.us-west-2.amazonaws.com/autogluon-training:1.1.1-cpu-py311-ubuntu20.04
2+
3+
# Copy the contents of the tab-benchmark folder to the container - run from tabflow
4+
COPY ../ .
5+
COPY ./tabflow/evaluate.py .
6+
7+
RUN pip install autogluon==1.1.1
8+
9+
RUN pip install fire
10+
# Install the required packages
11+
RUN pip install -e tabrepo \
12+
&& pip install -e autogluon-bench \
13+
&& pip install -e autogluon-benchmark
14+
15+
# Install pytabkit and seaborn
16+
RUN pip install pytabkit seaborn
17+
18+
RUN chmod +x ./evaluate.py

scripts/tabflow/evaluate.py

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from __future__ import annotations
2+
3+
import pandas as pd
4+
5+
from experiment_utils import ExperimentBatchRunner
6+
from tabrepo import EvaluationRepository, EvaluationRepositoryCollection, Evaluator
7+
from tabrepo.scripts_v5.AutoGluon_class import AGWrapper
8+
from tabrepo.scripts_v5.ag_models.realmlp_model import RealMLPModel
9+
10+
# If the artifact is present, it will be used and the models will not be re-run.
11+
if __name__ == '__main__':
12+
# Parse args
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument('--datasets', nargs='+', type=str, required=True, help="List of datasets to evaluate")
15+
parser.add_argument('--folds', nargs='+', type=int, required=True, help="List of folds to evaluate")
16+
parser.add_argument('--methods', type=str, required=True, help="Path to the YAML file containing methods")
17+
args = parser.parse_args()
18+
19+
# Load Context
20+
context_name = "D244_F3_C1530_30" # 30 Datasets. To run larger, set to "D244_F3_C1530_200"
21+
expname = "./initial_experiment_simple_simulator" # folder location of all experiment artifacts
22+
ignore_cache = False # set to True to overwrite existing caches and re-run experiments from scratch
23+
24+
repo_og: EvaluationRepository = EvaluationRepository.from_context(context_name, cache=True)
25+
26+
# Sample for a quick demo
27+
# datasets = ["Australian", "blood-transfusion-service-center"]
28+
# folds = [0, 1]
29+
30+
# To run everything:
31+
# datasets = repo_og.datasets()
32+
# folds = repo_og.folds
33+
34+
datasets = args.datasets
35+
if -1 in args.folds:
36+
folds = repo_og.folds # run on all folds
37+
else:
38+
folds = args.folds
39+
40+
# Load methods from YAML file
41+
with open(args.methods, 'r') as file:
42+
methods_data = yaml.safe_load(file)
43+
44+
methods = [(method["name"], eval(method["wrapper_class"]), method["fit_kwargs"]) for method in methods_data["methods"]]
45+
46+
# methods = [
47+
# (
48+
# "RealMLP_c1_BAG_L1_v4_noes_r0", # Name of the method
49+
# AGWrapper, # Wrapper class
50+
# {
51+
# "fit_kwargs": { # Fit kwargs: AutoGluon hyperparameters + custom model hyperparameters
52+
# "num_bag_folds": 8,
53+
# "num_bag_sets": 1,
54+
# "fit_weighted_ensemble": False,
55+
# "calibrate": False,
56+
# "verbosity": 2,
57+
# "hyperparameters": {
58+
# RealMLPModel: { # Custom model class and its hyperparameters
59+
# "random_state": 0,
60+
# "use_early_stopping": False,
61+
# },
62+
# },
63+
# }
64+
# },
65+
# ),
66+
# ]
67+
68+
tids = [repo_og.dataset_to_tid(dataset) for dataset in datasets]
69+
repo: EvaluationRepository = ExperimentBatchRunner().generate_repo_from_experiments(
70+
expname=expname,
71+
tids=tids,
72+
folds=folds,
73+
methods=methods,
74+
task_metadata=repo_og.task_metadata,
75+
ignore_cache=ignore_cache,
76+
convert_time_infer_s_from_batch_to_sample=True,
77+
)
78+
79+
repo.print_info()
80+
81+
save_path = "repo_new"
82+
repo.to_dir(path=save_path) # Load the repo later via `EvaluationRepository.from_dir(save_path)`
83+
84+
print(f"New Configs : {repo.configs()}")
85+
86+
repo_combined = EvaluationRepositoryCollection(repos=[repo_og, repo], config_fallback="ExtraTrees_c1_BAG_L1")
87+
repo_combined = repo_combined.subset(datasets=repo.datasets(), folds=repo.folds)
88+
89+
repo_combined.print_info()
90+
91+
comparison_configs_og = [
92+
"RandomForest_c1_BAG_L1",
93+
"ExtraTrees_c1_BAG_L1",
94+
"LightGBM_c1_BAG_L1",
95+
"XGBoost_c1_BAG_L1",
96+
"CatBoost_c1_BAG_L1",
97+
"NeuralNetTorch_c1_BAG_L1",
98+
"NeuralNetFastAI_c1_BAG_L1",
99+
]
100+
101+
comparison_configs = comparison_configs_og + [
102+
"RealMLP_c1_BAG_L1_v4_noes_r0",
103+
]
104+
105+
df_ensemble_results, df_ensemble_weights = repo_combined.evaluate_ensembles(configs=comparison_configs, ensemble_size=40)
106+
df_ensemble_results = df_ensemble_results.reset_index()
107+
df_ensemble_results["framework"] = "ensemble_with_RealMLP_c1"
108+
109+
df_ensemble_results_og, df_ensemble_weights_og = repo_combined.evaluate_ensembles(configs=comparison_configs_og, ensemble_size=40)
110+
df_ensemble_results_og = df_ensemble_results_og.reset_index()
111+
df_ensemble_results_og["framework"] = "ensemble_og"
112+
113+
results_df = pd.concat([
114+
df_ensemble_results,
115+
df_ensemble_results_og,
116+
], ignore_index=True)
117+
118+
baselines = [
119+
"AutoGluon_bq_4h8c_2023_11_14",
120+
]
121+
122+
evaluator = Evaluator(repo=repo_combined)
123+
124+
p = evaluator.plot_ensemble_weights(df_ensemble_weights=df_ensemble_weights, figsize=(16, 60))
125+
p.savefig("ensemble_weights_w_RealMLP_c1")
126+
127+
metrics = evaluator.compare_metrics(
128+
results_df=results_df,
129+
datasets=datasets,
130+
folds=folds,
131+
baselines=baselines,
132+
configs=comparison_configs,
133+
)
134+
135+
metrics_tmp = metrics.reset_index(drop=False)
136+
137+
with pd.option_context("display.max_rows", None, "display.max_columns", None, "display.width", 1000):
138+
print(f"Config Metrics Example:\n{metrics.head(100)}")
139+
140+
evaluator_output = evaluator.plot_overall_rank_comparison(
141+
results_df=metrics,
142+
save_dir=expname,
143+
evaluator_kwargs={
144+
"treat_folds_as_datasets": True,
145+
"frameworks_compare_vs_all": ["RealMLP_c1_BAG_L1_v4_noes_r0"],
146+
},
147+
)

0 commit comments

Comments
 (0)