-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_experiment.py
More file actions
122 lines (95 loc) · 3.51 KB
/
run_experiment.py
File metadata and controls
122 lines (95 loc) · 3.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# run_experiment.py
"""Main experiment runner. Called by the autoresearch loop.
Usage: python run_experiment.py
Reads config.yaml, runs simulation on all events, outputs scores.
Output format (grep-parseable):
mean_score: 0.4567
time_acc: 0.3890
dir_acc: 0.6500
path_sim: 0.3100
"""
import asyncio
import hashlib
import json
import sys
from pathlib import Path
from simulation.config import load_config
from simulation.runner import run_simulation
from data.schema import load_event
from evaluation.evaluate import score_all_events
EVENTS_DIR = Path("data/events")
CONFIG_PATH = "config.yaml"
REPEATS = 3
def load_all_events() -> list[dict]:
"""Load all event JSON files from data/events/."""
events = []
for path in sorted(EVENTS_DIR.glob("*.json")):
if path.name.startswith("_"):
continue
with open(path) as f:
raw = json.load(f)
events.append(raw)
return events
def config_hash(path: str) -> str:
"""Compute short hash of config file for tracking."""
with open(path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()[:8]
async def run_all(config, events_raw):
"""Run simulation for all events with repeats, return evaluation input."""
eval_inputs = []
for raw_event in events_raw:
event = load_event(raw_event)
event_content = (
f"BREAKING: FDA has {event.decision} {event.drug_name} "
f"({event.company}) for {event.indication}. "
f"Approval type: {event.approval_type}."
)
all_runs = []
for repeat in range(REPEATS):
db_path = f"/tmp/sim_predict_{event.event_id}_{repeat}.db"
result = await run_simulation(
config=config,
event_content=event_content,
db_path=db_path,
seed=repeat,
)
all_runs.append(result)
avg_result = _average_results(all_runs)
eval_inputs.append({
"sim_result": avg_result,
"ground_truth": {
"price_in_hours": event.price_in_hours,
"direction": event.direction,
"social_volume_curve": event.social_volume_curve,
},
})
return eval_inputs
def _average_results(runs: list[dict]) -> dict:
"""Average simulation results across multiple runs."""
avg_hours = sum(r["predicted_price_in_hours"] for r in runs) / len(runs)
directions = [r["predicted_direction"] for r in runs]
direction = max(set(directions), key=directions.count)
max_len = max(len(r["volume_curve"]) for r in runs)
padded = [r["volume_curve"] + [0] * (max_len - len(r["volume_curve"])) for r in runs]
avg_curve = [sum(col) / len(col) for col in zip(*padded)]
return {
"predicted_price_in_hours": avg_hours,
"predicted_direction": direction,
"volume_curve": avg_curve,
}
def main():
config = load_config(CONFIG_PATH)
events_raw = load_all_events()
if not events_raw:
print("ERROR: No event files found in data/events/", file=sys.stderr)
sys.exit(1)
eval_inputs = asyncio.run(run_all(config, events_raw))
scores = score_all_events(eval_inputs)
print(f"mean_score: {scores['mean_score']:.4f}")
print(f"time_acc: {scores['time_acc']:.4f}")
print(f"dir_acc: {scores['dir_acc']:.4f}")
print(f"path_sim: {scores['path_sim']:.4f}")
print(f"config_hash: {config_hash(CONFIG_PATH)}")
print(f"events_count: {len(events_raw)}")
if __name__ == "__main__":
main()