forked from karpathy/autoresearch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlog_utils.py
More file actions
124 lines (95 loc) · 3.64 KB
/
log_utils.py
File metadata and controls
124 lines (95 loc) · 3.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Project-wide logging, diagnostics, and structured output for autoresearch-mlx.
Usage:
from log_utils import logger, is_debug
from log_utils import sample_memory, format_step_timings
from log_utils import hardware_info, save_json, build_bench_data, FORMAT_VERSION
Enable debug mode by passing --debug flag to any script, or setting
the AUTORESEARCH_DEBUG=1 environment variable.
"""
import logging
import os
import platform
import sys
import time
import mlx.core as mx
import orjson
_LOG_FORMAT = "%(asctime)s %(levelname)-5s %(message)s"
_LOG_DATE_FORMAT = "%H:%M:%S"
def _check_debug():
"""Check if debug mode is enabled via --debug flag or env var."""
if os.environ.get("AUTORESEARCH_DEBUG", "0") == "1":
return True
if "--debug" in sys.argv:
sys.argv.remove("--debug")
return True
return False
is_debug = _check_debug()
logger = logging.getLogger("autoresearch")
logger.setLevel(logging.DEBUG if is_debug else logging.INFO)
if not logger.handlers:
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(logging.DEBUG if is_debug else logging.INFO)
handler.setFormatter(logging.Formatter(_LOG_FORMAT, datefmt=_LOG_DATE_FORMAT))
logger.addHandler(handler)
# ---------------------------------------------------------------------------
# Memory diagnostics
# ---------------------------------------------------------------------------
def sample_memory(step, interval=10):
"""Sample active and peak memory every `interval` steps.
Returns (active_mb, peak_mb) or (None, None) if not a sampling step.
"""
if step % interval != 0:
return None, None
active_mb = round(mx.get_active_memory() / 1024 / 1024, 1)
peak_mb = round(mx.get_peak_memory() / 1024 / 1024, 1)
return active_mb, peak_mb
def format_step_timings(step_timings):
"""Convert step_timings tuples to JSON-serializable dicts.
Each tuple: (step, dt, tok_sec, loss, active_mb, peak_mb).
Omits memory fields when None (non-sampling steps).
"""
result = []
for s, dt, ts, l, am, pm in step_timings:
entry = {"step": s, "dt": dt, "tok_sec": ts, "loss": l}
if am is not None:
entry["active_mb"] = am
entry["peak_mb"] = pm
result.append(entry)
return result
# ---------------------------------------------------------------------------
# Structured JSON output (format_version 0.1)
# ---------------------------------------------------------------------------
FORMAT_VERSION = "0.1"
def hardware_info():
"""Return hardware metadata dict."""
return {
"chip": platform.processor() or "Apple Silicon",
"memory_gb": None,
"os": platform.system(),
}
def save_json(prefix, data, *, write_latest=False):
"""Write data to data/<prefix>_<timestamp>.json.
If write_latest=True, also writes data/last_run.json as a stable path
for agent metric extraction.
Returns the timestamped output path.
"""
timestamp = time.strftime("%Y%m%d_%H%M%S")
out_path = os.path.join("data", f"{prefix}_{timestamp}.json")
payload = orjson.dumps(data, option=orjson.OPT_INDENT_2)
with open(out_path, "wb") as f:
f.write(payload)
if write_latest:
last_run_path = os.path.join("data", "last_run.json")
with open(last_run_path, "wb") as f:
f.write(payload)
print(f"Results saved to {out_path}")
return out_path
def build_bench_data(configs):
"""Build the structured bench JSON dict."""
return {
"format_version": FORMAT_VERSION,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"hardware": hardware_info(),
"configs": configs,
}