1010from tqdm import tqdm
1111
1212# Set the style for beautiful plots
13- plt .style .use ('seaborn-v0_8-whitegrid ' )
14- sns .set_palette ("husl " )
13+ plt .style .use ('seaborn-v0_8-pastel ' )
14+ sns .set_palette ("pastel " )
1515
1616# Define a beautiful color palette
1717COLORS = ['#FF6B6B' , '#4ECDC4' , '#45B7D1' , '#96CEB4' , '#FFEAA7' , '#DDA0DD' , '#98D8C8' ]
@@ -24,6 +24,22 @@ def get_colors(n_colors):
2424 # Use a colormap for more colors
2525 return plt .cm .Set3 (np .linspace (0 , 1 , n_colors ))
2626
27+
28+ def get_consistent_color_mapping (options ):
29+ """Create consistent color mapping based on option names"""
30+ # Sort options to ensure consistent assignment
31+ unique_options = sorted (set (options ))
32+
33+ # Generate enough colors
34+ if len (unique_options ) <= len (COLORS ):
35+ colors = COLORS [:len (unique_options )]
36+ else :
37+ colors = plt .cm .Set3 (np .linspace (0 , 1 , len (unique_options )))
38+
39+ # Create mapping
40+ return dict (zip (unique_options , colors ))
41+
42+
2743def process_eval_file (file_path : str ) -> pd .DataFrame :
2844 data = json .loads (Path (file_path ).read_text ())
2945 return pd .DataFrame (data )
@@ -77,7 +93,7 @@ def aggregate_data(
7793 return eval_df , stats_df
7894
7995
80- def plot_turns_with_std (df : pd .DataFrame , input_path : str ) -> None :
96+ def plot_turns_with_std (df : pd .DataFrame , input_path : str , global_color_mapping : dict = None ) -> None :
8197 """Create a beautiful violin plot for turns distribution"""
8298 # Filter out rows with missing or invalid turns data
8399 df = df .dropna (subset = ['turns' ])
@@ -87,42 +103,68 @@ def plot_turns_with_std(df: pd.DataFrame, input_path: str) -> None:
87103 print ("Warning: No valid turns data found. Skipping turns plot." )
88104 return
89105
90- # Create combination labels for better grouping
91- df ['condition' ] = df ['option' ] + '_' + df ['dataset' ]
92- unique_labels = get_unique_labels_from_conditions (df ['condition' ].unique ())
106+ # Create grouped data like other plots for consistent color assignment
107+ grouped_data = df .groupby (['option' , 'dataset' ]).agg ({
108+ 'turns' : list # Keep all turns values for violin plot
109+ }).reset_index ()
110+
111+ # Create unique labels like other plots
112+ unique_labels = get_unique_labels (grouped_data )
113+ grouped_data ['label' ] = unique_labels
114+
115+ # Use global color mapping if provided, otherwise create local one
116+ if global_color_mapping is None :
117+ color_mapping = get_consistent_color_mapping (grouped_data ['option' ].unique ())
118+ else :
119+ color_mapping = global_color_mapping
120+
121+ # Create color palette based on option order in grouped data
122+ colors = [color_mapping [option ] for option in grouped_data ['option' ]]
93123
94- # Create a mapping from full condition to unique label
95- condition_to_label = dict (zip (df ['condition' ].unique (), unique_labels ))
96- df ['condition_label' ] = df ['condition' ].map (condition_to_label )
124+ # Expand the grouped data back to individual rows for violin plot
125+ expanded_data = []
126+ for i , row in grouped_data .iterrows ():
127+ for turn_value in row ['turns' ]:
128+ expanded_data .append ({
129+ 'option' : row ['option' ],
130+ 'dataset' : row ['dataset' ],
131+ 'label' : row ['label' ],
132+ 'turns' : turn_value
133+ })
134+
135+ plot_df = pd .DataFrame (expanded_data )
97136
98137 plt .figure (figsize = (10 , 4 ))
99138
100- # Create violin plot with individual points
101- ax = sns .violinplot (data = df , x = 'condition_label ' , y = 'turns' ,
102- hue = 'condition_label' , palette = get_colors ( len ( df [ 'condition_label' ]. unique ())) ,
103- inner = None , alpha = 0.7 , legend = False )
139+ # Create violin plot with the same label order as other plots
140+ ax = sns .violinplot (data = plot_df , x = 'label ' , y = 'turns' ,
141+ order = grouped_data [ 'label' ] , palette = colors ,
142+ inner = None , legend = False )
104143
105144 # Add individual points with jitter
106- sns .stripplot (data = df , x = 'condition_label' , y = 'turns' ,
107- color = 'white' , size = 6 , alpha = 0.8 , edgecolor = 'black' , linewidth = 0.5 )
145+ sns .stripplot (data = plot_df , x = 'label' , y = 'turns' ,
146+ order = grouped_data ['label' ], color = 'white' , size = 6 ,
147+ edgecolor = 'black' , linewidth = 0.5 )
108148
109- # Add red diamond mean markers that align correctly with violin plots
110- unique_conditions = df ['condition_label' ].unique ()
149+ # Set all plot elements above grid
150+ for collection in ax .collections :
151+ collection .set_zorder (4 )
111152
112- for i , condition in enumerate (unique_conditions ):
113- mean_val = df [df ['condition_label' ] == condition ]['turns' ].mean ()
153+ # Add red diamond mean markers that align correctly with violin plots
154+ for i , label in enumerate (grouped_data ['label' ]):
155+ mean_val = plot_df [plot_df ['label' ] == label ]['turns' ].mean ()
114156 # Use red diamond markers positioned correctly
115157 ax .plot (i , mean_val , marker = 'D' , color = 'red' , markersize = 8 ,
116- markeredgecolor = 'white' , markeredgewidth = 1 , zorder = 10 )
158+ markeredgecolor = 'white' , markeredgewidth = 1 , zorder = 5 )
117159
118160 # Styling
119161 ax .set_xlabel ('' ) # Remove automatic seaborn x-axis label
120- ax .set_ylabel ('Number of Turns' , fontsize = 14 , fontweight = 'bold' )
162+ ax .set_ylabel ('Number of Turns' , fontsize = 14 )
121163
122164 # Rotate labels and improve spacing
123165 plt .xticks (rotation = 45 , ha = 'right' , fontsize = 14 )
124166 plt .yticks (fontsize = 14 )
125- plt .grid (True , alpha = 0.3 )
167+ ax .grid (True , alpha = 0.3 , zorder = 0 )
126168 # Add a subtle background
127169 ax .set_facecolor ('#fafafa' )
128170
@@ -133,7 +175,7 @@ def plot_turns_with_std(df: pd.DataFrame, input_path: str) -> None:
133175 plt .close ()
134176
135177
136- def plot_clock_seconds_with_std (df : pd .DataFrame , input_path : str ) -> None :
178+ def plot_clock_seconds_with_std (df : pd .DataFrame , input_path : str , global_color_mapping : dict = None ) -> None :
137179 """Create a beautiful horizontal lollipop chart for clock seconds"""
138180 grouped = (
139181 df .groupby (["option" , "dataset" ])["clockSeconds" ]
@@ -161,30 +203,35 @@ def sort_key(row):
161203
162204 # Create discrete marker chart (no stems)
163205 y_pos = np .arange (len (grouped ))
164- colors = get_colors (len (grouped ))
206+
207+ # Use global color mapping if provided, otherwise create local one
208+ if global_color_mapping is None :
209+ color_mapping = get_consistent_color_mapping (grouped ['option' ].unique ())
210+ else :
211+ color_mapping = global_color_mapping
212+ colors = [color_mapping [option ] for option in grouped ['option' ]]
165213
166214 # Draw discrete circular markers only
167215 scatter = ax .scatter (grouped ['mean' ], y_pos ,
168216 s = 250 , c = colors ,
169- alpha = 0.9 , edgecolors = 'white' , linewidth = 3 , zorder = 10 )
217+ edgecolors = 'white' , linewidth = 3 , zorder = 5 )
170218
171- # Add subtle error bars
219+ # Add error bars
172220 ax .errorbar (grouped ['mean' ], y_pos , xerr = grouped ['std' ],
173- fmt = 'none' , color = 'gray' , alpha = 0.5 , capsize = 6 , linewidth = 2 )
221+ fmt = 'none' , color = 'gray' , capsize = 6 , linewidth = 2 , zorder = 4 )
174222
175223 # Add value labels with better positioning to avoid circle overlap
176224 for i , (_ , row ) in enumerate (grouped .iterrows ()):
177225 # Calculate offset to avoid overlap with circle (larger offset)
178226 offset = max (row ['std' ] + max (grouped ['mean' ]) * 0.08 , max (grouped ['mean' ]) * 0.05 )
179227 ax .text (row ['mean' ] + offset , i ,
180228 f'{ row ["mean" ]:.1f} s' ,
181- va = 'center' , ha = 'left' , fontweight = 'bold' , fontsize = 14 ,
182- bbox = dict (boxstyle = 'round,pad=0.15' , facecolor = 'white' , alpha = 0.8 , edgecolor = 'none' ))
229+ va = 'center' , ha = 'left' , fontsize = 14 , zorder = 6 )
183230
184231 # Styling
185232 ax .set_yticks (y_pos )
186233 ax .set_yticklabels (grouped ['label' ], fontsize = 14 )
187- ax .set_xlabel ('Execution Time (seconds)' , fontsize = 14 , fontweight = 'bold' )
234+ ax .set_xlabel ('Execution Time (seconds)' , fontsize = 14 )
188235
189236 # Set x-axis limits with proper margins for labels
190237 max_val = max (grouped ['mean' ] + grouped ['std' ])
@@ -196,7 +243,7 @@ def sort_key(row):
196243 ax .spines ['left' ].set_color ('#cccccc' )
197244 ax .spines ['bottom' ].set_color ('#cccccc' )
198245 ax .tick_params (axis = 'x' , labelsize = 14 )
199- ax .grid (True , alpha = 0.3 , axis = 'x' , linestyle = '-' , linewidth = 0.5 )
246+ ax .grid (True , alpha = 0.3 , axis = 'x' , zorder = 0 )
200247 ax .set_facecolor ('#fafafa' )
201248
202249 plt .tight_layout ()
@@ -205,7 +252,7 @@ def sort_key(row):
205252 plt .close ()
206253
207254
208- def plot_decision_success_with_std (df : pd .DataFrame , input_path : str ) -> None :
255+ def plot_decision_success_with_std (df : pd .DataFrame , input_path : str , global_color_mapping : dict = None ) -> None :
209256 """Create a beautiful horizontal bar chart for decision success rates"""
210257 if "decisionSuccess" not in df .columns :
211258 print (
@@ -234,18 +281,22 @@ def plot_decision_success_with_std(df: pd.DataFrame, input_path: str) -> None:
234281
235282 fig , ax = plt .subplots (figsize = (10 , 3 ))
236283
237- # Create gradient colors based on success rate
238- colors = plt .cm .RdYlGn (grouped ['mean' ])
284+ # Use global color mapping if provided, otherwise create local one
285+ if global_color_mapping is None :
286+ color_mapping = get_consistent_color_mapping (grouped ['option' ].unique ())
287+ else :
288+ color_mapping = global_color_mapping
289+ colors = [color_mapping [option ] for option in grouped ['option' ]]
239290
240291 # Create horizontal bars
241292 bars = ax .barh (range (len (grouped )), grouped ['mean' ],
242- color = colors , alpha = 0.8 , height = 0.6 )
293+ color = colors , height = 0.6 , zorder = 3 )
243294
244295 # Add percentage labels on bars
245296 for i , (_ , row ) in enumerate (grouped .iterrows ()):
246297 percentage = row ['mean' ] * 100
247298 ax .text (row ['mean' ] + 0.02 , i , f'{ percentage :.1f} %' ,
248- va = 'center' , ha = 'left' , fontweight = 'bold' , fontsize = 14 )
299+ va = 'center' , ha = 'left' , fontsize = 14 , zorder = 6 )
249300
250301 # Add a subtle pattern to bars
251302 for bar , rate in zip (bars , grouped ['mean' ]):
@@ -255,7 +306,7 @@ def plot_decision_success_with_std(df: pd.DataFrame, input_path: str) -> None:
255306 # Styling
256307 ax .set_yticks (range (len (grouped )))
257308 ax .set_yticklabels (grouped ['label' ], fontsize = 14 )
258- ax .set_xlabel ('Decision Success Rate' , fontsize = 14 , fontweight = 'bold' )
309+ ax .set_xlabel ('Decision Success Rate' , fontsize = 14 )
259310 ax .set_xlim (0 , 1.1 )
260311
261312 # Add percentage ticks
@@ -265,7 +316,7 @@ def plot_decision_success_with_std(df: pd.DataFrame, input_path: str) -> None:
265316 # Remove spines and add grid
266317 ax .spines ['top' ].set_visible (False )
267318 ax .spines ['right' ].set_visible (False )
268- ax .grid (True , alpha = 0.3 , axis = 'x' )
319+ ax .grid (True , alpha = 0.3 , axis = 'x' , zorder = 0 )
269320 ax .set_facecolor ('#fafafa' )
270321
271322 plt .tight_layout ()
@@ -355,7 +406,7 @@ def get_unique_labels_from_conditions(conditions) -> list[str]:
355406 return unique_labels
356407
357408
358- def plot_score_distributions_with_std (df : pd .DataFrame , input_path : str ) -> None :
409+ def plot_score_distributions_with_std (df : pd .DataFrame , input_path : str , global_color_mapping : dict = None ) -> None :
359410 """Create beautiful enhanced bar charts for score distributions"""
360411 print ("Shape of stats_df:" , df .shape )
361412 print ("Columns in stats_df:" , df .columns )
@@ -409,16 +460,21 @@ def sort_key(row):
409460 index = False ,
410461 )
411462
412- # Create beautiful bar plot with gradient colors
463+ # Create beautiful bar plot with consistent colors
413464 x = np .arange (len (score_data ))
414- colors = plt .cm .viridis (np .linspace (0 , 1 , len (score_data )))
465+
466+ # Use global color mapping if provided, otherwise create local one
467+ if global_color_mapping is None :
468+ color_mapping = get_consistent_color_mapping (score_data ['option' ].unique ())
469+ else :
470+ color_mapping = global_color_mapping
471+ colors = [color_mapping [option ] for option in score_data ['option' ]]
415472
416473 bars = ax .bar (x , score_data ["mean" ],
417474 yerr = score_data ["std" ],
418475 capsize = 8 ,
419- color = colors , alpha = 0.8 ,
420- edgecolor = 'white' , linewidth = 2 ,
421- width = 0.6 ) # Slightly narrower bars for more discrete look
476+ color = colors ,
477+ width = 0.6 , zorder = 3 ) # Slightly narrower bars for more discrete look
422478
423479 # Calculate proper y-axis limits
424480 max_height = max (score_data ["mean" ] + score_data ["std" ])
@@ -430,17 +486,10 @@ def sort_key(row):
430486 height = mean_val + std_val
431487 ax .text (bar .get_x () + bar .get_width ()/ 2. , height + y_range * 0.05 ,
432488 f'{ mean_val :.3f} ' , ha = 'center' , va = 'bottom' ,
433- fontweight = 'bold' , fontsize = 14 ,
434- bbox = dict (boxstyle = 'round,pad=0.2' , facecolor = 'white' , alpha = 0.8 ))
435-
436- # Add gradient effect to bars
437- gradient = np .linspace (0 , 1 , 256 ).reshape (256 , - 1 )
438- ax .imshow (gradient , extent = [bar .get_x (), bar .get_x () + bar .get_width (),
439- 0 , bar .get_height ()],
440- aspect = 'auto' , alpha = 0.3 , cmap = 'viridis' )
489+ fontsize = 14 , zorder = 6 )
441490
442491 # Styling
443- ax .set_ylabel ('Average Score' , fontsize = 14 , fontweight = 'bold' )
492+ ax .set_ylabel ('Average Score' , fontsize = 14 )
444493
445494 # Set x-axis with proper spacing and labels
446495 ax .set_xticks (x )
@@ -452,7 +501,7 @@ def sort_key(row):
452501
453502 # Add grid and styling
454503 ax .tick_params (axis = 'y' , labelsize = 14 )
455- ax .grid (True , alpha = 0.3 , axis = 'y' )
504+ ax .grid (True , alpha = 0.3 , axis = 'y' , zorder = 0 )
456505 ax .set_facecolor ('#fafafa' )
457506 ax .spines ['top' ].set_visible (False )
458507 ax .spines ['right' ].set_visible (False )
@@ -481,22 +530,31 @@ def create_plots_for_path(input_dir_path: str, output_dir_path: str) -> None:
481530 print ("First few rows of eval_df:" )
482531 print (eval_df .head ())
483532
533+ # Create global color mapping for all options across all plots
534+ all_options = set ()
535+ if not eval_df .empty and 'option' in eval_df .columns :
536+ all_options .update (eval_df ['option' ].unique ())
537+ if not stats_df .empty and 'option' in stats_df .columns :
538+ all_options .update (stats_df ['option' ].unique ())
539+
540+ global_color_mapping = get_consistent_color_mapping (list (all_options ))
541+
484542 available_columns = eval_df .columns
485543
486544 if "turns" in available_columns :
487- plot_turns_with_std (eval_df , output_dir_path )
545+ plot_turns_with_std (eval_df , output_dir_path , global_color_mapping )
488546 else :
489547 print ("Warning: 'turns' column not found. Skipping turns plot." )
490548
491549 if "clockSeconds" in available_columns :
492- plot_clock_seconds_with_std (eval_df , output_dir_path )
550+ plot_clock_seconds_with_std (eval_df , output_dir_path , global_color_mapping )
493551 else :
494552 print ("Warning: 'clockSeconds' column not found. Skipping clock seconds plot." )
495553
496- plot_decision_success_with_std (eval_df , output_dir_path )
554+ plot_decision_success_with_std (eval_df , output_dir_path , global_color_mapping )
497555
498556 if not stats_df .empty :
499- plot_score_distributions_with_std (stats_df , output_dir_path )
557+ plot_score_distributions_with_std (stats_df , output_dir_path , global_color_mapping )
500558 else :
501559 print ("Warning: No stats data available. Skipping score distributions plot." )
502560
0 commit comments