-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvalidate_shape_cluster.py
More file actions
185 lines (146 loc) · 6.91 KB
/
validate_shape_cluster.py
File metadata and controls
185 lines (146 loc) · 6.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import json
import numpy as np
import re
from statsmodels.tsa.api import STL
import pandas as pd
import warnings
from scipy.spatial.distance import euclidean
from fastdtw import fastdtw
# Suppress specific warnings from statsmodels
warnings.filterwarnings("ignore", message="X does not have valid frequency.")
warnings.filterwarnings("ignore", message="An input array is constant; the correlation coefficient is not defined.")
def euclidean_1d(x, y):
"""A wrapper for scipy's euclidean distance that ensures 1-D input."""
return euclidean(np.ravel(x), np.ravel(y))
def _get_shape_cluster_questions():
"""Generator to yield shape-cluster-inductive questions and their QIDs from the dataset."""
dataset_path = "TSReasoningDatasets/datasets/ChatTS/evaluation/dataset_b.json"
with open(dataset_path, 'r') as f:
data = json.load(f)
for i, q in enumerate(data):
if "shape-cluster-inductive" in q.get('ability_types', []):
yield i, q
def extract_shape_cluster_ground_truth(attributes: list) -> list | None:
"""Extracts shape-cluster-inductive ground truth attributes."""
for attr in attributes:
# The ground truth is often in an attribute with a 'cols' key.
if isinstance(attr, dict) and 'cols' in attr and len(attr['cols']) > 1:
return attr['cols']
return None
def _get_source_metric_from_question(question_text: str, all_cols: list[str]) -> str | None:
"""Extracts the source metric from the question text."""
# A common pattern is "similar with <metric>" or "similar to <metric>"
# Updated to handle "similar trend characteristics with"
match = re.search(r"similar (?:trend characteristics with|with|to) (.*?)\?", question_text, re.IGNORECASE)
if not match:
return None
potential_metric_name = match.group(1).strip()
# Find the closest match in all_cols
for col in all_cols:
if potential_metric_name in col:
return col
return None
def find_period(signal, min_period=10, amplitude_threshold=0.1):
"""Finds the dominant period in a signal using autocorrelation."""
if np.std(signal) < 1e-6: # Constant signal
return None
acf = np.correlate(signal - np.mean(signal), signal - np.mean(signal), 'full')[-len(signal):]
# Find peaks in autocorrelation
inflection = np.diff(np.sign(np.diff(acf)))
peaks = (inflection < 0).nonzero()[0] + 1
if len(peaks) == 0:
return None
max_acf_value = acf[peaks].max()
valid_peaks = [p for p in peaks if acf[p] >= amplitude_threshold * max_acf_value]
valid_peaks = [p for p in valid_peaks if p >= min_period]
if len(valid_peaks) == 0:
return None
return valid_peaks[np.argmax(acf[valid_peaks])]
def compute_shape_clusters(timeseries: list, all_cols: list, source_col: str) -> list:
"""
Computes clusters of time series based on trend similarity using STL decomposition and Pearson correlation.
"""
if source_col not in all_cols:
return []
source_idx = all_cols.index(source_col)
# Align time series length
min_len = min(len(ts) for ts in timeseries)
aligned_series = [np.array(ts)[:min_len] for ts in timeseries]
# Extract trends using STL decomposition
trends = []
for ts in aligned_series:
# If signal is constant, trend is the signal itself
if np.std(ts) < 1e-6:
trends.append(ts)
continue
period = find_period(ts)
# Choose a robust default period if none is found or if it's invalid
stl_period = period if period is not None and period > 1 and period < len(ts) // 2 else max(2, min(20, len(ts) // 2))
if len(ts) < 2 * stl_period: # STL requires len(ts) >= 2 * period
stl_period = len(ts) // 2
# Fallback for very short series where decomposition is not possible
if stl_period < 2:
trends.append(ts) # Cannot decompose, use original series as trend
continue
stl = STL(ts, period=stl_period, robust=True)
res = stl.fit()
trends.append(res.trend)
source_trend = trends[source_idx]
similar_series = [source_col]
for i, trend in enumerate(trends):
if i == source_idx:
continue
# Handle constant trends, which can cause issues with corrcoef
if np.std(source_trend) < 1e-6 or np.std(trend) < 1e-6:
# If both are constant and close, they are similar. Otherwise, not correlated.
corr = 1.0 if np.std(source_trend) < 1e-6 and np.std(trend) < 1e-6 and np.abs(np.mean(source_trend) - np.mean(trend)) < 1e-6 else 0.0
else:
corr = float(np.corrcoef(source_trend, trend)[0, 1])
# Use a high correlation threshold to define "similar shape"
if corr > 0.8:
similar_series.append(all_cols[i])
return similar_series
def validate_shape_cluster_inductive():
"""
Validates the 'shape-cluster-inductive' ability by comparing predicted
clusters with ground truth using the F1 score.
"""
f1_scores = []
wrong_cases = []
for qid, question in _get_shape_cluster_questions():
ground_truth_cluster = extract_shape_cluster_ground_truth(question['attributes'])
source_metric = _get_source_metric_from_question(question['question'], question['cols'])
if not ground_truth_cluster or not source_metric:
continue
predicted_cluster = compute_shape_clusters(question['timeseries'], question['cols'], source_metric)
# Calculate F1 score
gt_set = set(ground_truth_cluster)
pred_set = set(predicted_cluster)
tp = len(gt_set & pred_set)
fp = len(pred_set - gt_set)
fn = len(gt_set - pred_set)
denominator = 2 * tp + fp + fn
f1_score = (2 * tp / denominator) if denominator > 0 else 1.0 if not gt_set and not pred_set else 0.0
f1_scores.append(f1_score)
if f1_score < 1.0:
wrong_cases.append({
'qid': qid,
'f1_score': f1_score,
'predicted': sorted(list(pred_set)),
'ground_truth': sorted(list(gt_set)),
'source_metric': source_metric
})
print(f"\n--- Shape Cluster Inductive Validation Results ---")
if f1_scores:
print(f"Found {len(f1_scores)} questions with 'shape-cluster-inductive' ability.")
print(f"Average F1 Score: {np.mean(f1_scores):.2%}")
else:
print("No questions with 'shape-cluster-inductive' ability found.")
if wrong_cases:
print("\n--- Cases with F1 Score < 1.0 ---")
for case in wrong_cases:
print(f" QID: {case['qid']}, F1: {case['f1_score']:.2f}, Source: '{case['source_metric']}'")
print(f" - Predicted: {case['predicted']}")
print(f" - Ground Truth: {case['ground_truth']}")
if __name__ == "__main__":
validate_shape_cluster_inductive()