Skip to content

Commit c60cd40

Browse files
author
Richard Michael
committed
seed filtered viz update, lambo rel.HV included
1 parent 0b0521a commit c60cd40

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

experiments/run_optimization_viz.py

+35-14
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
find_closest_wildtype_pdb_file_to_mutant
1818
from torch import Tensor
1919

20-
from corel.observers import (ABS_HYPER_VOLUME, BLACKBOX, MIN_BLACKBOX,
20+
from corel.observers import (ABS_HYPER_VOLUME, BLACKBOX,
21+
LAMBO_REL_HYPER_VOLUME, MIN_BLACKBOX,
2122
REL_HYPER_VOLUME, UNNORMALIZED_HV)
2223

2324
TRACKING_URI = "file:/Users/rcml/corel/results/slurm_mlruns/mlruns/"
2425
METRIC_DICT = {ABS_HYPER_VOLUME: "hypervolume",
2526
REL_HYPER_VOLUME: "rel. hypervolume",
2627
UNNORMALIZED_HV: "unnorm. hypervolume",
28+
LAMBO_REL_HYPER_VOLUME: "LamBO rel. hypervolume",
2729
"blackbox_0": r"$f_0$",
2830
"blackbox_1": r"$f_1$",
2931
"min_blackbox_0": r"$\min(f_0)$",
@@ -218,17 +220,34 @@ def optimization_line_figure(df: pd.DataFrame, metric: str, n_steps, title: str=
218220
full_size_df = unpack_observations(df, column=metric)
219221
if n_steps:
220222
full_size_df = full_size_df[full_size_df.step <= n_steps]
221-
batch_size = int(full_size_df.iloc[0,0].split("_")[-1][1:])
223+
# filter incomplete batches
224+
def filter_unique_counts(group):
225+
return group["step"].nunique() == full_size_df.step.max()+1
226+
full_size_df = full_size_df.groupby(["algorithm", "seed"]).filter(filter_unique_counts)
227+
print(f"Filtered seeds to {full_size_df.step.max()+1} completed steps")
228+
# we filter by minimal count of available seeds between COREL or LAMBO - if random has less dont take less!
229+
min_number_seeds = full_size_df[full_size_df.algorithm.str.split("_").str[0].isin(["COREL", "LAMBO"])].groupby(["algorithm"])["seed"].nunique().min()
230+
print(f"Minimal number of seeds: {min_number_seeds}")
231+
subselected_algo_dfs = []
232+
for algo in full_size_df.algorithm.unique(): # filter by minimal amount of overlapping seeds
233+
algo_df = full_size_df[full_size_df.algorithm==algo]
234+
min_seeds_for_algo = algo_df.seed.unique()[:min_number_seeds]
235+
subselected_df = algo_df[algo_df.seed.isin(min_seeds_for_algo)]
236+
subselected_algo_dfs.append(subselected_df)
237+
filtered_results_df = pd.concat(subselected_algo_dfs)
238+
n_seeds = filtered_results_df.groupby(["algorithm"])["seed"].nunique()
239+
print(f"n={n_seeds} seeds remaining for algorithms")
240+
batch_size = int(filtered_results_df.iloc[0,0].split("_")[-1][1:])
222241
# HACK to overlay plots: point and lineplot treat x-axis differently, ensure categorical
223-
full_size_df["step_str"] = full_size_df.step.astype(str)
242+
filtered_results_df["step_str"] = filtered_results_df.step.astype(str)
224243
fig, ax = plt.subplots(figsize=(5, 3.5))
225-
sns.lineplot(full_size_df, x="step_str", y=metric, hue="algorithm", ax=ax, palette=opt_colorscheme)
226-
batched_stats = full_size_df[full_size_df["step"] % batch_size == 0]
227-
sns.pointplot(batched_stats, x="step_str", y=metric, errorbar=("se", 1), capsize=.1, hue="algorithm", ax=ax, join="False", palette=opt_colorscheme)
244+
sns.lineplot(filtered_results_df, x="step_str", y=metric, hue="algorithm", ax=ax, palette=opt_colorscheme)
245+
batch_stats = filtered_results_df[filtered_results_df["step"] % batch_size == 0]
246+
sns.pointplot(batch_stats, x="step_str", y=metric, errorbar=("se", 1), capsize=.1, hue="algorithm", ax=ax, join="False", palette=opt_colorscheme)
228247
for line in ax.lines:
229248
line.set_markersize(3.)
230249
line.set_linewidth(1.)
231-
ax.set_xticks(np.arange(0, full_size_df["step"].max()+1, tick_every_batch*batch_size))
250+
ax.set_xticks(np.arange(0, filtered_results_df["step"].max()+1, tick_every_batch*batch_size))
232251
ax.tick_params(axis="x", labelsize=14, rotation=45)
233252
ax.tick_params(axis="y", labelsize=14)
234253
plt.xlabel("steps", fontsize=16)
@@ -239,8 +258,8 @@ def optimization_line_figure(df: pd.DataFrame, metric: str, n_steps, title: str=
239258
plt.legend(updated_legend.values(), updated_legend.keys())
240259
plt.subplots_adjust(top=0.99, right=0.972, left=0.17, bottom=0.25)
241260
figure_path = Path(__file__).parent.parent.resolve() / "results" / "figures" / "rfp"
242-
plt.savefig(f"{figure_path}/OPT_experiment_{metric.lower()}_{title.split()[0]}_batch{batch_size}.png")
243-
plt.savefig(f"{figure_path}/OPT_experiment_{metric.lower()}_{title.split()[0]}_batch{batch_size}.pdf")
261+
plt.savefig(f"{figure_path}/OPT_experiment_{metric.lower()}_{title.split()[0]}_batch{batch_size}_seeds{min_number_seeds}.png")
262+
plt.savefig(f"{figure_path}/OPT_experiment_{metric.lower()}_{title.split()[0]}_batch{batch_size}_seeds{min_number_seeds}.pdf")
244263
plt.show()
245264

246265

@@ -358,6 +377,7 @@ def load_viz_rfp_experiments(exp_name: str="rfp_foldx_stability_and_sasa",
358377
strict=True,
359378
n_steps: int=180,
360379
pareto_fig=False,
380+
metric_names: List[str]=METRIC_DICT.keys()
361381
):
362382
experiment_combinations = product(seeds, algorithms, starting_n, batch_size)
363383
mlf_client = mlflow.tracking.MlflowClient(tracking_uri=TRACKING_URI)
@@ -367,7 +387,7 @@ def load_viz_rfp_experiments(exp_name: str="rfp_foldx_stability_and_sasa",
367387
if finished_only:
368388
runs = [r for r in runs if r.info.status == "FINISHED"]
369389
run_results = filter_run_results(experiment_combinations, runs)
370-
metric_dict = get_algo_metric_history_from_run(mlf_client, run_results, algorithms=algorithms, seeds=seeds, batch_sizes=batch_size, starting_n=starting_n)
390+
metric_dict = get_algo_metric_history_from_run(mlf_client, run_results, algorithms=algorithms, seeds=seeds, batch_sizes=batch_size, starting_n=starting_n, metric_names=metric_names)
371391
experiment_results_df = pd.concat({k: pd.DataFrame.from_dict(v, 'index') for k,v in metric_dict.items()}, axis=0)
372392
experiment_results_df = experiment_results_df.reset_index().rename(columns={"level_0": "algorithm", "level_1": "seed"})
373393
experiment_combinations = product(seeds, algorithms, starting_n, batch_size)
@@ -388,7 +408,7 @@ def load_viz_rfp_experiments(exp_name: str="rfp_foldx_stability_and_sasa",
388408
for metric in METRIC_DICT.keys():
389409
if exp_name != "foldx_rfp_lambo":
390410
optimization_line_figure(cold_experiments[["algorithm", "seed", "starting_N", metric]], metric=metric, title="cold HV optimization N=6", strict=strict, n_steps=n_steps)
391-
optimization_line_figure(warm_experiments[["algorithm", "seed", "starting_N", metric]], metric=metric, title="warm HV optimization N=50", strict=strict, n_steps=None)
411+
optimization_line_figure(warm_experiments[["algorithm", "seed", "starting_N", metric]], metric=metric, title="warm HV optimization N=50", strict=strict, n_steps=n_steps*2)
392412
else:
393413
optimization_line_figure(ref_experiments[["algorithm", "seed", "starting_N", metric]], metric=metric, title="ref. HV optimization N=512", strict=strict, n_steps=n_steps)
394414
if pareto_fig:
@@ -486,10 +506,11 @@ def load_viz_gfp_experiments(
486506
if __name__ == "__main__":
487507
## LOAD AND VISUALIZE RFP EXPERIMENTS
488508
# RFP base experiments
489-
# load_viz_rfp_experiments(pareto_fig=False)
509+
load_viz_rfp_experiments(pareto_fig=True)
490510
# ## LOAD AND VISUALIZE GFP EXPERIMENTS
491511
load_viz_gfp_experiments()
492-
# # RFP reference experiments # SUPPLEMENTARY TODO
493-
# load_viz_rfp_experiments(exp_name="foldx_rfp_lambo", starting_n=["512"], finished_only=False)
512+
# RFP reference experiments
513+
load_viz_rfp_experiments(exp_name="foldx_rfp_lambo", starting_n=["512"],
514+
metric_names=list(METRIC_DICT.keys()) + [LAMBO_REL_HYPER_VOLUME], finished_only=False)
494515

495516

src/corel/observers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
UNNORMALIZED_HV = "HYPER_VOLUME_UNNORMALIZED"
1616
REL_HYPER_VOLUME = "REL_HYPER_VOLUME"
1717
BATCH_SIZE = "BATCH_SIZE"
18+
LAMBO_REL_HYPER_VOLUME = "LAMBO_REL_HYPER_VOLUME"

0 commit comments

Comments
 (0)