diff --git a/src/dnet/perf/__init__.py b/src/dnet/perf/__init__.py new file mode 100644 index 00000000..330893e8 --- /dev/null +++ b/src/dnet/perf/__init__.py @@ -0,0 +1,2 @@ + +from .trace import TraceConfig, Tracer diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py new file mode 100644 index 00000000..8afcca47 --- /dev/null +++ b/src/dnet/perf/trace.py @@ -0,0 +1,319 @@ + +from __future__ import annotations + +import os +import io +import sys +import time +import json +import pstats +import cProfile +import threading +import queue + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, List +from contextlib import contextmanager + +import httpx + +from dnet.utils.logger import logger + +@dataclass +class TraceConfig: + file: str = "logs/dnet-trace.jsonl" + streaming: bool = True + include_prefixes: Tuple[str, ...] = ("src/dnet/",) + include_c_calls: bool = False + budget: int = 0 # 0 means unlimited + enabled: bool = True + node_id: Optional[str] = None + record_pid_tid: bool = True + aggregate: bool = False + aggregate_url: Optional[str] = None + agg_max_events: int = 300 + +class _NoopFrame: + def __enter__(self): + return self + def __exit__(self, *a): + return False + def event(self, *a, **k): + pass + def set(self, *a, **k): + pass + +class _Frame: + __slots__ = ("t", "name", "attrs", "_t0") + def __init__(self, tracer: "Tracer", name: str, attrs: Optional[Dict[str, Any]]): + self.t = tracer + self.name = name + self.attrs = dict(attrs or {}) + self._t0 = 0.0 + def __enter__(self): + self._t0 = time.time_ns() # cross-node timekeeping + self.attrs.update({"t0": self._t0}) + self.t._emit({"type": "B", "name": self.name, "args": dict(self.attrs)}) + return self + def __exit__(self, ex_type, ex, tb): + dt_ms = (time.time_ns() - self._t0) * 1e-6 + self.attrs.update({"ms": round(dt_ms, 3), "exc": bool(ex), "t0": time.time_ns()}) + self.t._emit({"type": "E", "name": self.name, "args": self.attrs}) + return False + def event(self, name: str, **attrs): + out = dict(attrs or {}) + out.setdefault("t_rel_ms", (time.perf_counter() - self._t0) * 1000.0) + self.t._emit({"type": "I", "name": f"{self.name}.{name}", "args": out}) + def set(self, key: str, val: Any): + self.attrs[key] = val + +class Tracer: + def __init__(self, config: TraceConfig): + self.config = config + self._lock = threading.Lock() + self._fh: Optional[io.TextIOBase] = None + self._events: List[Dict[str, Any]] = [] + self._req_id: str = None + self._active = False + + self._agg_enabled: bool = False + self._agg_max_events: int = int(self.config.agg_max_events or 1000) + self._agg_q: queue.Queue[dict] = queue.Queue(maxsize=256) + self._agg_thread: Optional[threading.Thread] = None + + if self.config.aggregate: + self.start_aggregator() + + # Aggregator worker thread + def start_aggregator(self) -> None: + self._agg_enabled = True + self._agg_max_events = max(10, int(self.config.agg_max_events or 1000)) # 10 min, 1000 default + if not self._agg_thread or not self._agg_thread.is_alive(): + self._agg_thread = threading.Thread(target=self._agg_exec, name="trace-agg", daemon=True) + self._agg_thread.start() + + def stop_aggregator(self, *, flush: bool = True, timeout: float = 5.0) -> None: + self._agg_enabled = False + if flush and self._events: + try: + self._agg_q.put_nowait({ + "req_id": (self._req_id or "run"), + "node_id": (self.config.node_id or "node"), + "events": list(self._events), }) + except queue.Full: + logger.warning(f"Trace aggragator queue is full.") + self._events.clear() + if self._agg_thread and self._agg_thread.is_alive(): + self._agg_thread.join(timeout) + self._agg_thread = None + + def _agg_exec(self) -> None: + assert self.config.aggregate_url != "" + client = httpx.Client(timeout=5.0) + try: + logger.debug(f"Aggregation worker thread {self._agg_enabled}, {self._agg_q.empty()}") + while self._agg_enabled or not self._agg_q.empty(): + try: + batch = self._agg_q.get(timeout=0.2) + except queue.Empty: + continue + logger.info(f"Sending trace buffer to API : {self.config.aggregate_url}") + try: + res = client.post(self.config.aggregate_url, json=batch) + if res.status_code != 200: + logger.error(f"Aggregator POST failed {res.status_code}: {res.text}") + except Exception as e: + logger.warning(f"Unable to POST trace aggregation data to {self.config.aggregate_url}: {e}") + finally: + self._agg_q.task_done() + finally: + try: + client.close() + except Exception: + logger.warining("Unable to close httpx client.") + + def update_api_addr(self, addr): + self.config.aggregate_url = addr + logger.debug(f"Updated API Address: {self.config.aggregate_url}") + + def start(self, *, reset: bool = True) -> None: + self._active = bool(self.config.enabled) + if not self._active: + logger.info("Initialized tracer.") + return + if self.config.file: + d = os.path.dirname(self.config.file) or "." + os.makedirs(d, exist_ok=True) + if reset and os.path.exists(self.config.file): + try: + os.remove(self.config.file) + except Exception: + logger.warning(f"Unable to remove existing trace file {self.config.file}") + if self.config.streaming: + with self._lock: + self._fh = open(self.config.file, "a", encoding="utf-8") + logger.info(f"Streaming trace to {self.config.file}.") + if self.config.aggregate and self.config.aggregate_url and self.config.node_id: + self.start_aggregator() + + def stop(self, *, flush_events: bool = True) -> None: + if flush_events: + self.flush() + self._active = False + with self._lock: + if self._fh: + try: + self._fh.flush() + self._fh.close() + except Exception: + logger.warning(f"Unable to flush to file {self.config.file}") + self._fh = None + + # Flush file to disk + def flush(self, *, clear: bool = False) -> None: + if not self._active: return + with self._lock: + if not self.config.streaming and self._events: + with open(self.config.file, "a", encoding="utf-8") as f: + for ev in self._events: + f.write(json.dumps(ev, ensure_ascii=False) + "\n") + if clear: + self._events.clear() + + # Quick dump to memory + def snapshot(self, path: str) -> None: + with self._lock: + with open(path, "w", encoding="utf-8") as f: + for ev in self._events: + f.write(json.dumps(ev, ensure_ascii=False) + "\n") + + # emit a new frame + def _emit(self, ev: Dict[str, Any]) -> None: + if not self._active: return + ev.setdefault("ts", time.perf_counter()) + if self._req_id is not None: + ev.setdefault("req_id", self._req_id) + if self.config.record_pid_tid: + try: + ev.setdefault("pid", os.getpid()) + ev.setdefault("tid", threading.get_ident()) + except Exception: + logger.warning("Unable to get PID and TID for tracer frame.") + + with self._lock: + if self.config.streaming and self._fh: + self._fh.write(json.dumps(ev, ensure_ascii=False) + "\n") + self._fh.flush() + else: + self._events.append(ev) + + if self._agg_enabled: + if len(self._events) < self._agg_max_events: return + logger.debug(f"Queuing tracer frame batch of {len(self._events)}") + batch = { "run_id": (self._req_id or "NONE"), + "node_id": (self.config.node_id or "NODE"), + "events": list(self._events)} + try: + self._agg_q.put_nowait(batch) + except queue.Full: + logger.warning(f"Aggregator queue is full. Dropping {len(batch["events"])} frames.") + self._events.clear() + + # Frames + def frame(self, scope: str, name: str, attrs: Optional[Dict[str, Any]] = None): + if not self._active: + return _NoopFrame() + return _Frame(self, f"{scope}.{name}", attrs) + + # Same as normal frame but signals that this trace is a cannon event (required for runtime stats) + def canonical(self, scope: str, name: str, attrs: Optional[Dict[str, Any]] = None): + return self.frame(scope, name, attrs) + + # Mark an event outside of a frame + def mark(self, name: str, attrs: Any = {}) -> None: + self._emit({"type": "I", "name": name, "args": attrs}) + + # Helpers + @contextmanager + def profile_block(self, outfile: Optional[str] = None, sort: str = "cumtime", limit: int = 40): + pr = cProfile.Profile() + pr.enable() + try: + yield pr + finally: + pr.disable() + s = io.StringIO() + pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sort).print_stats(limit) + out = s.getvalue() + if outfile: + d = os.path.dirname(outfile) or "." + os.makedirs(d, exist_ok=True) + with open(outfile, "w", encoding="utf-8") as f: + f.write(out) + else: + self._emit({"type": "PROFILE", "name": "cprofile", "attrs": {"sort": sort, "limit": limit, "report": out}}) + + @contextmanager + def callgraph( + self, + include_prefixes: Optional[Tuple[str, ...]] = None, + budget_events: Optional[int] = None, + include_c_calls: Optional[bool] = None, + apply_to_new_threads: bool = False, + ): + """ + Interpreter-level tracing (sys.setprofile) for all Python calls/returns + within the scope. Heavy overhead; best for deep debugging runs. + """ + prefixes = include_prefixes if include_prefixes is not None else self.config.include_prefixes + budget = (budget_events if budget_events is not None else self.config.budget) or 0 + inc_c = include_c_calls if include_c_calls is not None else self.config.include_c_calls + + emitted = 0 + stack: list[Tuple[str, float]] = [] + + def prof(frame, event, arg): + nonlocal emitted + if budget and emitted >= budget: + return + if event in ("call", "return"): + code = frame.f_code + filename = code.co_filename or "" + if prefixes and not any(filename.startswith(p) for p in prefixes): + return + name = code.co_name + key = f"{filename}:{code.co_firstlineno}:{name}" + if event == "call": + stack.append((key, time.perf_counter())) + self._emit({"type": "B", "name": f"py.{name}", "attrs": {"file": filename, "line": code.co_firstlineno}}) + emitted += 1 + else: + if stack and stack[-1][0] == key: + _, t0 = stack.pop() + dt_ms = (time.perf_counter() - t0) * 1000.0 + self._emit({"type": "E", "name": f"py.{name}", "attrs": {"ms": round(dt_ms, 3)}}) + emitted += 1 + elif inc_c and event in ("c_call", "c_return"): + func = getattr(arg, "__name__", None) + mod = getattr(arg, "__module__", None) + if not func: + return + if event == "c_call": + self._emit({"type": "B", "name": f"c.{mod}.{func}", "attrs": {}}) + emitted += 1 + else: + self._emit({"type": "E", "name": f"c.{mod}.{func}", "attrs": {}}) + emitted += 1 + + prev = sys.getprofile() + sys.setprofile(prof) + prev_thread = None + if apply_to_new_threads: + prev_thread = threading.getprofile() + threading.setprofile(prof) + try: + yield + finally: + sys.setprofile(prev) + if apply_to_new_threads: + threading.setprofile(prev_thread) diff --git a/src/dnet/perf/utils/__init__.py b/src/dnet/perf/utils/__init__.py new file mode 100644 index 00000000..0ee2f5e1 --- /dev/null +++ b/src/dnet/perf/utils/__init__.py @@ -0,0 +1 @@ +from .aggregators import TraceAggregator, StatsAggregator diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py new file mode 100644 index 00000000..86e07f41 --- /dev/null +++ b/src/dnet/perf/utils/aggregators.py @@ -0,0 +1,502 @@ + +from __future__ import annotations + +import sys +import threading +import statistics +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple, Optional, DefaultDict +from collections import defaultdict, deque + +#from dnet.utils.logger import logger +from dnet.ring.api.api_logging import get_api_logger +from dnet.ring import LayerAssignment, TopologyInfo + +logger = get_api_logger() + +Key = Tuple[str, Optional[int], Optional[int], str] # (node_id, pid, tid, req_id) + +@dataclass +class _ActiveSpan: + """Per-instance active span used for self-time accounting on a call stack.""" + name: str + t0: int + child: int = 0 + + +@dataclass +class _SymbolAgg: + """Aggregated statistics for a single trace symbol (name).""" + total_ms: float = 0.0 + count: int = 0 + durations: deque = field(default_factory=lambda: deque(maxlen=10000)) + + def add(self, self_ms: float) -> None: + self.total_ms += float(self_ms) + self.count += 1 + self.durations.append(float(self_ms)) + +# Sort known frames and compute averages by key +@dataclass +class RunAggregator: + sums_by_name: Dict[str, float] = field(default_factory=dict) + counts_by_name: Dict[str, int] = field(default_factory=dict) + last_batch_seq: Dict[str, int] = field(default_factory=dict) + + stacks: Dict[Key, List[_ActiveSpan]] = field(default_factory=dict) + drops: int = 0 + # Aggregated stats per symbol (primary source of truth) + symbols: Dict[str, _SymbolAgg] = field(default_factory=dict) + # Back-compat mirrors for existing readers (e.g., legacy REPL code) + durations_by_name: Dict[str, deque] = field(default_factory=dict) + + def _key(self, node_id: str, pid: Optional[int], tid: Optional[int], req_id: Optional[str]) -> Key: + return (node_id, pid, tid, req_id or "") + + def _push(self, key: Key, f: _ActiveSpan) -> None: + self.stacks.setdefault(key, []).append(f) + + def _pop(self, key: Key) -> Optional[_ActiveSpan]: + st = self.stacks.get(key) + if not st: return None + return st.pop() + + def _peek(self, key: Key) -> Optional[_ActiveSpan]: + st = self.stacks.get(key) + return st[-1] if st else None + + def _accumulate(self, name: str, self_ms: float) -> None: + sym = self.symbols.get(name) + if sym is None: + sym = _SymbolAgg() + self.symbols[name] = sym + sym.add(self_ms) + + # FIXME: Remove + self.sums_by_name[name] = sym.total_ms + self.counts_by_name[name] = sym.count + dq = self.durations_by_name.get(name) + if dq is None: + dq = deque(maxlen=10000) + self.durations_by_name[name] = dq + dq.append(float(self_ms)) + + def ingest_event(self, node_id: str, ev: Dict[str, Any]) -> None: + if not ev.get("name"): + logger.error(f"Received trace frame without name {ev.get('ts')}") + return + # Normalize timestamp to microseconds + ts_raw = ev.get("ts") + ts = 0 + try: + if isinstance(ts_raw, float): + ts = int(ts_raw * 1_000_000) + elif isinstance(ts_raw, int): + ts = ts_raw + else: + ts = int(ts_raw or 0) + except Exception: + ts = 0 + req_id = ev.get("req_id") or "" + key = self._key(node_id, ev.get("pid"), ev.get("tid"), req_id) + if ev.get("type") == "B": + self._push(key, _ActiveSpan(name=ev.get("name"), t0=ts)) + elif ev.get("type") == "E": + fr = self._pop(key) + if not fr: return + dur_us = max(0, ts - fr.t0) + self_us = max(0, dur_us - fr.child) + self_ms = self_us / 1000.0 + self._accumulate(fr.name, self_ms) + parent = self._peek(key) + if parent: + parent.child += dur_us + else: + # TODO :Process other events + pass + + +class TraceAggregator: + def __init__(self) -> None: + self._req: Dict[str, RunAggregator] = {} + self._lock = threading.Lock() + + def enqueue(self, batch: Dict[str, Any]) -> None: + try: + run_id = batch.get("run_id") + node_id = batch.get("node_id") + events = batch.get("events") or [] + logger.debug(f"Enquing trace buffer from {run_id}, {node_id}") + if not run_id or not node_id: return # Drop batch + with self._lock: + agg = self._req.setdefault(run_id, RunAggregator()) + for ev in events: + agg.ingest_event(node_id, ev) + except Exception as e: + logger.error(f"Trace aggregator enque error: {e}") + + def annotate(self, run_id: str, *, mapping: Optional[Dict[str, str]] = None, repeats: int = 0) -> List[Dict[str, Any]]: + with self._lock: + agg = self._req.get(run_id) + if not agg: + return [] + def _stats(xs: List[float]) -> Dict[str, float]: + if not xs: + return {"mean": 0.0, "p50": 0.0, "p90": 0.0, "p99": 0.0, "min": 0.0, "max": 0.0} + n = len(xs) + srt = sorted(xs) + def q(p: float) -> float: + if n == 1: + return srt[0] + k = int(round(p * (n - 1))) + k = max(0, min(k, n - 1)) + return srt[k] + return { + "mean": (sum(xs) / n), + "p50": q(0.5), + "p90": q(0.9), + "p99": q(0.99), + "min": srt[0], + "max": srt[-1], + } + + rows: List[Dict[str, Any]] = [] + if not mapping: + for name, sym in agg.symbols.items(): + total = sym.total_ms + samples = list(sym.durations) + st = _stats(samples) + rows.append({ + "name": name, + "total": total, + "max": st["max"], + "mean": st["mean"], + "p50": st["p50"], + "p90": st["p90"], + "p99": st["p99"], + "samples": len(samples), + }) + else: + sums: Dict[str, float] = {} + counts: Dict[str, int] = {} + dists: Dict[str, List[float]] = {} + for raw, sym in agg.symbols.items(): + disp = mapping.get(raw, raw) + sums[disp] = sums.get(disp, 0.0) + sym.total_ms + counts[disp] = counts.get(disp, 0) + sym.count + if sym.durations: + dists.setdefault(disp, []).extend(sym.durations) + for name, total in sums.items(): + samples = dists.get(name, []) + st = _stats(samples) + rows.append({ + "name": name, + "total": total, + "max": st["max"], + "mean": st["mean"], + "p50": st["p50"], + "p90": st["p90"], + "p99": st["p99"], + "samples": len(samples), + }) + rows.sort(key=lambda r: r["total"], reverse=True) + return rows + + def roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: + # Call-tree storage is disabled to reduce memory; keep API for compatibility. + return [] + + +# Runtime statistics +# Use a RunAggregator to get raw frames per request, then transform into ReqStats + +# Track a single request, use multiple for a full benchmark +@dataclass +class ReqStats: + model: str = "" # Model name + tokenizer: str = "" # Tokenizer name + run_id: str = "" # ID of session (for later mapping) + req_id: str = "" # List of serviced requests + ttft: float = 0.0 # Time to first token + itl: List[float] = None # Inter-token latency per round + prompt_tokens: int = -1 # Number of prompt tokens per request (req_id: #) + generated_tokens: int = -1 # Number of generated tokens per request (req_id: #) + total_tokens: int = -1 # Total number of tokens processed + nodes: List[str] = None # Nodes that participated in computation + _rounds_t0: List[int] = None # Keep round times for post-processing + + latencies: List[List[str, str, str, int]] = None # List of inter-node latencies: [node0, node1, p50, 0.0] + latency_per_layer: Dict[int, float] = None # Map of {layer: 0.0} + latency_per_shard: Dict[str, float] = None # Map of {shard: 0.0} + total_latency: int = -1 # Total runtime of requests + startup_t: float = 0.0 # Time to start shard (ms) + layer_assignment_t: float = 0.0 # Time to layer assignment (ms) + + # Per-worker data + compute_per_worker: Dict[str, float] = None + network_per_worker: Dict[str, float] = None + memory_per_worker: Dict[str, float] = None + + # Network info + tx_bytes_per_node: Dict[str, int] = None # Volume of trafic per node + rx_bytes_per_node: Dict[str, int] = None + + topo: TopologyInfo = None # Topology information for this request (keep here since it might change) + assignment: LayerAssignment = None # Map of layer to shard IDs + + +# Process stats + handle per-request data +# NOTE: Hardcodes some high-level trace frame symbols +# TODO: Use a bitmap to track the stages for each req and prevent double-count +class StatsAggregator: + def __init__(self) -> None: + self._lock = threading.Lock() + + self._max_inflight_req = 20 # per node FIXME: modify from repl + + # Frames are kept in here while in-flight, then remove the frame objects and append to _stats + self._frames: Dict[str, Dict[str, Dict[str, Any]]] = {} # Store frames per req_id, per node_id + + self._req: List[str] = [] # Tracked requests (in-flight or done) + self._req_round_finish: Dict[str, bool] = {} # Track in-flight requests + self._req_prefill: Dict[str, bool] = {} # Track if this request round is prefill + self._open_frames: Dict[str, Dict[str, Dict[str, Any]]] = {} + self._global_memory_per_worker: Dict[str, float] = {} + + # Staging environment for events that arrive before + # the request.start of the request they belong to + self._staging: Dict[str, Any] = {} + + # Finished and running requests + self._running_stats: Dict[str, ReqStats] = {} # Unfinished stat frames + self._stats: Dict[str, ReqStats] = {} # Finished stat frames + + self.nodes = [] # Keep track of active nodes in the network + + # Ingest raw data from tracer + def add(self, data: Dict[str, Any]) -> None: + events = data["events"] or [] + if not events: return # Nothing to do + + node_id = data.get("node_id") + if not node_id: return # Drop unknown node + + with self._lock: + + # Update in-flight events or register new ones + for i, e in enumerate(events): + symbol = e["name"].split(".") + + if e["type"] == 'B': + req_id = data.get("req_id") + if not req_id or not node_id: continue + self._open_frames[req_id][node_id][e["name"]] = e + continue + + req_id = e["args"].get("req_id") + if not req_id: + if symbol[0] == "memory": # Global memory frames are not request-based + if node_id not in self._global_memory_per_worker: + self._global_memory_per_worker[node_id] = 0.0 + self._global_memory_per_worker[node_id] += e["args"]["ms"] + continue # Drop anonymous frames + + if symbol[0] == "request": + if symbol[1] == "start": # Start request, setup buffers and ingest staged frames + self._req.append(req_id) + self._open_frames[req_id] = {} + self._req_prefill[req_id] = True + stats = ReqStats( + model=e["args"]["model"], + tokenizer=e["args"]["tokenizer"], + req_id=req_id, + ttft=0.0, + itl=[], + prompt_tokens=e["args"]["prompt_tokens"], + total_tokens=e["args"]["prompt_tokens"], + latencies={}, + latency_per_layer={}, + latency_per_shard={}, + assignment=None, + compute_per_worker={}, + network_per_worker={}, + memory_per_worker={}, + nodes=[], + _rounds_t0=[], + ) + self._running_stats[req_id] = stats + + # Process all events in staging + if req_id in self._staging: + for pe in self._staging[req_id]: + node_id = e["args"].get("node_id") + if not node_id: continue + self._process_frame(pe, req_id, node_id, stats) + + del self._staging[req_id] + continue + + elif symbol[1] == "end": # Finish processing request + st_obj = self._running_stats[req_id] + st_obj.generated_tokens = e["args"]["generated_tokens"] + st_obj.total_tokens += e["args"]["generated_tokens"] + self._stats[req_id] = st_obj + del self._running_stats[req_id] + # TODO: Handle latency of transfer back to API + continue + + elif symbol[1] == "round": + stats = self._running_stats[req_id] + stats._rounds_t0.append(e["args"]["t0"]) + continue + + if req_id in self._stats: continue # Already finidhed processing request + if req_id not in self._req: # If unknown request, stage frames + if req_id not in self._staging: + self._staging[req_id] = [] + self._staging[req_id].append(e) + continue + + #node_id = e["args"].get("node_id") + #if not node_id: return # Drop unknown node + + stats = self._running_stats[req_id] + self._process_frame(e, req_id, node_id, stats) + + + def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): + symbol = e["name"].split(".") + if node_id not in self.nodes: self.nodes.append(node_id) + if node_id not in stats.nodes: + stats.nodes.append(node_id) + stats.compute_per_worker[node_id] = 0.0 + stats.network_per_worker[node_id] = 0.0 + stats.memory_per_worker[node_id] = 0.0 + + if symbol[0] == "compute": + if symbol[1] == "forward": + pass + #_cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] + # Defer stats compute until after we sort the times (async is kil) + stats.compute_per_worker[node_id] += e["args"]["ms"] + + elif symbol[0] == "network": + if symbol[1] == "rx": # Time in transport, ingress queue and ingress_worker + _cost = lambda e: e["args"]["inflight"] + e["args"]["inwait"] + e["args"]["ms"] + #TODO: change shard in metadata + stats.network_per_worker[node_id] += e["args"]["ms"] + + return + + def _compute_round_stats(self, stats): + rounds = stats._rounds_t0 + #rounds.sort() + assert len(rounds) > 2, "Not enough data." + stats.ttft = (rounds[1] - rounds[0]) * 1e-6 + stats.itl.append(rounds[1]) + for i in range(1, len(rounds)): + stats.itl[-1] = (rounds[i] - rounds[i-1]) * 1e-6 + stats.itl.append(rounds[i]) + stats.itl = stats.itl[:-1] + + # Return data for total, per req, worker or model (maybe add per layer too?) + def stats( + self, + req_id: Optional[str] = None, + worker: Optional[str] = None, + model: Optional[str] = None + ): + + # FIXME: Allow manual selection of counters (and push to tracer) + fields = [ # 0 is native, 1 compound + (0, "prompt_tokens", ""), + (0, "generated_tokens", ""), + (0, "total_tokens", ""), + (0, -1, ""), # special for empty line + (0, "ttft", "ms"), + (1, "tokens_per_second", "ms"), + (1, "inter_token_latency", "ms"), + (0, -1, ""), + (1, "estimated_compute", "GFLOPs"), + ] + + # FIXME: Allow filtering by these + if req_id: pass + elif worker: pass + elif model: pass + + else: # Sort per model, per request (node info only when requested) + if len(self._stats) < 1: + print("No tracked stats in memory. Track a request first.\n") + return + stats = self._stats[list(self._stats.keys())[-1]] + self._compute_round_stats(stats) + #sys.stdout.write(f"\n Loaded model '{stats.model}'.\n") + sys.stdout.write(f"Performance stats for request '{stats.req_id}':\n\n") + try: + for tag, n, unit in fields: + if tag == 0: # Native counter + if n == -1: + sys.stdout.write("\n") + continue + nr = getattr(stats, n) + if isinstance(nr, int): + nr_str = f"{nr}" + elif isinstance(nr, float): + nr_str = f"{nr:.2f}" + elif isinstance(nr, str): + if len(nr) > 20: + nr_str = nr[:15] + "..." + else: + nr_str = nr + sys.stdout.write(f"{nr_str.rjust(20)} {unit.ljust(5)}\t{n}\n") + + # Compound trackers + elif tag == 1: + match n: + case "tokens_per_second": + tps = [ 1 / (rt/1000) for rt in stats.itl ] + #median = statistics.median(tps) + mean = sum(tps) / len(tps) + sys.stdout.write(f"{mean:.2f}".rjust(20)+" tok/s".rjust(5)+" \ttokens_per_second") + sys.stdout.write(f"\t# {statistics.median(stats.itl)/1000:.3f} s/tok\n") + + case "inter_token_latency": + assert len(stats.itl) > 1, "Not enough trace frames" + itl = stats.itl[:-1] # FIXME: last element is super big + median = statistics.median(itl) + p90 = statistics.quantiles(itl, n=100)[89] + p99 = statistics.quantiles(itl, n=100)[98] + sys.stdout.write(f"{median:.4f}".rjust(20) + " ms".ljust(5) + "\tmean_inter_token_latency (ITL)\n") + sys.stdout.write(" "*35 + f"{p90:.3f} (p90), {p99:.3f} (p99)\n") + sys.stdout.write(" "*35 +f"{min(itl):.3f} (min), {max(itl):.3f} (max)\n") + + case "estimated_compute": + sys.stdout.write(f"UNKNOWN".rjust(20)+" GFLOPs".ljust(5)+"\testimated_flops\n") + + case _: + pass + + for i, n in enumerate(self.nodes): + try: + comp = stats.compute_per_worker[n] + net = stats.network_per_worker[n] + req_mem = stats.memory_per_worker[n] + g_mem = self._global_memory_per_worker[n] + mem = req_mem + g_mem + total = comp + net + mem + sys.stdout.write(f"\n node{i} [{n}]\n") + sys.stdout.write(f"{comp:.2f}".rjust(20)+"ms".ljust(5)+f"\tcompute_time # [{(comp/total)*100:0.2f}%]\n") + sys.stdout.write(f"{net:.2f}".rjust(20)+"ms".ljust(5)+f"\tnetwork_time # [{(net/total)*100:0.2f}%]\n") + sys.stdout.write(f"{mem:.2f}".rjust(20)+"ms".ljust(5)+f"\tmemory_time # [{(mem/total)*100:0.2f}%]\n") + except Exception as e: + print(f"{e}") + + except Exception as e: + logger.error(f"{e}") + + # Per-node information + sys.stdout.write("\n") + return + + diff --git a/src/dnet/protos/dnet_ring.proto b/src/dnet/protos/dnet_ring.proto index 0b46c5be..d1b3b33a 100644 --- a/src/dnet/protos/dnet_ring.proto +++ b/src/dnet/protos/dnet_ring.proto @@ -34,6 +34,9 @@ message ActivationRequest { int64 timestamp = 3; string node_origin = 4; string callback_url = 5; + float rx_enq_t = 6; + float tx_enq_prev_t = 7; + float rx_inflight_t = 8; } // Response message for activation sending diff --git a/src/dnet/protos/shard_api_comm.proto b/src/dnet/protos/shard_api_comm.proto index ec80cf18..6540f65f 100644 --- a/src/dnet/protos/shard_api_comm.proto +++ b/src/dnet/protos/shard_api_comm.proto @@ -35,6 +35,7 @@ message TokenRequest { string nonce = 1; int32 token_id = 2; int64 timestamp = 3; + float tx_enq_prev_t = 4; } // Response for token reception @@ -49,4 +50,4 @@ message RingError { string failed_node = 2; string error_code = 3; string error = 4; -} \ No newline at end of file +} diff --git a/src/dnet/ring/__init__.py b/src/dnet/ring/__init__.py index e69de29b..adb40d20 100644 --- a/src/dnet/ring/__init__.py +++ b/src/dnet/ring/__init__.py @@ -0,0 +1,2 @@ + +from .common import LayerAssignment, TopologyInfo diff --git a/src/dnet/ring/api/api_logging.py b/src/dnet/ring/api/api_logging.py new file mode 100644 index 00000000..48999dac --- /dev/null +++ b/src/dnet/ring/api/api_logging.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import os +import logging +from logging.handlers import RotatingFileHandler +from pathlib import Path + +_CONFIGURED_FLAG = "_dnet_api_logger_configured" + +def get_api_logger() -> logging.Logger: + log = logging.getLogger("dnet.api") + if getattr(log, _CONFIGURED_FLAG, False): + return log + + # Level from env, fallback INFO + level_name = (os.getenv("DNET_API_LOG", "INFO") or "INFO").strip().upper() + level = getattr(logging, level_name, logging.INFO) + #log.setLevel(level) + log.setLevel(logging.DEBUG) + + # Do not bubble to root (console) + log.propagate = False + + # Ensure logs directory + try: + Path("logs").mkdir(parents=True, exist_ok=True) + except Exception: + pass + + # Attach a rotating file handler + try: + fh = RotatingFileHandler("logs/api.log", maxBytes=10000000, backupCount=5) + fmt = logging.Formatter( + "%(asctime)s %(levelname)s [%(threadName)s] %(name)s: %(message)s" + ) + fh.setFormatter(fmt) + log.addHandler(fh) + except Exception: + # As a last resort, attach a NullHandler to avoid 'No handler' warnings + log.addHandler(logging.NullHandler()) + + setattr(log, _CONFIGURED_FLAG, True) + return log + diff --git a/src/dnet/ring/api/models.py b/src/dnet/ring/api/models.py index 7199142a..c9f4b45b 100644 --- a/src/dnet/ring/api/models.py +++ b/src/dnet/ring/api/models.py @@ -8,6 +8,7 @@ from ..common import LayerAssignment +from dnet.perf.trace import _Frame class RoleMapping(BaseModel): """Role mapping for chat formatting.""" @@ -403,3 +404,29 @@ class UnloadModelResponse(BaseModel): message: Optional[str] = Field( default=None, description="Overall status or error message" ) + +# Tracer ingest + +class TraceEvent(BaseModel): + type: str = Field(..., description="Event type/phase") + name: str = Field(..., description="Span/mark name") + ts: float = Field(..., description="Timestamp in microseconds") + args: Dict[str, Any] = Field(default_factory=dict) + req_id: Optional[str] = None + pid: Optional[int] = None + tid: Optional[int] = None + +class TraceIngestBatch(BaseModel): + run_id: str = Field(..., description="Bench run identifier") + node_id: str = Field(..., description="Shard/service identity") + events: List[TraceEvent] = Field(default_factory=list) + #dropped: Optional[int] = Field(default=0, description="Events dropped on sender") + #max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") + #last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") + #schema_version: int = Field(default=1) + +class TraceIngestResponse(BaseModel): + ok: bool = True + accepted: int = 0 + batch_seq: Optional[int] = None + message: Optional[str] = None diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index a9793f9d..2c08c627 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -6,7 +6,7 @@ import json from dataclasses import asdict from io import StringIO -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Callable import httpx import mlx.core as mx @@ -39,8 +39,10 @@ add_ShardApiServiceServicer_to_server, ) +from .api_logging import get_api_logger from ...utils.logger import logger from ...utils.banner import print_startup_banner +from ...utils.latency import LatencyResults, calculate_median_latency_seconds from ...utils.model import ( ModelMetadata, get_model_config_json, @@ -72,15 +74,22 @@ UnloadModelResponse, ) from ..shard.models import ( + MeasureLatencyRequest, + MeasureLatencyResponse, ShardProfileRequest, ShardLoadModelRequest, ShardLoadModelResponse, ShardProfileResponse, + TraceIngestBatch, + TraceIngestResponse, + TraceConfigRequest, + TraceConfigResponse, ) from ..data_types import StopCondition from .servicer import ShardApiServicer from ..common import TopologyInfo, LayerAssignment +from dnet.perf import Tracer, TraceConfig async def arange(count: int): """Async range generator.""" @@ -99,6 +108,9 @@ async def azip(*async_iterables): break +logger = get_api_logger() + + class RingApiNode: """API node for distributed inference ring with dynamic topology.""" @@ -140,12 +152,27 @@ def __init__( except Exception: pass + cfg = TraceConfig( + file="./trace.json", + streaming=False, + include_prefixes = ("src/dnet/"), + include_c_calls = False, + budget = 10000, + enabled = True, + record_pid_tid = True, + aggregate=False, + aggregate_url=None, + ) + self.tracer = Tracer(cfg) + self.tracer.start() + logger.info( "API node initialized on HTTP port %s, gRPC port %s", self.http_port, self.grpc_port, ) + async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()) -> None: """Start the API node. @@ -153,6 +180,24 @@ async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()) -> None: shutdown_trigger: Shutdown trigger function """ self.running = True + # Reduce third‑party library noise in this process (keeps REPL TTY clean) + try: + import logging as _logging + for name in ( + "grpc", + "grpc._cython", + "asyncio", + "hpack", + "h2", + "hypercorn", + "hypercorn.error", + "hypercorn.access", + ): + lg = _logging.getLogger(name) + lg.setLevel(_logging.CRITICAL) + lg.propagate = False + except Exception: + pass await self._start_grpc_server() await self._start_discovery() @@ -201,7 +246,7 @@ async def _start_http_server(self, shutdown_trigger: Any) -> None: config = Config.from_mapping( bind=f"0.0.0.0:{self.http_port}", - log_level="info", + log_level="error", # keep HTTP server quiet on console log_config=None, use_reloader=False, h2c=True, @@ -368,6 +413,72 @@ async def completions(req: CompletionRequestModel): # type: ignore ) return await self._handle_text_completion(req) + # Ingest trace buffers and forward to REPL + @self.app.post("/trace/ingest") + async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: ignore + try: + if self._trace_ingest_cb is not None: + logger.debug("Forwarding trace batch to REPL.") + self._trace_ingest_cb(batch.model_dump()) + + _t_batch = { "run_id": "NONE", "node_id": "API", "events": list(self.tracer._events) } + self._trace_ingest_cb(_t_batch) # FIXME: Move this + self.tracer._events.clear() + + return TraceIngestResponse(ok=True, accepted=len(batch.events)) + + try: + run_dir = Path("logs/trace/ingest") / batch.run_id + logger.debug(f"callback not available. Dumping trace buffer to file {run_dir}") + run_dir.mkdir(parents=True, exist_ok=True) + fpath = run_dir / f"{batch.node_id}.jsonl" + with fpath.open("a", encoding="utf-8") as f: + f.write(batch.model_dump_json() + "\n") + except Exception: + logger.warning(f"Unable to write trace ingest buffer to temp file {fpath}") + return TraceIngestResponse( + ok=True, + accepted=len(batch.events), + message="no aggregator; appended" + ) + except Exception as e: + logger.warning(f"Unable to ingest trace buffer: {e}") + return TraceIngestResponse(ok=False, accepted=0, message=str(e)) + + async def _forward_trace_config(self, cfg: Any) -> bool: + logger.debug("Forwarding Trace config") + shards = self._get_shards_from_discovery() + this = self.discovery.get_own_properties() + api_endpoint = f"http://{this.local_ip}:{this.server_port}/trace/ingest" + payload = TraceConfigRequest( + file=cfg.file, + streaming=cfg.streaming, + include_prefixes=cfg.include_prefixes, + include_c_calls=cfg.include_c_calls, + budget=cfg.budget, + enabled=cfg.enabled, + node_id=None, + record_pid_tid=cfg.record_pid_tid, + aggregate=cfg.aggregate, + aggregate_url=api_endpoint, + agg_max_events=cfg.agg_max_events + ) + + async with httpx.AsyncClient(timeout=5.0) as client: + for name, props in shards.items(): + url = f"http://{props.local_ip}:{props.server_port}/trace" + logger.debug(f"Forwarding trace config to {url}") + payload.node_id = name + try: + res = await client.post(url, json=dict(payload)) + if res.status_code != 200: + logger.error(f"Failed to POST tracer config to {url}.: {res.text}") + except Exception as e: + logger.error(f"Failed to POST tracer config: {e}") + return False + return True + + async def _handle_prepare_topology( self, req: PrepareTopologyRequest ) -> TopologyInfo: @@ -383,6 +494,10 @@ async def _handle_prepare_topology( # Load only config.json to avoid weight downloads on API node cfg = get_model_config_json(req.model) + + if str(cfg.get("model_type", "")).strip().lower() == "gpt_oss" and req.kv_bits != "fp16": + raise ValueError("GPT-OSS models only support kv_bits='fp16'") + num_layers_raw = cfg.get("num_hidden_layers") if not isinstance(num_layers_raw, int): raise ValueError( @@ -412,12 +527,14 @@ async def _handle_prepare_topology( logger.info("Discovered %d shards: %s", len(shards), list(shards.keys())) - shard_profiles, thunderbolt_conns = await self._collect_shard_profiles( + thunderbolt_conns = discover_all_thunderbolt_connections(shards) + shard_profiles = await self._collect_shard_profiles( shards, req.model, embedding_size, req.max_batch_exp, batch_sizes, + thunderbolt_conns ) optimized_device_name_order = optimize_device_ordering( shard_profiles, thunderbolt_conns @@ -970,7 +1087,8 @@ async def _collect_shard_profiles( embedding_size: int, max_batch_exp: int, batch_sizes: List[int], - ) -> Tuple[Dict[str, DeviceProfile], Dict[str, Dict[str, ThunderboltConnection]]]: + thunderbolt_conns: dict[str, dict[str, ThunderboltConnection]] = {}, + ) -> Dict[str, DeviceProfile]: """Collect profile data from all shards. Args: @@ -978,6 +1096,8 @@ async def _collect_shard_profiles( repo_id: Model repository ID embedding_size: Model embedding size max_batch_exp: Maximum batch size exponent + batch_sizes: List of batch sizes to profile + thunderbolt_conns: Pre-discovered thunderbolt connections per shard Returns: Tuple of (collected shard profiles, thunderbolt connections) @@ -986,8 +1106,6 @@ async def _collect_shard_profiles( base_size = embedding_size * 4 # 4*e due to paper payload_sizes = [base_size * batch_size for batch_size in batch_sizes] - this_device = await self.discovery.async_get_own_properties() - logger.info( "Model %s: embedding_size=%d, payload_sizes=%s", repo_id, @@ -995,69 +1113,174 @@ async def _collect_shard_profiles( payload_sizes, ) - # Find Thunderbolt connections - all_thunderbolts = discover_all_thunderbolt_connections(shards) - - # Call each shard's /profile endpoint - # FIXME: do this in parallel - shard_profiles: Dict[str, DeviceProfile] = {} async with httpx.AsyncClient() as client: + # health-check all shards in parallel + logger.info("Starting health checks for all shards...") + health_tasks: list[asyncio._CoroutineLike[httpx.Response]] = [] + shard_list: list[tuple[str, DnetDeviceProperties]] = [] for shard_name, shard_props in shards.items(): if shard_props.is_manager: - logger.warning( - "Skipping manager node %s in profile collection", shard_name - ) + logger.warning("Skipping manager node %s in profile collection", shard_name) continue - server_port, server_ip = shard_props.server_port, shard_props.local_ip + shard_list.append((shard_name, shard_props)) + health_tasks.append(client.get(f"http://{shard_props.local_ip}:{shard_props.server_port}/health", timeout=5.0)) - try: - shard_url = f"http://{server_ip}:{server_port}/profile" - logger.info( - "Calling /profile endpoint for shard %s at %s", - shard_name, - shard_url, - ) + health_results = await asyncio.gather(*health_tasks, return_exceptions=True) - response = await client.post( - shard_url, - json=ShardProfileRequest( - repo_id=repo_id, - thunderbolts=all_thunderbolts.get(shard_name, {}), - payload_sizes=payload_sizes, - max_batch_exp=max_batch_exp, - devices=shards, - ).model_dump(), - timeout=1000.0, - ) + # filter healthy shards + healthy_shards: list[tuple[str, DnetDeviceProperties]] = [] + for (shard_name, shard_props), health_result in zip(shard_list, health_results): + if isinstance(health_result, Exception): + logger.warning("Health check failed for %s: %s", shard_name, health_result) + continue + elif isinstance(health_result, httpx.Response): + if health_result.status_code == 200: + healthy_shards.append((shard_name, shard_props)) + logger.info("Health check passed for %s", shard_name) + else: + logger.warning("Health check failed for %s: status %s", shard_name, health_result.status_code) + else: + pass + + + logger.info("Healthy shards: %d/%d", len(healthy_shards), len(shard_list)) + if not healthy_shards: + logger.error("No healthy shards found!") + return {} + + # measure latencies on all healthy shards in parallel) + logger.info("Measuring latencies for all healthy shards...") + latency_tasks: list[asyncio._CoroutineLike[httpx.Response]] = [] + for shard_name, shard_props in healthy_shards: + server_port, server_ip = shard_props.server_port, shard_props.local_ip + latency_url = f"http://{server_ip}:{server_port}/measure_latency" + latency_request = MeasureLatencyRequest( + devices=shards, + thunderbolts=thunderbolt_conns.get(shard_name, {}), + payload_sizes=payload_sizes, + ) + latency_tasks.append( + client.post(latency_url, json=latency_request.model_dump(), timeout=1000.0) + ) + latency_results = await asyncio.gather(*latency_tasks, return_exceptions=True) - if response.status_code == 200: - profile_data = ShardProfileResponse.model_validate( - response.json() + # store latency data for each shard + shard_latencies: dict[str, LatencyResults] = {} + final_healthy_shards = [] + for (shard_name, shard_props), latency_result in zip(healthy_shards, latency_results): + if isinstance(latency_result, Exception): + logger.warning( + "Latency measurement failed for %s: %s", shard_name, latency_result + ) + continue + elif isinstance(latency_result, httpx.Response): + if latency_result.status_code == 200: + latency_data = MeasureLatencyResponse.model_validate(latency_result.json()) + shard_latencies[shard_name] = latency_data.latency + final_healthy_shards.append((shard_name, shard_props)) + logger.info("Latency measurement succeeded for %s", shard_name) + else: + logger.warning( + "Latency measurement failed for %s: status %s", + shard_name, + latency_result.status_code, ) - profile = load_device_profile_from_dict(profile_data.profile) + else: + pass # unexpected case + + logger.info("Latencies collected from %d shards", len(shard_latencies)) + + if not final_healthy_shards: + logger.error("No shards with successful latency measurements!") + return {} + + # group healthy shards by local_ip (same device), so that we can profile per-device + shards_by_device: Dict[str, List[Tuple[str, DnetDeviceProperties]]] = {} + for shard_name, shard_props in final_healthy_shards: + local_ip = shard_props.local_ip + if local_ip not in shards_by_device: + shards_by_device[local_ip] = [] + shards_by_device[local_ip].append((shard_name, shard_props)) + logger.info("Grouped %d shards into %d devices", len(final_healthy_shards), len(shards_by_device)) + + # profile devices (parallel per device, sequential per shard within device) + async def profile_device_shards( + device_shards: List[Tuple[str, DnetDeviceProperties]] + ) -> List[Tuple[str, DeviceProfile]]: + profiles = [] + + for shard_name, shard_props in device_shards: + try: + profile_url = f"http://{shard_props.local_ip}:{shard_props.server_port}/profile" + logger.info( - "Successfully collected profile from %s", shard_name + "Calling /profile endpoint for shard %s at %s", + shard_name, + profile_url, ) - # Mark head device (same local IP as API) - if shard_props.local_ip == this_device.local_ip: - profile.is_head = True - - # FIXME: DeviceProfileInfo to DeviceProfile should be better - shard_profiles[shard_name] = profile - else: - logger.error( - "Failed to get profile from %s: %s", - shard_name, - response.status_code, + response = await client.post( + profile_url, + json=ShardProfileRequest( + repo_id=repo_id, + thunderbolts=thunderbolt_conns.get(shard_name, {}), + payload_sizes=payload_sizes, + max_batch_exp=max_batch_exp, + devices=shards, + ).model_dump(), + timeout=1000.0, ) - except Exception as e: - logger.exception("Error calling /profile for %s: %s", shard_name, e) + if response.status_code == 200: + profile_data = ShardProfileResponse.model_validate(response.json()) + profile = load_device_profile_from_dict(profile_data.profile) + profiles.append((shard_name, profile)) + logger.info("Successfully collected profile from %s", shard_name) + else: + logger.error( + "Failed to get profile from %s: %s", + shard_name, + response.status_code, + ) + + except Exception as e: + logger.exception("Error calling /profile for %s: %s", shard_name, e) + + return profiles + + # run profiling for all devices in parallel + device_tasks = [ + profile_device_shards(device_shards) + for device_shards in shards_by_device.values() + ] + device_results = await asyncio.gather(*device_tasks, return_exceptions=True) + + # merge latency data into device profiles + shard_profiles: Dict[str, DeviceProfile] = {} + for device_result in device_results: + if isinstance(device_result, Exception): + logger.error("Device profiling failed: %s", device_result) + continue + elif isinstance(device_result, list): + for shard_name, profile in device_result: + # set t_comm using median latency + if shard_name in shard_latencies: + median_latency = calculate_median_latency_seconds(shard_latencies[shard_name]) + if median_latency is not None: + profile.t_comm = float(median_latency) + logger.info( + f"Set t_comm for {shard_name} to median latency: {profile.t_comm:.6f}s" + ) + else: + logger.warning( + f"No valid latency measurements for {shard_name}, keeping default t_comm" + ) + + shard_profiles[shard_name] = profile logger.info("Collected profiles from %d shards", len(shard_profiles)) - return shard_profiles, all_thunderbolts + return shard_profiles # FIXME: move this to elsewhere async def _run_solver( @@ -1084,6 +1307,10 @@ async def _run_solver( if not sorted_shard_profiles: raise ValueError("No valid shard profiles found") + # mark the first device as head, others as non-head + for i, profile in enumerate(sorted_shard_profiles): + profile.is_head = (i == 0) + logger.info("Running solver with %d shard profiles", len(sorted_shard_profiles)) solution = halda_solve( @@ -1182,6 +1409,16 @@ async def _handle_completion( t_start = time.perf_counter() t_first_token = None nonce = f"chatcmpl-{uuid.uuid4()}" + + self.tracer.mark("request.start", { + "tokenizer": "", + "model": req.model, + "temperature": req.temperature, + "prompt_tokens": prompt.size, + "req_id": nonce, + "t0": time.perf_counter(), + }) + detokenizer = self.tokenizer.detokenizer # type: ignore detokenizer.reset() tokens: List[int] = [] @@ -1200,6 +1437,8 @@ async def _handle_completion( ), # type: ignore arange(req.max_tokens or 0), ): + self.tracer.mark("request.round", {"req_id": nonce,"t0": time.time_ns()}) + if profile_enabled and t_first_token is None: t_first_token = time.perf_counter() detokenizer.add_token(token) @@ -1238,6 +1477,12 @@ async def _handle_completion( else detokenizer.text[: -len(stop_sequence_suffix)] ) + self.tracer.mark("request.end", { + "generated_tokens": len(tokens), + "req_id": nonce, + "t0": time.perf_counter(), + }) + # Build optional metrics metrics = None if profile_enabled: @@ -1581,3 +1826,8 @@ async def shutdown(self) -> None: logger.warning("Discovery service was not running") logger.info("API server shutdown complete") + + # REPL helper to install a trace ingestion callback + def set_trace_ingest_callback(self, cb: Optional[Callable[[Dict[str, Any]], None]]) -> None: + logger.debug(f"Registered tracer ingest callback.") + self._trace_ingest_cb = cb diff --git a/src/dnet/ring/api/servicer.py b/src/dnet/ring/api/servicer.py index 31b53d35..a1024822 100644 --- a/src/dnet/ring/api/servicer.py +++ b/src/dnet/ring/api/servicer.py @@ -6,7 +6,8 @@ from ...protos import shard_api_comm_pb2 as pb2 from ...protos import shard_api_comm_pb2_grpc as pb2_grpc -from ...utils.logger import logger +from .api_logging import get_api_logger +logger = get_api_logger() if TYPE_CHECKING: pass diff --git a/src/dnet/ring/api/utils.py b/src/dnet/ring/api/utils.py index ad0fee36..002a3ea9 100644 --- a/src/dnet/ring/api/utils.py +++ b/src/dnet/ring/api/utils.py @@ -16,7 +16,8 @@ from .models import ChatParams - +from .api_logging import get_api_logger +logger = get_api_logger() def create_generate_step_for_ring_with_grpc( stub: DnetRingServiceStub, @@ -263,22 +264,10 @@ def optimize_device_ordering( thunderbolt_conns: Thunderbolt connections mapping (device -> {neighbor -> connection_info}) Returns: - Optimized list of device names with head devices first and Thunderbolt neighbors adjacent + Optimized list of device names """ device_names = list(shard_profiles.keys()) - # Find all head devices (multiple shards can run on same machine as API) - head_devices = [] - for device_name, profile_data in shard_profiles.items(): - if profile_data.is_head: - head_devices.append(device_name) - - if not head_devices: - logger.warning("No head device found in profiles, using first device") - head_devices = [device_names[0]] if device_names else [] - else: - logger.info("Found %d head device(s): %s", len(head_devices), head_devices) - # FIXME: shards on the same machine should be adjacent too! # Build adjacency graph of Thunderbolt connections @@ -292,8 +281,8 @@ def optimize_device_ordering( tb_graph[neighbor_name].add(device_name) # Greedy ordering: Start with all head devices, then pick neighbors with most TB connections - ordered = head_devices.copy() - remaining = set(device_names) - set(head_devices) + ordered: list[str] = [] + remaining: set[str] = set(device_names) while remaining: best_candidate = None diff --git a/src/dnet/ring/api/utils_test.py b/src/dnet/ring/api/utils_test.py index b250e29a..1c54665e 100644 --- a/src/dnet/ring/api/utils_test.py +++ b/src/dnet/ring/api/utils_test.py @@ -54,20 +54,15 @@ def test_single_round_postprocess_complex(): def test_optimize_device_ordering(): - from pydantic import BaseModel - - # fake type for the sake of testing - class _DeviceProfileIsHead(BaseModel): - is_head: bool device_profiles: dict[str, DeviceProfile] = { - "dev1": _DeviceProfileIsHead(is_head=False), # type: ignore - "dev2": _DeviceProfileIsHead(is_head=False), # type: ignore - "dev3": _DeviceProfileIsHead(is_head=True), # type: ignore - "dev4": _DeviceProfileIsHead(is_head=False), # type: ignore - "dev5": _DeviceProfileIsHead(is_head=False), # type: ignore - "dev6": _DeviceProfileIsHead(is_head=False), # type: ignore - "dev7": _DeviceProfileIsHead(is_head=False), # type: ignore + "dev1": {}, # type: ignore + "dev2": {}, # type: ignore + "dev3": {}, # type: ignore + "dev4": {}, # type: ignore + "dev5": {}, # type: ignore + "dev6": {}, # type: ignore + "dev7": {}, # type: ignore } thunderbolts: dict[str, dict[str, ThunderboltConnection]] = { "dev3": {"dev1": 1}, # type: ignore @@ -78,10 +73,6 @@ class _DeviceProfileIsHead(BaseModel): optimized_order = optimize_device_ordering(device_profiles, thunderbolts) - # the ordering is not deterministic, but the connectino should be as follows: - # head must be the first - assert optimized_order[0] == "dev3" - # dev1 and dev3 must be next to each other (due to thunderbolt) dev1_index = optimized_order.index("dev1") dev3_index = optimized_order.index("dev3") diff --git a/src/dnet/ring/data_types.py b/src/dnet/ring/data_types.py index 5da3e65a..ee6f81e6 100644 --- a/src/dnet/ring/data_types.py +++ b/src/dnet/ring/data_types.py @@ -25,7 +25,13 @@ class ActivationMessage: recv_perf_t: float = 0.0 enq_perf_t: float = 0.0 # TX queue enqueue time (perf_counter seconds) + rx_enq_perf_t: float = 0.0 tx_enq_perf_t: float = 0.0 + tx_enq_prev_t: float = 0.0 + rx_ingress_t: float = 0.0 + rx_inflight_t: float = 0.0 + ex_enq_t: float = 0.0 + tx_enq_t: float = 0.0 # Final token path (end-shard sampling) is_final: bool = False token_id: int = -1 diff --git a/src/dnet/ring/model/__init__.py b/src/dnet/ring/model/__init__.py index 4f432a7e..93d31dd8 100644 --- a/src/dnet/ring/model/__init__.py +++ b/src/dnet/ring/model/__init__.py @@ -8,6 +8,11 @@ from .llama import LlamaRingModel from .gpt_oss import GptOssRingModel from .qwen3 import Qwen3RingModel +from .llama3 import Llama3RingModel +#from .llama4 import Llama4RingModel +#from .gpt_oss import GptOssRingModel +#from .glm import GLMRingModel +#from .glm4 import GLM4RingModel def get_ring_model( diff --git a/src/dnet/ring/model/llama3.py b/src/dnet/ring/model/llama3.py new file mode 100644 index 00000000..0ee26280 --- /dev/null +++ b/src/dnet/ring/model/llama3.py @@ -0,0 +1,216 @@ +from typing import Any, Dict, List, Optional, Tuple + +import mlx.nn as nn +import mlx.core as mx +from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.llama import ModelArgs, TransformerBlock + +from .base import BaseRingModel + +import logging +logger = logging.getLogger(__name__) + + +class Llama3RingModel(BaseRingModel): + model_type = "llama" + + def __init__( + self, + model_config: Any, + assigned_layers: Optional[List[int]] = [], + is_api_layer: bool = False, + shard_config: Optional[Any] = None, + ): + super().__init__() + + if is_api_layer and assigned_layers: + raise RuntimeError(f"API Service doesn't handle layers") + + self.config = ModelArgs.from_dict(model_config) + self.config.quantization = model_config["quantization"] # lmao + self.is_api_layer = is_api_layer + + self._converted_to_quantized = False + self.runtime_cache: Optional[List[Any]] = None + + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + self.norm = nn.RMSNorm(self.config.hidden_size, self.config.rms_norm_eps) + + if not self.config.tie_word_embeddings: + self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) + + self.layers: List[nn.Module] = [] + self.abs_to_local: Dict[int, int] = {} + + for i, l in enumerate(sorted(assigned_layers or [])): + self.layers.append(TransformerBlock(self.config)) + self.abs_to_local[l] = i + + logger.debug(f"Created {len(self.layers)} Transformer layers") + #logger.debug(f"abs_to_local mapping: {self.abs_to_local}") + + @property + def decoding_layers(self): + return self.layers + + @property + def head_dim(self) -> Tuple[int, int]: + return self.config.head_dim + + @property + def n_kv_heads(self) -> int: + return self.config.num_key_value_heads + + @property + def num_layers(self) -> int: + return len(self.layers) + + def set_runtime_cache(self, cache: Optional[List[Any]]) -> None: + self._runtime_cache = cache + + def class_predicate(p, m): + return hasattr(m, "to_quantized") + + def embed(self, x: mx.array): + return self.embed_tokens(x) + + def normalize(self, x: mx.array): + return self.norm(x) + + # FIXME: Weird MLX bug, lm_head weights are transposed internally for no reason + def lm_project(self, x: mx.array): + if self.config.tie_word_embeddings: + return self.embed_tokens.as_linear(x) + try: + return self.lm_head(x) + except Exception as e: + return mx.matmul(x, self.lm_head.weight) + + def quantize_layers(self): + self.quantization = None + logger.debug(f"{self.config}") + if hasattr(self.config, "quantization"): + self.quantization = getattr(self.config, "quantization") + elif hasattr(self.config, "quantization_config"): + self.quantization = getattr(self.config, "quantization_config") + + logger.debug(f"QUANTIZING {self.quantization}") + if self.quantization is not None: + bits = int(self.quantization.get("bits", 4)) + group = int(self.quantization.get("group_size", 64)) + try: + from mlx.nn.layers.quantized import QuantizedEmbedding + self.embed_tokens = QuantizedEmbedding(self.config.vocab_size, + self.config.hidden_size, + group_size=group, bits=bits) + + logger.debug(f"API Service initialized to QuantizedEmbedding:" + f"{self.config.vocab_size}, hidden={self.config.hidden_size}" + f"group_size={group}, bits={bits}") + except Exception as e: + logger.warning(f"Unable to initialize QuantizedEmbedding: {e}") + + try: + nn.quantize(self, bits=bits, group_size=group, class_predicate=Llama3RingModel.class_predicate) + logger.debug(f"Quantized the model: bits={bits}, group_size={group}") + self._converted_to_quantized = True + except: + self._converted_to_quantized = False + + def forward( + self, + x: mx.array, + cache: Optional[List[Any]] = None + ): + mask = create_attention_mask(x, cache) + if cache is None: + cache = [None] * len(self.layers) + + for i, l in enumerate(self.layers): + x = l(x, mask, cache[i] if i < len(cache) else None) + + return x + + # TODO: Original implementation is slidin window. Bench to see if it's faster or just do sparse + def apply_single_layer( + self, + layer_idx: int, + x: mx.array, + cache: Optional[List[Any]] = None + ): + if layer_idx not in self.abs_to_local: + raise RuntimeError(f"Attempted execution of foreign layer {layer_idx}") + + mask = None + sqlen = int(x.shape[1]) + if sqlen > 1: + cached = getattr(self, "_cached_mask_len", None) + cached_mask = getattr(self, "_cached_mask", None) + if cached is None or cached != sqlen or not cached_mask: + mask = create_attention_mask(x, cache) + self._cached_mask = mask + self._cached_mask_len = sqlen + else: + mask = cached_mask + + local_idx = self.abs_to_local[layer_idx] + logger.debug(f"apply_single_layer: layer:{layer_idx}, local_idx:{local_idx}, input_shape:{x.shape}") + + layer = self.layers[local_idx] + ret = self.layers[local_idx](x, mask, cache[local_idx] if local_idx < len(cache) else None) + return ret + + def load_weights(self, weights, strict=False): + weight_keys = [k for k, _ in weights] + has_scales = any(".scales" in k for k in weight_keys) + has_biases = any(".biases" in k for k in weight_keys) + + if has_scales and has_biases: + if not self._converted_to_quantized: + self.quantize_layers() + + shard_weights = {} + for k, v in weights: + if k.startswith("model.layers.") or k.startswith("layers."): + p = k.split(".") + idx_pos = 2 if p[0] == "model" else 1 + try: + idx = int(p[idx_pos]) + except Exception as e: + logger.warning(f"Unable to read weight positions: {e}") + continue + if idx not in self.abs_to_local: + continue + local_idx = self.abs_to_local[idx] + p[idx_pos] = str(local_idx) + if p[0] == "model": + p = p[1:] + new_key = ".".join(p) + logger.debug(f"Mapping weight {k} -> {new_key}") + shard_weights[new_key] = v + + elif k.startswith("lm_head"): + shard_weights[k] = v + elif (k.startswith("embed_tokens") or k.startswith("norm")): + shard_weights[k] = v + + if shard_weights: + try: + super().load_weights(list(shard_weights.items()), strict=strict) + logger.debug(f"Loaded {len(shard_weights.keys())} weights into model") + except Exception as e: + logger.error(f"Failed to load weights: {e}") + logger.error(f"Weight keys: {list(shard_weights.keys())}") + raise + + def unload_layers(self, layers: List[int]): + for l in layers: + local = self.abs_to_local[l] + for name, mod in self.layers[local].named_modules(): + if name in ['self_attn', 'mlp']: + for pname in mod.parameters(): + setattr(mod, pname, None) + logger.debug(f"Unloaded {pname}") + elif name in ['input_layernorm', 'post_attention_layernorm']: + mod.weight = None + logger.debug(f"Unloaded {name}") diff --git a/src/dnet/ring/shard/__init__.py b/src/dnet/ring/shard/__init__.py index ddc69862..00b76b81 100644 --- a/src/dnet/ring/shard/__init__.py +++ b/src/dnet/ring/shard/__init__.py @@ -2,5 +2,6 @@ from .node import RingShardNode from .servicer import ShardServicer +from .config import ShardConfig __all__ = ["RingShardNode", "ShardServicer"] diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index 9ca00406..8d55fe7e 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -172,18 +172,22 @@ async def _send_worker(self): ): try: activation_msg = await self.activation_computed_queue.get() - if activation_msg.tx_enq_perf_t and self._profile: - q_wait_ms = ( - time.perf_counter() - activation_msg.tx_enq_perf_t - ) * 1000.0 - logger.info( - "[PROFILE][QUEUE-TX] node=%s nonce=%s wait_ms=%.3f size=%s", - self.node_id, - activation_msg.nonce, - q_wait_ms, - self.activation_computed_queue.qsize(), - ) - await self._send_activation(activation_msg) + with self.tracer.frame("network", "tx") as f: + if activation_msg.tx_enq_perf_t and self._profile: + f.set("inwait", (time.perf_counter() - activation_msg.tx_enq_t)*1000) + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + q_wait_ms = ( + time.perf_counter() - activation_msg.tx_enq_perf_t + ) * 1000.0 + logger.info( + "[PROFILE][QUEUE-TX] node=%s nonce=%s wait_ms=%.3f size=%s", + self.node_id, + activation_msg.nonce, + q_wait_ms, + self.activation_computed_queue.qsize(), + ) + await self._send_activation(activation_msg) except asyncio.CancelledError: break except Exception as e: @@ -226,304 +230,340 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) return try: + logger.debug(f"Sending activation") if activation_msg.is_final: - try: - if self._mode == "offload" and self.window_size > 0: - first_window = self._assigned_sorted[: self.window_size] - if first_window: - loop = asyncio.get_running_loop() - fut = loop.run_in_executor( - self.executor, - self._prepare_window_blocking, - list(first_window), - ) - self._prepared_by_nonce[activation_msg.nonce] = ( - list(first_window), - fut, - ) - except Exception: - pass - cb = activation_msg.callback_url or "" - parsed = urlparse(cb) if cb else None - t_rpc = time.perf_counter() - if parsed and parsed.scheme == "grpc": - addr = parsed.netloc - if not addr: - logger.error("Invalid gRPC callback URL for token: %s", cb) - return - # Ensure API channel/stub - if (self.api_channel is None) or (addr != self.api_address): - # close existing channel if any - try: - if self.api_channel is not None: - await self.api_channel.close() - except Exception: - pass - - self.api_address = addr - self.api_channel = aio_grpc.insecure_channel( - addr, options=GRPC_AIO_OPTIONS - ) - self.api_stub = shard_api_comm_pb2_grpc.ShardApiServiceStub( - self.api_channel - ) + with self.tracer.frame("network", "send_activation.final") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) try: - req = shard_api_comm_pb2.TokenRequest( - nonce=activation_msg.nonce, - token_id=int(getattr(activation_msg, "token_id", -1)), - timestamp=utc_epoch_now(), + if self._mode == "offload" and self.window_size > 0: + first_window = self._assigned_sorted[: self.window_size] + if first_window: + loop = asyncio.get_running_loop() + fut = loop.run_in_executor( + self.executor, + self._prepare_window_blocking, + list(first_window), + ) + self._prepared_by_nonce[activation_msg.nonce] = ( + list(first_window), + fut, + ) + except Exception: + pass + + cb = activation_msg.callback_url or "" + parsed = urlparse(cb) if cb else None + t_rpc = time.perf_counter() + if parsed and parsed.scheme == "grpc": + addr = parsed.netloc + if not addr: + logger.error("Invalid gRPC callback URL for token: %s", cb) + return + + if (self.api_channel is None) or (addr != self.api_address): # Ensure API channel/stub + try: # close existing channel if any + if self.api_channel is not None: + await self.api_channel.close() + except Exception: + pass + + self.api_address = addr + self.api_channel = aio_grpc.insecure_channel( addr, options=GRPC_AIO_OPTIONS) + self.api_stub = shard_api_comm_pb2_grpc.ShardApiServiceStub( self.api_channel) + f.event("reset_api") + + with self.tracer.frame("network", "token_request") as fr: + fr.set("req_id", activation_msg.nonce) + fr.set("node", self._instance_name) + try: + req = shard_api_comm_pb2.TokenRequest( + nonce=activation_msg.nonce, + token_id=int(getattr(activation_msg, "token_id", -1)), + timestamp=utc_epoch_now(), + tx_enq_prev_t=time.perf_counter(), + ) + resp = await self.api_stub.SendToken(req) # type: ignore + rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 + if not resp.success: + logger.error( + "API SendToken failed for %s: %s", + activation_msg.nonce, + resp.message, + ) + if self._profile: + logger.info( + "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", + self.node_id, + activation_msg.nonce, + int(getattr(activation_msg, "token_id", -1)), + rpc_ms, + ) + except Exception as e: + logger.exception("Error sending token via gRPC: %s", e) + else: + logger.error(activation_msg) + logger.error( + "No valid gRPC callback for token; cannot deliver nonce=%s", + activation_msg.nonce, ) - resp = await self.api_stub.SendToken(req) # type: ignore - rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 - if not resp.success: - logger.error( - "API SendToken failed for %s: %s", - activation_msg.nonce, - resp.message, - ) - if self._profile: - logger.info( - "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", - self.node_id, - activation_msg.nonce, - int(getattr(activation_msg, "token_id", -1)), - rpc_ms, - ) - except Exception as e: - logger.exception("Error sending token via gRPC: %s", e) - else: - logger.error( - "No valid gRPC callback for token; cannot deliver nonce=%s", - activation_msg.nonce, - ) - return + return used_pool = False # FIXME: shaped var is a bit weird (is it np_array or mlx_array), @andthattoo shall check shaped = activation_msg.tensor - if shaped is None: - output_buffer = self.output_pool.get_buffer(activation_msg.pool_id) - if output_buffer is None: - logger.error( - "Failed to get output buffer %s", activation_msg.pool_id + with self.tracer.frame("network.send_activations.default", "get_buffer") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + if shaped is None: + output_buffer = self.output_pool.get_buffer(activation_msg.pool_id) + if output_buffer is None: + logger.error("Failed to get output buffer %s", activation_msg.pool_id) + return + + if self._profile: + logger.info( + "[PROFILE][SER-START] node=%s nonce=%s", + self.node_id, + activation_msg.nonce, ) - return - data_size = int(np.prod(activation_msg.shape)) - shaped = output_buffer[:data_size].reshape(activation_msg.shape) - used_pool = True - - if self._profile: - logger.info( - "[PROFILE][SER-START] node=%s nonce=%s", - self.node_id, - activation_msg.nonce, - ) - t_ser = time.perf_counter() - t_cast = t_ser - _len_bytes = int(getattr(shaped, "nbytes", 0)) - _do_compress = bool( - self._compress and _len_bytes >= self._compress_min_bytes - ) - if _do_compress: - # Skip compression for decode. - _do_compress = False - try: - wire_np_dtype = dtype_map[self._wire_dtype_str] - except Exception: - wire_np_dtype = np.float16 # reasonable default fallback - - if isinstance(shaped, np.ndarray): - logger.warning("Activation tensor is a numpy array!!!") - if shaped.dtype != wire_np_dtype: - # FIXME: numpy vs mx array here - shaped = shaped.astype(wire_np_dtype, copy=False) - else: - # MLX array -> cast to desired wire dtype - if str(shaped.dtype) != self._wire_dtype_str: - shaped = shaped.astype(self._wire_mx_dtype) - activation_msg.dtype = self._wire_dtype_str - t_cast = time.perf_counter() - - if isinstance(shaped, np.ndarray): - data = shaped.tobytes(order="C") - else: - data = tensor_to_bytes(shaped) - - ser_ms = (time.perf_counter() - t_ser) * 1000.0 - cast_ms = (t_cast - t_ser) * 1000.0 - - nxt = activation_msg.layer_id + 1 - if (nxt < self.model_metadata.num_layers) and ( - nxt not in self._assigned_set - ): - if self.next_node_stub: - request = activation_msg.to_proto(data) - request.timestamp = utc_epoch_now() - if self._mode == "offload" and self.window_size > 0: - next_window = self._next_local_layers( - activation_msg.layer_id, self.window_size - ) - loop = asyncio.get_running_loop() - if next_window: - fut = loop.run_in_executor( - self.executor, - self._prepare_window_blocking, - list(next_window), - ) - self._prepared_by_nonce[activation_msg.nonce] = ( - list(next_window), - fut, + with self.tracer.frame("network.tx", "cast") as f: + t_ser = time.perf_counter() + t_cast = t_ser + _len_bytes = int(getattr(shaped, "nbytes", 0)) + _do_compress = bool( + self._compress and _len_bytes >= self._compress_min_bytes + ) + if _do_compress: + # Skip compression for decode. + _do_compress = False + try: + wire_np_dtype = dtype_map[self._wire_dtype_str] + except Exception: + wire_np_dtype = np.float16 # reasonable default fallback + + if isinstance(shaped, np.ndarray): + logger.warning("Activation tensor is a numpy array!!!") + if shaped.dtype != wire_np_dtype: + # FIXME: numpy vs mx array here + shaped = shaped.astype(wire_np_dtype, copy=False) + + else: # MLX array -> cast to desired wire dtype + if str(shaped.dtype) != self._wire_dtype_str: + shaped = shaped.astype(self._wire_mx_dtype) + + activation_msg.dtype = self._wire_dtype_str + t_cast = time.perf_counter() + + if isinstance(shaped, np.ndarray): # Cast to target dtype + if shaped.dtype != wire_np_dtype: + shaped = shaped.astype(wire_np_dtype, copy=False) + f.event("ndarray.cast") + data = shaped.tobytes(order="C") + else: + if str(shaped.dtype) != self._wire_dtype_str: # MLX array + shaped = shaped.astype(self._wire_mx_dtype) + f.event("mxarray.cast") + data = tensor_to_bytes(shaped) + + activation_msg.dtype = self._wire_dtype_str + + with self.tracer.frame("memory", "prepare.window") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + + nxt = activation_msg.layer_id + 1 + if (nxt < self.model_metadata.num_layers) and (nxt not in self._assigned_set): + if self.next_node_stub: + request = activation_msg.to_proto(data) + request.timestamp = utc_epoch_now() + if self._mode == "offload" and self.window_size > 0: + next_window = self._next_local_layers( + activation_msg.layer_id, self.window_size ) - else: - first_window = self._assigned_sorted[: self.window_size] - if first_window: + loop = asyncio.get_running_loop() + if next_window: fut = loop.run_in_executor( self.executor, self._prepare_window_blocking, - list(first_window), + list(next_window), ) self._prepared_by_nonce[activation_msg.nonce] = ( - list(first_window), + list(next_window), fut, ) - stream_used = False - ctx = await self._ensure_stream(activation_msg.nonce) - if ( - ctx - and ctx.open - and not ctx.disabled - and hasattr(dnet_ring_pb2, "ActivationFrame") - ): - try: - ctx.last_seq += 1 - frame = dnet_ring_pb2.ActivationFrame( - request=request, - seq=ctx.last_seq, - end_of_request=False, + else: + first_window = self._assigned_sorted[: self.window_size] + if first_window: + fut = loop.run_in_executor( + self.executor, + self._prepare_window_blocking, + list(first_window), + ) + self._prepared_by_nonce[activation_msg.nonce] = ( + list(first_window), + fut, + ) + stream_used = False + ctx = await self._ensure_stream(activation_msg.nonce) + if ( + ctx + and ctx.open + and not ctx.disabled + and hasattr(dnet_ring_pb2, "ActivationFrame") + ): + try: + ctx.last_seq += 1 + frame = dnet_ring_pb2.ActivationFrame( + request=request, + seq=ctx.last_seq, + end_of_request=False, + ) + await ctx.queue.put(frame) + ctx.last_activity_t = asyncio.get_running_loop().time() + stream_used = True + if self._profile: + logger.info( + "[PROFILE][STREAM-ENQ] nonce=%s seq=%s q=%s", + activation_msg.nonce, + ctx.last_seq, + ctx.queue.qsize(), + ) + except Exception as e: + logger.warning( + "[STREAM] enqueue failed; fallback to unary: %s", e + ) + ctx.disabled = True + + request.tx_enq_prev_t = time.perf_counter() + + # Prefer streaming if enabled/available; fallback to unary + stream_used = False + ctx = await self._ensure_stream(activation_msg.nonce) + if (ctx and ctx.open and not ctx.disabled and hasattr(dnet_ring_pb2, "ActivationFrame")): + logger.debug(f"Sending activation with stream") + try: + ctx.last_seq += 1 + frame = dnet_ring_pb2.ActivationFrame( + request=request, + seq=ctx.last_seq, + end_of_request=False, + ) + await ctx.queue.put(frame) + ctx.last_activity_t = asyncio.get_running_loop().time() + stream_used = True + except Exception as e: + logger.warning("[STREAM] enqueue failed; fallback to unary: %s", e) + ctx.disabled = True + + if not stream_used: + # In fit mode, avoid long unary stalls: use short deadline and min retries + # Streaming should be the norm; unary is a quick safety valve only. + ring_timeout = 3.0 if self._mode == "fit" else 30.0 + ring_retries = ( + 1 + if self._mode == "fit" + else max(1, int(self._send_retries)) ) - await ctx.queue.put(frame) - ctx.last_activity_t = asyncio.get_running_loop().time() - stream_used = True + # Emit a clear fallback log with reason/context if self._profile: - logger.info( - "[PROFILE][STREAM-ENQ] nonce=%s seq=%s q=%s", + if ctx is None: + reason = "no_stream_ctx" + elif not ctx.open: + reason = "stream_closed" + elif ctx.disabled: + reason = "stream_disabled" + else: + reason = "enqueue_failed" + logger.warning( + "[STREAM->UNARY] node=%s nonce=%s reason=%s mode=%s timeout_s=%.1f retries=%d", + self.node_id, activation_msg.nonce, - ctx.last_seq, - ctx.queue.qsize(), + reason, + self._mode, + ring_timeout, + ring_retries, ) - except Exception as e: - logger.warning( - "[STREAM] enqueue failed; fallback to unary: %s", e - ) - ctx.disabled = True - - if not stream_used: - # In fit mode, avoid long unary stalls: use short deadline and min retries - # Streaming should be the norm; unary is a quick safety valve only. - ring_timeout = 3.0 if self._mode == "fit" else 30.0 - ring_retries = ( - 1 - if self._mode == "fit" - else max(1, int(self._send_retries)) - ) - # Emit a clear fallback log with reason/context - if self._profile: - if ctx is None: - reason = "no_stream_ctx" - elif not ctx.open: - reason = "stream_closed" - elif ctx.disabled: - reason = "stream_disabled" + t0 = time.perf_counter() + last_exc: Optional[Exception] = None + for attempt in range(1, ring_retries + 1): + try: + # FIXME: use response here? + _ = await self.next_node_stub.SendActivation( + request, timeout=ring_timeout + ) # type: ignore + break + except grpc.aio.AioRpcError as e: # type: ignore + last_exc = e + code = e.code() + if code in { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.CANCELLED, + grpc.StatusCode.DEADLINE_EXCEEDED, + }: + logger.warning( + "SendActivation attempt %s/%s failed (%s); reconnecting...", + attempt, + ring_retries, + code.name, + ) + await self._reconnect_next_node() + await asyncio.sleep(min(0.25 * attempt, 1.0)) + continue + raise else: - reason = "enqueue_failed" - logger.warning( - "[STREAM->UNARY] node=%s nonce=%s reason=%s mode=%s timeout_s=%.1f retries=%d", + raise last_exc # type: ignore + rpc_ms = (time.perf_counter() - t0) * 1000.0 + logger.info( + "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", self.node_id, activation_msg.nonce, - reason, - self._mode, - ring_timeout, - ring_retries, + activation_msg.layer_id + 1, + (len(data) / 1024), + ser_ms, + rpc_ms, + cast_ms, ) - t0 = time.perf_counter() - last_exc: Optional[Exception] = None - for attempt in range(1, ring_retries + 1): - try: - # FIXME: use response here? - _ = await self.next_node_stub.SendActivation( - request, timeout=ring_timeout - ) # type: ignore - break - except grpc.aio.AioRpcError as e: # type: ignore - last_exc = e - code = e.code() - if code in { - grpc.StatusCode.UNAVAILABLE, - grpc.StatusCode.CANCELLED, - grpc.StatusCode.DEADLINE_EXCEEDED, - }: - logger.warning( - "SendActivation attempt %s/%s failed (%s); reconnecting...", - attempt, - ring_retries, - code.name, - ) - await self._reconnect_next_node() - await asyncio.sleep(min(0.25 * attempt, 1.0)) - continue - raise - else: - raise last_exc # type: ignore - rpc_ms = (time.perf_counter() - t0) * 1000.0 - logger.info( - "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", - self.node_id, - activation_msg.nonce, - activation_msg.layer_id + 1, - (len(data) / 1024), - ser_ms, - rpc_ms, - cast_ms, - ) + else: + logger.error("Cannot forward activation - no next node configured; end shard should sample inline.") + + # Final layer not annotated with 'is_final' else: logger.error( - "Cannot forward activation - no next node configured; end shard should sample inline." + "Final activation reached send path unexpectedly; sampling should occur on end shard." ) - else: - logger.error( - "Final activation reached send path unexpectedly; sampling should occur on end shard." - ) + # Clear scheduling at request end + # Sequential offload: prefetch state is unused - # Clear scheduling at request end - # Sequential offload: prefetch state is unused - - # Optional: explicitly end the per-nonce stream on request completion - # Enable by setting RING_EXPLICIT_EOR=1 when you emit a true end-of-request signal. - try: - if self._explicit_eor: - if ( - hasattr(self, "_streams") - and activation_msg.nonce in self._streams - ): - await self._end_stream(activation_msg.nonce, eor=True) - except Exception: - pass - - # Release resources at end of send - try: - activation_msg.tensor = None - except Exception: - pass - if used_pool: + # Optional: explicitly end the per-nonce stream on request completion + # Enable by setting RING_EXPLICIT_EOR=1 when you emit a true end-of-request signal. + try: + if self._explicit_eor: + if ( + hasattr(self, "_streams") + and activation_msg.nonce in self._streams + ): + await self._end_stream(activation_msg.nonce, eor=True) + except Exception: + pass + + # Release resources at end of send try: - self.output_pool.release(activation_msg.pool_id) + activation_msg.tensor = None except Exception: pass + if used_pool: + try: + self.output_pool.release(activation_msg.pool_id) + except Exception: + pass except Exception as e: logger.exception("Error sending activation: %s", e) + + async def _connect_next_node(self) -> bool: """Connect to next node in ring. diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 2c5626cd..c92bee94 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -65,48 +65,61 @@ def _delta_swap_eviction( return len(evicted) def _process_activation(self, activation_msg: ActivationMessage): - if ( - not self.model + if (not self.model or not self.model_metadata or not self.weight_cache or not self.input_pool or not self.output_pool ): - logger.error( - "Node %s: Cannot process activation - model not loaded", self.node_id - ) + logger.error("Node %s: Cannot process activation - model not loaded", self.node_id) return + logger.error(f"PROCESS_ACTIVATION: {activation_msg.callback_url}") try: # per-nonce kvcache for concurrent requests - kv = self._get_or_make_kv(activation_msg.nonce) + with self.tracer.frame("compute.thread", "kvcache.init") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + kv = self._get_or_make_kv(activation_msg.nonce) # Get input activation from pool - input_buffer = self.input_pool.get_buffer(activation_msg.pool_id) - if input_buffer is None: - logger.error("Failed to get input buffer %s", activation_msg.pool_id) - return + with self.tracer.frame("compute.thread", "activations.load") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + input_buffer = self.input_pool.get_buffer(activation_msg.pool_id) + if input_buffer is None: + logger.error("Failed to get input buffer %s", activation_msg.pool_id) + return # Prepare input activation - if activation_msg.dtype == "tokens": - # tokens were staged as int32 in the pool; embed locally on start shard - input_size = int(np.prod(activation_msg.shape)) - tok_view = input_buffer[:input_size].reshape(activation_msg.shape) - # Convert robustly to MLX int32 and embed (batch=1) - toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) - x = self.model.embed(toks[None]) - if x.dtype != self._wire_mx_dtype: - x = x.astype(self._wire_mx_dtype) - else: - # Prepare input activation using MLX view operations only - input_size = int(np.prod(activation_msg.shape)) - x = input_buffer[:input_size].reshape(activation_msg.shape) - # Ensure expected dtype without re-materializing when not needed - try: - if str(x.dtype) != activation_msg.dtype: - x = x.astype(mlx_dtype_map[activation_msg.dtype]) - except Exception: - pass + with self.tracer.frame("compute.thread", "activations.process") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + + if activation_msg.dtype == "tokens": # embed locally on start shard + numel = int(np.prod(activation_msg.shape)) + tok_view = input_buffer[:numel].reshape(activation_msg.shape) + toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) + x = self.model.embed(toks[None]) + + self.tracer.mark("embedding", { # NOTE: Used to track start of request in perf stats + "nonce": activation_msg.nonce, + "prompt_tokens": toks.size, + }) + + if x.dtype != self._wire_mx_dtype: + x = x.astype(self._wire_mx_dtype) + + else: # Prepare input activation using MLX view operations only + f.set("activation_dtype", activation_msg.dtype) + numel = int(np.prod(activation_msg.shape)) + x = input_buffer[:numel].reshape(activation_msg.shape) + + try: # Ensure expected dtype + if str(x.dtype) != activation_msg.dtype: + x = x.astype(mlx_dtype_map[activation_msg.dtype]) + except Exception: + logger.warning(f"Unable to update activation dtype") # Compute windows until boundary (stay local as long as possible) current_layer = activation_msg.layer_id + 1 @@ -116,221 +129,192 @@ def _process_activation(self, activation_msg: ActivationMessage): processed = 0 did_early_swap = False - # Determine contiguous local window starting at current_layer - window_layers: List[int] = [] - _tmp_layer = current_layer - while processed < self.window_size and ( - _tmp_layer in self._assigned_set - ): - window_layers.append(_tmp_layer) - _tmp_layer += 1 - processed += 1 - - if self._mode == "offload" and window_layers: - prep = self._prepared_by_nonce.get(activation_msg.nonce) - if prep is not None: - layers, fut = prep - if layers == window_layers and fut is not None: - try: - fut.result(timeout=30) - except Exception: - pass + with self.tracer.frame("compute.thread", "weights.prepare") as fr: + fr.set("req_id", activation_msg.nonce) + fr.set("node", self._instance_name) - # In sliding_fit with a single resident window, proactively evict only the - # non-needed head from the current resident set before loading new weights. - # This prevents LRU from evicting the useful tail during materialization. - if ( - self._mode == "sliding_fit" - and int(self._resident_windows) <= 1 - and window_layers - ): - try: - resident = [] + # Determine contiguous local window starting at current_layer + window_layers: List[int] = [] + _tmp_layer = current_layer + while processed < self.window_size and (_tmp_layer in self._assigned_set): + window_layers.append(_tmp_layer) + _tmp_layer += 1 + processed += 1 + + if self._mode == "offload" and window_layers: + prep = self._prepared_by_nonce.get(activation_msg.nonce) + if prep is not None: + layers, fut = prep + if layers == window_layers and fut is not None: + try: + fut.result(timeout=30) + except Exception: + pass + + # In sliding_fit with a single resident window, proactively evict only the + # non-needed head from the current resident set before loading new weights. + # This prevents LRU from evicting the useful tail during materialization. + if ( + self._mode == "sliding_fit" + and int(self._resident_windows) <= 1 + and window_layers + ): try: - resident = self.weight_cache.get_resident_layers() # type: ignore[union-attr] - except Exception: resident = [] - evicted_cnt = self._delta_swap_eviction( - window_layers, resident, activation_msg, early=True - ) - if evicted_cnt > 0: - did_early_swap = True - except Exception: - pass + try: + resident = self.weight_cache.get_resident_layers() # type: ignore[union-attr] + except Exception: + resident = [] + evicted_cnt = self._delta_swap_eviction( + window_layers, resident, activation_msg, early=True + ) + if evicted_cnt > 0: + did_early_swap = True + except Exception: + pass - # Ensure weights for the window are resident and bind only if arrays changed - # if model fits and we've already bound these layers, skip the scan entirely. - fast_fit = ( - self._mode == "fit" - and len(self._assigned_sorted) <= self.window_size - ) - skip_scan = fast_fit and all( - (wl in self._bound_versions) for wl in window_layers - ) - to_bind: Dict[str, mx.array] = {} - if not skip_scan: - t_w_ready = time.perf_counter() - for wl in window_layers: - weights = self.weight_cache.get_weight(wl) - if weights is None: - logger.error("Failed to load weights for layer %s", wl) - self.input_pool.release(activation_msg.pool_id) - return - try: - # Use identity of first array as a cheap version/fingerprint - first_arr = next(iter(weights.values())) - version = id(first_arr) - except StopIteration: - version = -1 - if self._bound_versions.get(wl) != version: - for k, v in weights.items(): - to_bind[k] = v - self._bound_versions[wl] = version - if self._profile: - t_w_ms = (time.perf_counter() - t_w_ready) * 1000.0 - # Only log when non-trivial or binding happened to reduce overhead/noise - if to_bind or t_w_ms > 0.5: + # Ensure weights for the window are resident and bind only if arrays changed + # if model fits and we've already bound these layers, skip the scan entirely. + fast_fit = (self._mode == "fit" and len(self._assigned_sorted) <= self.window_size) + skip_scan = fast_fit and all( (wl in self._bound_versions) for wl in window_layers) + + to_bind: Dict[str, mx.array] = {} + if not skip_scan: + t_w_ready = time.perf_counter() + for wl in window_layers: + weights = self.weight_cache.get_weight(wl) + if weights is None: + logger.error("Failed to load weights for layer %s", wl) + self.input_pool.release(activation_msg.pool_id) + return + try: # Use identity of first array as a cheap version/fingerprint + first_arr = next(iter(weights.values())) + version = id(first_arr) + except StopIteration: + version = -1 + if self._bound_versions.get(wl) != version: + for k, v in weights.items(): + to_bind[k] = v + self._bound_versions[wl] = version + if self._profile: + t_w_ms = (time.perf_counter() - t_w_ready) * 1000.0 + # Only log when non-trivial or binding happened to reduce overhead/noise + if to_bind or t_w_ms > 0.5: + logger.info( + "[PROFILE][WAIT-WEIGHTS] node=%s nonce=%s layers=%s ms=%.3f", + self.node_id, + activation_msg.nonce, + window_layers, + t_w_ms, + ) + + # Execute the window + with self.tracer.frame("compute.thread", "execute") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + + if to_bind: # Block prefetch-touch during binding and serialize MLX ops + self._compute_busy.set() + t_bind = time.perf_counter() + with self._mlx_lock: + self.model.load_weights(list(to_bind.items()), strict=False) + bind_ms = (time.perf_counter() - t_bind) * 1000.0 + if self._profile: logger.info( - "[PROFILE][WAIT-WEIGHTS] node=%s nonce=%s layers=%s ms=%.3f", + "[PROFILE][BIND] node=%s nonce=%s layers=%s tensors=%s bind_ms=%.3f", self.node_id, activation_msg.nonce, window_layers, - t_w_ms, + len(to_bind), + bind_ms, ) + t_comp = time.perf_counter() + self._compute_busy.set() + for i, lyr in enumerate(window_layers): + with self._mlx_lock: + x = self.model.apply_single_layer(lyr, x, cache=kv) + try: + if str(x.dtype) != str(self._wire_mx_dtype): + x = x.astype(self._wire_mx_dtype) + except Exception: + pass + + last_layer = ( window_layers[-1] if window_layers else activation_msg.layer_id) + mx.eval(x) - bind_ms = 0.0 - if to_bind: - # Block prefetch-touch during binding and serialize MLX ops - try: - self._compute_busy.set() - except Exception: - pass - t_bind = time.perf_counter() - with self._mlx_lock: - self.model.load_weights(list(to_bind.items()), strict=False) - bind_ms = (time.perf_counter() - t_bind) * 1000.0 if self._profile: + t_comp_done = time.perf_counter() logger.info( - "[PROFILE][BIND] node=%s nonce=%s layers=%s tensors=%s bind_ms=%.3f", + "[PROFILE][WINDOW] node=%s nonce=%s layers=%s compute_ms=%.3f", self.node_id, activation_msg.nonce, window_layers, - len(to_bind), - bind_ms, + (t_comp_done - t_comp) * 1000.0, ) - t_comp = time.perf_counter() - try: - self._compute_busy.set() - except Exception: - pass - for i, lyr in enumerate(window_layers): - with self._mlx_lock: - x = self.model.apply_single_layer(lyr, x, cache=kv) - try: - if str(x.dtype) != str(self._wire_mx_dtype): - x = x.astype(self._wire_mx_dtype) - except Exception: - pass - last_layer = ( - window_layers[-1] if window_layers else activation_msg.layer_id - ) - try: - mx.eval(x) - except Exception: - pass - if self._profile: - t_comp_done = time.perf_counter() - logger.info( - "[PROFILE][WINDOW] node=%s nonce=%s layers=%s compute_ms=%.3f", - self.node_id, - activation_msg.nonce, - window_layers, - (t_comp_done - t_comp) * 1000.0, - ) - for lid in window_layers: - self.weight_cache.decrease_reference(lid) + with self.tracer.frame("compute.thread", "execute.evict_and_unload") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) - try: - # Sliding-fit delta swap: maintain a single resident set by evicting - # only what's needed to fit the next window into the budget. Prefer - # keeping the tail of the previous window so we don't thrash weights - # that are likely to be reused. - if self._mode == "sliding_fit": - if int(self._resident_windows) <= 1: - if did_early_swap: - # Early delta-swap already trimmed resident set for this window - pass - elif not self._recent_windows: - # First window in token: seed resident set - self._recent_windows.append(list(window_layers)) - else: - prev = self._recent_windows.pop(0) - self._delta_swap_eviction( - window_layers, prev, activation_msg, early=False - ) - budget = max(1, int(self.window_size or 1)) - curr = list(window_layers) - prev_only = [x for x in prev if x not in curr] - keep_quota = max(0, budget - len(curr)) - keep_tail = ( - prev_only[-keep_quota:] if keep_quota > 0 else [] - ) - combined = list(keep_tail) + curr - self._recent_windows.append(combined) - else: - # resident_windows>1 not expected in sliding_fit; fall back to seeding + for lid in window_layers: + self.weight_cache.decrease_reference(lid) + + try: # Sliding-fit delta swap: maintain a single resident set by evicting + # only what's needed to fit the next window into the budget. Prefer + # keeping the tail of the previous window so we don't thrash weights + # that are likely to be reused. + if self._mode == "sliding_fit": + if int(self._resident_windows) <= 1: + if did_early_swap: # Early delta-swap already trimmed resident set for this window + pass + elif not self._recent_windows: # First window in token: seed resident set + self._recent_windows.append(list(window_layers)) + else: + prev = self._recent_windows.pop(0) + self._delta_swap_eviction( + window_layers, prev, activation_msg, early=False + ) + budget = max(1, int(self.window_size or 1)) + curr = list(window_layers) + prev_only = [x for x in prev if x not in curr] + keep_quota = max(0, budget - len(curr)) + keep_tail = ( + prev_only[-keep_quota:] if keep_quota > 0 else [] + ) + combined = list(keep_tail) + curr + self._recent_windows.append(combined) + + else: # Original eviction policy for other modes self._recent_windows.append(list(window_layers)) - else: - # Original eviction policy for other modes - self._recent_windows.append(list(window_layers)) - if int(self._resident_windows) <= 1: - old = self._recent_windows.pop(0) - try: - evicted_cnt = self.weight_cache.evict_layers(old) - except Exception: - evicted_cnt = 0 - try: + if int(self._resident_windows) <= 1: + old = self._recent_windows.pop(0) + try: + evicted_cnt = self.weight_cache.evict_layers(old) + except Exception: + evicted_cnt = 0 + if hasattr(self.model, "unload_layers"): self.model.unload_layers(old) # type: ignore[attr-defined] for lid in old: self._bound_versions.pop(lid, None) - except Exception: - pass - if self._profile: - try: - logger.info( - "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s", - self.node_id, - activation_msg.nonce, - old, - evicted_cnt, - self._resident_windows, - ) - except Exception: - pass - else: - if not self._defer_unload: - while len(self._recent_windows) > max( - 1, int(self._resident_windows) - ): - old = self._recent_windows.pop(0) - try: - evicted_cnt = self.weight_cache.evict_layers( - old - ) - except Exception: - evicted_cnt = 0 - try: + else: + if not self._defer_unload: + while len(self._recent_windows) > max( + 1, int(self._resident_windows) + ): + old = self._recent_windows.pop(0) + try: + evicted_cnt = self.weight_cache.evict_layers( + old + ) + except Exception: + evicted_cnt = 0 if hasattr(self.model, "unload_layers"): self.model.unload_layers(old) # type: ignore[attr-defined] for lid in old: self._bound_versions.pop(lid, None) - except Exception: - pass - if self._profile: - try: + if self._profile: logger.info( "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s", self.node_id, @@ -339,48 +323,42 @@ def _process_activation(self, activation_msg: ActivationMessage): evicted_cnt, self._resident_windows, ) - except Exception: - pass - except Exception: - pass + except Exception: + pass - computation_time = time.perf_counter() - start_time - self._prof.info( - "[PROFILE][COMPUTE] node=%s nonce=%s window_layers=%s total_ms=%.3f", - self.node_id, - activation_msg.nonce, - window_layers, - computation_time * 1000.0, - ) - self._prof.info( - "Completed layers up to %s in %.3fs, nonce: %s, result: %s %s", - last_layer, - computation_time, - activation_msg.nonce, - x.shape, - x.dtype, - ) - - # If next layer is still local, continue without staging/tx - nxt = last_layer + 1 - if nxt in self._assigned_set: - current_layer = nxt - continue + computation_time = time.perf_counter() - start_time + self._prof.info( + "[PROFILE][COMPUTE] node=%s nonce=%s window_layers=%s total_ms=%.3f", + self.node_id, + activation_msg.nonce, + window_layers, + computation_time * 1000.0, + ) + self._prof.info( + "Completed layers up to %s in %.3fs, nonce: %s, result: %s %s", + last_layer, + computation_time, + activation_msg.nonce, + x.shape, + x.dtype, + ) + + # If next layer is still local, continue without staging/tx + nxt = last_layer + 1 + if nxt in self._assigned_set: + current_layer = nxt + continue # Boundary reached — directly pass tensor to TX to avoid extra copy/sync - t_stage = time.perf_counter() - x_cast = ( - x - if x.dtype == self._wire_mx_dtype - else x.astype(self._wire_mx_dtype) - ) - try: + with self.tracer.frame("compute.thread", "staging") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + + t_stage = time.perf_counter() + x_cast = ( x if x.dtype == self._wire_mx_dtype else x.astype(self._wire_mx_dtype)) self._compute_busy.clear() - except Exception: - pass - if self._profile: - try: + if self._profile: logger.info( "[PROFILE][STAGE-DIRECT] node=%s nonce=%s layer_tail=%s stage_ms=%.3f shape=%s dtype=%s", self.node_id, @@ -390,125 +368,123 @@ def _process_activation(self, activation_msg: ActivationMessage): tuple(x_cast.shape), str(self._wire_mx_dtype), ) - except Exception: - pass - nxt = last_layer + 1 - if nxt >= self.model_metadata.num_layers: # End of model - try: - with self._mlx_lock: - y = self.model.normalize(x_cast) - y = self.model.lm_project(y) - # Greedy sample last position - if y.ndim == 3: - logits_2d = y[:, -1, :] - elif y.ndim == 2: - logits_2d = y[-1:, :] - else: - logits_2d = y.reshape(1, -1) - tok = mx.argmax(logits_2d, axis=-1) - token_id = int(tok.item()) - except Exception as e: - logger.error("End-shard sampling failed: %s", e) - return - output_msg = ActivationMessage( - nonce=activation_msg.nonce, - layer_id=last_layer, - pool_id=-1, - shape=cast(tuple[int, ...], x.shape), - batch_size=activation_msg.batch_size, - timestamp=utc_epoch_now(), - node_origin=f"node_{self.node_id}", - dtype=str(self._wire_mx_dtype), - callback_url=activation_msg.callback_url, - is_final=True, - token_id=token_id, - ) - else: - output_msg = ActivationMessage( - nonce=activation_msg.nonce, - layer_id=last_layer, - pool_id=-1, - shape=cast(tuple[int, ...], x.shape), - batch_size=activation_msg.batch_size, - timestamp=utc_epoch_now(), - node_origin=f"node_{self.node_id}", - dtype=str(self._wire_mx_dtype), - callback_url=activation_msg.callback_url, - tensor=x_cast, - ) - try: - output_msg.tx_enq_perf_t = time.perf_counter() - except Exception: - output_msg.tx_enq_perf_t = 0.0 - # Enqueue to appropriate asyncio TX queue from compute thread - try: - if self._loop is not None: - target_q = ( - self.activation_token_queue - if output_msg.is_final - else self.activation_computed_queue - ) - fut = asyncio.run_coroutine_threadsafe( - target_q.put(output_msg), self._loop + if nxt >= self.model_metadata.num_layers: # End of model + with self.tracer.frame("compute.thread", "sampling") as fr: + try: + with self._mlx_lock: + y = self.model.normalize(x_cast) + y = self.model.lm_project(y) + self.tracer.mark("lm_head", {"nonce": activation_msg.nonce}) # NOTE: canonical stats end + + # Greedy sample last position + if y.ndim == 3: + logits_2d = y[:, -1, :] + elif y.ndim == 2: + logits_2d = y[-1:, :] + else: + logits_2d = y.reshape(1, -1) + tok = mx.argmax(logits_2d, axis=-1) + token_id = int(tok.item()) + except Exception as e: + logger.error("End-shard sampling failed: %s", e) + return + + output_msg = ActivationMessage( + nonce=activation_msg.nonce, + layer_id=last_layer, + pool_id=-1, + shape=cast(tuple[int, ...], x.shape), + batch_size=activation_msg.batch_size, + timestamp=utc_epoch_now(), + node_origin=f"node_{self.node_id}", + dtype=str(self._wire_mx_dtype), + callback_url=activation_msg.callback_url, + is_final=True, + token_id=token_id, ) - fut.result() else: - raise RuntimeError("Event loop not available for TX queue") - except Exception as e: - logger.error( - "Failed to queue computed activation for sending: %s", e - ) + output_msg = ActivationMessage( + nonce=activation_msg.nonce, + layer_id=last_layer, + pool_id=-1, + shape=cast(tuple[int, ...], x.shape), + batch_size=activation_msg.batch_size, + timestamp=utc_epoch_now(), + node_origin=f"node_{self.node_id}", + dtype=str(self._wire_mx_dtype), + callback_url=activation_msg.callback_url, + tensor=x_cast, + ) + + # Clean up input resources + self.input_pool.release(activation_msg.pool_id) - # Clean up input resources - self.input_pool.release(activation_msg.pool_id) + try: + output_msg.tx_enq_perf_t = time.perf_counter() + except Exception: + output_msg.tx_enq_perf_t = 0.0 - # Optional unload/evict after stage - if self._mode != "sliding_fit": - if self._defer_unload: + # Enqueue to appropriate asyncio TX queue from compute thread + with self.tracer.frame("network.tx", "enque") as fr: + output_msg.tx_enq_t = time.perf_counter() try: - while len(self._recent_windows) > max( - 1, int(self._resident_windows) - ): - old = self._recent_windows.pop(0) - try: - evicted_cnt = self.weight_cache.evict_layers(old) + if self._loop is not None: + target_q = ( + self.activation_token_queue + if output_msg.is_final + else self.activation_computed_queue + ) + fut = asyncio.run_coroutine_threadsafe( + target_q.put(output_msg), self._loop + ) + fut.result() + else: + raise RuntimeError("Event loop not available for TX queue") + except Exception as e: + logger.error( + "Failed to queue computed activation for sending: %s", e + ) + + # Clean up input resources + with self.tracer.frame("compute.thread", "cleanup") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + if self._mode != "sliding_fit": + if self._defer_unload: + try: + while len(self._recent_windows) > max(1, int(self._resident_windows)): + old = self._recent_windows.pop(0) + try: + evicted_cnt = self.weight_cache.evict_layers(old) + except Exception: + evicted_cnt = 0 + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) + except Exception: + pass except Exception: - evicted_cnt = 0 + pass + + if self._resident_windows <= 1: try: + evicted = self.weight_cache.evict_layers(window_layers) if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) + self.model.unload_layers(window_layers) # type: ignore[attr-defined] + if self._profile: + logger.info( + "[PROFILE][EVICT] node=%s nonce=%s layers=%s evicted=%s", + self.node_id, + activation_msg.nonce, + window_layers, + evicted, + ) except Exception: pass - if self._profile: - logger.info( - "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s (post-stage)", - self.node_id, - activation_msg.nonce, - old, - evicted_cnt, - self._resident_windows, - ) - except Exception: - pass - - if self._resident_windows <= 1: - try: - evicted = self.weight_cache.evict_layers(window_layers) - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(window_layers) # type: ignore[attr-defined] - if self._profile: - logger.info( - "[PROFILE][EVICT] node=%s nonce=%s layers=%s evicted=%s", - self.node_id, - activation_msg.nonce, - window_layers, - evicted, - ) - except Exception: - pass return except Exception as e: logger.exception("Error processing activation: %s", e) + diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index cbd9610e..a44ad608 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -1,6 +1,6 @@ """Shard models for dnet ring topology endpoints.""" -from typing import Any, Dict, List, Optional, Literal +from typing import Any, Dict, List, Optional, Literal, Tuple from pydantic import BaseModel, Field from dnet_p2p import DnetDeviceProperties, ThunderboltConnection @@ -53,6 +53,7 @@ class ShardUnloadModelResponse(BaseModel): class ShardProfileRequest(BaseModel): """Request to profile device and measure latencies.""" + #api_address: Optional[str] = Field( ..., description="API Address" ) devices: Dict[str, DnetDeviceProperties] = Field( ..., description="Device information mapping" ) @@ -69,9 +70,28 @@ class ShardProfileRequest(BaseModel): class ShardProfileResponse(BaseModel): - """Response from device profiling and latency measurement.""" + """Response from device profiling.""" profile: Dict[str, Any] = Field(..., description="Device profile information") + + +class MeasureLatencyRequest(BaseModel): + """Request to measure latency to other devices.""" + + devices: Dict[str, DnetDeviceProperties] = Field( + ..., description="Device information mapping" + ) + thunderbolts: Dict[str, ThunderboltConnection] = Field( + default={}, description="Thunderbolt connection information from this device" + ) + payload_sizes: List[int] = Field( + default=[1024], description="Payload sizes to test for latency measurement" + ) + + +class MeasureLatencyResponse(BaseModel): + """Response from latency measurement.""" + latency: LatencyResults = Field(..., description="Latency measurement results") @@ -90,3 +110,45 @@ class HealthResponse(BaseModel): grpc_port: int = Field(..., description="gRPC server port") http_port: int = Field(..., description="HTTP server port") instance: Optional[str] = Field(default=None, description="Shard name") + + +# Tracer + +class TraceConfigRequest(BaseModel): + file: str = Field(..., description="File for trace streaming") + streaming: bool = Field(..., description="Toggle for trace streaming to file") + include_prefixes: List[str] = Field(default=("src/dnet/"), description="") + include_c_calls: bool = Field(default=False, description="") + budget: int = Field(default=0, description="Max amount of traces in memory") + enabled: bool = Field(default=False, description="Start or stop tracing") + node_id: Optional[str] = Field(default="NONE", description="") + record_pid_tid: bool = Field(default=True, descriptino="") + aggregate: bool = Field(default=True, description="") + aggregate_url: Optional[str] = Field(default=None, description="") + agg_max_events: int = Field(default=300, description="Threshold for sending frames to API") + +class TraceConfigResponse(BaseModel): + ok: bool = True + +class TraceEvent(BaseModel): + type: str = Field(..., description="Event type/phase") + name: str = Field(..., description="Span/mark name") + ts: float = Field(..., description="Timestamp in microseconds") + args: Dict[str, Any] = Field(default_factory=dict) + req_id: Optional[str] = None + pid: Optional[int] = None + tid: Optional[int] = None + +class TraceIngestBatch(BaseModel): + run_id: str = Field(..., description="Bench run identifier") + node_id: str = Field(..., description="Shard/service identity") + events: List[TraceEvent] = Field(default_factory=list) + #dropped: Optional[int] = Field(default=0, description="Events dropped on sender") + #max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") + #last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") + #schema_version: int = Field(default=1) + +class TraceIngestResponse(BaseModel): + ok: bool = True + accepted: int = 0 + message: Optional[str] = None diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 82a60d40..3f551a62 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -19,7 +19,6 @@ from dnet_p2p import AsyncDnetP2P, DnetDeviceProperties -from dnet.utils.latency import calculate_median_latency_seconds from dnet.utils.serialization import tensor_to_bytes from .servicer import ShardServicer @@ -27,11 +26,15 @@ from .models import ( HealthResponse, + MeasureLatencyRequest, + MeasureLatencyResponse, ShardLoadModelRequest, ShardLoadModelResponse, ShardProfileRequest, ShardProfileResponse, ShardUnloadModelResponse, + TraceConfigRequest, + TraceConfigResponse, ) from ..model.base import BaseRingModel @@ -62,6 +65,8 @@ from .comms import CommsMixin from ..weight_cache import WeightCache +from dnet.perf import TraceConfig, Tracer + class RingShardNode(ComputeMixin, PrefetchMixin, CommsMixin): """Single shard node in the distributed inference ring with dynamic model loading.""" @@ -103,7 +108,7 @@ def __init__( self._assigned_sorted = sorted(self.assigned_layers or []) self._assigned_set = set(self._assigned_sorted) - # Topology (configured later) + # Topology self.next_node: Optional[DnetDeviceProperties] = None self.total_layers: int = 0 # Total layers in model self.api_callback_address: Optional[str] = None @@ -188,6 +193,7 @@ def __init__( # Discovery self.discovery = AsyncDnetP2P("lib/dnet-p2p/lib") + self._instance_name = "" # Background tasks self.background_tasks: List[asyncio.Task] = [] @@ -197,8 +203,21 @@ def __init__( self._sync_per_layer = obs.sync_per_layer self._sync_every_n = obs.sync_every_n self._prof = make_profiler(self._profile) - if self._profile: - logger.info("[PROFILE] enabled on shard node %s", self.node_id) + + # Debug tracing + cfg = TraceConfig( + file="./trace.json", + streaming=False, + include_prefixes = ("src/dnet/"), + include_c_calls = False, + budget = 10000, + enabled = True, + record_pid_tid = True, + aggregate=False, + aggregate_url=None, + ) + self.tracer = Tracer(cfg) + self.tracer.start() # Per-nonce KV caches (concurrent requests) self._kv_by_nonce: Dict[str, list] = {} @@ -218,11 +237,8 @@ def __init__( ) async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse: - """Load model with specified layers.""" - try: - start_time = time.perf_counter() - - # Check if already loaded with same configuration + """Load model with specified layers""" + try: # Check if already loaded with same configuration if ( self.model is not None and self.model_path == req.model_path @@ -240,22 +256,24 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # If model loaded with different config, unload first - if self.model is not None and ( - self.model_path != req.model_path or self.assigned_layers != req.layers + if (self.model is not None + and (self.model_path != req.model_path or self.assigned_layers != req.layers) ): - logger.info( - "Node %s: Unloading current model to load new configuration", - self.node_id, - ) - await self.unload_model() + logger.info("Node %s: Unloading current model to load new configuration", self.node_id) + with self.tracer.frame("memory.model", "unload") as f: + f.set("node", self._instance_name) + await self.unload_model() # Load model metadata - self.model_metadata = get_model_metadata(req.model_path) + with self.tracer.frame("memory.model", "load_metadata"): + self.model_metadata = get_model_metadata(req.model_path) + self.assigned_layers = req.layers self._assigned_sorted = sorted(self.assigned_layers) self._assigned_set = set(self._assigned_sorted) self.model_path = req.model_path + # Decide mode dynamically from assignment + requested window requested_w = int(max(1, int(req.window_size))) local_count = max(1, len(self.assigned_layers)) @@ -351,45 +369,53 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # Initialize weight cache - self.weight_cache = WeightCache( - self.assigned_layers, - self.model_metadata, - window_size=self.window_size, - prefetch_threads=self._prefetch_threads, - resident_windows=self._resident_windows, - use_mxload_fastpath=self.config.mxload_fastpath, - prefetch_mode=self.config.prefetch_mode, - ) + with self.tracer.frame("memory.weights", "cache.init") as f: + f.set("node", self._instance_name) + self.weight_cache = WeightCache( + self.assigned_layers, + self.model_metadata, + window_size=self.window_size, + prefetch_threads=self._prefetch_threads, + resident_windows=self._resident_windows, + use_mxload_fastpath=self.config.mxload_fastpath, + prefetch_mode=self.config.prefetch_mode, + tracer=self.tracer, + ) # Load the model - self.model = get_ring_model( - self.model_metadata.model_type, - self.model_metadata.model_config, - assigned_layers=self.assigned_layers, - is_api_layer=False, - ) - try: - applied = bool( - self.model.apply_quantization_from_config( # type: ignore[attr-defined] - self.model_metadata.model_config, - model_metadata=self.model_metadata, - ) - ) - logger.info( - "[QUANT] applied=%s for model=%s", - applied, + with self.tracer.frame("memory.model", "load") as f: + f.set("node", self._instance_name) + self.model = get_ring_model( self.model_metadata.model_type, + self.model_metadata.model_config, + assigned_layers=self.assigned_layers, + is_api_layer=False, ) + try: + applied = bool( + self.model.apply_quantization_from_config( # type: ignore[attr-defined] + self.model_metadata.model_config, + model_metadata=self.model_metadata, + ) + ) + logger.info( + "[QUANT] applied=%s for model=%s", + applied, + self.model_metadata.model_type, + ) - except RuntimeError as e: - logger.warning("[QUANT] apply failed: %s", e) - self.model.eval() - self.cache = make_cache( - self.model, - kv_mode=self.config.kv_cache.mode, - kv_bits=self.config.kv_cache.bits, - kv_group=self.config.kv_cache.group_size, - ) + except RuntimeError as e: + logger.warning("[QUANT] apply failed: %s", e) + self.model.eval() + + with self.tracer.frame("memory.cache", "make_cache") as f: + f.set("node", self._instance_name) + self.cache = make_cache( + self.model, + kv_mode=self.config.kv_cache.mode, + kv_bits=self.config.kv_cache.bits, + kv_group=self.config.kv_cache.group_size, + ) try: has_start = 0 in self.assigned_layers @@ -423,28 +449,31 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse self.total_layers = req.total_layers self.api_callback_address = req.api_callback_address - if self.next_node: - await self._connect_next_node() - else: - logger.warning("Node %s: No next node configured", self.node_id) + with self.tracer.frame("network.connect", "next_node") as f: + f.set("node", self._instance_name) + if self.next_node: + await self._connect_next_node() + else: + logger.warning("Node %s: No next node configured", self.node_id) # Warmup: compile hot path and stabilize allocators before first request - if req.warmup and self._mode == "fit": - loop = asyncio.get_running_loop() - try: - await loop.run_in_executor(self.executor, self._warmup_shard) - except Exception: - # Fall back to direct call if executor is unavailable - self._warmup_shard() - elif req.warmup and self._mode != "fit": - # Offload/sliding-fit: perform a small, offload-safe warmup for the first window - loop = asyncio.get_running_loop() - try: - await loop.run_in_executor( - self.executor, self._warmup_shard_offload - ) - except Exception: - self._warmup_shard_offload() + with self.tracer.frame("memory", "warmup"): + if req.warmup and self._mode == "fit": + loop = asyncio.get_running_loop() + try: + await loop.run_in_executor(self.executor, self._warmup_shard) + except Exception: + # Fall back to direct call if executor is unavailable + self._warmup_shard() + elif req.warmup and self._mode != "fit": + # Offload/sliding-fit: perform a small, offload-safe warmup for the first window + loop = asyncio.get_running_loop() + try: + await loop.run_in_executor( + self.executor, self._warmup_shard_offload + ) + except Exception: + self._warmup_shard_offload() if self._mode == "offload" and not ( self._warmup_completed and self._warmup_keep_flag @@ -465,10 +494,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse if m % self.window_size != 0: logger.warning( "Window size %s does not divide local layer count %s. Rounds per token will vary; consider setting k*w = %s.", - self.window_size, - m, - m, - ) + self.window_size, m, m) else: k = m // self.window_size logger.info( @@ -478,20 +504,12 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse k, ) - load_time_ms = (time.perf_counter() - start_time) * 1000.0 - logger.info( - "Node %s: Successfully loaded model %s with layers %s in %.2fms", - self.node_id, - req.model_path, - req.layers, - load_time_ms, - ) return ShardLoadModelResponse( success=True, message="Model loaded successfully", layers_loaded=req.layers, - load_time_ms=load_time_ms, + load_time_ms=0.0, ) except Exception as e: @@ -537,23 +555,24 @@ async def unload_model(self) -> ShardUnloadModelResponse: self._assigned_set = set() # Clear memory pools - if self.weight_cache: - # Stop any in-flight prefetch and close layer manager resources - try: - self.weight_cache.cancel_all_prefetch() - except Exception: - pass - # Clear all cached weights - for layer_id in list(self._bound_versions.keys()): + with self.tracer.frame("memory", "cache.evict"): + if self.weight_cache: + # Stop any in-flight prefetch and close layer manager resources try: - self.weight_cache.evict_layer(layer_id) + self.weight_cache.cancel_all_prefetch() except Exception: pass - try: - self.weight_cache.layer_manager.close() - except Exception: - pass - self.weight_cache = None + # Clear all cached weights + for layer_id in list(self._bound_versions.keys()): + try: + self.weight_cache.evict_layer(layer_id) + except Exception: + pass + try: + self.weight_cache.layer_manager.close() + except Exception: + pass + self.weight_cache = None self.input_pool = None self.output_pool = None @@ -581,334 +600,332 @@ async def unload_model(self) -> ShardUnloadModelResponse: async def reset_cache(self) -> None: """Reset LLM KV cache.""" if not self.model: - logger.warning( - "Node %s: Cannot reset cache - no model loaded", self.node_id - ) + logger.warning("Node %s: Cannot reset cache - no model loaded", self.node_id) return - try: - self.cache = make_cache( - self.model, # type: ignore[arg-type] - kv_mode=self.config.kv_cache.mode, - kv_bits=self.config.kv_cache.bits, - kv_group=self.config.kv_cache.group_size, - ) - logger.info("Node %s: Cache reset successfully", self.node_id) - except Exception as e: - logger.error("Node %s: Error resetting cache: %s", self.node_id, e) + with self.tracer.frame("memory", "cache.reset"): + try: + self.cache = make_cache( + self.model, # type: ignore[arg-type] + kv_mode=self.config.kv_cache.mode, + kv_bits=self.config.kv_cache.bits, + kv_group=self.config.kv_cache.group_size, + ) + logger.info("Node %s: Cache reset successfully", self.node_id) + except Exception as e: + logger.error("Node %s: Error resetting cache: %s", self.node_id, e) + + # FIXME This seems to still be dead code async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): """Receive activation from previous node and queue for local compute or forward.""" + logger.debug("RECEIVE ACTIVATION") if self.input_pool is None: - logger.error( - "Node %s: Cannot receive activation - input pool not initialized", - self.node_id, - ) + logger.error("Node %s: Cannot receive activation - input pool not initialized", self.node_id) return - t_recv = time.perf_counter() - await self._connect_next_node() - - try: - activation = request.activation - target_layer = activation.layer_id + 1 + with self.tracer.frame("network.rx", "connect_next_node") as f: + f.set("req_id", request.nonce) + f.set("node", self._instance_name) + await self._connect_next_node() + with self.tracer.frame("network.rx", "process_activation") as f: + f.set("req_id", request.nonce) try: - payload_bytes = len(activation.data) - except Exception: - payload_bytes = -1 - transport_ms = float(utc_epoch_now() - request.timestamp) - logger.info( - "[PROFILE][RX] node=%s nonce=%s target_layer=%s transport_ms=%.1f payload_kb=%.1f", - self.node_id, - request.nonce, - target_layer, - transport_ms, - (payload_bytes / 1024.0), - ) - - # Detect new sequence per node: initialize per-nonce KV - if request.nonce != self._active_nonce: - self._active_nonce = request.nonce - try: - self._get_or_make_kv(request.nonce) - except Exception: - pass - - if target_layer in self._assigned_set: - # Allocate input pool and copy payload (with optional decompression) - t_alloc = time.perf_counter() - if "|" in activation.dtype: + activation = request.activation + target_layer = activation.layer_id + 1 + + # Detect new sequence per node: initialize per-nonce KV + if request.nonce != self._active_nonce: + self._active_nonce = request.nonce try: - deq = decompress_tensor_from_protobuf_data( - tensor_data=activation.data, - shape=list(activation.shape), - dtype_with_metadata=activation.dtype, - ) - except Exception as e: - logger.error( - "Decompression failed for nonce %s: %s", request.nonce, e - ) - return + payload_bytes = len(activation.data) + except Exception: + payload_bytes = -1 + f.event("process_payload") - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=deq.dtype, - shape=cast(tuple[int, ...], tuple(deq.shape)), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - flat = deq.reshape(-1) - buffer[: flat.size] = flat - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", - self.node_id, - request.nonce, - alloc_copy_ms, - ) - # Update activation message with true dtype/shape - new_dtype_str = str(deq.dtype) - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = new_dtype_str - activation_msg.shape = tuple(deq.shape) - else: - # Special token stream support: dtype='tokens' carries int32 token IDs - if activation.dtype == "tokens": - try: - tokens = np.frombuffer( - request.activation.data, dtype=np.int32 - ) - shp = (int(len(tokens)),) - except Exception as e: - logger.error( - "Failed to parse tokens for nonce %s: %s", - request.nonce, - e, - ) - return - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mx.int32, - shape=cast(tuple[int, ...], shp), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, + if target_layer in self._assigned_set: + # Allocate input pool and copy payload (with optional decompression) + t_alloc = time.perf_counter() + if "|" in activation.dtype: + with self.tracer.frame("grpc.receive", "decompress") as fr: + fr.set("req_id", request.nonce) + fr.set("node", self._instance_name) + try: + deq = decompress_tensor_from_protobuf_data( + tensor_data=activation.data, + shape=list(activation.shape), + dtype_with_metadata=activation.dtype, + ) + except Exception as e: + logger.error( + "Decompression failed for nonce %s: %s", request.nonce, e + ) + return + + with self.tracer.frame("grpc.receive", "alloc.buffer") as fr: + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=deq.dtype, + shape=cast(tuple[int, ...], tuple(deq.shape)), ) - return - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - buffer[: len(tokens)] = tokens - if self._profile: + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + flat = deq.reshape(-1) + buffer[: flat.size] = flat alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (tokens)", + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", self.node_id, request.nonce, alloc_copy_ms, ) - activation_msg = ActivationMessage.from_proto(request, pool_id) - # Ensure dtype reflects token payload for compute path - activation_msg.dtype = "tokens" - activation_msg.shape = shp - else: - # Safety: byte length must match shape*dtype - try: - expected = ( - int(np.prod(activation.shape)) - * np.dtype(dtype_map[activation.dtype]).itemsize - ) - actual = len(request.activation.data) - except Exception: - expected = -1 - actual = -1 - if expected != actual: - logger.error( - "Payload size mismatch for nonce=%s: expected=%d actual=%d dtype=%s shape=%s", - request.nonce, - expected, - actual, - activation.dtype, - activation.shape, - ) - return - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mlx_dtype_map[activation.dtype], - shape=cast(tuple[int, ...], activation.shape), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - data = request.activation.data - input_data = np.frombuffer( - data, dtype=dtype_map[activation.dtype] - ) - buffer[: len(input_data)] = input_data - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f", - self.node_id, - request.nonce, - alloc_copy_ms, - ) + # Update activation message with true dtype/shape + new_dtype_str = str(deq.dtype) activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + else: + # Special token stream support: dtype='tokens' carries int32 token IDs + if activation.dtype == "tokens": + with self.tracer.frame("grpc.receive", "token_stream") as fr: + try: + deq = decompress_tensor_from_protobuf_data( + tensor_data=activation.data, + shape=list(activation.shape), + dtype_with_metadata=activation.dtype) + except Exception as e: + logger.error("Decompression failed for nonce %s: %s", request.nonce, e) + return + + with self.tracer.frame("network.rx", "alloc.buffer") as fr: + fr.set("req_id", request.nonce) + fr.set("node", self._instance_name) + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=deq.dtype, + shape=cast(tuple[int, ...], tuple(deq.shape))) + + if pool_id is None: + logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) + return + + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + flat = deq.reshape(-1) + buffer[: flat.size] = flat + + # Update activation message with true dtype/shape + new_dtype_str = str(deq.dtype) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + + else: # Special token stream support: dtype='tokens' carries int32 token IDs + if activation.dtype == "tokens": + with self.tracer.frame("network.rx", "token_stream") as fr: + fr.set("req_id", request.nonce) + fr.set("node", self._instance_name) + try: + tokens = np.frombuffer(request.activation.data, dtype=np.int32) + shp = (int(len(tokens)), ) + except Exception as e: + logger.error("Failed to parse tokens for nonce %s: %s", request.nonce, e,) + return + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mx.int32, + shape=cast(tuple[int, ...], shp)) + + if pool_id is None: + logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) + return + + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + buffer[: len(tokens)] = tokens + activation_msg = ActivationMessage.from_proto(request, pool_id) + + # Ensure dtype reflects token payload for compute path + activation_msg.dtype = "tokens" + activation_msg.shape = shp + + else: + with self.tracer.frame("network.ex", "default") as fr: + fr.set("node", self._instance_name) + fr.set("req_id", request.nonce) + # Safety: byte length must match shape*dtype + try: + expected = ( + int(np.prod(activation.shape)) + * np.dtype(dtype_map[activation.dtype]).itemsize + ) + actual = len(request.activation.data) + except Exception: + pass + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mlx_dtype_map[activation.dtype], + shape=cast(tuple[int, ...], activation.shape)) + + if pool_id is None: + logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) + return + + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + data = request.activation.data + input_data = np.frombuffer(data, dtype=dtype_map[activation.dtype]) + buffer[: len(input_data)] = input_data + + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + + # Queue for processing — non-blocking back-off loop (cancellable) + while self.running: + try: + logger.error(f"NETWORK RX: {activation_msg.callback_url}") + self.activation_recv_queue.put_nowait(activation_msg) + activatino_msg.ex_enq_t = time.perf_counter() + logger.debug("Queued activation for processing: nonce %s", activation_msg.nonce) + break + except Full: + await asyncio.sleep(0) + else: + logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce) + self.input_pool.release(pool_id) + + else: # Forward to next node (not our layer) + logger.debug("Forwarding activation (layer %s) to next node, nonce: %s", target_layer, request.nonce) + await self._forward_activation(request) - if self._profile: - activation_msg.recv_perf_t = t_recv - - # Queue for processing — non-blocking back-off loop (cancellable) - if self._profile: - activation_msg.enq_perf_t = time.perf_counter() - while self.running: - try: - self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", - activation_msg.nonce, - ) - break - except Full: - await asyncio.sleep(0) - else: - logger.error( - "Failed to queue activation %s (node stopping)", - activation_msg.nonce, - ) - self.input_pool.release(pool_id) - else: - # Forward to next node (not our layer) - logger.debug( - "Forwarding activation (layer %s) to next node, nonce: %s", - target_layer, - request.nonce, - ) - await self._forward_activation(request) + except Exception as e: + logger.exception("Error receiving activation: %s", e) - except Exception as e: - logger.exception("Error receiving activation: %s", e) async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: - """ - Lightweight admission for streaming: - enqueue protobuf frame to ingress queue, then return. - """ + """enqueue protobuf frame to ingress queue""" while self.running: try: + rx_t = time.perf_counter() + request.rx_enq_t = rx_t + request.rx_inflight_t = 0.0 if request.tx_enq_prev_t == 0.0 else rx_t - request.tx_enq_prev_t + + logger.error(f"ADMIT_FRAME: {request.callback_url}") self.ingress_q.put_nowait(request) + logger.debug(f"[ENQUE] Enqueued activation request") return except asyncio.QueueFull: await asyncio.sleep(0) - # If we reached here, node is stopping; drop admission silently return + async def _ingress_worker(self): """Drains ingress queue and processes frames with heavy work offloaded. Admission (servicer) is lightweight; this worker performs per-frame processing, offloading alloc/copy/decompress to the threadpool, and - finally enqueues for compute or forwards to the next shard. - """ - while self.running: - try: - req = await self.ingress_q.get() - except asyncio.CancelledError: - break - try: - t_recv = time.perf_counter() - await self._connect_next_node() + finally enqueues for compute or forwards to the next shard. """ - activation = req.activation - target_layer = activation.layer_id + 1 + while self.running: + with self.tracer.frame("network.rx", "wait"): # NOTE: bad counter + try: + req = await self.ingress_q.get() + logger.debug(f"[DEQUE]Dequeued activation for processing") + except asyncio.CancelledError: + logger.error("Error while waiting ingress worker.") + break + + # Trace processing of request, in-flight and in-wait times + with self.tracer.frame("network", "rx") as f: + f.set("inwait", time.perf_counter() - req.rx_enq_t) + f.set("inflight", req.rx_inflight_t) + f.set("node", self._instance_name) + f.set("req_id", req.nonce) try: - payload_bytes = len(activation.data) - except Exception: - payload_bytes = -1 - transport_ms = float(utc_epoch_now() - req.timestamp) - logger.info( - "[PROFILE][RX] node=%s nonce=%s target_layer=%s transport_ms=%.1f payload_kb=%.1f", - self.node_id, - req.nonce, - target_layer, - transport_ms, - (payload_bytes / 1024.0), - ) + activation = req.activation + target_layer = activation.layer_id + 1 - # Detect new sequence per node: initialize per-nonce KV - if req.nonce != self._active_nonce: - self._active_nonce = req.nonce try: - self._get_or_make_kv(req.nonce) + payload_bytes = len(activation.data) except Exception: - pass + payload_bytes = -1 - if target_layer in self._assigned_set: - # Heavy prep in executor (alloc/copy/decompress) - loop = asyncio.get_running_loop() - try: - activation_msg = await loop.run_in_executor( - self.executor, - self._prepare_activation_message_blocking, - req, - ) - except Exception as e: - logger.error( - "Activation prepare failed for nonce %s: %s", req.nonce, e - ) - continue - if activation_msg is None: - continue - if self._profile: - activation_msg.recv_perf_t = t_recv - - # Enqueue for compute (cancellable back-off) - while self.running: - try: - self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", - activation_msg.nonce, - ) - break - except Full: - await asyncio.sleep(0) - else: - logger.error( - "Failed to queue activation %s (node stopping)", - activation_msg.nonce, - ) + # Detect new sequence per node: initialize per-nonce KV + if req.nonce != self._active_nonce: + self._active_nonce = req.nonce try: - if self.input_pool: - # FIXME: !!! - self.input_pool.release(activation_msg.pool_id) + self._get_or_make_kv(req.nonce) except Exception: - pass - else: - # Forward to next node (not our layer) - logger.debug( - "Forwarding activation (layer %s) to next node, nonce: %s", - target_layer, - req.nonce, - ) - await self._forward_activation(req) + pass + + if target_layer in self._assigned_set: + # Heavy prep in executor (alloc/copy/decompress) + with self.tracer.frame("network.ingress", "prepare") as fr: + #fr.set("node", self._instance_name) + #fr.set("nonce", req.nonce) + loop = asyncio.get_running_loop() + try: + activation_msg = await loop.run_in_executor( + self.executor, + self._prepare_activation_message_blocking, + req, + ) + except Exception as e: + logger.error("Activation prepare failed for nonce %s: %s", req.nonce, e) + continue + if activation_msg is None: + continue + #if self._profile: + # activation_msg.recv_perf_t = t_recv + + # Enqueue for compute (cancellable back-off) + with self.tracer.frame("network.rx", "enque") as fr: + fr.set("req_id", req.nonce) + fr.set("node", self._instance_name) + while self.running: + try: + logger.error(f"NETWORK RX: {activation_msg.callback_url}") + self.activation_recv_queue.put_nowait(activation_msg) + logger.debug( + "Queued activation for processing: nonce %s", + activation_msg.nonce, + ) + break + except Full: + await asyncio.sleep(0) + else: + logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce,) + try: + if self.input_pool: + # FIXME: !!! + self.input_pool.release(activation_msg.pool_id) + except Exception: + logger.error("Unable to release from input pool") + + else: # Forward to next node (not our layer) + logger.debug( + "Forwarding activation (layer %s) to next node, nonce: %s", + target_layer, + req.nonce, + ) + await self._forward_activation(req) + + except Exceptio as e: + logger.error("Ingress worker error: %s", e) + - except Exception as e: - logger.error("Ingress worker error: %s", e) def _get_or_make_kv(self, nonce: str) -> list: """Return a per-nonce KV cache list for this shard's local layers.""" @@ -947,13 +964,14 @@ def _clear_kv(self, nonce: str) -> None: except Exception: pass + def _prepare_activation_message_blocking( self, request: dnet_ring_pb2.ActivationRequest ) -> Optional[ActivationMessage]: + """Blocking heavy prep: allocate pool buffer, copy/decompress payload, build ActivationMessage. + Returns None on failure. """ - Returns None on failure. - """ if self.input_pool is None: logger.error( "Node %s: Cannot prepare activation - input pool not initialized", @@ -963,113 +981,123 @@ def _prepare_activation_message_blocking( try: activation = request.activation - if "|" in activation.dtype: - # Compressed path: decompress to MLX array and copy to pool - try: - deq = decompress_tensor_from_protobuf_data( - tensor_data=activation.data, - shape=list(activation.shape), - dtype_with_metadata=activation.dtype, - ) - except Exception as e: - logger.error( - "Decompression failed for nonce %s: %s", request.nonce, e - ) - return None + if "|" in activation.dtype: # Compressed path: decompress to MLX array and copy to pool - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=deq.dtype, - shape=cast(tuple[int, ...], tuple(deq.shape)), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return None - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - flat = deq.reshape(-1) - buffer[: flat.size] = flat - # Update activation message with true dtype/shape - new_dtype_str = str(deq.dtype) - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = new_dtype_str - activation_msg.shape = tuple(deq.shape) - return activation_msg - elif activation.dtype == "tokens": - # Tokens path: parse int32 token IDs and stage them - try: - tokens = np.frombuffer(activation.data, dtype=np.int32) - shp = (int(len(tokens)),) - except Exception as e: - logger.error( - "Failed to parse tokens for nonce %s: %s", request.nonce, e - ) - return None - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mx.int32, - shape=cast(tuple[int, ...], shp), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return None - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - buffer[: len(tokens)] = tokens - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = "tokens" - activation_msg.shape = shp - return activation_msg - else: - # Dense path: validate size and copy raw bytes view into pool buffer - try: - expected = ( - int(np.prod(activation.shape)) - * np.dtype(dtype_map[activation.dtype]).itemsize + with self.tracer.frame("network.rx.prepare_activation", "decompress") as f: + f.set("req_id", request.nonce) + f.set("node", self._instance_name) + try: + deq = decompress_tensor_from_protobuf_data( + tensor_data=activation.data, + shape=list(activation.shape), + dtype_with_metadata=activation.dtype, + ) + except Exception as e: + logger.error( + "Decompression failed for nonce %s: %s", request.nonce, e + ) + return None + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=deq.dtype, + shape=cast(tuple[int, ...], tuple(deq.shape)), ) - actual = len(activation.data) - except Exception: - expected = -1 - actual = -1 - if expected != actual: - logger.error( - "Payload size mismatch for nonce=%s: expected=%d actual=%d dtype=%s shape=%s", - request.nonce, - expected, - actual, - activation.dtype, - activation.shape, + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return None + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + flat = deq.reshape(-1) + buffer[: flat.size] = flat + # Update activation message with true dtype/shape + new_dtype_str = str(deq.dtype) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + return activation_msg + + elif activation.dtype == "tokens": # Tokens path: parse int32 token IDs and stage them + with self.tracer.frame("network.rx.prepare_activation", "tokens") as f: + f.set("req_id", request.nonce) + f.set("node", self._instance_name) + try: + tokens = np.frombuffer(activation.data, dtype=np.int32) + shp = (int(len(tokens)),) + except Exception as e: + logger.error( + "Failed to parse tokens for nonce %s: %s", request.nonce, e + ) + return None + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mx.int32, + shape=cast(tuple[int, ...], shp), ) - return None + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return None + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + buffer[: len(tokens)] = tokens + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = "tokens" + activation_msg.shape = shp + return activation_msg + + else: # Dense path: validate size and copy raw bytes view into pool buffer + with self.tracer.frame("network.rx.prepare_activation", "default") as f: + f.set("req_id", request.nonce) + f.set("node", self._instance_name) + try: + expected = ( + int(np.prod(activation.shape)) + * np.dtype(dtype_map[activation.dtype]).itemsize + ) + actual = len(activation.data) + except Exception: + expected = -1 + actual = -1 + if expected != actual: + logger.error( + "Payload size mismatch for nonce=%s: expected=%d actual=%d dtype=%s shape=%s", + request.nonce, + expected, + actual, + activation.dtype, + activation.shape, + ) + return None - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mlx_dtype_map[activation.dtype], - shape=cast(tuple[int, ...], activation.shape), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mlx_dtype_map[activation.dtype], + shape=cast(tuple[int, ...], activation.shape), ) - return None - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - data = request.activation.data - input_data = np.frombuffer(data, dtype=dtype_map[activation.dtype]) - buffer[: len(input_data)] = input_data - activation_msg = ActivationMessage.from_proto(request, pool_id) - return activation_msg + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return None + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + data = request.activation.data + input_data = np.frombuffer(data, dtype=dtype_map[activation.dtype]) + buffer[: len(input_data)] = input_data + activation_msg = ActivationMessage.from_proto(request, pool_id) + return activation_msg except Exception as e: logger.error("Activation prep error: %s", e) return None + def _next_local_layers(self, after_layer: int, count: int) -> List[int]: """Get next local layers after specified layer. @@ -1086,6 +1114,7 @@ def _next_local_layers(self, after_layer: int, count: int) -> List[int]: i = _bisect_left(s, after_layer + 1) return s[i : i + count] + def _compute_worker(self) -> None: """Compute thread worker.""" while self.running: @@ -1094,13 +1123,26 @@ def _compute_worker(self) -> None: activation_msg = self.activation_recv_queue.get(timeout=1.0) # Process the activation - self._process_activation(activation_msg) + with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + if activation_msg.ex_enq_t == 0.0: # FIXME float comparison + f.set("inwait", 0.0) + else: + f.set("inwait", time.perf_counter() - activation_msg.ex_enq_t) + + if (self.model_metadata.num_layers - 1) in self.assigned_layers: + f.set("lm_head", True) + + self._process_activation(activation_msg) + f.set("t0", time.perf_counter()) except Empty: continue except Exception as e: logger.error("Compute worker error: %s", e) + async def shutdown(self) -> None: """Shutdown the node.""" self.running = False @@ -1191,6 +1233,7 @@ async def _start_discovery(self) -> None: hostname = gethostname() # TODO: optionally take shard name from CLI instance = f"shard-{token_hex(4)}-{hostname}" + self._instance_name = instance self.discovery.create_instance( instance, self.http_port, @@ -1235,9 +1278,7 @@ def _warmup_serialization(self): pass def _warmup_shard(self): - logger.info( - "[WARMUP] Starting shard warmup with window size %s", self.window_size - ) + logger.info("[WARMUP] Starting shard warmup with window size %s", self.window_size) if not self.model or not self.model_metadata or not self.weight_cache: logger.warning("[WARMUP] No model loaded; skipping warmup") return @@ -1259,10 +1300,9 @@ def _warmup_shard(self): max_windows = max(1, self.config.warmup_windows) windows: list[list[int]] = [] for window_start in range(0, len(self._assigned_sorted), self.window_size): - window_end = min( - window_start + self.window_size, len(self._assigned_sorted) - ) + window_end = min(window_start + self.window_size, len(self._assigned_sorted)) windows.append(self._assigned_sorted[window_start:window_end]) + for wi, window_layers in enumerate(windows[:max_windows]): weights_to_bind = {} for layer_id in window_layers: @@ -1270,6 +1310,7 @@ def _warmup_shard(self): if weights: for k, v in weights.items(): weights_to_bind[k] = v + if weights_to_bind: # Serialize MLX parameter binding with self._mlx_lock: @@ -1283,6 +1324,7 @@ def _warmup_shard(self): mx.eval(_s) except Exception: pass + try: for lid in window_layers: self.weight_cache.decrease_reference(lid) @@ -1438,37 +1480,60 @@ async def health() -> HealthResponse: async def profile(req: ShardProfileRequest) -> ShardProfileResponse: logger.info("Received /profile request") try: - # Measure latencies - latency_results = await self._measure_latency_to_devices( - req.devices, req.thunderbolts, req.payload_sizes - ) - # Profile device using dperf device_profile = await self._profile_device( req.repo_id, req.max_batch_exp ) + logger.debug(device_profile) - # Overwrite `t_comm` with median latency (subprocess returns a dict) - median_latency = calculate_median_latency_seconds(latency_results) - if median_latency is not None: - device_profile["t_comm"] = float(median_latency) - logger.info( - f"Set t_comm to median latency: {device_profile['t_comm']:.6f}s" - ) - else: - logger.warning( - "No valid latency measurements, keeping default t_comm" - ) + return ShardProfileResponse(profile=device_profile) + except Exception as e: + logger.error(f"Error in /profile endpoint: {e}") + raise - # Return the dict payload directly - return ShardProfileResponse( - profile=device_profile, - latency=latency_results, + @self.app.post("/measure_latency") + async def measure_latency( + req: MeasureLatencyRequest, + ) -> MeasureLatencyResponse: + logger.info("Received /measure_latency request") + try: + # Measure latencies to other devices + latency_results = await self._measure_latency_to_devices( + req.devices, req.thunderbolts, req.payload_sizes ) + + return MeasureLatencyResponse(latency=latency_results) except Exception as e: - logger.error(f"Error in /profile endpoint: {e}") + logger.error(f"Error in /measure_latency endpoint: {e}") raise + @self.app.post("/trace") + async def setup_trace(req: TraceConfigRequest) -> TraceConfigResponse: + logger.debug("Updating trace config") + try: + cfg = TraceConfig( + file=req.file, + streaming=req.streaming, + include_prefixes=req.include_prefixes, + include_c_calls=req.include_c_calls, + budget=req.budget, + enabled=req.enabled, + node_id=req.node_id, + record_pid_tid=req.record_pid_tid, + aggregate=req.aggregate, + aggregate_url=req.aggregate_url, + agg_max_events=req.agg_max_events, + ) + self.tracer.config = cfg + logger.info("Updated tracer config.") + self.api_address = cfg.aggregate_url + self.tracer.start_aggregator() + return TraceConfigResponse(ok=True) + except Exception as e: + logger.error(f"Unable to setup tracing on shard: {e}") + return TraceConfigResponse(ok=False) + + @self.app.post("/load_model") async def load_model_endpoint( req: ShardLoadModelRequest, @@ -1481,7 +1546,10 @@ async def load_model_endpoint( f"total_layers={req.total_layers}, kv_bits={req.kv_bits or 'default'}, " f"api_callback={req.api_callback_address or 'none'}" ) - result = await self.load_model(req) + self.tracer.mark("model", {"model": req.model_path, "ts": time.perf_counter()}) # Record model name + with self.tracer.frame("memory", "model.load") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("node", self._instance_name) + result = await self.load_model(req) return result except Exception as e: @@ -1498,7 +1566,9 @@ async def unload_model_endpoint() -> ShardUnloadModelResponse: """Unload current model.""" try: logger.info("HTTP /unload_model") - result = await self.unload_model() + with self.tracer.frame("memory", "model.unload") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("node", self._instance_name) + result = await self.unload_model() return result except Exception as e: @@ -1512,29 +1582,33 @@ async def unload_model_endpoint() -> ShardUnloadModelResponse: # FIXME: add pydantic type here async def warm(request: Request) -> JSONResponse: try: - body = await request.json() - start = int(body.get("start", -1)) - window = int(body.get("window", self.window_size)) - if start < 0: - return JSONResponse( - status_code=400, content={"error": "missing/invalid start"} - ) - start_idx = 0 - for i, lyr in enumerate(self._assigned_sorted): - if lyr >= start: - start_idx = i - break - else: - return JSONResponse(content={"prefetched": []}) - window_layers = self._assigned_sorted[ - start_idx : start_idx + max(1, window) - ] - for wl in window_layers: - # Prefetch disabled in fit mode; allow only when non-fit and enabled - if self._mode != "fit" and self.config.prefetch_mode != "off": - self._prefetch_to_ram(wl) - self._enqueue_weight_prefetch(wl) - return JSONResponse(content={"prefetched": window_layers}) + # FIXME: Append warmup config? Or something to distinguish + with self.tracer.frame("memory", "model.warm") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("req_id", request.nonce) + f.set("node", self._instance_name) + body = await request.json() + start = int(body.get("start", -1)) + window = int(body.get("window", self.window_size)) + if start < 0: + return JSONResponse( + status_code=400, content={"error": "missing/invalid start"} + ) + start_idx = 0 + for i, lyr in enumerate(self._assigned_sorted): + if lyr >= start: + start_idx = i + break + else: + return JSONResponse(content={"prefetched": []}) + window_layers = self._assigned_sorted[ + start_idx : start_idx + max(1, window) + ] + for wl in window_layers: + # Prefetch disabled in fit mode; allow only when non-fit and enabled + if self._mode != "fit" and self.config.prefetch_mode != "off": + self._prefetch_to_ram(wl) + self._enqueue_weight_prefetch(wl) + return JSONResponse(content={"prefetched": window_layers}) except Exception as e: logger.error("/warm failed: %s", e) return JSONResponse(status_code=500, content={"error": str(e)}) @@ -1576,9 +1650,11 @@ async def _profile_device(self, repo_id: str, max_batch_exp: int) -> dict: Returns: Device profile information as a plain dict """ - profile_dict = profile_device_via_subprocess( - repo_id, max_batch_exp=max_batch_exp, debug=0 - ) + with self.tracer.frame("startup", "profile.device") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("node", self._instance_name) + profile_dict = profile_device_via_subprocess( + repo_id, max_batch_exp=max_batch_exp, debug=0 + ) logger.info("Device profiling completed for node %s", self.node_id) return profile_dict diff --git a/src/dnet/ring/shard/servicer.py b/src/dnet/ring/shard/servicer.py index 93e8627e..006f38a8 100644 --- a/src/dnet/ring/shard/servicer.py +++ b/src/dnet/ring/shard/servicer.py @@ -35,6 +35,7 @@ async def SendActivation( request.activation.layer_id, ) + logger.error(f"SERVICER: {request.callback_url}") await self.node.admit_frame(request) return ActivationResponse( diff --git a/src/dnet/ring/weight_cache.py b/src/dnet/ring/weight_cache.py index f7c9ae93..3749522d 100644 --- a/src/dnet/ring/weight_cache.py +++ b/src/dnet/ring/weight_cache.py @@ -26,19 +26,19 @@ def __init__( model_metadata: ModelMetadata, window_size: Optional[int] = None, prefetch_threads: int = 2, + tracer=None, *, resident_windows: int = 2, use_mxload_fastpath: bool = False, prefetch_mode: str = "off", ): self.assigned_layers = assigned_layers - # Resident budget: enforce up to N windows resident - resident_windows = max(1, int(resident_windows)) + resident_windows = max(1, int(resident_windows)) # Resident budget if window_size is not None and window_size > 0: self.max_weights = min( - len(self.assigned_layers), max(1, resident_windows * int(window_size)) - ) + len(self.assigned_layers), + max(1, resident_windows * int(window_size))) else: self.max_weights = len(self.assigned_layers) self.cache = {} # layer_id -> (data, access_time) @@ -51,112 +51,106 @@ def __init__( prefetch_mode=prefetch_mode, ) self.lock = threading.Lock() + + if not tracer: + logger.error("Invalid tracer object passed to WeightCache.") + self.tracer = tracer + # Track in-flight materializations so compute can wait on prefetch self.loading_futures: Dict[int, Future] = {} self.prefetch_futures: Dict[int, Future] = {} logger.info("WeightCache resident budget: max_weights=%d", self.max_weights) - def get_weight( - self, layer_id: int, *, inc_ref: bool = True - ) -> Optional[Dict[str, mx.array]]: + + def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[str, mx.array]]: """Get weight from cache""" # First, fast path under lock for cache hit or in-flight load - with self.lock: - if layer_id in self.cache: - data, _ = self.cache[layer_id] - # refresh LRU timestamp - self.cache[layer_id] = (data, time.time()) - if inc_ref: - self.reference_counts[layer_id] = ( - self.reference_counts.get(layer_id, 0) + 1 - ) - return data - - # If a load is in-flight, wait on it outside the lock - inflight = self.loading_futures.get(layer_id) - if inflight is None: - # Prepare eviction decision now to avoid overfilling once loaded - need_evict = len(self.cache) >= self.max_weights - if need_evict: - # Evict under lock, then proceed to load - self._evict_lru() - # Install a new future marker so others wait - fut = Future() - self.loading_futures[layer_id] = fut - inflight = fut - creator = True - else: - creator = False - - if creator: - # Perform the blocking load without holding the cache lock - try: - t0 = time.perf_counter() - data = self.layer_manager.load_layer_to_gpu(layer_id) - dt_ms = (time.perf_counter() - t0) * 1000.0 - # Estimate bytes by summing tensor sizes for the layer - try: - winfo = self.layer_manager.weight_info.get(layer_id, {}) - total_bytes = sum(w.size_bytes for w in winfo.values()) - except Exception: - total_bytes = 0 - # Commit to cache under lock - with self.lock: + with self.tracer.frame("memory.weights", "cache.search") as f: + with self.lock: + if layer_id in self.cache: + data, _ = self.cache[layer_id] + # refresh LRU timestamp self.cache[layer_id] = (data, time.time()) if inc_ref: self.reference_counts[layer_id] = ( self.reference_counts.get(layer_id, 0) + 1 ) - else: - self.reference_counts.setdefault(layer_id, 0) - # Resolve future and clear from tracking - try: - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_result(True) - except Exception: - self.loading_futures.pop(layer_id, None) - logger.info( - "[PROFILE][MATERIALIZE] layer=%s ms=%.2f bytes=%.2fMB", - layer_id, - dt_ms, - (total_bytes / 1_048_576), - ) - return data - except Exception as e: - # Signal failure to any waiters - with self.lock: - try: - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_exception(e) + return data + + inflight = self.loading_futures.get(layer_id) # If a load is in-flight, wait on it outside the lock + if inflight is None: + need_evict = len(self.cache) >= self.max_weights + if need_evict: # Prepare eviction decision now to avoid overfilling once loaded + self._evict_lru() # Evict under lock, then proceed to load + fut = Future() # Install a new future marker so others wait + self.loading_futures[layer_id] = fut + inflight = fut + creator = True + else: + creator = False + + if creator: # Perform the blocking load without holding the cache lock + with self.tracer.frame("memory.weights", "cache.load") as f: + try: + data = self.layer_manager.load_layer_to_gpu(layer_id) + f.event("load") + + try: # Estimate bytes by summing tensor sizes for the layer + winfo = self.layer_manager.weight_info.get(layer_id, {}) + total_bytes = sum(w.size_bytes for w in winfo.values()) + f.set("bytes", total_bytes) except Exception: - self.loading_futures.pop(layer_id, None) - logger.exception("Failed to load weight %s: %s", layer_id, e) - return None + total_bytes = 0 + + with self.lock: # Commit to cache under lock + self.cache[layer_id] = (data, time.time()) + if inc_ref: + self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) + else: + self.reference_counts.setdefault(layer_id, 0) + + try: # Resolve future and clear from tracking + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_result(True) + except Exception: + self.loading_futures.pop(layer_id, None) + return data + + except Exception as e: + with self.lock: # Signal failure to any waiters + try: + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_exception(e) + except Exception: + self.loading_futures.pop(layer_id, None) + logger.exception("Failed to load weight %s: %s", layer_id, e) + return None else: # Not the creator: wait for the in-flight load to complete - t0w = time.perf_counter() - try: - inflight.result() # block until the creator completes - except Exception as e: - logger.error("Wait for layer %s load failed: %s", layer_id, e) - return None - wait_ms = (time.perf_counter() - t0w) * 1000.0 - logger.info("[PROFILE][WAIT-WEIGHT] layer=%s ms=%.2f", layer_id, wait_ms) - # Return from cache (now populated) and update ref/LRU - with self.lock: - data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] - if data is None: + with self.tracer.frame("memory.weights", "cache.wait") as f: + t0w = time.perf_counter() + try: + inflight.result() # block until the creator completes + except Exception as e: + logger.error("Wait for layer %s load failed: %s", layer_id, e) return None - self.cache[layer_id] = (data, time.time()) - if inc_ref: - self.reference_counts[layer_id] = ( - self.reference_counts.get(layer_id, 0) + 1 - ) - else: - self.reference_counts.setdefault(layer_id, 0) - return data + wait_ms = (time.perf_counter() - t0w) * 1000.0 + logger.info("[PROFILE][WAIT-WEIGHT] layer=%s ms=%.2f", layer_id, wait_ms) + # Return from cache (now populated) and update ref/LRU + with self.lock: + data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] + if data is None: + return None + self.cache[layer_id] = (data, time.time()) + if inc_ref: + self.reference_counts[layer_id] = ( + self.reference_counts.get(layer_id, 0) + 1 + ) + else: + self.reference_counts.setdefault(layer_id, 0) + return data def decrease_reference(self, layer_id: int): """Decrease reference count for layer""" @@ -184,6 +178,7 @@ def prefetch_to_ram(self, layer_id: int): except Exception: return None + def cancel_all_prefetch(self): """Cancel any in-flight prefetch tasks and clear tracking.""" with self.lock: @@ -195,6 +190,7 @@ def cancel_all_prefetch(self): pass self.prefetch_futures.clear() + def _evict_lru(self): """Evict least recently used weight with zero references""" candidates = [ diff --git a/src/dnet/utils/logger.py b/src/dnet/utils/logger.py index 19f09c46..bac462f9 100644 --- a/src/dnet/utils/logger.py +++ b/src/dnet/utils/logger.py @@ -14,7 +14,7 @@ def get_logger() -> logging.Logger: Returns: Configured logger instance """ - logLevelEnv = os.getenv("DNET_LOG", "INFO").strip().upper() + logLevelEnv = os.getenv("DNET_LOG", "DEBUG").strip().upper() logLevel = logging.INFO # default if logLevelEnv in logging._nameToLevel: logLevel = logging._nameToLevel[logLevelEnv] diff --git a/src/repl.py b/src/repl.py new file mode 100644 index 00000000..c033ccb9 --- /dev/null +++ b/src/repl.py @@ -0,0 +1,865 @@ + +import io +import os +import sys +import json +import logging +import cmd +import time +import argparse +import subprocess +import contextlib +from dataclasses import dataclass +from typing import Optional, List, Any, Dict + +import asyncio +import inspect +import threading +import concurrent.futures + +from dnet.ring.api import RingApiNode +from dnet.ring.shard import RingShardNode +from dnet.utils.network import NodeAddress +#from dnet.utils.logger import logger +from dnet.ring.api.api_logging import get_api_logger +from dnet.utils.model import ( + ModelMetadata, + get_model_metadata, + load_api_layer_weights, + get_safetensor_details, +) + +logger = get_api_logger() + +from dnet.perf.trace import TraceConfig, Tracer +from dnet.perf.utils import TraceAggregator, StatsAggregator +#from dnet.perf.bench import +from dnet.ring.common import TopologyInfo + +from dnet.ring.api.models import ( + PrepareTopologyManualRequest, + PrepareTopologyRequest, + APILoadModelRequest, + APILoadModelResponse, + ChatParams, + ChatMessage, + ChatRequestModel, +) + +# Handle restricted repos +from importlib import import_module +import huggingface_hub as hb +from huggingface_hub import snapshot_download, try_to_load_from_cache +try: + hf_errors = import_module("huggingface_hub.errors") +except ModuleNotFoundError: + hf_errors = import_module("huggingface_hub.utils") +GatedRepoError = getattr(hf_errors, "GatedRepoError", Exception) +HfHubHTTPError = getattr(hf_errors, "HfHubHTTPError", Exception) +LocalEntryNotFoundError = getattr(hf_errors, "LocalEntryNotFoundError", Exception) + + +def dprint(msg): + sys.stdout.write(msg) + sys.stdout.flush() + +@dataclass +class ChatInterface: + id: str = "" + model: str = "" + params: ChatParams = None + messages: List[ChatMessage] = None + +@dataclass +class REPLState: + model: str = None + model_info: ModelMetadata = None, + num_local_nodes: int = 1 + running_port = 50501 + running_httpport = 8091 + api_http_port: int = 8080 + api_grpc_port: int = 50500 + window_size = 2 # Number of layers per node per visit (also number resident in cache) + topo: TopologyInfo = None + + +class REPL(cmd.Cmd): + PS1 = "dnet > " + WELCOME = "\nDNET Distributed Inference Engine, v0.1\nExperimental software. Type 'help' for usage hints.\n\n" + + def __init__(self, model=None, nodes=1): + assert nodes >= 1 and nodes < 10, "Invalid number of local nodes. Must be 0 < num < 10." + super().__init__() + + # State + self.state = REPLState() + self.state.model = model + self.state.running_port += 2 + self.state.num_local_nodes = nodes + + # API Thread + self._node: Optional[RingApiNode] = None + self._api_thread: Optional[threading.Thread] = None + self._api_ready = threading.Event() + self._api_running = threading.Event() + self._api_searching = threading.Event() # Track mDNS searching + self._api_loop: Optional[asyncio.AbstractEventLoop] = None + self._api_shutdown_e: Optional[asyncio.Event] = None + self._api_exc: Optional[BaseException] = None + + # Tracing + self._trace_cfg = TraceConfig( + enabled=False, + streaming=False, + budget=3000, + aggregate=True, + agg_max_events=50, + ) + self._tracing = threading.Event() + self._tracer = None + self._trace_file = f"trace-{time.strftime("%Y%m%d-%H%M%S")}" + self._trace_cursor = 0 # keep track of registered events in buffer + self._trace_agg = TraceAggregator() + + # Runtime stats (ingests data from tracer) + self._stats_agg = StatsAggregator() + self._stats = threading.Event() + self._stats.set() # Trace runtime information by default + + + def loop(self): # Main tty loop + sys.stdout.write(self.WELCOME) + while True: + dprint(self.PS1) + cmd = sys.stdin.readline().strip().split(" ") + + match cmd[0]: + case "": continue + case s if s in ["exit", "quit", "q"]: self.handle_terminate_signal() + case s if s in ["help", "h"]: self.print_help() + case s if s in ["api", "server"]: self.do_api(cmd) + case s if s in ["search", "s"]: self.do_search(cmd) + case s if s in ["nodes", "n"]: self.print_mdns_nodes() + case s if s in ["load", "l"]: self.load_model(self.state.model) + case s if s in ["trace", ".trace", "t"]: self.do_trace(cmd) + case s if s in ["perf", ".perf", "p"]: self.do_perf(cmd) + case s if s in ["topo", "topology", "t"]: self.do_topo(cmd) + case s if s in ["model", "m"]: self.do_model(cmd) + case "chat": self.do_chat(cmd) + + def do_api(self, cmd: List[str]) -> None: + if len(cmd) < 2: + dprint("Invalid API command. Type 'help' for a list of valid commands.\n") + return + match cmd[1]: # TODO Maybe allow kwargs here? > api start grpc_port=8080 + case s if s in ["start", "run"]: + http_port, grpc_port = None, None + if len(cmd) > 2: + try: + http_port = cmd[2]; + grpc_port = cmd[3] + except: + pass + self.start_api( http_port or self.state.api_http_port, grpc_port or self.state.api_grpc_port) + self.api_call("set_trace_ingest_callback", self.__trace_cb, timeout=2.0) + + case "stop": self.stop_api() + case "status": dprint("Running\n" if self._api_running else "Stopped.\n") + case "log": dprint("Log print is not yet supported.\n") + case _: dprint("Invalid API command. Type 'help' for a list of valid commands.\n") + + def do_search(self, cmd: List[str]) -> None: + if len(cmd) != 2: + dprint("mDNS search is " + ("ON\n\n" if self._api_searching else "OFF\n\n")) + return + match cmd[1]: # NOTE on by default + case "on": + if self._api_searching: + return + if not self._api_ready and self._api_running: + dprint("Starting API Server thread.\n") + self.start_api() + self.api_call("_start_discovery", timeout=10) + self._api_searching.set() + dprint("Starting mDNS search for worker nodes.\n") + case "off": dprint("Stop discovery not yet implemented in the API node.\n") + case _: dprint("Invalid topology command. Start searchign with 'search on'.\n") + + def do_topo(self, cmd: List[str]) -> None: + if len(cmd) < 2: + dprint("Invalid topology command. Type 'help' for a list of valid commands.\n") + return + match cmd[1]: + case "search": self.print_mdns_nodes() + case s if s in ["auto", "build", "b"]: self.prepare_topo(self.state.model) + case "add": dprint("Not implemented.\n") + case s if s in ["remove", "rm"]: dprint("Not implemented.\n") + + # TODO: standardize ANSI escape codes for easy use + def print_help(self): + def _print_hf(cmd, desc, examples=[""]): + pcmd = " " + cmd.ljust(30, '.') + sys.stdout.write(f"{pcmd} {desc}\n") + for e in examples: + pex = e.rjust(len(e)+37)+"\n" if e != "" else "" + sys.stdout.write(f"{pex}") + + sys.stdout.write("\033[1m\nAvailable commands:\n\033[0m") + dprint("\033[1m\n Common:\n\033[0m") + _print_hf("model list ", "List locally available models.") + _print_hf("model [REPO]", "Set the target model. [REPO] must be a valid repository") + _print_hf("api start [HTTP] [GRPC]", "Start the API server in a separate thread. Use provided ports if given.") + _print_hf("topo auto ", "Automatically optimize topology from available nodes.") + _print_hf("load ", "Load model into memory on all nodes.") + _print_hf("api log [LEVEL]", "Output live logs to current terminal.") + dprint("\033[1m\n Controlling the API Server:\n\033[0m") + _print_hf("api start [HTTP] [GRPC]", "Start the API server in a separate thread. Use provided ports if given.") + _print_hf("api stop ", "Stop the API server.") + _print_hf("api log [LEVEL]", "Output live logs to current terminal.") + dprint("\033[1m\n Topology management:\n\033[0m") + _print_hf("search [on/off] ", "Toggle mDNS search across the local network. Default is ON.") + _print_hf("nodes list ", "List all nodes in the current topology (including local ones).") + _print_hf("nodes all ", "List all mDNS discovered nodes (including local ones).") + _print_hf("topo auto ", "Automatically optimize topology from available nodes.") + _print_hf("topo add [NODE]", "Add [NODE] to the topology.") + _print_hf("topo [remove|rm] [NODE]", "Remove [NODE] from the topology.") + _print_hf("topo assign [NODE] [LAYERS] [ROUNDS] ", "Assign [LAYERS] to [NODE]. e.g:", + ["> topo assign dnet0 1-10"]) + sys.stdout.write("\033[1m\n Scheduling:\n\033[0m") + _print_hf("sched auto ", "Automatic search for best schedule given the active topology and the loaded model.") + _print_hf("sched assign [INDEX] [NODE]", "Assign the layer range between [START] and [END] to [NODE]. e.g:", + ["> sched assign 10 benny_234", + "> sched assign 0-12 benny_234"]) + sys.stdout.write("\033[1m\n Benchmarking, Tracing and Profiling:\n\033[0m") + _print_hf("perf ", "Prints the current state of runtime performance tracking.") + _print_hf("perf stat [-n node] [-r req] ", "Prints the runtime statistics of target system.") + _print_hf("perf vm [interval]", "Prints virtual memory information for worker thread at [INTERVAL].") + _print_hf("bench [repo_id]", "Benchmark the system using the model from [REPO]") + _print_hf("bench [kernel]", "Behcnmark the system using base kernel [KERNEL]") + _print_hf("bench [node]", "Behcnmark the network latency between the current system and [NODE]") + _print_hf("bench [node0, node1]", "Behcnmark the network latency between [NODE0] and [NODE11") + _print_hf("bench ", "Behcnmark the system using base library kernels") + _print_hf("trace [on|off] ", "Enable system tracing.") + _print_hf("trace [-o path] [-p [probes]]", "Trace [probes] and output to file at [path]. e.g:", + ["> trace on -p memory,network", + "> trace on -o ./trace_output.txt"]) + _print_hf("trace list ", "List available trace probes.") + _print_hf("trace add [probe] ", "Activate [probe].") + _print_hf("trace stream [on|off] ", "Stream the trace events to the current terminal.") + _print_hf("trace [budget|b] [limit]", "Set the maximum budget for recoded events. Default is 1000.") + _print_hf("trace stat ", "See status of the trace, eg. number of frames captured") + sys.stdout.write("\033[1m\n System control:\n\033[0m") + _print_hf("limit [RESOURCE] [VALUE]", "Set a higher limit for a system resource. e.g:", + ["> limit SYSMEM 12000 (MB)", + "> limit CPU_CORE_COUNT 4", + "> limit GPU_SM 128"]) + sys.stdout.write("\n") + sys.stdout.flush() + + # ===== Handle Model input and pull from server + + def prompt_model(self): + while True: + dprint("Target model > ") + model = sys.stdin.readline().strip() + try: + path = self._handle_model_pull(model) + return path + #self.model_info = ModelMetadata() + except Exception as e: + dprint(f"Unable to load model {model}. Target needs to be a valid HF repository. Try again:{e}\n") + + def do_model(self, cmd): + if len(cmd) < 2: + if self.state.model is None: dprint("No target model.\n") + else: dprint(f"Target model: {self.state.model}.\n") + return + + match cmd[1]: + case "list": # List locally available models + lists = self._list_local_models() + dprint("\nLocally available weights:\n") + for x in lists[0]: + dprint(f" {x.replace("--", "/")}\n") + dprint("\nMetadata only:\n") + for x in lists[1]: + dprint(f" {x.replace("--", "/")}\n") + dprint("\n") + case _: # Treat unknown commands as model repos + self._handle_model_pull(cmd[1]) + + def _are_weights_local(self, repo_id, revision="main"): + try: + import re, os + from pathlib import Path + root = Path.home() / ".cache/huggingface/hub" + if os.getenv("HF_HOME") is not None: root = Path(f"{os.getenv("HF_HOME")}/huggingface/hub") + if os.getenv("HF_HUB_CACHE") is not None: root = Path(f"{os.getenv("HF_HUB_CACHE")}/huggingface/hub") + files = os.scandir(root / f"models--{repo_id}/snapshots") + commit_hash = next(files).name + files = os.scandir(root / f"models--{repo_id}/snapshots/{commit_hash}") + files = [x.name for x in files] + except Exception as e: + dprint(f"Failed to search files in local model cache: {e}\n") + return False + PATTERNS = [ + re.compile(r".*?model.*\.safetensors$"), + re.compile(r"^(model|pytorch_model)([._-]\d+([-_]of[-_]\d+)?)?[._-]?\.safetensors$"), + re.compile(r"pytorch_model[-_]\d{3,}[.]bin$"), + re.compile(r"model[.]safetensors$"), + re.compile(r"pytorch_model[.]bin$"), + re.compile(r"tf_model[.]h5$"), + re.compile(r"flax_model[.]msgpack$"), + ] + for pat in PATTERNS: + if any(pat.search(f) for f in files): + return True + return False + + def _list_local_models(self): + import os + from pathlib import Path + root = Path.home() / ".cache/huggingface/hub" + if os.getenv("HF_HOME") is not None: root = Path(f"{os.getenv("HF_HOME")}/huggingface/hub") + if os.getenv("HF_HUB_CACHE") is not None: root = Path(f"{os.getenv("HF_HUB_CACHE")}/huggingface/hub") + models = [x.name.replace("models--", "") for x in os.scandir(root) if x.name.startswith("models--")] + weights_local = sorted([x for x in models if self._are_weights_local(x)], key=len) + config_local = sorted([x for x in models if x not in weights_local], key=len) + return [weights_local, config_local] + + # Require a HF access token for restricted repositories + # Ask user for HF access token until they have a valid one + def _handle_model_pull(self, repo_path): + local = self._are_weights_local(repo_path.replace("/", "--")) + try: + if not local: + dprint(f"Weights for {repo_path} not found in local cache. Downloading.\n") + path = snapshot_download(repo_path) + dprint("Download complete\n") + else: + dprint("Target model found in local registry.\n") + + self.state.model = repo_path + return repo_path + + except GatedRepoError as e: + tok = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") + while True: + dprint("\nRestricted model. Ener the HuggingFace access token > ") + tok = sys.stdin.readline().strip() + try: + ret = snapshot_download(repo_id=repo_path, token=tok) + self.state.model = path + return ret + except GatedRepoError as e: + print(e) + continue + except Exception as e: + raise RuntimeError(f"Unknown error during HF snapshot_download") + + except HfHubHTTPError as e: + dprint(f"Repository {repo_path} not found in Hugging Face registry: {e}") + return None + + except Exception as e: + raise RuntimeError(f"Unable to pull model {repo_path} locally") + + def _parse_model_metadata(self, model_path): + if isinstance(model_path, tuple): + model_path = model_path[0] if model_path else None + if model_path is None: + raise ValueError("Could not resolve model path {model_path}") + path = Path(model_path) if not isinstance(model_path, Path) else model_path + + with open(path / "config.json", "r") as f: + config = json.load(f) + + weight_files = glob.glob(oath / "*.safetensors") + weight_info = Dict[int, Dict[str, Any]] = defaultdict(dict) + embed_tokens, lm_head, norm = {}, {}, {} + for weight in weight_files: + details = get_safetensor_details(weight) + for key, val in details.items(): + if m := EMBED_TOKENS_RE.match(key): + embed_tokens[m.group(1)] = val + elif m := LM_HEAD_RE.match(key): + lm_head[m.group(1)] = val + elif m := NORM_RE.match(key): + norm[m.group(1)] = val + elif m := LAYERS_RE.match(key): + layer_idx, suffix = m.groups() + weight_info[int(layer_idx)][suffix] = val + else: + raise RuntimeError(f"Unexpected key {key}") + num_layers = max(weight_info.keys()) + 1 + if not (set(weight_info.keys()) == set(range(num_layers))): + raise RuntimeError("Inconsistent weights") + return ModelMetadata(path, weight_info, embed_tokens, lm_head, norm, config) + + # ===== Handle termination signals + + def handle_terminate_signal(self): + # Handle worker/api shutdown + if self._api_running: + self.stop_api() + else: + dprint("No workers to shut down. Terminating.\n") + sys.exit() + + # ===== Handle Shard worker servers + + # TODO: Redirect output logs to different files + def handle_start_worker(self): + bin = os.path.join(".venv", "bin", "piped_mlx_ring_shard") + cmd = ["uv", "run", bin] + cmd.append(f" -m {self.state.model}") + cmd.append(f" -p {self.state.running_port}") + cmd.append(f" -httpport {self.state.running_httpport}") + cmd.append(f" -l [{0}]") + cmd.append(f" --prefetch-window {2}") + proc = subprocess.Popen(cmd) + + self.state.running_port += 1 # increment the running port + self.state.running_httpport += 1 + + # ===== Handle API server + + async def _api_main(self) -> None: # main thread loop + #logging.disable(logging.CRITICAL) + self._api_loop = asyncio.get_running_loop() + self._api_shutdown_e = asyncio.Event() + self._node = RingApiNode( + http_port=self.state.api_http_port, + grpc_port=self.state.api_grpc_port + ) + + try: + await self._node.start(shutdown_trigger=self._api_shutdown_e.wait) + self._api_searching.set() + self._api_running.set() + self._api_ready.set() + await self._api_shutdown_e.wait() + except Exception as e: + self._api_exc = e + self._api_running.set() + self._api_ready.set() + finally: + try: + await self._node.shutdown() + except Exception: + pass + self._api_running.clear() + + def _api_running_loop(self): + try: + asyncio.run(self._api_main()) + except BaseException as e: + self._api_exc = e + self._api_ready.set() + self._api_running.clear() + + def start_api(self, http_port: int=8080, grpc_port: int=50500, timeout=10): + if self._api_thread and self._api_thread.is_alive(): return + self._api_exc = None + self._api_ready.clear() + self._api_running.clear() + self._api_thread = threading.Thread(target=self._api_running_loop, name="api_server", daemon=True) + self._api_thread.start() + if not self._api_ready.wait(timeout): + raise RuntimeError("API Server Timeout.") + if self._api_exc is not None: + raise RuntimeError(f"API Server failed to start: {self._api_exc}") + # Register REPL aggregator callback on the API node + try: + self.api_call("set_trace_ingest_callback", self._trace_agg.enqueue, timeout=5) + except Exception: + pass + + # Silence API server logs on the REPL console: drop records emitted from the API thread + try: + class _DropApiOnConsole(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + # Only drop records coming from the API thread so other threads keep logging + tname = getattr(record, "threadName", "") or "" + return tname != "api_server" + + root = logging.getLogger() + for h in list(root.handlers): + if isinstance(h, logging.StreamHandler) and getattr(h, "stream", None) in (sys.stdout, sys.stderr): + if not any(isinstance(f, _DropApiOnConsole) for f in getattr(h, "filters", [])): + h.addFilter(_DropApiOnConsole()) + + # Also quiet Hypercorn logs explicitly (HTTP server used by API) + logging.getLogger("hypercorn").setLevel(logging.CRITICAL) + logging.getLogger("hypercorn.error").setLevel(logging.CRITICAL) + logging.getLogger("hypercorn.access").setLevel(logging.CRITICAL) + except Exception: + pass + + def stop_api(self, timeout: float = 5.0) -> None: + if not self._api_thread: return + if self._api_loop and self._api_shutdown_e: + self._api_loop.call_soon_threadsafe(self._api_shutdown_e.set) + if self._api_loop and self._node: + f = asyncio.run_coroutine_threadsafe(self._node.shutdown(), self._api_loop) + try: + f.result(timeout=timeout) + except Exception: + pass + self._api_thread.join(timeout=timeout) + self._api_thread = None + self._api_running.clear() + self._api_ready.clear() + + # Call an API function from the REPL thread + def api_call( self, method: str, *args: Any, timeout: float=30.0, **kwargs: Any) -> Any: + if not self._api_loop or not self._node: + raise RuntimeError("API Thread not set up correctly.") + + target = getattr(self._node, method, None) + if target is None: + raise AttributeError(f"RingApiNode has no method {method}") + + # method is async + if inspect.iscoroutinefunction(target): + coroutine = target(*args, **kwargs) + f = asyncio.run_coroutine_threadsafe(coroutine, self._api_loop) + return f.result(timeout) + + f = concurrent.futures.Future() # method is sync + + def runner(): + try: + ret = target(*args, **kwargs) + if inspect.isawaitable(ret): + + async def _await_then_set(): + try: + val = await res + f.set_result(val) + except BaseException as e: + f.set_exception(e) + + asyncio.create_task(_await_then_set()) + else: + f.set_result(ret) + except BaseException as e: + f.set_exception(e) + try: + self._api_loop.call_soon_threadsafe(runner) + return f.result(timeout) + except Exception as e: + raise + + # ------- Trace aggregation helpers + + def do_trace(self, cmd): + if len(cmd) < 2: + dprint(f"Tracing is currently {"ON" if self._trace_cfg.enabled else "OFF"}\n") + return + + match cmd[1]: + case s if s in ["on", "ON"]: + self._tracing.set() + self._trace_cfg.enabled = True + dprint("Tracing is now ON\n") + + case s if s in ["off", "OFF"]: + self._tracing.clear() + self._trace_cfg.enabled = False + dprint("Tracing is now OFF\n") + + case s if s in ["status"]: + dprint(f"Captured {len(self._trace_agg._req)} frames.\n") + + case s if s == "focus": + dprint("Subsystems not yet implemented.\n") + + case s if s == "stream": + if len(cmd) == 2: + dprint(f"Trace is {"streaming to file: "+str(self._trace_cfg.file) if self._trace_cfg.streaming else "not streaming."}\n") + elif cmd[2] == "on": + self._trace_cfg.streaming = True + dprint(f"Streaming trace frames to {self._trace_cfg.file}\n") + elif cmd[2] == "off": + self._trace_cfg.streaming = False + dprint("Trace streaming is OFF.\n") + + case s if s == "set": + if len(cmd) == 2: + dprint("Use: trace set [BUDGET], eg. 2000\n") + else: + dprint("Not implemented yet\n") + + case s if s == "annotate": + if len(self._trace_agg._req) < 1: + dprint("No trace frames captured. Is tracing enabled?\n") + return + last = list(self._trace_agg._req.keys())[-1] + self.print_trace_annotate(last) + + case _: + dprint("Unknown trace command. Type 'help' for a list of available commands.\n") + + if self._api_running.is_set() and self._api_ready.is_set(): + self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards + + # Performance trackers + def do_perf(self, cmd): + if len(cmd) < 2 or cmd[1] == "stat": + dprint("Runtime performance metrics are ON by default.\n") + dprint("Turn tracking off with 'perf off'. Do 'perf stat' for statistics on previous requests or 'help' for more commands.\n\n") + return + + match cmd[1]: + case s if s in "stats": + self._stats_agg.stats() + pass + case _: + pass + + # Trace callback registered with API Thread + # This forwards the tracer frames back to the REPL for printing + def __trace_cb(self, data): + try: + if self._tracing.is_set(): + self._trace_agg.enqueue(data) + if self._stats.is_set(): + self._stats_agg.add(data) + except Exception as e: + print(f"Unable to ingest trace buffer into REPL: {e}") + + def __print_tr(self, row): + sym = " " + symbol.ljust(40, ' ') + pms = f"{ms:.10}".ljust(10, ' ') + cns = f"{counts}".ljust(4, ' ') + sys.stdout.write(f"{sym} {pms} {cns}\n") + + def print_trace_annotate( + self, + run_id: str = "run", + mapping: Optional[Dict[str, str]] = None, + repeats: int = 0, + ) -> List[Dict[str, Any]]: + + try: + rows = self._trace_agg.annotate(run_id) + logger.debug(f"rows") + headers = ["name", "total","max","mean","p50","p90","p99","samples"] + limits = {"name": 50,} + w = {h: max(len(h), min(limits.get(h, 8), max(len(str(r[h])) for r in rows))) for h in headers} + w["name"] = max(w["name"], 35) + + line = " ".join(h.ljust(w[h]) for h in headers); sys.stdout.write("\n") + sys.stdout.write(line + "\n") + sys.stdout.write(" ".join("."*w[h] for h in headers)); sys.stdout.write("\n") + for r in rows: + name = str(r["name"]) + if len(name) > w["name"]: name = name[:w["name"]-1] + "..." + vals = { + "name": r["name"], + "total": r["total"], + "max": r["max"], + "mean": r["mean"], + "p50": r["p50"], + "p90": r["p90"], + "p99": r["p99"], + "samples": r["samples"], + } + sys.stdout.write(" " + str(vals[headers[0]]).ljust(w[headers[0]])) + sys.stdout.write(" ".join(f"{vals[h]:8.2f}".rjust(w[h]) for h in headers[1:])) + sys.stdout.write("\n") + sys.stdout.write("\n\n") + sys.stdout.flush() + except Exception as e: + logger.error(f"{e}") + + def _print_nodes_table(self, rows: List[Any]) -> None: + headers = ["name", "role", "addr", "http", "grpc", "status", "head"] + limits = {"name": 36, "addr": 15} + w = {h: max(len(h), min(limits.get(h, 8), max(len(str(r[h])) for r in rows))) for h in headers} + line = " ".join(h.ljust(w[h]) for h in headers) + sys.stdout.write("\n") + sys.stdout.write(line + "\n") + sys.stdout.write(" ".join("-"*w[h] for h in headers)) + sys.stdout.write("\n") + for r in rows: + name = str(r["name"]) + addr = str(r["addr"]) + if len(name) > w["name"]: name = name[:w["name"]-1] + "..." + if len(addr) > w["addr"]: addr = addr[:w["addr"]-1] + "..." + vals = { + "name": name, + "role": r["role"], + "addr": addr, + "http": r["http"], + "grpc": r["grpc"], + "status": "yes" if r["status"] else "no", + "head": "head" if r["head"] else "no", + } + sys.stdout.write(" ".join(str(vals[h]).ljust(w[h]) for h in headers)) + sys.stdout.write("\n") + sys.stdout.write("\n\n") + sys.stdout.flush() + + + # Print a table of discovered nodes + def print_mdns_nodes(self) -> None: + try: + shards = self.api_call("_get_shards_from_discovery", timeout=10) + if not shards: + dprint("No worker nodes discovered. Is the API searching?\n") + return + + rows = [] + for name, props in shards.items(): + addr = getattr(props, "local_ip", getattr(props, "host", "")) + http = getattr(props, "server_port", 0) + grpc = getattr(props, "shard_port", 0) + busy = bool(getattr(props, "is_busy", False)) + head = bool("127.0.0.1" and "127.0.0.1" == addr) # TODO: FIX + rows.append({ + "name": name, + "role": "worker", + "addr": addr, + "http": http, + "grpc": grpc, + "status": busy, + "head": head, + }) + self._print_nodes_table(rows) + except Exception as e: + dprint(f"Could not list nodes: {e}\n") + + def print_topo(self, topo): + line = "="*20+" Topology " + "="*20 + sys.stdout.write(f"{line}\nModel: {topo.model}\nLayers: {topo.num_layers}\n") + sys.stdout.write(f"Devices: {topo.devices}\n\n") + # TODO: Better print here + + def prepare_topo(self, model): + req = PrepareTopologyRequest(model=model) + try: + topo = self.api_call("_handle_prepare_topology", req, timeout=120) + except Exception as e: + dprint(f"Unable to create topology: {e}\n\n") + return False + self.state.topo = topo + #self.print_topo(topo) + return True + + def load_model(self, model): + req = APILoadModelRequest(model=model) + try: + res = self.api_call("_handle_load_model", req, timeout=30) + return True + except Exception as e: + dprint(f"Failed to load model: {e}\n\n") + return False + + # ===== Handle chat + + def _p_msg(self, msg: str, role: str): + match role: + case "user": + ps = f"\n@\033[97;1m{role}\033[0m".rjust(10) + sys.stdout.write(f"{ps} > ") + sys.stdout.flush() + return + case "llm" | "think": + ps = f"@\033[97;1m{role}\033[0m".rjust(10) + sys.stdout.write(f"{ps} > ") + sys.stdout.flush() + return + case _: + ps = f"@\033[1m{role}\033[0m".rjust(10) + sys.stdout.write(f"{ps} > {msg}\n") + + def _chat_loop(self, ci: ChatInterface): + while True: + self._p_msg(" ", "user") + prompt = sys.stdin.readline().strip() + + match prompt: # Match meta commands, else prompt + case s if s in [".quit", ".q"]: + break + case _: + msg = ChatMessage(role="user", content=prompt) + ci.messages.append(msg) + + req = ChatRequestModel( + messages=ci.messages, + model=self.state.model, + max_tokens=5000, + ) + # eval async generator in api thread + agen = self.api_call("_stream_chat", req) + msg = [] + #self._p_msg("", "think") + self._p_msg("", "llm") + while True: + fut = asyncio.run_coroutine_threadsafe(agen.__anext__(), self._api_loop) + line = fut.result() + if not line.startswith("data: "): + sys.stdout.write(f"[stream error] Invalid data returned from API thread.\n") + break + payload = line[6:].strip() + if payload == "[DONE]": + dprint("\n") + break + obj = json.loads(payload) + choices = obj.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + text = delta.get("content", "") + match text: + case "": + sys.stdout.write("\033[90;3m") + case "": + sys.stdout.write("\033[0m") + case "\n": + pass + case _: + sys.stdout.write(text) + sys.stdout.flush() + msg.append(text) + + # TODO: Capture usage and metrics + + def do_chat(self, cmd): + if len(cmd) < 2: + if not self.state.model or self.state.model == "": + self.prompt_model() + sys.stdout.write(""+"-"*80+"\n\n") + if not hasattr(self.state, "topology"): + self._p_msg("Initializing topology", "system") + if not self.prepare_topo(self.state.model): + raise RuntimeError("Unable to create topology.") + self._p_msg("Loading weights into memory", "system") + if not self.load_model(self.state.model): + raise RuntimeError("Unable to load model.") + + self._p_msg("New session initialized. Welcome :3", "system") + ci = ChatInterface(messages=[]) + self._chat_loop(ci) + + # Start default chat with selected model + pass + pass + + # ===== Handle shutdown + + def handle_shutdown(self): + os.system("pkill -9 -f piped_mlx") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", "-m", type=str, help="HF Repository of target model") + parser.add_argument("--local-nodes", "-n", type=int, help="Number of local worker nodes") + args = parser.parse_args() + + #workers = args.workers + #model = args.model + + repl = REPL() + repl.loop()