Skip to content

Commit 7816cb1

Browse files
committed
fix: better styling for plots
1 parent cf856f8 commit 7816cb1

File tree

1 file changed

+116
-58
lines changed

1 file changed

+116
-58
lines changed

mallm/evaluation/plotting/plots.py

Lines changed: 116 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from tqdm import tqdm
1111

1212
# Set the style for beautiful plots
13-
plt.style.use('seaborn-v0_8-whitegrid')
14-
sns.set_palette("husl")
13+
plt.style.use('seaborn-v0_8-pastel')
14+
sns.set_palette("pastel")
1515

1616
# Define a beautiful color palette
1717
COLORS = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD', '#98D8C8']
@@ -24,6 +24,22 @@ def get_colors(n_colors):
2424
# Use a colormap for more colors
2525
return plt.cm.Set3(np.linspace(0, 1, n_colors))
2626

27+
28+
def get_consistent_color_mapping(options):
29+
"""Create consistent color mapping based on option names"""
30+
# Sort options to ensure consistent assignment
31+
unique_options = sorted(set(options))
32+
33+
# Generate enough colors
34+
if len(unique_options) <= len(COLORS):
35+
colors = COLORS[:len(unique_options)]
36+
else:
37+
colors = plt.cm.Set3(np.linspace(0, 1, len(unique_options)))
38+
39+
# Create mapping
40+
return dict(zip(unique_options, colors))
41+
42+
2743
def process_eval_file(file_path: str) -> pd.DataFrame:
2844
data = json.loads(Path(file_path).read_text())
2945
return pd.DataFrame(data)
@@ -77,7 +93,7 @@ def aggregate_data(
7793
return eval_df, stats_df
7894

7995

80-
def plot_turns_with_std(df: pd.DataFrame, input_path: str) -> None:
96+
def plot_turns_with_std(df: pd.DataFrame, input_path: str, global_color_mapping: dict = None) -> None:
8197
"""Create a beautiful violin plot for turns distribution"""
8298
# Filter out rows with missing or invalid turns data
8399
df = df.dropna(subset=['turns'])
@@ -87,42 +103,68 @@ def plot_turns_with_std(df: pd.DataFrame, input_path: str) -> None:
87103
print("Warning: No valid turns data found. Skipping turns plot.")
88104
return
89105

90-
# Create combination labels for better grouping
91-
df['condition'] = df['option'] + '_' + df['dataset']
92-
unique_labels = get_unique_labels_from_conditions(df['condition'].unique())
106+
# Create grouped data like other plots for consistent color assignment
107+
grouped_data = df.groupby(['option', 'dataset']).agg({
108+
'turns': list # Keep all turns values for violin plot
109+
}).reset_index()
110+
111+
# Create unique labels like other plots
112+
unique_labels = get_unique_labels(grouped_data)
113+
grouped_data['label'] = unique_labels
114+
115+
# Use global color mapping if provided, otherwise create local one
116+
if global_color_mapping is None:
117+
color_mapping = get_consistent_color_mapping(grouped_data['option'].unique())
118+
else:
119+
color_mapping = global_color_mapping
120+
121+
# Create color palette based on option order in grouped data
122+
colors = [color_mapping[option] for option in grouped_data['option']]
93123

94-
# Create a mapping from full condition to unique label
95-
condition_to_label = dict(zip(df['condition'].unique(), unique_labels))
96-
df['condition_label'] = df['condition'].map(condition_to_label)
124+
# Expand the grouped data back to individual rows for violin plot
125+
expanded_data = []
126+
for i, row in grouped_data.iterrows():
127+
for turn_value in row['turns']:
128+
expanded_data.append({
129+
'option': row['option'],
130+
'dataset': row['dataset'],
131+
'label': row['label'],
132+
'turns': turn_value
133+
})
134+
135+
plot_df = pd.DataFrame(expanded_data)
97136

98137
plt.figure(figsize=(10, 4))
99138

100-
# Create violin plot with individual points
101-
ax = sns.violinplot(data=df, x='condition_label', y='turns',
102-
hue='condition_label', palette=get_colors(len(df['condition_label'].unique())),
103-
inner=None, alpha=0.7, legend=False)
139+
# Create violin plot with the same label order as other plots
140+
ax = sns.violinplot(data=plot_df, x='label', y='turns',
141+
order=grouped_data['label'], palette=colors,
142+
inner=None, legend=False)
104143

105144
# Add individual points with jitter
106-
sns.stripplot(data=df, x='condition_label', y='turns',
107-
color='white', size=6, alpha=0.8, edgecolor='black', linewidth=0.5)
145+
sns.stripplot(data=plot_df, x='label', y='turns',
146+
order=grouped_data['label'], color='white', size=6,
147+
edgecolor='black', linewidth=0.5)
108148

109-
# Add red diamond mean markers that align correctly with violin plots
110-
unique_conditions = df['condition_label'].unique()
149+
# Set all plot elements above grid
150+
for collection in ax.collections:
151+
collection.set_zorder(4)
111152

112-
for i, condition in enumerate(unique_conditions):
113-
mean_val = df[df['condition_label'] == condition]['turns'].mean()
153+
# Add red diamond mean markers that align correctly with violin plots
154+
for i, label in enumerate(grouped_data['label']):
155+
mean_val = plot_df[plot_df['label'] == label]['turns'].mean()
114156
# Use red diamond markers positioned correctly
115157
ax.plot(i, mean_val, marker='D', color='red', markersize=8,
116-
markeredgecolor='white', markeredgewidth=1, zorder=10)
158+
markeredgecolor='white', markeredgewidth=1, zorder=5)
117159

118160
# Styling
119161
ax.set_xlabel('') # Remove automatic seaborn x-axis label
120-
ax.set_ylabel('Number of Turns', fontsize=14, fontweight='bold')
162+
ax.set_ylabel('Number of Turns', fontsize=14)
121163

122164
# Rotate labels and improve spacing
123165
plt.xticks(rotation=45, ha='right', fontsize=14)
124166
plt.yticks(fontsize=14)
125-
plt.grid(True, alpha=0.3)
167+
ax.grid(True, alpha=0.3, zorder=0)
126168
# Add a subtle background
127169
ax.set_facecolor('#fafafa')
128170

@@ -133,7 +175,7 @@ def plot_turns_with_std(df: pd.DataFrame, input_path: str) -> None:
133175
plt.close()
134176

135177

136-
def plot_clock_seconds_with_std(df: pd.DataFrame, input_path: str) -> None:
178+
def plot_clock_seconds_with_std(df: pd.DataFrame, input_path: str, global_color_mapping: dict = None) -> None:
137179
"""Create a beautiful horizontal lollipop chart for clock seconds"""
138180
grouped = (
139181
df.groupby(["option", "dataset"])["clockSeconds"]
@@ -161,30 +203,35 @@ def sort_key(row):
161203

162204
# Create discrete marker chart (no stems)
163205
y_pos = np.arange(len(grouped))
164-
colors = get_colors(len(grouped))
206+
207+
# Use global color mapping if provided, otherwise create local one
208+
if global_color_mapping is None:
209+
color_mapping = get_consistent_color_mapping(grouped['option'].unique())
210+
else:
211+
color_mapping = global_color_mapping
212+
colors = [color_mapping[option] for option in grouped['option']]
165213

166214
# Draw discrete circular markers only
167215
scatter = ax.scatter(grouped['mean'], y_pos,
168216
s=250, c=colors,
169-
alpha=0.9, edgecolors='white', linewidth=3, zorder=10)
217+
edgecolors='white', linewidth=3, zorder=5)
170218

171-
# Add subtle error bars
219+
# Add error bars
172220
ax.errorbar(grouped['mean'], y_pos, xerr=grouped['std'],
173-
fmt='none', color='gray', alpha=0.5, capsize=6, linewidth=2)
221+
fmt='none', color='gray', capsize=6, linewidth=2, zorder=4)
174222

175223
# Add value labels with better positioning to avoid circle overlap
176224
for i, (_, row) in enumerate(grouped.iterrows()):
177225
# Calculate offset to avoid overlap with circle (larger offset)
178226
offset = max(row['std'] + max(grouped['mean']) * 0.08, max(grouped['mean']) * 0.05)
179227
ax.text(row['mean'] + offset, i,
180228
f'{row["mean"]:.1f}s',
181-
va='center', ha='left', fontweight='bold', fontsize=14,
182-
bbox=dict(boxstyle='round,pad=0.15', facecolor='white', alpha=0.8, edgecolor='none'))
229+
va='center', ha='left', fontsize=14, zorder=6)
183230

184231
# Styling
185232
ax.set_yticks(y_pos)
186233
ax.set_yticklabels(grouped['label'], fontsize=14)
187-
ax.set_xlabel('Execution Time (seconds)', fontsize=14, fontweight='bold')
234+
ax.set_xlabel('Execution Time (seconds)', fontsize=14)
188235

189236
# Set x-axis limits with proper margins for labels
190237
max_val = max(grouped['mean'] + grouped['std'])
@@ -196,7 +243,7 @@ def sort_key(row):
196243
ax.spines['left'].set_color('#cccccc')
197244
ax.spines['bottom'].set_color('#cccccc')
198245
ax.tick_params(axis='x', labelsize=14)
199-
ax.grid(True, alpha=0.3, axis='x', linestyle='-', linewidth=0.5)
246+
ax.grid(True, alpha=0.3, axis='x', zorder=0)
200247
ax.set_facecolor('#fafafa')
201248

202249
plt.tight_layout()
@@ -205,7 +252,7 @@ def sort_key(row):
205252
plt.close()
206253

207254

208-
def plot_decision_success_with_std(df: pd.DataFrame, input_path: str) -> None:
255+
def plot_decision_success_with_std(df: pd.DataFrame, input_path: str, global_color_mapping: dict = None) -> None:
209256
"""Create a beautiful horizontal bar chart for decision success rates"""
210257
if "decisionSuccess" not in df.columns:
211258
print(
@@ -234,18 +281,22 @@ def plot_decision_success_with_std(df: pd.DataFrame, input_path: str) -> None:
234281

235282
fig, ax = plt.subplots(figsize=(10, 3))
236283

237-
# Create gradient colors based on success rate
238-
colors = plt.cm.RdYlGn(grouped['mean'])
284+
# Use global color mapping if provided, otherwise create local one
285+
if global_color_mapping is None:
286+
color_mapping = get_consistent_color_mapping(grouped['option'].unique())
287+
else:
288+
color_mapping = global_color_mapping
289+
colors = [color_mapping[option] for option in grouped['option']]
239290

240291
# Create horizontal bars
241292
bars = ax.barh(range(len(grouped)), grouped['mean'],
242-
color=colors, alpha=0.8, height=0.6)
293+
color=colors, height=0.6, zorder=3)
243294

244295
# Add percentage labels on bars
245296
for i, (_, row) in enumerate(grouped.iterrows()):
246297
percentage = row['mean'] * 100
247298
ax.text(row['mean'] + 0.02, i, f'{percentage:.1f}%',
248-
va='center', ha='left', fontweight='bold', fontsize=14)
299+
va='center', ha='left', fontsize=14, zorder=6)
249300

250301
# Add a subtle pattern to bars
251302
for bar, rate in zip(bars, grouped['mean']):
@@ -255,7 +306,7 @@ def plot_decision_success_with_std(df: pd.DataFrame, input_path: str) -> None:
255306
# Styling
256307
ax.set_yticks(range(len(grouped)))
257308
ax.set_yticklabels(grouped['label'], fontsize=14)
258-
ax.set_xlabel('Decision Success Rate', fontsize=14, fontweight='bold')
309+
ax.set_xlabel('Decision Success Rate', fontsize=14)
259310
ax.set_xlim(0, 1.1)
260311

261312
# Add percentage ticks
@@ -265,7 +316,7 @@ def plot_decision_success_with_std(df: pd.DataFrame, input_path: str) -> None:
265316
# Remove spines and add grid
266317
ax.spines['top'].set_visible(False)
267318
ax.spines['right'].set_visible(False)
268-
ax.grid(True, alpha=0.3, axis='x')
319+
ax.grid(True, alpha=0.3, axis='x', zorder=0)
269320
ax.set_facecolor('#fafafa')
270321

271322
plt.tight_layout()
@@ -355,7 +406,7 @@ def get_unique_labels_from_conditions(conditions) -> list[str]:
355406
return unique_labels
356407

357408

358-
def plot_score_distributions_with_std(df: pd.DataFrame, input_path: str) -> None:
409+
def plot_score_distributions_with_std(df: pd.DataFrame, input_path: str, global_color_mapping: dict = None) -> None:
359410
"""Create beautiful enhanced bar charts for score distributions"""
360411
print("Shape of stats_df:", df.shape)
361412
print("Columns in stats_df:", df.columns)
@@ -409,16 +460,21 @@ def sort_key(row):
409460
index=False,
410461
)
411462

412-
# Create beautiful bar plot with gradient colors
463+
# Create beautiful bar plot with consistent colors
413464
x = np.arange(len(score_data))
414-
colors = plt.cm.viridis(np.linspace(0, 1, len(score_data)))
465+
466+
# Use global color mapping if provided, otherwise create local one
467+
if global_color_mapping is None:
468+
color_mapping = get_consistent_color_mapping(score_data['option'].unique())
469+
else:
470+
color_mapping = global_color_mapping
471+
colors = [color_mapping[option] for option in score_data['option']]
415472

416473
bars = ax.bar(x, score_data["mean"],
417474
yerr=score_data["std"],
418475
capsize=8,
419-
color=colors, alpha=0.8,
420-
edgecolor='white', linewidth=2,
421-
width=0.6) # Slightly narrower bars for more discrete look
476+
color=colors,
477+
width=0.6, zorder=3) # Slightly narrower bars for more discrete look
422478

423479
# Calculate proper y-axis limits
424480
max_height = max(score_data["mean"] + score_data["std"])
@@ -430,17 +486,10 @@ def sort_key(row):
430486
height = mean_val + std_val
431487
ax.text(bar.get_x() + bar.get_width()/2., height + y_range * 0.05,
432488
f'{mean_val:.3f}', ha='center', va='bottom',
433-
fontweight='bold', fontsize=14,
434-
bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.8))
435-
436-
# Add gradient effect to bars
437-
gradient = np.linspace(0, 1, 256).reshape(256, -1)
438-
ax.imshow(gradient, extent=[bar.get_x(), bar.get_x() + bar.get_width(),
439-
0, bar.get_height()],
440-
aspect='auto', alpha=0.3, cmap='viridis')
489+
fontsize=14, zorder=6)
441490

442491
# Styling
443-
ax.set_ylabel('Average Score', fontsize=14, fontweight='bold')
492+
ax.set_ylabel('Average Score', fontsize=14)
444493

445494
# Set x-axis with proper spacing and labels
446495
ax.set_xticks(x)
@@ -452,7 +501,7 @@ def sort_key(row):
452501

453502
# Add grid and styling
454503
ax.tick_params(axis='y', labelsize=14)
455-
ax.grid(True, alpha=0.3, axis='y')
504+
ax.grid(True, alpha=0.3, axis='y', zorder=0)
456505
ax.set_facecolor('#fafafa')
457506
ax.spines['top'].set_visible(False)
458507
ax.spines['right'].set_visible(False)
@@ -481,22 +530,31 @@ def create_plots_for_path(input_dir_path: str, output_dir_path: str) -> None:
481530
print("First few rows of eval_df:")
482531
print(eval_df.head())
483532

533+
# Create global color mapping for all options across all plots
534+
all_options = set()
535+
if not eval_df.empty and 'option' in eval_df.columns:
536+
all_options.update(eval_df['option'].unique())
537+
if not stats_df.empty and 'option' in stats_df.columns:
538+
all_options.update(stats_df['option'].unique())
539+
540+
global_color_mapping = get_consistent_color_mapping(list(all_options))
541+
484542
available_columns = eval_df.columns
485543

486544
if "turns" in available_columns:
487-
plot_turns_with_std(eval_df, output_dir_path)
545+
plot_turns_with_std(eval_df, output_dir_path, global_color_mapping)
488546
else:
489547
print("Warning: 'turns' column not found. Skipping turns plot.")
490548

491549
if "clockSeconds" in available_columns:
492-
plot_clock_seconds_with_std(eval_df, output_dir_path)
550+
plot_clock_seconds_with_std(eval_df, output_dir_path, global_color_mapping)
493551
else:
494552
print("Warning: 'clockSeconds' column not found. Skipping clock seconds plot.")
495553

496-
plot_decision_success_with_std(eval_df, output_dir_path)
554+
plot_decision_success_with_std(eval_df, output_dir_path, global_color_mapping)
497555

498556
if not stats_df.empty:
499-
plot_score_distributions_with_std(stats_df, output_dir_path)
557+
plot_score_distributions_with_std(stats_df, output_dir_path, global_color_mapping)
500558
else:
501559
print("Warning: No stats data available. Skipping score distributions plot.")
502560

0 commit comments

Comments
 (0)