From e78c2b17697bf3585507fba8ef1e59e9db1402f8 Mon Sep 17 00:00:00 2001 From: Mandlin Sarah Date: Wed, 4 Sep 2024 03:21:32 -0700 Subject: [PATCH] Improve robustness and handling of missing values in post_process_pauserate.py --- .../expressivity/evaluate/post_process_pauserate.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/seamless_communication/cli/expressivity/evaluate/post_process_pauserate.py b/src/seamless_communication/cli/expressivity/evaluate/post_process_pauserate.py index f421e9b5..670959c7 100644 --- a/src/seamless_communication/cli/expressivity/evaluate/post_process_pauserate.py +++ b/src/seamless_communication/cli/expressivity/evaluate/post_process_pauserate.py @@ -26,12 +26,12 @@ def get_pause(pause_data_tsv: str) -> Dict[str, float]: "wmean_joint_score", ]: metrics[score_name] = ( - utt_pause_align_data[f"{score_name}"] * pause_duration_weight + utt_pause_align_data[f"{score_name}"].fillna(0) * pause_duration_weight ).sum() return metrics -def get_rate(target_speech_tsv: str, source_speech_tsv: str) -> float: +def get_rate(target_speech_tsv: str, source_speech_tsv: str) -> float { speech_unit = "syllable" target_speech_df = pd.read_csv( @@ -44,5 +44,12 @@ def get_rate(target_speech_tsv: str, source_speech_tsv: str) -> float: # using "syllable" speech unit for rate computation src_speech_rate = source_speech_df[f"speech_rate_{speech_unit}"].to_numpy() tgt_speech_rate = target_speech_df[f"speech_rate_{speech_unit}"].to_numpy() + + # Handle missing or NaN values to avoid crashing + src_speech_rate = src_speech_rate[~pd.isnull(src_speech_rate)] + tgt_speech_rate = tgt_speech_rate[~pd.isnull(tgt_speech_rate)] + src_tgt_spearman = scipy.stats.spearmanr(src_speech_rate, tgt_speech_rate) return src_tgt_spearman.correlation # type: ignore[no-any-return] +} +