17
17
find_closest_wildtype_pdb_file_to_mutant
18
18
from torch import Tensor
19
19
20
- from corel .observers import (ABS_HYPER_VOLUME , BLACKBOX , MIN_BLACKBOX ,
20
+ from corel .observers import (ABS_HYPER_VOLUME , BLACKBOX ,
21
+ LAMBO_REL_HYPER_VOLUME , MIN_BLACKBOX ,
21
22
REL_HYPER_VOLUME , UNNORMALIZED_HV )
22
23
23
24
TRACKING_URI = "file:/Users/rcml/corel/results/slurm_mlruns/mlruns/"
24
25
METRIC_DICT = {ABS_HYPER_VOLUME : "hypervolume" ,
25
26
REL_HYPER_VOLUME : "rel. hypervolume" ,
26
27
UNNORMALIZED_HV : "unnorm. hypervolume" ,
28
+ LAMBO_REL_HYPER_VOLUME : "LamBO rel. hypervolume" ,
27
29
"blackbox_0" : r"$f_0$" ,
28
30
"blackbox_1" : r"$f_1$" ,
29
31
"min_blackbox_0" : r"$\min(f_0)$" ,
@@ -218,17 +220,34 @@ def optimization_line_figure(df: pd.DataFrame, metric: str, n_steps, title: str=
218
220
full_size_df = unpack_observations (df , column = metric )
219
221
if n_steps :
220
222
full_size_df = full_size_df [full_size_df .step <= n_steps ]
221
- batch_size = int (full_size_df .iloc [0 ,0 ].split ("_" )[- 1 ][1 :])
223
+ # filter incomplete batches
224
+ def filter_unique_counts (group ):
225
+ return group ["step" ].nunique () == full_size_df .step .max ()+ 1
226
+ full_size_df = full_size_df .groupby (["algorithm" , "seed" ]).filter (filter_unique_counts )
227
+ print (f"Filtered seeds to { full_size_df .step .max ()+ 1 } completed steps" )
228
+ # we filter by minimal count of available seeds between COREL or LAMBO - if random has less dont take less!
229
+ min_number_seeds = full_size_df [full_size_df .algorithm .str .split ("_" ).str [0 ].isin (["COREL" , "LAMBO" ])].groupby (["algorithm" ])["seed" ].nunique ().min ()
230
+ print (f"Minimal number of seeds: { min_number_seeds } " )
231
+ subselected_algo_dfs = []
232
+ for algo in full_size_df .algorithm .unique (): # filter by minimal amount of overlapping seeds
233
+ algo_df = full_size_df [full_size_df .algorithm == algo ]
234
+ min_seeds_for_algo = algo_df .seed .unique ()[:min_number_seeds ]
235
+ subselected_df = algo_df [algo_df .seed .isin (min_seeds_for_algo )]
236
+ subselected_algo_dfs .append (subselected_df )
237
+ filtered_results_df = pd .concat (subselected_algo_dfs )
238
+ n_seeds = filtered_results_df .groupby (["algorithm" ])["seed" ].nunique ()
239
+ print (f"n={ n_seeds } seeds remaining for algorithms" )
240
+ batch_size = int (filtered_results_df .iloc [0 ,0 ].split ("_" )[- 1 ][1 :])
222
241
# HACK to overlay plots: point and lineplot treat x-axis differently, ensure categorical
223
- full_size_df ["step_str" ] = full_size_df .step .astype (str )
242
+ filtered_results_df ["step_str" ] = filtered_results_df .step .astype (str )
224
243
fig , ax = plt .subplots (figsize = (5 , 3.5 ))
225
- sns .lineplot (full_size_df , x = "step_str" , y = metric , hue = "algorithm" , ax = ax , palette = opt_colorscheme )
226
- batched_stats = full_size_df [ full_size_df ["step" ] % batch_size == 0 ]
227
- sns .pointplot (batched_stats , x = "step_str" , y = metric , errorbar = ("se" , 1 ), capsize = .1 , hue = "algorithm" , ax = ax , join = "False" , palette = opt_colorscheme )
244
+ sns .lineplot (filtered_results_df , x = "step_str" , y = metric , hue = "algorithm" , ax = ax , palette = opt_colorscheme )
245
+ batch_stats = filtered_results_df [ filtered_results_df ["step" ] % batch_size == 0 ]
246
+ sns .pointplot (batch_stats , x = "step_str" , y = metric , errorbar = ("se" , 1 ), capsize = .1 , hue = "algorithm" , ax = ax , join = "False" , palette = opt_colorscheme )
228
247
for line in ax .lines :
229
248
line .set_markersize (3. )
230
249
line .set_linewidth (1. )
231
- ax .set_xticks (np .arange (0 , full_size_df ["step" ].max ()+ 1 , tick_every_batch * batch_size ))
250
+ ax .set_xticks (np .arange (0 , filtered_results_df ["step" ].max ()+ 1 , tick_every_batch * batch_size ))
232
251
ax .tick_params (axis = "x" , labelsize = 14 , rotation = 45 )
233
252
ax .tick_params (axis = "y" , labelsize = 14 )
234
253
plt .xlabel ("steps" , fontsize = 16 )
@@ -239,8 +258,8 @@ def optimization_line_figure(df: pd.DataFrame, metric: str, n_steps, title: str=
239
258
plt .legend (updated_legend .values (), updated_legend .keys ())
240
259
plt .subplots_adjust (top = 0.99 , right = 0.972 , left = 0.17 , bottom = 0.25 )
241
260
figure_path = Path (__file__ ).parent .parent .resolve () / "results" / "figures" / "rfp"
242
- plt .savefig (f"{ figure_path } /OPT_experiment_{ metric .lower ()} _{ title .split ()[0 ]} _batch{ batch_size } .png" )
243
- plt .savefig (f"{ figure_path } /OPT_experiment_{ metric .lower ()} _{ title .split ()[0 ]} _batch{ batch_size } .pdf" )
261
+ plt .savefig (f"{ figure_path } /OPT_experiment_{ metric .lower ()} _{ title .split ()[0 ]} _batch{ batch_size } _seeds { min_number_seeds } .png" )
262
+ plt .savefig (f"{ figure_path } /OPT_experiment_{ metric .lower ()} _{ title .split ()[0 ]} _batch{ batch_size } _seeds { min_number_seeds } .pdf" )
244
263
plt .show ()
245
264
246
265
@@ -358,6 +377,7 @@ def load_viz_rfp_experiments(exp_name: str="rfp_foldx_stability_and_sasa",
358
377
strict = True ,
359
378
n_steps : int = 180 ,
360
379
pareto_fig = False ,
380
+ metric_names : List [str ]= METRIC_DICT .keys ()
361
381
):
362
382
experiment_combinations = product (seeds , algorithms , starting_n , batch_size )
363
383
mlf_client = mlflow .tracking .MlflowClient (tracking_uri = TRACKING_URI )
@@ -367,7 +387,7 @@ def load_viz_rfp_experiments(exp_name: str="rfp_foldx_stability_and_sasa",
367
387
if finished_only :
368
388
runs = [r for r in runs if r .info .status == "FINISHED" ]
369
389
run_results = filter_run_results (experiment_combinations , runs )
370
- metric_dict = get_algo_metric_history_from_run (mlf_client , run_results , algorithms = algorithms , seeds = seeds , batch_sizes = batch_size , starting_n = starting_n )
390
+ metric_dict = get_algo_metric_history_from_run (mlf_client , run_results , algorithms = algorithms , seeds = seeds , batch_sizes = batch_size , starting_n = starting_n , metric_names = metric_names )
371
391
experiment_results_df = pd .concat ({k : pd .DataFrame .from_dict (v , 'index' ) for k ,v in metric_dict .items ()}, axis = 0 )
372
392
experiment_results_df = experiment_results_df .reset_index ().rename (columns = {"level_0" : "algorithm" , "level_1" : "seed" })
373
393
experiment_combinations = product (seeds , algorithms , starting_n , batch_size )
@@ -388,7 +408,7 @@ def load_viz_rfp_experiments(exp_name: str="rfp_foldx_stability_and_sasa",
388
408
for metric in METRIC_DICT .keys ():
389
409
if exp_name != "foldx_rfp_lambo" :
390
410
optimization_line_figure (cold_experiments [["algorithm" , "seed" , "starting_N" , metric ]], metric = metric , title = "cold HV optimization N=6" , strict = strict , n_steps = n_steps )
391
- optimization_line_figure (warm_experiments [["algorithm" , "seed" , "starting_N" , metric ]], metric = metric , title = "warm HV optimization N=50" , strict = strict , n_steps = None )
411
+ optimization_line_figure (warm_experiments [["algorithm" , "seed" , "starting_N" , metric ]], metric = metric , title = "warm HV optimization N=50" , strict = strict , n_steps = n_steps * 2 )
392
412
else :
393
413
optimization_line_figure (ref_experiments [["algorithm" , "seed" , "starting_N" , metric ]], metric = metric , title = "ref. HV optimization N=512" , strict = strict , n_steps = n_steps )
394
414
if pareto_fig :
@@ -486,10 +506,11 @@ def load_viz_gfp_experiments(
486
506
if __name__ == "__main__" :
487
507
## LOAD AND VISUALIZE RFP EXPERIMENTS
488
508
# RFP base experiments
489
- # load_viz_rfp_experiments(pareto_fig=False )
509
+ load_viz_rfp_experiments (pareto_fig = True )
490
510
# ## LOAD AND VISUALIZE GFP EXPERIMENTS
491
511
load_viz_gfp_experiments ()
492
- # # RFP reference experiments # SUPPLEMENTARY TODO
493
- # load_viz_rfp_experiments(exp_name="foldx_rfp_lambo", starting_n=["512"], finished_only=False)
512
+ # RFP reference experiments
513
+ load_viz_rfp_experiments (exp_name = "foldx_rfp_lambo" , starting_n = ["512" ],
514
+ metric_names = list (METRIC_DICT .keys ()) + [LAMBO_REL_HYPER_VOLUME ], finished_only = False )
494
515
495
516
0 commit comments