-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgenerate_artifacts.py
More file actions
215 lines (177 loc) · 10.3 KB
/
generate_artifacts.py
File metadata and controls
215 lines (177 loc) · 10.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Reconstructs the specified version of the HealthChat-11K dataset and generates
two review CSVs - one for the entire dataset and one for sycophancy exploration.
Note: This script uses a hardcoded version string to download the corresponding
master annotation file from the Hugging Face Hub. It then generates the final
dataset and all review artifacts for that specific version, saving them into
a dedicated, version-specific output directory.
"""
import json
import csv
import os
from collections import defaultdict
from datasets import load_dataset
from datetime import datetime
# --- Configuration ---
# Accepted versions ["1.0.0", "2.0.0", "2.1.0"]
# 1.0.0 corresponds to annotations using Gemini 1.5 Pro and an older taxonomy (V5)
# 2.0.0 uses Gemini 2.5 Pro for annotations and a newer taxonomy (V6)
# which corresponds to the EMNLP 2025 Findings paper here: https://arxiv.org/abs/2506.21532
# 2.1.0 is the latest version: same data and schema as 2.0.0; only license metadata
# changed (dual-license-by-upstream-source — see README's Licensing section).
DATASET_VERSION = "2.1.0"
# A dedicated output directory for all generated files
OUTPUT_DIR = f'HealthChat-11K_v{DATASET_VERSION}_artifacts'
ANNOTATIONS_FILENAME = f"HealthChat-11K_master_annotations_v{DATASET_VERSION}.jsonl"
OUTPUT_DATASET_PATH = os.path.join(OUTPUT_DIR, f'HealthChat-11K_v{DATASET_VERSION}.jsonl')
OUTPUT_FULL_REVIEW_CSV = os.path.join(OUTPUT_DIR, f'HealthChat-11K_v{DATASET_VERSION}_full_review.csv')
OUTPUT_SYCOPHANCY_REVIEW_CSV = os.path.join(OUTPUT_DIR, f'HealthChat-11K_v{DATASET_VERSION}_sycophancy_review.csv')
# The HF repository ID where the annotations are stored
ANNOTATIONS_REPO_ID = "yahskapar/HealthChat-11K"
# Mapping to find the original source datasets on HF
DATASET_MAPPING = {
'lmsys': {
'path': 'lmsys/lmsys-chat-1m',
'id_field': 'conversation_id'
},
'wildchat': {
'path': 'allenai/WildChat-1M',
'id_field': 'conversation_hash'
}
}
# Human-readable mapping for specialties, used in review CSVs.
SPECIALTIES_MAP = {
1: "General Health", 2: "Mental Health", 3: "Allergy and Immunology",
4: "Cardiology", 5: "Dermatology", 6: "Endocrinology",
7: "Gastroenterology", 8: "Hematology/Oncology", 9: "Infectious Disease",
10: "Nephrology", 11: "Neurology", 12: "Obstetrics and Gynecology (OB/GYN)",
13: "Ophthalmology", 14: "Fitness/Orthopedics/Sports Medicine",
15: "Otolaryngology (ENT)", 16: "Pediatrics", 17: "Pulmonology",
18: "Rheumatology", 19: "Urology", 20: "Dentistry",
21: "Diet and Nutrition", 22: "Not a Health Conversation"
}
# --- Helper Functions ---
def json_serial_default(o):
"""JSON serializer for objects not serializable by default, like datetimes."""
if isinstance(o, datetime):
return o.isoformat()
raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable")
def load_annotations(hf_repo_id: str, data_filename: str) -> (dict, dict):
"""Loads the master annotations file directly from the Hugging Face Hub."""
print(f"Downloading master annotation file from Hugging Face repo: {hf_repo_id}/{data_filename}")
annotations_by_id = {}
target_ids_by_source = defaultdict(set)
try:
annotation_dataset = load_dataset(hf_repo_id, data_files=data_filename, split="train")
for annotation_record in annotation_dataset:
conv_id = annotation_record.get('conversation_id')
source = annotation_record.get('dataset_source')
if conv_id and source in DATASET_MAPPING:
annotations_by_id[conv_id] = annotation_record
target_ids_by_source[source].add(conv_id)
except Exception as e:
print(f" ERROR: Failed to download or process annotations from Hugging Face.")
print(f" Please check that the version exists in the repo: {hf_repo_id}/{data_filename}")
print(f" Error details: {e}")
exit(1)
for source, ids in target_ids_by_source.items():
print(f" - Found {len(ids)} target conversations for dataset '{source}'")
return annotations_by_id, target_ids_by_source
def write_csv(filepath: str, headers: list, rows: list):
"""A helper function to write data to a CSV file."""
print(f"\n Writing {len(rows)} rows to {filepath}...")
try:
with open(filepath, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(headers)
writer.writerows(rows)
print(f" -> Successfully created {filepath}")
except IOError as e:
print(f" ERROR: Could not write to CSV file {filepath}: {e}")
# --- End of Helper Functions Section ---
def main():
"""Main function to orchestrate the download, merging, and writing process."""
print(f" Creating output directory: {OUTPUT_DIR}")
os.makedirs(OUTPUT_DIR, exist_ok=True)
annotations_by_id, target_ids_by_source = load_annotations(ANNOTATIONS_REPO_ID, ANNOTATIONS_FILENAME)
if not annotations_by_id:
print(" No valid conversations to process were found. Exiting.")
return
full_review_rows = []
sycophancy_review_rows = []
found_count = 0
initial_total_count = len(annotations_by_id)
print(f"\n Reconstructing final dataset at: {OUTPUT_DATASET_PATH} (version {DATASET_VERSION})")
with open(OUTPUT_DATASET_PATH, 'w', encoding='utf-8') as fout:
for source, source_info in DATASET_MAPPING.items():
target_ids = target_ids_by_source.get(source)
if not target_ids: continue
dataset_path = source_info['path']
id_field = source_info['id_field']
print(f"\n Streaming and searching '{dataset_path}' for {len(target_ids)} conversations...")
ds_iter = load_dataset(dataset_path, split="train", streaming=True)
for source_record in ds_iter:
record_id = source_record.get(id_field)
if record_id in target_ids:
final_record = source_record.copy()
annotation_data = annotations_by_id[record_id]
final_record.update(annotation_data)
final_record['dataset_version'] = DATASET_VERSION
fout.write(json.dumps(final_record, default=json_serial_default) + "\n")
found_count += 1
print(f" -> Merged and saved: {record_id} ({found_count}/{initial_total_count})")
conv_id = final_record.get('conversation_id')
web_url = final_record.get('web_url', 'N/A')
specialty_code = final_record.get('specialty_conversation_classification')
specialty_str = SPECIALTIES_MAP.get(specialty_code, "N/A")
taxonomy_by_user_turn = {i: "; ".join(sorted(item.get('taxonomy_codes', []))) for i, item in enumerate(final_record.get('taxonomy_messages_classified', []))}
user_turn_idx = 0
for turn_idx, turn in enumerate(final_record.get('conversation', [])):
role, message = turn.get('role'), turn.get('content', '')
tax_codes = taxonomy_by_user_turn.get(user_turn_idx, "") if role == 'user' else ""
full_review_rows.append([conv_id, web_url, specialty_str, turn_idx, role, message, tax_codes])
if role == 'user': user_turn_idx += 1
# Check handles cases where the key is missing, the value is None, or the value is an empty list.
if final_record.get('leading_question_classifications'):
for lq_item in final_record['leading_question_classifications']:
if lq_item.get('classification') != 'N':
turn_index = lq_item.get('user_message_original_turn_index')
user_message = final_record['conversation'][turn_index].get('content') if turn_index < len(final_record['conversation']) else "TEXT NOT FOUND"
prior_assistant_message = final_record['conversation'][turn_index - 1].get('content') if turn_index > 0 else ""
lq_user_turn_counter, lq_tax_codes = 0, ""
for i, t in enumerate(final_record.get('conversation', [])):
if t.get('role') == 'user':
if i == turn_index:
lq_tax_codes = taxonomy_by_user_turn.get(lq_user_turn_counter, "")
break
lq_user_turn_counter += 1
sycophancy_review_rows.append([
conv_id, web_url, turn_index, prior_assistant_message, user_message,
lq_tax_codes, lq_item.get('classification')
])
target_ids.remove(record_id)
if not target_ids:
print(f" All targets for '{source}' found.")
break
print("\n---")
print(" Primary dataset reconstruction complete!")
print(f" - Wrote {found_count} of {initial_total_count} conversations to {OUTPUT_DATASET_PATH}")
write_csv(OUTPUT_FULL_REVIEW_CSV,
["Conversation ID", "Web URL", "Specialty", "Turn Index", "Role", "Message Text", "Taxonomy Codes"],
full_review_rows)
write_csv(OUTPUT_SYCOPHANCY_REVIEW_CSV,
["Conversation ID", "Web URL", "User Message Original Turn Index", "Prior Assistant Message Text", "User Message Text", "Taxonomy Codes", "LQST Classification"],
sycophancy_review_rows)
any_missing = False
for source, remaining_ids in target_ids_by_source.items():
if remaining_ids:
any_missing = True
print(f" - Could not find {len(remaining_ids)} conversations from source '{source}':")
for missing_id in remaining_ids:
print(f" - {missing_id}")
if not any_missing:
print("\n - All requested conversations were found and merged successfully.")
if __name__ == "__main__":
main()