Skip to content

Commit bbf915a

Browse files
committed
Fix results
1 parent aa12cf2 commit bbf915a

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

delphi/log/result_analysis.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,17 @@ def get_agg_metrics(
225225
return pd.DataFrame(processed_rows)
226226

227227

228+
def add_latent_f1(latent_df: pd.DataFrame) -> pd.DataFrame:
229+
f1s = (
230+
latent_df.groupby(["module", "latent_idx"])
231+
.apply(
232+
lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"]
233+
)
234+
.reset_index(name="f1_score") # <- naive (un-weighted) F1
235+
)
236+
return latent_df.merge(f1s, on=["module", "latent_idx"])
237+
238+
228239
def log_results(
229240
scores_path: Path, viz_path: Path, modules: list[str], scorer_names: list[str]
230241
):

0 commit comments

Comments
 (0)