From 4da88b8374dd9d6baa9edb012e5f0abd47480191 Mon Sep 17 00:00:00 2001 From: erhant Date: Wed, 5 Nov 2025 16:37:33 +0300 Subject: [PATCH 001/172] initial parallelization stuff --- src/dnet/ring/api/node.py | 198 +++++++++++++++++++++++++++------- src/dnet/ring/shard/models.py | 21 +++- src/dnet/ring/shard/node.py | 40 ++++--- 3 files changed, 197 insertions(+), 62 deletions(-) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 270cb6eb..8e6ec365 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -42,6 +42,7 @@ from ...utils.logger import logger from ...utils.banner import print_startup_banner +from ...utils.latency import calculate_median_latency_seconds from ...utils.model import ( ModelMetadata, get_model_metadata, @@ -74,6 +75,8 @@ UnloadModelResponse, ) from ..shard.models import ( + MeasureLatencyRequest, + MeasureLatencyResponse, ShardProfileRequest, ShardLoadModelRequest, ShardLoadModelResponse, @@ -996,10 +999,10 @@ async def _collect_shard_profiles( # Find Thunderbolt connections all_thunderbolts = discover_all_thunderbolt_connections(shards) - # Call each shard's /profile endpoint - # FIXME: do this in parallel - shard_profiles: Dict[str, DeviceProfile] = {} async with httpx.AsyncClient() as client: + # Step 1: Health check all shards in parallel + logger.info("Starting health checks for all shards...") + health_tasks, shard_list = [], [] for shard_name, shard_props in shards.items(): if shard_props.is_manager: logger.warning( @@ -1007,52 +1010,169 @@ async def _collect_shard_profiles( ) continue + shard_list.append((shard_name, shard_props)) server_port, server_ip = shard_props.server_port, shard_props.local_ip + health_url = f"http://{server_ip}:{server_port}/health" + health_tasks.append(client.get(health_url, timeout=5.0)) - try: - shard_url = f"http://{server_ip}:{server_port}/profile" - logger.info( - "Calling /profile endpoint for shard %s at %s", - shard_name, - shard_url, - ) + health_results = await asyncio.gather(*health_tasks, return_exceptions=True) - response = await client.post( - shard_url, - json=ShardProfileRequest( - repo_id=repo_id, - thunderbolts=all_thunderbolts.get(shard_name, {}), - payload_sizes=payload_sizes, - max_batch_exp=max_batch_exp, - devices=shards, - ).model_dump(), - timeout=1000.0, - ) + # Filter healthy shards + healthy_shards = [] + for (shard_name, shard_props), health_result in zip(shard_list, health_results): + if isinstance(health_result, Exception): + logger.warning("Health check failed for %s: %s", shard_name, health_result) + continue + if health_result.status_code == 200: + healthy_shards.append((shard_name, shard_props)) + logger.info("Health check passed for %s", shard_name) + else: + logger.warning("Health check failed for %s: status %s", shard_name, health_result.status_code) + + logger.info("Healthy shards: %d/%d", len(healthy_shards), len(shard_list)) + + if not healthy_shards: + logger.error("No healthy shards found!") + return {}, all_thunderbolts + + # Step 2: Measure latencies on all healthy shards in parallel + logger.info("Measuring latencies for all healthy shards...") + latency_tasks = [] + for shard_name, shard_props in healthy_shards: + server_port, server_ip = shard_props.server_port, shard_props.local_ip + latency_url = f"http://{server_ip}:{server_port}/measure_latency" + latency_request = MeasureLatencyRequest( + devices=shards, + thunderbolts=all_thunderbolts.get(shard_name, {}), + payload_sizes=payload_sizes, + ) + latency_tasks.append( + client.post(latency_url, json=latency_request.model_dump(), timeout=1000.0) + ) - if response.status_code == 200: - profile_data = ShardProfileResponse.model_validate( - response.json() - ) - profile = load_device_profile_from_dict(profile_data.profile) - logger.info( - "Successfully collected profile from %s", shard_name - ) + latency_results = await asyncio.gather(*latency_tasks, return_exceptions=True) - # Mark head device (same local IP as API) - if shard_props.local_ip == this_device.local_ip: - profile.is_head = True + # Store latency data for each shard + shard_latencies = {} + final_healthy_shards = [] - # FIXME: DeviceProfileInfo to DeviceProfile should be better - shard_profiles[shard_name] = profile + for (shard_name, shard_props), latency_result in zip(healthy_shards, latency_results): + if isinstance(latency_result, Exception): + logger.warning( + "Latency measurement failed for %s: %s", shard_name, latency_result + ) + continue + else: + if latency_result.status_code == 200: + latency_data = MeasureLatencyResponse.model_validate(latency_result.json()) + shard_latencies[shard_name] = latency_data.latency + final_healthy_shards.append((shard_name, shard_props)) + logger.info("Latency measurement succeeded for %s", shard_name) else: - logger.error( - "Failed to get profile from %s: %s", + logger.warning( + "Latency measurement failed for %s: status %s", shard_name, - response.status_code, + latency_result.status_code, ) - except Exception as e: - logger.exception("Error calling /profile for %s: %s", shard_name, e) + logger.info("Latencies collected from %d shards", len(shard_latencies)) + + if not final_healthy_shards: + logger.error("No shards with successful latency measurements!") + return {}, all_thunderbolts + + # Step 3: Group healthy shards by local_ip (same device) + shards_by_device: Dict[str, List[Tuple[str, DnetDeviceProperties]]] = {} + for shard_name, shard_props in final_healthy_shards: + local_ip = shard_props.local_ip + if local_ip not in shards_by_device: + shards_by_device[local_ip] = [] + shards_by_device[local_ip].append((shard_name, shard_props)) + + logger.info("Grouped shards into %d devices", len(shards_by_device)) + + # Step 4: Profile devices (parallel per device, sequential per shard within device) + async def profile_device_shards( + device_ip: str, device_shards: List[Tuple[str, DnetDeviceProperties]] + ) -> List[Tuple[str, DeviceProfile]]: + """Profile all shards on a single device sequentially.""" + profiles = [] + + for shard_name, shard_props in device_shards: + try: + server_port, server_ip = shard_props.server_port, shard_props.local_ip + profile_url = f"http://{server_ip}:{server_port}/profile" + + logger.info( + "Calling /profile endpoint for shard %s at %s", + shard_name, + profile_url, + ) + + response = await client.post( + profile_url, + json=ShardProfileRequest( + repo_id=repo_id, + thunderbolts=all_thunderbolts.get(shard_name, {}), + payload_sizes=payload_sizes, + max_batch_exp=max_batch_exp, + devices=shards, + ).model_dump(), + timeout=1000.0, + ) + + if response.status_code == 200: + profile_data = ShardProfileResponse.model_validate(response.json()) + profile = load_device_profile_from_dict(profile_data.profile) + + # Mark head device (same local IP as API) + if shard_props.local_ip == this_device.local_ip: + profile.is_head = True + + profiles.append((shard_name, profile)) + logger.info("Successfully collected profile from %s", shard_name) + else: + logger.error( + "Failed to get profile from %s: %s", + shard_name, + response.status_code, + ) + + except Exception as e: + logger.exception("Error calling /profile for %s: %s", shard_name, e) + + return profiles + + # Run profiling for all devices in parallel + device_tasks = [ + profile_device_shards(device_ip, device_shards) + for device_ip, device_shards in shards_by_device.items() + ] + device_results = await asyncio.gather(*device_tasks, return_exceptions=True) + + # Step 5: Merge latency data into device profiles + shard_profiles: Dict[str, DeviceProfile] = {} + + for device_result in device_results: + if isinstance(device_result, Exception): + logger.error("Device profiling failed: %s", device_result) + continue + + for shard_name, profile in device_result: + # Set t_comm using median latency + if shard_name in shard_latencies: + median_latency = calculate_median_latency_seconds(shard_latencies[shard_name]) + if median_latency is not None: + profile.t_comm = float(median_latency) + logger.info( + f"Set t_comm for {shard_name} to median latency: {profile.t_comm:.6f}s" + ) + else: + logger.warning( + f"No valid latency measurements for {shard_name}, keeping default t_comm" + ) + + shard_profiles[shard_name] = profile logger.info("Collected profiles from %d shards", len(shard_profiles)) return shard_profiles, all_thunderbolts diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index cbd9610e..bb26c4d5 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -69,9 +69,28 @@ class ShardProfileRequest(BaseModel): class ShardProfileResponse(BaseModel): - """Response from device profiling and latency measurement.""" + """Response from device profiling.""" profile: Dict[str, Any] = Field(..., description="Device profile information") + + +class MeasureLatencyRequest(BaseModel): + """Request to measure latency to other devices.""" + + devices: Dict[str, DnetDeviceProperties] = Field( + ..., description="Device information mapping" + ) + thunderbolts: Dict[str, ThunderboltConnection] = Field( + default={}, description="Thunderbolt connection information from this device" + ) + payload_sizes: List[int] = Field( + default=[1024], description="Payload sizes to test for latency measurement" + ) + + +class MeasureLatencyResponse(BaseModel): + """Response from latency measurement.""" + latency: LatencyResults = Field(..., description="Latency measurement results") diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 5893748a..2ef8c1e8 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -19,7 +19,6 @@ from dnet_p2p import DnetP2P, DnetDeviceProperties -from dnet.utils.latency import calculate_median_latency_seconds from dnet.utils.serialization import tensor_to_bytes from .servicer import ShardServicer @@ -27,6 +26,8 @@ from .models import ( HealthResponse, + MeasureLatencyRequest, + MeasureLatencyResponse, ShardLoadModelRequest, ShardLoadModelResponse, ShardProfileRequest, @@ -1415,35 +1416,30 @@ async def health() -> HealthResponse: async def profile(req: ShardProfileRequest) -> ShardProfileResponse: logger.info("Received /profile request") try: - # Measure latencies - latency_results = await self._measure_latency_to_devices( - req.devices, req.thunderbolts, req.payload_sizes - ) - # Profile device using dperf device_profile = await self._profile_device( req.repo_id, req.max_batch_exp ) - # Overwrite `t_comm` with median latency (subprocess returns a dict) - median_latency = calculate_median_latency_seconds(latency_results) - if median_latency is not None: - device_profile["t_comm"] = float(median_latency) - logger.info( - f"Set t_comm to median latency: {device_profile['t_comm']:.6f}s" - ) - else: - logger.warning( - "No valid latency measurements, keeping default t_comm" - ) + return ShardProfileResponse(profile=device_profile) + except Exception as e: + logger.error(f"Error in /profile endpoint: {e}") + raise - # Return the dict payload directly - return ShardProfileResponse( - profile=device_profile, - latency=latency_results, + @self.app.post("/measure_latency") + async def measure_latency( + req: MeasureLatencyRequest, + ) -> MeasureLatencyResponse: + logger.info("Received /measure_latency request") + try: + # Measure latencies to other devices + latency_results = await self._measure_latency_to_devices( + req.devices, req.thunderbolts, req.payload_sizes ) + + return MeasureLatencyResponse(latency=latency_results) except Exception as e: - logger.error(f"Error in /profile endpoint: {e}") + logger.error(f"Error in /measure_latency endpoint: {e}") raise @self.app.post("/load_model") From de9fdb68d84193b1063df016f4088a2faca0be8c Mon Sep 17 00:00:00 2001 From: erhant Date: Fri, 7 Nov 2025 17:13:54 +0300 Subject: [PATCH 002/172] add type (smol commit) --- src/dnet/ring/api/node.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 8e6ec365..b0f09d14 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -1002,7 +1002,8 @@ async def _collect_shard_profiles( async with httpx.AsyncClient() as client: # Step 1: Health check all shards in parallel logger.info("Starting health checks for all shards...") - health_tasks, shard_list = [], [] + health_tasks: list[asyncio._CoroutineLike[httpx.Response]] = [] + shard_list: list[tuple[str, DnetDeviceProperties]] = [] for shard_name, shard_props in shards.items(): if shard_props.is_manager: logger.warning( From 481c76625719aaa7b9e8b1ef26842ff904b3c479 Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 15 Oct 2025 18:54:49 -0700 Subject: [PATCH 003/172] repl sketch --- src/repl.py | 295 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 295 insertions(+) create mode 100644 src/repl.py diff --git a/src/repl.py b/src/repl.py new file mode 100644 index 00000000..87b1fe77 --- /dev/null +++ b/src/repl.py @@ -0,0 +1,295 @@ + +import os +import sys +import cmd +import argparse +import subprocess +from dataclasses import dataclass + +from src.ring.api import run as run_api_node +from src.ring.shard import run as run_shard_node +from src.util import ( + ModelMetadata, + NodeAddress, + logger, + get_model_metadata, + load_api_layer_weights, + get_safetensor_details, + create_generate_step_for_ring_with_grpc, +) + +# Handle restricted repos +from importlib import import_module +import huggingface_hub as hb +from huggingface_hub import snapshot_download, try_to_load_from_cache +try: + hf_errors = import_module("huggingface_hub.errors") +except ModuleNotFoundError: + hf_errors = import_module("huggingface_hub.utils") +GatedRepoError = getattr(hf_errors, "GatedRepoError", Exception) +HfHubHTTPError = getattr(hf_errors, "HfHubHTTPError", Exception) + +from src.ring.api_node import RingApiNode + +def dprint(msg): + sys.stdout.write(msg) + sys.stdout.flush() + + +@dataclass +class REPLState: + model: str = "NULL" + model_info: ModelMetadata = None, + num_local_nodes: int = 1 + running_port = 50501 + running_httpport = 8091 + api_addr_host: str = "10.0.0.2" # TODO: Don't hardcode + api_addr_port: int = 0 + grpc_listen_port:int = 0 + window_size = 2 # Number of layers per node per visit (also number resident in cache) + +class REPL(cmd.Cmd): + + PS1 = "dnet > " + WELCOME = "\nDNET Distributed Inference Engine, v0.1\nExperimental software. Enter '.help' for usage hints.\n\n" + def __init__(self, model="NULL", nodes=1): + super().__init__() + self.state = REPLState() + self.state.model = model + + self.state.api_addr_port = self.state.running_port + self.state.grpc_listening_port = self.state.running_port + 1 + self.state.running_port += 2 + self.discovery = None + + # TODO: Maybe have a 'start search' 'stop search' cmds to manage discovery + + self.api = None + #self.config_api_node() + #self.start_api_discovery() + + assert nodes >= 1 and nodes < 10, "Invalid number of local nodes. Must be 0 < num < 10." + self.state.num_local_nodes = nodes + + def loop(self): + self.greeting() + while True: + + #if self.state.model == "NULL": + # self.prompt_model() + # continue + + dprint(self.PS1) + cmd = sys.stdin.readline().strip() + + if cmd == "": + self.print_state() + if cmd in [".exit", "exit", "quit", "q"]: + self.handle_terminate_signal() + if cmd in [".help", "help", "h"]: + self.print_help() + if cmd.startswith((".model", "model", "m")): + cmd.split(" ") + path = self._handle_model_pull(cmd[1]) + if path: + self.state.model = path + + def greeting(self): + sys.stdout.write(self.WELCOME) + + def print_help(self): + def _print_hf(cmd, desc, examples=[""]): + pcmd = " " + cmd.ljust(30, '.') + dprint(f"{pcmd} {desc}\n") + for e in examples: + pex = e.rjust(len(e)+35)+"\n" if e != "" else "" + dprint(f"{pex}") + + dprint("Command Options:\n") + _print_hf("nodes [VALUE]", "Set the number of local worker nodes") + _print_hf("model [REPO]", "Set the target model. [REPO] must be a valid repository", + ["Examples > model meta-llama/Meta-Llama-3-8B"]) + _print_hf("limit [RESOURCE] [VALUE]", "Set a higher limit for a system resource.", + ["Examples > limit memory 12000 (MB)", + " > limit CPU_CORE_COUNT 4", + " > limit GPU_SM 128"]) + _print_hf("log [LEVEL]", "Set the logging level.") + dprint("\n Building a topology:\n") + _print_hf("search [ON/OFF]", "Toggle mDNS worker node search across the local network.") + _print_hf("topo [AUTO/SETUP]", "Toggle between automatic and manual topology creation.") + _print_hf("topo add [NODE]", "Add [NODE] to the topology.") + _print_hf("topo remove [NODE]", "Add [NODE] to the topology.") + dprint("\n Building a schedule:\n") + _print_hf("sched create", "Automatic search for best schedule given the active topology and the loaded model.") + _print_hf("sched assign [LAYER] [NODE]", "Assign the layer with index [LAYER] to [NODE].", + ["Example > sched assign 10 benny_234"]) + _print_hf("schedule assign [START-END] [NODE]", "Assign the layer range between [START] and [END] to [NODE].", + ["Example > sched assign 0-12 benny_234"]) + dprint("\n Benchmarking and profiling:\n") + _print_hf("profile [REPO]", "Estimate the total FLOPS of the model from [REPO]") + _print_hf("bench [REPO]", "Benchmark the system using the model from [REPO]") + _print_hf("bench [KERNEL]", "Behcnmark the system using base kernel [KERNEL]") + _print_hf("bench [NODE]", "Behcnmark the network latency between the current system and [NODE]") + _print_hf("bench [NODE0] [NODE1]", "Behcnmark the network latency between [NODE0] and [NODE11") + _print_hf("bench ", "Behcnmark the system using base library kernels") + dprint("\n") + + def print_state(self): + dprint("Network state:\n") + dprint(f"{("Model".ljust(20)): >10}: {self.state.model}\n") + dprint(f"{("Local workers".ljust(20)): >10}: {self.state.num_local_nodes}\n") + + + # ===== Handle Model input and pull from server + + def prompt_model(self): + while True: + dprint("Target model > ") + model = sys.stdin.readline().strip() + try: + path = self._handle_model_pull(model) + return path + #self.model_info = ModelMetadata() + except Exception as e: + dprint(f"Unable to load model {model}. Target needs to be a valid HF repository. Try again:{e}\n") + + # Read HF access token + def _resolve_hf_token(self): + dprint("Ener the HuggingFace access token > ") + tok = sys.stdin.readline().strip() + return tok + + # Require a HF access token for restricted repositories + # Ask user for HF access token until they have a valid one + def _handle_model_pull(self, repo_path): + try: + path = try_to_load_from_cache(repo_path) + if path is None: + dprint(f"Model {repo_path} not found in local cache\n") + path = get_model_path(repo_path) + self.state.model = repo_path + return path + except hb.errors.HTTPError: + dprint(f"Repository {repo_path} not found in Hugging Face registry.") + return Null + except GatedRepoError as e: + dprint("Restricted model.\n") + tok = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") + while True: + tok = self._resolve_hf_token() + print(tok) + try: + ret = snapshot_download(repo_id=repo_path, token=tok) + return ret + except GatedRepoError as e: + print(e) + continue + except Exception as e: + raise RuntimeError(f"Unknown error during HF snapshot_download") + except Exception as e: + raise RuntimeError(f"Unable to pull model {repo_path} locally") + + def _parse_model_metadata(self, model_path): + if isinstance(model_path, tuple): + model_path = model_path[0] if model_path else None + if model_path is None: + raise ValueError("Could not resolve model path {model_path}") + path = Path(model_path) if not isinstance(model_path, Path) else model_path + + with open(path / "config.json", "r") as f: + config = json.load(f) + + weight_files = glob.glob(oath / "*.safetensors") + weight_info = Dict[int, Dict[str, Any]] = defaultdict(dict) + embed_tokens, lm_head, norm = {}, {}, {} + for weight in weight_files: + details = get_safetensor_details(weight) + for key, val in details.items(): + if m := EMBED_TOKENS_RE.match(key): + embed_tokens[m.group(1)] = val + elif m := LM_HEAD_RE.match(key): + lm_head[m.group(1)] = val + elif m := NORM_RE.match(key): + norm[m.group(1)] = val + elif m := LAYERS_RE.match(key): + layer_idx, suffix = m.groups() + weight_info[int(layer_idx)][suffix] = val + else: + raise RuntimeError(f"Unexpected key {key}") + num_layers = max(weight_info.keys()) + 1 + if not (set(weight_info.keys()) == set(range(num_layers))): + raise RuntimeError("Inconsistent weights") + return ModelMetadata(path, weight_info, embed_tokens, lm_head, norm, config) + + # ===== Handle termination signals + + def handle_terminate_signal(self): + # Handle worker/api shutdown + dprint("No workers to shut down. Terminating.\n") + sys.exit() + + # ===== Handle Shard worker servers + + # TODO: Redirect output logs to different files + def handle_start_worker(self): + bin = os.path.join(".venv", "bin", "piped_mlx_ring_shard") + cmd = ["uv", "run", bin] + cmd.append(f" -m {self.state.model}") + cmd.append(f" -p {self.state.running_port}") + cmd.append(f" -httpport {self.state.running_httpport}") + cmd.append(f" -l [{0}]") + cmd.append(f" --prefetch-window {2}") + proc = subprocess.Popen(cmd) + + self.state.running_port += 1 # increment the running port + self.state.running_httpport += 1 + + # ===== Handle API server + + def handle_device_discovery(self): + from socket import gethostname + from secrets import token_hex + + hostname = gethostname() + instance = f"api-{token_hex(4)}-{hostname}" + lib = DnetP2P("lib/dnet-p2p/lib") + + """ + self.discovery = lib.create_instance( + instance, hostname, + self.state.p2p_addr.host, self.state.p2p_addr_port, + self.state.grpc_listen_port, is_manager=True + ) + self.discovery.start() + """ + + def config_api_node(self): + api_address = NodeAddress(self.state.api_addr_host, self.state.api_addr_port) + self.api = RingApiNode(api_address, shard_address.format(), model_metadata) + + def start_api_discovery(self): + if self.api: + self.api._start_discovery() + + # Calls dsolver and optimizes topology + async def build_topology(self): + if self.api: + topo = await self.api.topology() + return topo + + # ===== Handle shutdown + + def handle_shutdown(self): + os.system("pkill -9 -f piped_mlx") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", "-m", type=str, help="HF Repository of target model") + parser.add_argument("--local-nodes", "-n", type=int, help="Number of local worker nodes") + args = parser.parse_args() + + #workers = args.workers + #model = args.model + + repl = REPL() + repl.loop() From 2e099c2be2aaade6a7e1cd958ef2cc2656dae320 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 16 Oct 2025 20:43:57 -0700 Subject: [PATCH 004/172] manage api server, discover and print nodes table --- src/repl.py | 384 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 299 insertions(+), 85 deletions(-) diff --git a/src/repl.py b/src/repl.py index 87b1fe77..0b95a19f 100644 --- a/src/repl.py +++ b/src/repl.py @@ -5,17 +5,22 @@ import argparse import subprocess from dataclasses import dataclass - -from src.ring.api import run as run_api_node -from src.ring.shard import run as run_shard_node -from src.util import ( - ModelMetadata, - NodeAddress, - logger, - get_model_metadata, +from typing import Optional, List, Any + +import asyncio +import inspect +import threading +import concurrent.futures + +from dnet.ring.api import RingApiNode +from dnet.ring.shard import RingShardNode +from dnet.utils.network import NodeAddress +from dnet.utils.logger import logger +from dnet.utils.model import ( + ModelMetadata, + get_model_metadata, load_api_layer_weights, get_safetensor_details, - create_generate_step_for_ring_with_grpc, ) # Handle restricted repos @@ -29,13 +34,11 @@ GatedRepoError = getattr(hf_errors, "GatedRepoError", Exception) HfHubHTTPError = getattr(hf_errors, "HfHubHTTPError", Exception) -from src.ring.api_node import RingApiNode def dprint(msg): sys.stdout.write(msg) sys.stdout.flush() - @dataclass class REPLState: model: str = "NULL" @@ -43,96 +46,175 @@ class REPLState: num_local_nodes: int = 1 running_port = 50501 running_httpport = 8091 - api_addr_host: str = "10.0.0.2" # TODO: Don't hardcode - api_addr_port: int = 0 - grpc_listen_port:int = 0 + api_http_port: int = 8080 + api_grpc_port: int = 50500 window_size = 2 # Number of layers per node per visit (also number resident in cache) -class REPL(cmd.Cmd): +class REPL(cmd.Cmd): PS1 = "dnet > " - WELCOME = "\nDNET Distributed Inference Engine, v0.1\nExperimental software. Enter '.help' for usage hints.\n\n" + WELCOME = "\nDNET Distributed Inference Engine, v0.1\nExperimental software. Type 'help' for usage hints.\n\n" + def __init__(self, model="NULL", nodes=1): + assert nodes >= 1 and nodes < 10, "Invalid number of local nodes. Must be 0 < num < 10." + super().__init__() self.state = REPLState() self.state.model = model - - self.state.api_addr_port = self.state.running_port - self.state.grpc_listening_port = self.state.running_port + 1 self.state.running_port += 2 - self.discovery = None - - # TODO: Maybe have a 'start search' 'stop search' cmds to manage discovery - - self.api = None - #self.config_api_node() - #self.start_api_discovery() - - assert nodes >= 1 and nodes < 10, "Invalid number of local nodes. Must be 0 < num < 10." self.state.num_local_nodes = nodes - def loop(self): - self.greeting() - while True: + self._node: Optional[RingApiNode] = None + self._api_thread: Optional[threading.Thread] = None + self._api_ready = threading.Event() + self._api_running = threading.Event() + self._api_loop: Optional[asyncio.AbstractEventLoop] = None + self._api_shutdown_e: Optional[asyncio.Event] = None + self._api_exc: Optional[BaseException] = None - #if self.state.model == "NULL": - # self.prompt_model() - # continue + self._api_searching = threading.Event() # Track mDNS searching + def loop(self): # Main tty loop + sys.stdout.write(self.WELCOME) + while True: dprint(self.PS1) cmd = sys.stdin.readline().strip() if cmd == "": self.print_state() - if cmd in [".exit", "exit", "quit", "q"]: + elif cmd in [".exit", "exit", "quit", "q"]: self.handle_terminate_signal() - if cmd in [".help", "help", "h"]: + elif cmd in [".help", "help", "h"]: self.print_help() - if cmd.startswith((".model", "model", "m")): + + elif cmd.startswith(("api", ".api")): + self.do_api(cmd.split(" ")) + continue + elif cmd.startswith("search"): + self.do_search(cmd.split(" ")) + continue + elif cmd.startswith("nodes"): + self.print_mdns_nodes() + continue + elif cmd.startswith(("topo", ".topo")): + self.do_topo(cmd.split(" ")) + continue + elif cmd.startswith((".model", "model", "m")): cmd.split(" ") path = self._handle_model_pull(cmd[1]) if path: self.state.model = path - - def greeting(self): - sys.stdout.write(self.WELCOME) - + + def do_api(self, cmd: List[str]) -> None: + if len(cmd) < 2: + dprint("Invalid API command. Type 'help' for a list of valid commands.\n") + return + if cmd[1] in ["start", "run"]: + http_port, grpc_port = None, None + try: + http_port = cmd[2]; + grpc_port = cmd[3] + except: + pass + self.start_api( + http_port or self.state.api_http_port, + grpc_port or self.state.api_grpc_port + ) + elif cmd[1] == "stop": + self.stop_api() + elif cmd[1] == "status": + dprint("Running\n" if self._api_running else "Stopped.\n") + elif cmd[1] == "log": + dprint("Log print is not yet supported.\n") + else: + dprint("Invalid API command. Type 'help' for a list of valid commands.\n") + return + + def do_search(self, cmd: List[str]) -> None: + if len(cmd) != 2: + dprint("mDNS search is " + ("ON\n\n" if self._api_searching else "OFF\n\n")) + return + if cmd[1] == "on": + if self._api_searching: + return + if not self._api_ready and self._api_running: + dprint("Starting API Server thread.\n") + self.start_api() + self.api_call("_start_discovery", timeout=10) + self._api_searching.set() + dprint("Starting mDNS search for worker nodes.\n") + elif cmd[1] == "off": + dprint("Stop discovery not yet implemented in the API node.\n") + pass + else: + dprint("Invalid topology command. Start searchign with 'search on'.\n") + return + + def do_topo(self, cmd: List[str]) -> None: + if len(cmd) < 2: + dprint("Invalid topology command. Type 'help' for a list of valid commands.\n") + return + if cmd[1] == "search": + pass + elif cmd[1] == "auto": + pass + elif cmd[1] == "setup": + pass + elif cmd[1] == "add": + pass + elif cmd[1] in ["remove", "rm"]: + pass + return + + # TODO: standardize ANSI escape codes for easy use def print_help(self): def _print_hf(cmd, desc, examples=[""]): pcmd = " " + cmd.ljust(30, '.') - dprint(f"{pcmd} {desc}\n") + sys.stdout.write(f"{pcmd} {desc}\n") for e in examples: pex = e.rjust(len(e)+35)+"\n" if e != "" else "" - dprint(f"{pex}") + sys.stdout.write(f"{pex}") - dprint("Command Options:\n") - _print_hf("nodes [VALUE]", "Set the number of local worker nodes") + sys.stdout.write("\033[1m\nAvailable commands:\n\033[0m") + dprint("\033[1m\n Common:\n\033[0m") _print_hf("model [REPO]", "Set the target model. [REPO] must be a valid repository", ["Examples > model meta-llama/Meta-Llama-3-8B"]) - _print_hf("limit [RESOURCE] [VALUE]", "Set a higher limit for a system resource.", - ["Examples > limit memory 12000 (MB)", - " > limit CPU_CORE_COUNT 4", - " > limit GPU_SM 128"]) + _print_hf("nodes list ", "List mDNS discovered nodes.") _print_hf("log [LEVEL]", "Set the logging level.") - dprint("\n Building a topology:\n") - _print_hf("search [ON/OFF]", "Toggle mDNS worker node search across the local network.") + dprint("\033[1m\n API Server Control:\n\033[0m") + _print_hf("api start [http_port=8080] [grpc_port=50500]", "Start the API server in a separate thread. Use provided ports if given.") + _print_hf("api stop ", "Signal clean shutdown of the API server.") + _print_hf("api status ", "Prints the status of the API server.") + _print_hf("api log ", "Print latest logs to the current terminal.") + dprint("\033[1m\n Building a topology:\n\033[0m") + _print_hf("search ", "Returns the current state of mDNS search.") + _print_hf("search [on/off] ", "Toggle mDNS search across the local network.") + _print_hf("nodes list ", "List all nodes in the current topology (including local ones).") + _print_hf("nodes all ", "List all nodes (including local ones).") + _print_hf("nodes ", "List mDNS discovered nodes.") _print_hf("topo [AUTO/SETUP]", "Toggle between automatic and manual topology creation.") _print_hf("topo add [NODE]", "Add [NODE] to the topology.") _print_hf("topo remove [NODE]", "Add [NODE] to the topology.") - dprint("\n Building a schedule:\n") + sys.stdout.write("\033[1m\n Building a schedule:\n\033[0m") _print_hf("sched create", "Automatic search for best schedule given the active topology and the loaded model.") _print_hf("sched assign [LAYER] [NODE]", "Assign the layer with index [LAYER] to [NODE].", ["Example > sched assign 10 benny_234"]) _print_hf("schedule assign [START-END] [NODE]", "Assign the layer range between [START] and [END] to [NODE].", ["Example > sched assign 0-12 benny_234"]) - dprint("\n Benchmarking and profiling:\n") + sys.stdout.write("\033[1m\n Benchmarking and profiling:\n\033[0m") _print_hf("profile [REPO]", "Estimate the total FLOPS of the model from [REPO]") _print_hf("bench [REPO]", "Benchmark the system using the model from [REPO]") _print_hf("bench [KERNEL]", "Behcnmark the system using base kernel [KERNEL]") _print_hf("bench [NODE]", "Behcnmark the network latency between the current system and [NODE]") _print_hf("bench [NODE0] [NODE1]", "Behcnmark the network latency between [NODE0] and [NODE11") _print_hf("bench ", "Behcnmark the system using base library kernels") - dprint("\n") + sys.stdout.write("\033[1m\n System control:\n\033[0m") + _print_hf("limit [RESOURCE] [VALUE]", "Set a higher limit for a system resource.", + ["Examples > limit memory 12000 (MB)", + " > limit CPU_CORE_COUNT 4", + " > limit GPU_SM 128"]) + sys.stdout.write("\n") + sys.stdout.flush() def print_state(self): dprint("Network state:\n") @@ -225,7 +307,10 @@ def _parse_model_metadata(self, model_path): def handle_terminate_signal(self): # Handle worker/api shutdown - dprint("No workers to shut down. Terminating.\n") + if self._api_running: + self.stop_api() + else: + dprint("No workers to shut down. Terminating.\n") sys.exit() # ===== Handle Shard worker servers @@ -246,36 +331,165 @@ def handle_start_worker(self): # ===== Handle API server - def handle_device_discovery(self): - from socket import gethostname - from secrets import token_hex - - hostname = gethostname() - instance = f"api-{token_hex(4)}-{hostname}" - lib = DnetP2P("lib/dnet-p2p/lib") - - """ - self.discovery = lib.create_instance( - instance, hostname, - self.state.p2p_addr.host, self.state.p2p_addr_port, - self.state.grpc_listen_port, is_manager=True - ) - self.discovery.start() - """ - - def config_api_node(self): - api_address = NodeAddress(self.state.api_addr_host, self.state.api_addr_port) - self.api = RingApiNode(api_address, shard_address.format(), model_metadata) - - def start_api_discovery(self): - if self.api: - self.api._start_discovery() - - # Calls dsolver and optimizes topology - async def build_topology(self): - if self.api: - topo = await self.api.topology() - return topo + async def _api_main(self) -> None: # main thread loop + self._api_loop = asyncio.get_running_loop() + self._api_shutdown_e = asyncio.Event() + self._node = RingApiNode( + http_port=self.state.api_http_port, + grpc_port=self.state.api_grpc_port + ) + + try: + await self._node.start(shutdown_trigger=self._api_shutdown_e.wait) + self._api_searching.set() + self._api_running.set() + self._api_ready.set() + await self._api_shutdown_e.wait() + except Exception as e: + self._api_exc = e + self._api_running.set() + self._api_ready.set() + finally: + try: + await self._node.shutdown() + except Exception: + pass + self._api_running.clear() + + def _api_running_loop(self): + try: + asyncio.run(self._api_main()) + except BaseException as e: + self._api_exc = e + self._api_ready.set() + self._api_running.clear() + + def start_api(self, http_port: int=8080, grpc_port: int=50500, timeout=10): + if self._api_thread and self._api_thread.is_alive(): return + self._api_exc = None + self._api_ready.clear() + self._api_running.clear() + self._api_thread = threading.Thread(target=self._api_running_loop, name="api_server", daemon=True) + self._api_thread.start() + if not self._api_ready.wait(timeout): + raise RuntimeError("API Server Timeout.") + if self._api_exc is not None: + raise RuntimeError(f"API Server failed to start: {e}") + + def stop_api(self, timeout: float = 5.0) -> None: + if not self._api_thread: return + if self._api_loop and self._api_shutdown_e: + self._api_loop.call_soon_threadsafe(self._api_shutdown_e.set) + if self._api_loop and self._node: + f = asyncio.run_coroutine_threadsafe(self._node.shutdown(), self._api_loop) + try: + f.result(timeout=timeout) + except Exception: + pass + self._api_thread.join(timeout=timeout) + self._api_thread = None + self._api_running.clear() + self._api_ready.clear() + + def api_call( # Call an API function from the REPL thread + self, + method: str, + *args: Any, + timeout: float=30.0, + **kwargs: Any + ) -> Any: + if not self._api_loop or not self._node: + raise RuntimeError("API Thread not set up correctly.") + + target = getattr(self._node, method, None) + if target is None: + raise AttributeError(f"RingApiNode has no method {method}") + + # method is async + if inspect.iscoroutinefunction(target): + coroutine = target(*args, **kwargs) + f = asyncio.run_coroutine_threadsafe(coroutine, self._api_loop) + return f.result(timeout) + + # method is sync + f = concurrent.futures.Future() + + # TODO: this is a mess lol + def runner(): + try: + ret = target(*args, **kwargs) + if inspect.isawaitable(ret): + + async def _await_then_set(): + try: + val = await res + f.set_result(val) + except BaseException as e: + f.set_exception(e) + + asyncio.create_task(_await_then_set()) + else: + f.set_result(ret) + except BaseException as e: + f.set_exception(e) + self._api_loop.call_soon_threadsafe(runner) + return f.result(timeout) + + def _print_nodes_table(self, rows: List[Any]) -> None: + headers = ["name", "role", "addr", "http", "grpc", "status", "head"] + limits = {"name": 36, "addr": 15} + w = {h: max(len(h), min(limits.get(h, 8), max(len(str(r[h])) for r in rows))) for h in headers} + line = " ".join(h.ljust(w[h]) for h in headers) + sys.stdout.write("\n") + sys.stdout.write(line + "\n") + sys.stdout.write(" ".join("-"*w[h] for h in headers)) + sys.stdout.write("\n") + for r in rows: + name = str(r["name"]) + addr = str(r["addr"]) + if len(name) > w["name"]: name = name[:w["name"]-1] + "..." + if len(addr) > w["addr"]: addr = addr[:w["addr"]-1] + "..." + vals = { + "name": name, + "role": r["role"], + "addr": addr, + "http": r["http"], + "grpc": r["grpc"], + "status": "yes" if r["status"] else "no", + "head": "head" if r["head"] else "no", + } + sys.stdout.write(" ".join(str(vals[h]).ljust(w[h]) for h in headers)) + sys.stdout.write("\n\n") + sys.stdout.flush() + + + # Print a table of discovered nodes + def print_mdns_nodes(self) -> None: + try: + shards = self.api_call("_get_shards_from_discovery", timeout=10) + if not shards: + dprint("No worker nodes discovered. Is the API searching?\n") + return + + rows = [] + for name, props in shards.items(): + addr = getattr(props, "local_ip", getattr(props, "host", "")) + http = getattr(props, "server_port", 0) + grpc = getattr(props, "shard_port", 0) + busy = bool(getattr(props, "is_busy", False)) + head = bool("127.0.0.1" and "127.0.0.1" == addr) # TODO: FIX + rows.append({ + "name": name, + "role": "worker", + "addr": addr, + "http": http, + "grpc": grpc, + "status": busy, + "head": head, + }) + self._print_nodes_table(rows) + except Exception as e: + dprint(f"Could not list nodes: {e}\n") # ===== Handle shutdown From c7e5e7b71818d4fae8d9a906b3cb115e5adce344 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sat, 18 Oct 2025 23:29:29 -0700 Subject: [PATCH 005/172] trace to file (+ trace frames) --- src/dnet/perf/__init__.py | 0 src/dnet/perf/trace.py | 250 +++++++++++++++++++++++++ src/dnet/ring/shard/compute.py | 330 ++++++++++++++++++--------------- src/dnet/ring/shard/node.py | 160 +++++++++------- 4 files changed, 518 insertions(+), 222 deletions(-) create mode 100644 src/dnet/perf/__init__.py create mode 100644 src/dnet/perf/trace.py diff --git a/src/dnet/perf/__init__.py b/src/dnet/perf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py new file mode 100644 index 00000000..ad7738d8 --- /dev/null +++ b/src/dnet/perf/trace.py @@ -0,0 +1,250 @@ +""" +Object-oriented tracing utilities for dnet. + +This module provides a Tracer class configured explicitly from the REPL (or code), +without relying on environment variables or module-level globals. It supports: + +- Boundary frames via tracer.frame(scope, name, attrs) +- Deep sys.setprofile callgraph via tracer.callgraph(...) +- Aggregated call stats via tracer.profile_block(...) + +All events are written as JSON Lines to a file (TraceConfig.file), suitable +for simple REPL visualization and easy sharing. +""" + +from __future__ import annotations + +import os +import sys +import time +import json +import threading +import contextvars +import cProfile +import pstats +import io +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, List +from contextlib import contextmanager + +from dnet.utils.logger import logger + +@dataclass +class TraceConfig: + file: str = "logs/dnet-trace.jsonl" + streaming: bool = True + include_prefixes: Tuple[str, ...] = ("src/dnet/",) + include_c_calls: bool = False + budget: int = 0 # 0 means unlimited + enabled: bool = True + record_pid_tid: bool = True + + +class Tracer: + def __init__(self, cfg: TraceConfig): + self.cfg = cfg + self._lock = threading.Lock() + self._fh: Optional[io.TextIOBase] = None + self._events: List[Dict[str, Any]] = [] + self._req_id: str = None + self._active = False + + def start(self, *, reset: bool = True) -> None: + self._active = bool(self.cfg.enabled) + if not self._active: + logger.info("Initialized tracer.") + return + if self.cfg.file: + d = os.path.dirname(self.cfg.file) or "." + os.makedirs(d, exist_ok=True) + if reset and os.path.exists(self.cfg.file): + try: + os.remove(self.cfg.file) + except Exception: + logger.warning(f"Unable to remove existing trace file {self.cfg.file}") + if self.cfg.streaming: + with self._lock: + self._fh = open(self.cfg.file, "a", encoding="utf-8") + logger.info(f"Streaming trace to {self.cfg.file}.") + + def stop(self, *, flush_events: bool = True) -> None: + if flush_events: + self.flush() + self._active = False + with self._lock: + if self._fh: + try: + self._fh.flush() + self._fh.close() + except Exception: + logger.warning(f"Unable to flush to file {self.cfg.file}") + self._fh = None + + def set_request_id(self, rid: Optional[str]) -> None: + self._req_id = rid + + def get_request_id(self) -> Optional[str]: + return self._req_id + + # Flush file to disk + def flush(self, *, clear: bool = False) -> None: + if not self._active: + return + with self._lock: + if not self.cfg.streaming and self._events: + with open(self.cfg.file, "a", encoding="utf-8") as f: + for ev in self._events: + f.write(json.dumps(ev, ensure_ascii=False) + "\n") + if clear: + self._events.clear() + + def snapshot(self, path: str) -> None: + with self._lock: + with open(path, "w", encoding="utf-8") as f: + for ev in self._events: + f.write(json.dumps(ev, ensure_ascii=False) + "\n") + + # Emit a new frame event + def _emit(self, ev: Dict[str, Any]) -> None: + if not self._active: + return + ev.setdefault("ts_us", time.time_ns() // 1000) + if self._req_id is not None: + ev.setdefault("req_id", self._req_id) + if self.cfg.record_pid_tid: + try: + ev.setdefault("pid", os.getpid()) + ev.setdefault("tid", threading.get_ident()) + except Exception: + logger.warning("Unable to get PID and TID for tracer frame.") + with self._lock: + if self.cfg.streaming and self._fh: + self._fh.write(json.dumps(ev, ensure_ascii=False) + "\n") + self._fh.flush() + else: + self._events.append(ev) + + # Frames + class _NoopFrame: + def __enter__(self): + return self + def __exit__(self, *a): + return False + def event(self, *a, **k): + pass + def set(self, *a, **k): + pass + + class _Frame: + __slots__ = ("t", "name", "attrs", "_t0") + def __init__(self, tracer: "Tracer", name: str, attrs: Optional[Dict[str, Any]]): + self.t = tracer + self.name = name + self.attrs = dict(attrs or {}) + self._t0 = 0.0 + def __enter__(self): + self._t0 = time.perf_counter() + self.t._emit({"type": "B", "name": self.name, "args": dict(self.attrs)}) + return self + def __exit__(self, ex_type, ex, tb): + dt_ms = (time.perf_counter() - self._t0) * 1000.0 + self.t._emit({"type": "E", "name": self.name, "args": {"ms": round(dt_ms, 3), "exc": bool(ex)}}) + return False + def event(self, name: str, **attrs): + self.t._emit({"type": "I", "name": f"{self.name}.{name}", "args": attrs}) + def set(self, key: str, val: Any): + self.attrs[key] = val + + def frame(self, scope: str, name: str, attrs: Optional[Dict[str, Any]] = None): + if not self._active: + return Tracer._NoopFrame() + return Tracer._Frame(self, f"{scope}.{name}", attrs) + + def mark(self, name: str, **attrs: Any) -> None: + self._emit({"type": "I", "name": name, "args": attrs}) + + # Helpers + @contextmanager + def profile_block(self, outfile: Optional[str] = None, sort: str = "cumtime", limit: int = 40): + pr = cProfile.Profile() + pr.enable() + try: + yield pr + finally: + pr.disable() + s = io.StringIO() + pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sort).print_stats(limit) + out = s.getvalue() + if outfile: + d = os.path.dirname(outfile) or "." + os.makedirs(d, exist_ok=True) + with open(outfile, "w", encoding="utf-8") as f: + f.write(out) + else: + self._emit({"type": "PROFILE", "name": "cprofile", "args": {"sort": sort, "limit": limit, "report": out}}) + + @contextmanager + def callgraph( + self, + include_prefixes: Optional[Tuple[str, ...]] = None, + budget_events: Optional[int] = None, + include_c_calls: Optional[bool] = None, + apply_to_new_threads: bool = False, + ): + """ + Interpreter-level tracing (sys.setprofile) for all Python calls/returns + within the scope. Heavy overhead; best for deep debugging runs. + """ + prefixes = include_prefixes if include_prefixes is not None else self.cfg.include_prefixes + budget = (budget_events if budget_events is not None else self.cfg.budget) or 0 + inc_c = include_c_calls if include_c_calls is not None else self.cfg.include_c_calls + + emitted = 0 + stack: list[Tuple[str, float]] = [] + + def prof(frame, event, arg): + nonlocal emitted + if budget and emitted >= budget: + return + if event in ("call", "return"): + code = frame.f_code + filename = code.co_filename or "" + if prefixes and not any(filename.startswith(p) for p in prefixes): + return + name = code.co_name + key = f"{filename}:{code.co_firstlineno}:{name}" + if event == "call": + stack.append((key, time.perf_counter())) + self._emit({"type": "B", "name": f"py.{name}", "args": {"file": filename, "line": code.co_firstlineno}}) + emitted += 1 + else: + if stack and stack[-1][0] == key: + _, t0 = stack.pop() + dt_ms = (time.perf_counter() - t0) * 1000.0 + self._emit({"type": "E", "name": f"py.{name}", "args": {"ms": round(dt_ms, 3)}}) + emitted += 1 + elif inc_c and event in ("c_call", "c_return"): + func = getattr(arg, "__name__", None) + mod = getattr(arg, "__module__", None) + if not func: + return + if event == "c_call": + self._emit({"type": "B", "name": f"c.{mod}.{func}", "args": {}}) + emitted += 1 + else: + self._emit({"type": "E", "name": f"c.{mod}.{func}", "args": {}}) + emitted += 1 + + prev = sys.getprofile() + sys.setprofile(prof) + prev_thread = None + if apply_to_new_threads: + prev_thread = threading.getprofile() + threading.setprofile(prof) + try: + yield + finally: + sys.setprofile(prev) + if apply_to_new_threads: + threading.setprofile(prev_thread) + diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 2c5626cd..b9226e3b 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -79,52 +79,53 @@ def _process_activation(self, activation_msg: ActivationMessage): try: # per-nonce kvcache for concurrent requests - kv = self._get_or_make_kv(activation_msg.nonce) + with self.tracer.frame("compute.thread", "kvcache.init"): + kv = self._get_or_make_kv(activation_msg.nonce) # Get input activation from pool - input_buffer = self.input_pool.get_buffer(activation_msg.pool_id) - if input_buffer is None: - logger.error("Failed to get input buffer %s", activation_msg.pool_id) - return + with self.tracer.frame("compute.thread", "activations.load"): + input_buffer = self.input_pool.get_buffer(activation_msg.pool_id) + if input_buffer is None: + logger.error("Failed to get input buffer %s", activation_msg.pool_id) + return # Prepare input activation - if activation_msg.dtype == "tokens": - # tokens were staged as int32 in the pool; embed locally on start shard - input_size = int(np.prod(activation_msg.shape)) - tok_view = input_buffer[:input_size].reshape(activation_msg.shape) - # Convert robustly to MLX int32 and embed (batch=1) - toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) - x = self.model.embed(toks[None]) - if x.dtype != self._wire_mx_dtype: - x = x.astype(self._wire_mx_dtype) - else: - # Prepare input activation using MLX view operations only - input_size = int(np.prod(activation_msg.shape)) - x = input_buffer[:input_size].reshape(activation_msg.shape) - # Ensure expected dtype without re-materializing when not needed - try: - if str(x.dtype) != activation_msg.dtype: - x = x.astype(mlx_dtype_map[activation_msg.dtype]) - except Exception: - pass + with self.tracer.frame("compute.thread", "activations.process"): + if activation_msg.dtype == "tokens": + # tokens were staged as int32 in the pool; embed locally on start shard + input_size = int(np.prod(activation_msg.shape)) + tok_view = input_buffer[:input_size].reshape(activation_msg.shape) + # Convert robustly to MLX int32 and embed (batch=1) + toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) + x = self.model.embed(toks[None]) + if x.dtype != self._wire_mx_dtype: + x = x.astype(self._wire_mx_dtype) + else: + # Prepare input activation using MLX view operations only + input_size = int(np.prod(activation_msg.shape)) + x = input_buffer[:input_size].reshape(activation_msg.shape) + # Ensure expected dtype without re-materializing when not needed + try: + if str(x.dtype) != activation_msg.dtype: + x = x.astype(mlx_dtype_map[activation_msg.dtype]) + except Exception: + pass # Compute windows until boundary (stay local as long as possible) current_layer = activation_msg.layer_id + 1 last_layer = current_layer - 1 while True: - start_time = time.perf_counter() processed = 0 did_early_swap = False - # Determine contiguous local window starting at current_layer - window_layers: List[int] = [] - _tmp_layer = current_layer - while processed < self.window_size and ( - _tmp_layer in self._assigned_set - ): - window_layers.append(_tmp_layer) - _tmp_layer += 1 - processed += 1 + with self.tracer.frame("compute.thread", "weights.prepare"): + # Determine contiguous local window starting at current_layer + window_layers: List[int] = [] + _tmp_layer = current_layer + while processed < self.window_size and (_tmp_layer in self._assigned_set): + window_layers.append(_tmp_layer) + _tmp_layer += 1 + processed += 1 if self._mode == "offload" and window_layers: prep = self._prepared_by_nonce.get(activation_msg.nonce) @@ -198,10 +199,23 @@ def _process_activation(self, activation_msg: ActivationMessage): t_w_ms, ) - bind_ms = 0.0 - if to_bind: - # Block prefetch-touch during binding and serialize MLX ops + # Opportunistically schedule prefetch for the next window to overlap with compute try: + next_win_pre = self._next_local_layers( + (window_layers[-1] if window_layers else (activation_msg.layer_id)), + self.window_size, + ) + for nl in next_win_pre: + self._prefetch_to_ram(nl) + self._enqueue_weight_prefetch(nl) + except Exception: + pass + + # Execute the window + with self.tracer.frame("compute.thread", "execute"): + self._beyond_cursor = window_layers[-1] if window_layers else (activation_msg.layer_id) + + try: # Prevent prefetch touching during encode/compute to minimize UMA pressure self._compute_busy.set() except Exception: pass @@ -248,23 +262,35 @@ def _process_activation(self, activation_msg: ActivationMessage): window_layers, (t_comp_done - t_comp) * 1000.0, ) + """ for lid in window_layers: self.weight_cache.decrease_reference(lid) - try: - # Sliding-fit delta swap: maintain a single resident set by evicting - # only what's needed to fit the next window into the budget. Prefer - # keeping the tail of the previous window so we don't thrash weights - # that are likely to be reused. - if self._mode == "sliding_fit": - if int(self._resident_windows) <= 1: - if did_early_swap: - # Early delta-swap already trimmed resident set for this window - pass - elif not self._recent_windows: - # First window in token: seed resident set - self._recent_windows.append(list(window_layers)) + with self.tracer.frame("compute.thread", "execute.evict_and_unload"): + try: + # Sliding-fit delta swap: maintain a single resident set by evicting + # only what's needed to fit the next window into the budget. Prefer + # keeping the tail of the previous window so we don't thrash weights + # that are likely to be reused. + if self._mode == "sliding_fit": + if int(self._resident_windows) <= 1: + if did_early_swap: + # Early delta-swap already trimmed resident set for this window + pass + elif not self._recent_windows: + # First window in token: seed resident set + self._recent_windows.append(list(window_layers)) + else: + prev = self._recent_windows.pop(0) + self._delta_swap_eviction(window_layers, prev, activation_msg, early=False) + budget = max(1, int(self.window_size or 1)) + curr = list(window_layers) + prev_only = [x for x in prev if x not in curr] + keep_quota = max(0, budget - len(curr)) + keep_tail = prev_only[-keep_quota:] if keep_quota > 0 else [] + combined = list(keep_tail) + curr + self._recent_windows.append(combined) else: prev = self._recent_windows.pop(0) self._delta_swap_eviction( @@ -280,34 +306,19 @@ def _process_activation(self, activation_msg: ActivationMessage): combined = list(keep_tail) + curr self._recent_windows.append(combined) else: - # resident_windows>1 not expected in sliding_fit; fall back to seeding + # Original eviction policy for other modes self._recent_windows.append(list(window_layers)) - else: - # Original eviction policy for other modes - self._recent_windows.append(list(window_layers)) - if int(self._resident_windows) <= 1: - old = self._recent_windows.pop(0) - try: - evicted_cnt = self.weight_cache.evict_layers(old) - except Exception: - evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass - if self._profile: + if int(self._resident_windows) <= 1: + old = self._recent_windows.pop(0) try: - logger.info( - "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s", - self.node_id, - activation_msg.nonce, - old, - evicted_cnt, - self._resident_windows, - ) + evicted_cnt = self.weight_cache.evict_layers(old) + except Exception: + evicted_cnt = 0 + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) except Exception: pass else: @@ -341,25 +352,8 @@ def _process_activation(self, activation_msg: ActivationMessage): ) except Exception: pass - except Exception: - pass - - computation_time = time.perf_counter() - start_time - self._prof.info( - "[PROFILE][COMPUTE] node=%s nonce=%s window_layers=%s total_ms=%.3f", - self.node_id, - activation_msg.nonce, - window_layers, - computation_time * 1000.0, - ) - self._prof.info( - "Completed layers up to %s in %.3fs, nonce: %s, result: %s %s", - last_layer, - computation_time, - activation_msg.nonce, - x.shape, - x.dtype, - ) + except Exception: + pass # If next layer is still local, continue without staging/tx nxt = last_layer + 1 @@ -368,33 +362,64 @@ def _process_activation(self, activation_msg: ActivationMessage): continue # Boundary reached — directly pass tensor to TX to avoid extra copy/sync - t_stage = time.perf_counter() - x_cast = ( - x - if x.dtype == self._wire_mx_dtype - else x.astype(self._wire_mx_dtype) - ) - try: - self._compute_busy.clear() - except Exception: - pass - - if self._profile: + with self.tracer.frame("compute.thread", "execute.enqueue_prefetch"): + x_cast = x if x.dtype == self._wire_mx_dtype else x.astype(self._wire_mx_dtype) try: - logger.info( - "[PROFILE][STAGE-DIRECT] node=%s nonce=%s layer_tail=%s stage_ms=%.3f shape=%s dtype=%s", - self.node_id, - activation_msg.nonce, - last_layer, - (time.perf_counter() - t_stage) * 1000.0, - tuple(x_cast.shape), - str(self._wire_mx_dtype), - ) + self._compute_busy.clear() + except Exception: + pass + try: + for lid in list(self._prefetch_pending): + self._prefetch_pending.discard(lid) + self._enqueue_weight_prefetch(lid) except Exception: pass - nxt = last_layer + 1 - if nxt >= self.model_metadata.num_layers: # End of model + with self.tracer.frame("compute.thread", "mdns.send"): + nxt = last_layer + 1 + if nxt >= self.model_metadata.num_layers: # End of model + try: + with self._mlx_lock: + y = self.model.normalize(x_cast) + y = self.model.lm_project(y) + # Greedy sample last position + if y.ndim == 3: + logits_2d = y[:, -1, :] + elif y.ndim == 2: + logits_2d = y[-1:, :] + else: + logits_2d = y.reshape(1, -1) + tok = mx.argmax(logits_2d, axis=-1) + token_id = int(tok.item()) + except Exception as e: + logger.error("End-shard sampling failed: %s", e) + return + output_msg = ActivationMessage( + nonce=activation_msg.nonce, + layer_id=last_layer, + pool_id=-1, + shape=cast(tuple[int, ...], x.shape), + batch_size=activation_msg.batch_size, + timestamp=utc_epoch_now(), + node_origin=f"node_{self.node_id}", + dtype=str(self._wire_mx_dtype), + callback_url=activation_msg.callback_url, + is_final=True, + token_id=token_id, + ) + else: + output_msg = ActivationMessage( + nonce=activation_msg.nonce, + layer_id=last_layer, + pool_id=-1, + shape=cast(tuple[int, ...], x.shape), + batch_size=activation_msg.batch_size, + timestamp=utc_epoch_now(), + node_origin=f"node_{self.node_id}", + dtype=str(self._wire_mx_dtype), + callback_url=activation_msg.callback_url, + tensor=x_cast, + ) try: with self._mlx_lock: y = self.model.normalize(x_cast) @@ -464,41 +489,42 @@ def _process_activation(self, activation_msg: ActivationMessage): self.input_pool.release(activation_msg.pool_id) # Optional unload/evict after stage - if self._mode != "sliding_fit": - if self._defer_unload: - try: - while len(self._recent_windows) > max( - 1, int(self._resident_windows) - ): - old = self._recent_windows.pop(0) - try: - evicted_cnt = self.weight_cache.evict_layers(old) - except Exception: - evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass - if self._profile: - logger.info( - "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s (post-stage)", - self.node_id, - activation_msg.nonce, - old, - evicted_cnt, - self._resident_windows, - ) - except Exception: - pass + with self.tracer.frame("compute.thread", "cleanup"): + if self._mode != "sliding_fit": + if self._defer_unload: + try: + while len(self._recent_windows) > max( + 1, int(self._resident_windows) + ): + old = self._recent_windows.pop(0) + try: + evicted_cnt = self.weight_cache.evict_layers(old) + except Exception: + evicted_cnt = 0 + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) + except Exception: + pass + if self._profile: + logger.info( + "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s (post-stage)", + self.node_id, + activation_msg.nonce, + old, + evicted_cnt, + self._resident_windows, + ) + except Exception: + pass - if self._resident_windows <= 1: - try: - evicted = self.weight_cache.evict_layers(window_layers) - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(window_layers) # type: ignore[attr-defined] + if self._resident_windows <= 1: + try: + evicted = self.weight_cache.evict_layers(window_layers) + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(window_layers) # type: ignore[attr-defined] if self._profile: logger.info( "[PROFILE][EVICT] node=%s nonce=%s layers=%s evicted=%s", diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 82a60d40..e4b73a78 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -62,6 +62,8 @@ from .comms import CommsMixin from ..weight_cache import WeightCache +from dnet.perf.trace import TraceConfig, Tracer + class RingShardNode(ComputeMixin, PrefetchMixin, CommsMixin): """Single shard node in the distributed inference ring with dynamic model loading.""" @@ -200,6 +202,19 @@ def __init__( if self._profile: logger.info("[PROFILE] enabled on shard node %s", self.node_id) + # Debug tracing + cfg = TraceConfig( + file="./trace.json", + streaming=True, + include_prefixes = ("src/dnet/"), + include_c_calls = False, + budget = 10000, + enabled = True, + record_pid_tid = True, + ) + self.tracer = Tracer(cfg) + self.tracer.start() + # Per-nonce KV caches (concurrent requests) self._kv_by_nonce: Dict[str, list] = {} self._kv_last_seen: Dict[str, float] = {} @@ -218,11 +233,8 @@ def __init__( ) async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse: - """Load model with specified layers.""" - try: - start_time = time.perf_counter() - - # Check if already loaded with same configuration + """Load model with specified layers. """ + try: # Check if already loaded with same configuration if ( self.model is not None and self.model_path == req.model_path @@ -240,22 +252,22 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # If model loaded with different config, unload first - if self.model is not None and ( - self.model_path != req.model_path or self.assigned_layers != req.layers + if (self.model is not None + and (self.model_path != req.model_path or self.assigned_layers != req.layers) ): - logger.info( - "Node %s: Unloading current model to load new configuration", - self.node_id, - ) - await self.unload_model() + logger.info("Node %s: Unloading current model to load new configuration", self.node_id) + with self.tracer.frame("memory", "model.unload"): + await self.unload_model() # Load model metadata - self.model_metadata = get_model_metadata(req.model_path) + with self.tracer.frame("memory", "model.load_metadata"): + self.model_metadata = get_model_metadata(req.model_path) self.assigned_layers = req.layers self._assigned_sorted = sorted(self.assigned_layers) self._assigned_set = set(self._assigned_sorted) self.model_path = req.model_path + # Decide mode dynamically from assignment + requested window requested_w = int(max(1, int(req.window_size))) local_count = max(1, len(self.assigned_layers)) @@ -351,45 +363,49 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # Initialize weight cache - self.weight_cache = WeightCache( - self.assigned_layers, - self.model_metadata, - window_size=self.window_size, - prefetch_threads=self._prefetch_threads, - resident_windows=self._resident_windows, - use_mxload_fastpath=self.config.mxload_fastpath, - prefetch_mode=self.config.prefetch_mode, - ) + with self.tracer.frame("memory", "weight_cache.init"): + self.weight_cache = WeightCache( + self.assigned_layers, + self.model_metadata, + window_size=self.window_size, + prefetch_threads=self._prefetch_threads, + resident_windows=self._resident_windows, + use_mxload_fastpath=self.config.mxload_fastpath, + prefetch_mode=self.config.prefetch_mode, + ) # Load the model - self.model = get_ring_model( - self.model_metadata.model_type, - self.model_metadata.model_config, - assigned_layers=self.assigned_layers, - is_api_layer=False, - ) - try: - applied = bool( - self.model.apply_quantization_from_config( # type: ignore[attr-defined] - self.model_metadata.model_config, - model_metadata=self.model_metadata, - ) - ) - logger.info( - "[QUANT] applied=%s for model=%s", - applied, + with self.tracer.frame("memory", "model.load"): + self.model = get_ring_model( self.model_metadata.model_type, + self.model_metadata.model_config, + assigned_layers=self.assigned_layers, + is_api_layer=False, ) + try: + applied = bool( + self.model.apply_quantization_from_config( # type: ignore[attr-defined] + self.model_metadata.model_config, + model_metadata=self.model_metadata, + ) + ) + logger.info( + "[QUANT] applied=%s for model=%s", + applied, + self.model_metadata.model_type, + ) - except RuntimeError as e: - logger.warning("[QUANT] apply failed: %s", e) - self.model.eval() - self.cache = make_cache( - self.model, - kv_mode=self.config.kv_cache.mode, - kv_bits=self.config.kv_cache.bits, - kv_group=self.config.kv_cache.group_size, - ) + except RuntimeError as e: + logger.warning("[QUANT] apply failed: %s", e) + self.model.eval() + + with self.tracer.frame("memory", "make_cache"): + self.cache = make_cache( + self.model, + kv_mode=self.config.kv_cache.mode, + kv_bits=self.config.kv_cache.bits, + kv_group=self.config.kv_cache.group_size, + ) try: has_start = 0 in self.assigned_layers @@ -424,7 +440,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse self.api_callback_address = req.api_callback_address if self.next_node: - await self._connect_next_node() + with self.tracer.frame("network", "connect.next_node"): + await self._connect_next_node() else: logger.warning("Node %s: No next node configured", self.node_id) @@ -491,7 +508,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse success=True, message="Model loaded successfully", layers_loaded=req.layers, - load_time_ms=load_time_ms, + load_time_ms=0.0, ) except Exception as e: @@ -537,23 +554,24 @@ async def unload_model(self) -> ShardUnloadModelResponse: self._assigned_set = set() # Clear memory pools - if self.weight_cache: - # Stop any in-flight prefetch and close layer manager resources - try: - self.weight_cache.cancel_all_prefetch() - except Exception: - pass - # Clear all cached weights - for layer_id in list(self._bound_versions.keys()): + with self.tracer.frame("memory", "cache.evict"): + if self.weight_cache: + # Stop any in-flight prefetch and close layer manager resources try: - self.weight_cache.evict_layer(layer_id) + self.weight_cache.cancel_all_prefetch() except Exception: pass - try: - self.weight_cache.layer_manager.close() - except Exception: - pass - self.weight_cache = None + # Clear all cached weights + for layer_id in list(self._bound_versions.keys()): + try: + self.weight_cache.evict_layer(layer_id) + except Exception: + pass + try: + self.weight_cache.layer_manager.close() + except Exception: + pass + self.weight_cache = None self.input_pool = None self.output_pool = None @@ -587,12 +605,13 @@ async def reset_cache(self) -> None: return try: - self.cache = make_cache( - self.model, # type: ignore[arg-type] - kv_mode=self.config.kv_cache.mode, - kv_bits=self.config.kv_cache.bits, - kv_group=self.config.kv_cache.group_size, - ) + with self.tracer.frame("memory", "cache.reset"): + self.cache = make_cache( + self.model, # type: ignore[arg-type] + kv_mode=self.config.kv_cache.mode, + kv_bits=self.config.kv_cache.bits, + kv_group=self.config.kv_cache.group_size, + ) logger.info("Node %s: Cache reset successfully", self.node_id) except Exception as e: logger.error("Node %s: Error resetting cache: %s", self.node_id, e) @@ -1094,7 +1113,8 @@ def _compute_worker(self) -> None: activation_msg = self.activation_recv_queue.get(timeout=1.0) # Process the activation - self._process_activation(activation_msg) + with self.tracer.frame("compute", "forward"): + self._process_activation(activation_msg) except Empty: continue From e7bfee6b0b29c04b906d49fb4e34bb07f742811c Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 00:30:55 -0700 Subject: [PATCH 006/172] aggregate trace buffers back to api --- src/dnet/perf/__init__.py | 2 + src/dnet/perf/trace.py | 222 +++++++----- src/dnet/perf/utils/__init__.py | 2 + src/dnet/perf/utils/aggregator.py | 152 ++++++++ src/dnet/ring/api/models.py | 28 ++ src/dnet/ring/api/node.py | 34 ++ src/dnet/ring/shard/compute.py | 183 +++++----- src/dnet/ring/shard/models.py | 30 ++ src/dnet/ring/shard/startup.py | 563 ++++++++++++++++++++++++++++++ 9 files changed, 1042 insertions(+), 174 deletions(-) create mode 100644 src/dnet/perf/utils/__init__.py create mode 100644 src/dnet/perf/utils/aggregator.py create mode 100644 src/dnet/ring/shard/startup.py diff --git a/src/dnet/perf/__init__.py b/src/dnet/perf/__init__.py index e69de29b..330893e8 100644 --- a/src/dnet/perf/__init__.py +++ b/src/dnet/perf/__init__.py @@ -0,0 +1,2 @@ + +from .trace import TraceConfig, Tracer diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index ad7738d8..92119d79 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -1,32 +1,22 @@ -""" -Object-oriented tracing utilities for dnet. - -This module provides a Tracer class configured explicitly from the REPL (or code), -without relying on environment variables or module-level globals. It supports: - -- Boundary frames via tracer.frame(scope, name, attrs) -- Deep sys.setprofile callgraph via tracer.callgraph(...) -- Aggregated call stats via tracer.profile_block(...) - -All events are written as JSON Lines to a file (TraceConfig.file), suitable -for simple REPL visualization and easy sharing. -""" from __future__ import annotations import os +import io import sys import time import json -import threading -import contextvars -import cProfile import pstats -import io +import cProfile +import threading +import queue + from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, List from contextlib import contextmanager +import httpx + from dnet.utils.logger import logger @dataclass @@ -37,35 +27,126 @@ class TraceConfig: include_c_calls: bool = False budget: int = 0 # 0 means unlimited enabled: bool = True + node_id: Optional[str] = None record_pid_tid: bool = True + aggregate: bool = False + aggregate_url: Optional[str] = None + agg_max_events: int = 1000 + +class _NoopFrame: + def __enter__(self): + return self + def __exit__(self, *a): + return False + def event(self, *a, **k): + pass + def set(self, *a, **k): + pass +class _Frame: + __slots__ = ("t", "name", "attrs", "_t0") + def __init__(self, tracer: "Tracer", name: str, attrs: Optional[Dict[str, Any]]): + self.t = tracer + self.name = name + self.attrs = dict(attrs or {}) + self._t0 = 0.0 + def __enter__(self): + self._t0 = time.perf_counter() + self.t._emit({"type": "B", "name": self.name, "args": dict(self.attrs)}) + return self + def __exit__(self, ex_type, ex, tb): + dt_ms = (time.perf_counter() - self._t0) * 1000.0 + self.attrs.update({"ms": round(dt_ms, 3), "exc": bool(ex)}) + self.t._emit({"type": "E", "name": self.name, "args": self.attrs}) + return False + def event(self, name: str, **attrs): + out = dict(attrs or {}) + out.setdefault("t_rel_ms", (time.perf_counter() - self._t0) * 1000.0) + self.t._emit({"type": "I", "name": f"{self.name}.{name}", "args": out}) + def set(self, key: str, val: Any): + self.attrs[key] = val class Tracer: - def __init__(self, cfg: TraceConfig): - self.cfg = cfg + def __init__(self, config: TraceConfig): + self.config = config self._lock = threading.Lock() self._fh: Optional[io.TextIOBase] = None self._events: List[Dict[str, Any]] = [] self._req_id: str = None self._active = False + self._agg_enabled: bool = False + self._agg_max_events: int = int(self.config.agg_max_events or 1000) + self._agg_q: queue.Queue[dict] = queue.Queue(maxsize=256) + self._agg_thread: Optional[threading.Thread] = None + + if self.config.aggregate: + self.start_aggregator() + + # Aggregator worker thread + def start_aggregator(self) -> None: + self._agg_enabled = True + self._agg_max_events = max(10, int(self.config.agg_max_events or 1000)) # 10 min, 1000 default + if not self._agg_thread or not self._agg_thread.is_alive(): + self._agg_thread = threading.Thread(target=self._agg_exec, name="trace-agg", daemon=True) + self._agg_thread.start() + + def stop_aggregator(self, *, flush: bool = True, timeout: float = 5.0) -> None: + self._agg_enabled = False + if flush and self._events: + try: + self._agg_q.put_nowait({ + "run_id": (self._req_id or "run"), + "node_id": (self.config.node_id or "node"), + "events": list(self._events), }) + except queue.Full: + logger.warning(f"Trace aggragator queue is full.") + self._events.clear() + if self._agg_thread and self._agg_thread.is_alive(): + self._agg_thread.join(timeout) + self._agg_thread = None + + def _agg_exec(self) -> None: + url = self.config.aggregate_url or "" + assert url != "" + client = httpx.Client(timeout=5.0) + try: + while self._agg_enabled and not self._agg_q.empty(): + try: + batch = self._agg_q.get(timeout=0.2) + except queue.Empty: + continue + try: + client.post(url, json=batch) + except Exception: + logger.warning(f"Unable to POST trace aggregation data to {url}") + finally: + self._agg_q.task_done() + finally: + try: + client.close() + except Exception: + logger.warining("Unable to close httpx client.") + def start(self, *, reset: bool = True) -> None: - self._active = bool(self.cfg.enabled) + self._active = bool(self.config.enabled) if not self._active: logger.info("Initialized tracer.") return - if self.cfg.file: - d = os.path.dirname(self.cfg.file) or "." - os.makedirs(d, exist_ok=True) - if reset and os.path.exists(self.cfg.file): + if self.config.file: + d = os.path.dirname(self.config.file) or "." + os.makedirs(d, exist_ok=True) + if reset and os.path.exists(self.config.file): try: - os.remove(self.cfg.file) + os.remove(self.config.file) except Exception: - logger.warning(f"Unable to remove existing trace file {self.cfg.file}") - if self.cfg.streaming: + logger.warning(f"Unable to remove existing trace file {self.config.file}") + if self.config.streaming: with self._lock: - self._fh = open(self.cfg.file, "a", encoding="utf-8") - logger.info(f"Streaming trace to {self.cfg.file}.") + self._fh = open(self.config.file, "a", encoding="utf-8") + logger.info(f"Streaming trace to {self.config.file}.") + if self.config.aggregate and self.config.aggregate_url and self.config.node_id: + self.start_aggregator() def stop(self, *, flush_events: bool = True) -> None: if flush_events: @@ -77,89 +158,65 @@ def stop(self, *, flush_events: bool = True) -> None: self._fh.flush() self._fh.close() except Exception: - logger.warning(f"Unable to flush to file {self.cfg.file}") + logger.warning(f"Unable to flush to file {self.config.file}") self._fh = None - def set_request_id(self, rid: Optional[str]) -> None: - self._req_id = rid - - def get_request_id(self) -> Optional[str]: - return self._req_id - # Flush file to disk def flush(self, *, clear: bool = False) -> None: - if not self._active: - return + if not self._active: return with self._lock: - if not self.cfg.streaming and self._events: - with open(self.cfg.file, "a", encoding="utf-8") as f: + if not self.config.streaming and self._events: + with open(self.config.file, "a", encoding="utf-8") as f: for ev in self._events: f.write(json.dumps(ev, ensure_ascii=False) + "\n") if clear: self._events.clear() - + + # Quick dump to memory def snapshot(self, path: str) -> None: with self._lock: with open(path, "w", encoding="utf-8") as f: for ev in self._events: f.write(json.dumps(ev, ensure_ascii=False) + "\n") - # Emit a new frame event + # emit a new frame def _emit(self, ev: Dict[str, Any]) -> None: - if not self._active: - return - ev.setdefault("ts_us", time.time_ns() // 1000) + if not self._active: return + ev.setdefault("ts", time.perf_counter()) if self._req_id is not None: ev.setdefault("req_id", self._req_id) - if self.cfg.record_pid_tid: + if self.config.record_pid_tid: try: ev.setdefault("pid", os.getpid()) ev.setdefault("tid", threading.get_ident()) except Exception: logger.warning("Unable to get PID and TID for tracer frame.") + with self._lock: - if self.cfg.streaming and self._fh: + if self.config.streaming and self._fh: self._fh.write(json.dumps(ev, ensure_ascii=False) + "\n") self._fh.flush() else: self._events.append(ev) - # Frames - class _NoopFrame: - def __enter__(self): - return self - def __exit__(self, *a): - return False - def event(self, *a, **k): - pass - def set(self, *a, **k): - pass - - class _Frame: - __slots__ = ("t", "name", "attrs", "_t0") - def __init__(self, tracer: "Tracer", name: str, attrs: Optional[Dict[str, Any]]): - self.t = tracer - self.name = name - self.attrs = dict(attrs or {}) - self._t0 = 0.0 - def __enter__(self): - self._t0 = time.perf_counter() - self.t._emit({"type": "B", "name": self.name, "args": dict(self.attrs)}) - return self - def __exit__(self, ex_type, ex, tb): - dt_ms = (time.perf_counter() - self._t0) * 1000.0 - self.t._emit({"type": "E", "name": self.name, "args": {"ms": round(dt_ms, 3), "exc": bool(ex)}}) - return False - def event(self, name: str, **attrs): - self.t._emit({"type": "I", "name": f"{self.name}.{name}", "args": attrs}) - def set(self, key: str, val: Any): - self.attrs[key] = val + if self._agg_enabled: + if len(self._events) < self._agg_max_events: return + batch = { "run_id": self._agg_run_id, + "node_id": self._agg_node_id or self.config.node_id, + "events": list(self._events)} + try: + self._agg_q.put_nowait(batch) + except queue.Full: + logger.warning(f"Aggregator queue is full. Dropping {len(batch["events"])} frames.") + self._events.clear() + # Frames def frame(self, scope: str, name: str, attrs: Optional[Dict[str, Any]] = None): if not self._active: - return Tracer._NoopFrame() - return Tracer._Frame(self, f"{scope}.{name}", attrs) + return _NoopFrame() + return _Frame(self, f"{scope}.{name}", attrs) + # Mark an event outside of a frame def mark(self, name: str, **attrs: Any) -> None: self._emit({"type": "I", "name": name, "args": attrs}) @@ -195,9 +252,9 @@ def callgraph( Interpreter-level tracing (sys.setprofile) for all Python calls/returns within the scope. Heavy overhead; best for deep debugging runs. """ - prefixes = include_prefixes if include_prefixes is not None else self.cfg.include_prefixes - budget = (budget_events if budget_events is not None else self.cfg.budget) or 0 - inc_c = include_c_calls if include_c_calls is not None else self.cfg.include_c_calls + prefixes = include_prefixes if include_prefixes is not None else self.config.include_prefixes + budget = (budget_events if budget_events is not None else self.config.budget) or 0 + inc_c = include_c_calls if include_c_calls is not None else self.config.include_c_calls emitted = 0 stack: list[Tuple[str, float]] = [] @@ -247,4 +304,3 @@ def prof(frame, event, arg): sys.setprofile(prev) if apply_to_new_threads: threading.setprofile(prev_thread) - diff --git a/src/dnet/perf/utils/__init__.py b/src/dnet/perf/utils/__init__.py new file mode 100644 index 00000000..9c310fd5 --- /dev/null +++ b/src/dnet/perf/utils/__init__.py @@ -0,0 +1,2 @@ + +from aggregator import TraceAggregator diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py new file mode 100644 index 00000000..f9bbdef6 --- /dev/null +++ b/src/dnet/perf/utils/aggregator.py @@ -0,0 +1,152 @@ + +from __future__ import annotations + +import threading +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple, Optional, DefaultDict +from collections import defaultdict, deque + +from dnet.utils.logger import logger + +Key = Tuple[str, Optional[int], Optional[int], str] # (node_id, pid, tid, req_id) + +@dataclass +class _OpenFrame: + name: str + t0: int + child: int = 0 + children: List[Dict[str, Any]] = field(default_factory=list) + +# Sort known frames and compute averages by key +@dataclass +class RunAggregator: + sums_by_name: Dict[str, float] = field(default_factory=dict) + counts_by_name: Dict[str, int] = field(default_factory=dict) + last_batch_seq: Dict[str, int] = field(default_factory=dict) + + stacks: Dict[Key, List[_OpenFrame]] = field(default_factory=dict) + drops: int = 0 + roots_by_req: DefaultDict[str, List[Dict[str, Any]]] = field(default_factory=lambda: defaultdict(list)) + + def _key(self, node_id: str, pid: Optional[int], tid: Optional[int], req_id: Optional[str]) -> Key: + return (node_id, pid, tid, req_id or "") + + def _push(self, key: Key, f: _OpenFrame) -> None: + self.stacks.setdefault(key, []).append(f) + + def _pop(self, key: Key) -> Optional[_OpenFrame]: + st = self.stacks.get(key) + if not st: return None + return st.pop() + + def _peek(self, key: Key) -> Optional[_OpenFrame]: + st = self.stacks.get(key) + return st[-1] if st else None + + def _acc_annotate(self, name: str, self_ms: float) -> None: + self.sums_by_name[name] = self.sums_by_name.get(name, 0.0) + self_ms + self.counts_by_name[name] = self.counts_by_name.get(name, 0) + 1 + + def ingest_event(self, node_id: str, ev: Dict[str, Any]) -> None: + if not ev.get("name"): + logger.error(f"Received trace frame without name {ev.get('ts')}") + return + # Normalize ts to microseconds (accept float seconds or int microseconds) + ts_raw = ev.get("ts") + ts_us = 0 + try: + if isinstance(ts_raw, float): + ts_us = int(ts_raw * 1_000_000) + elif isinstance(ts_raw, int): + ts_us = ts_raw + else: + ts_us = int(ts_raw or 0) + except Exception: + ts_us = 0 + req_id = ev.get("req_id") or "" + key = self._key(node_id, ev.get("pid"), ev.get("tid"), req_id) + if ev.get("type") == "B": + self._push(key, _OpenFrame(name=ev.get("name"), t0=ts_us)) + elif ev.get("type") == "E": + fr = self._pop(key) + if not fr: return + dur_us = max(0, ts_us - fr.t0) + self_us = max(0, dur_us - fr.child) + self_ms = self_us / 1000.0 + self._acc_annotate(fr.name, self_ms) + parent = self._peek(key) + completed = { + "name": fr.name, + "ts": fr.t0, + "dur_ms": dur_us / 1000.0, + "self_ms": self_ms, + "children": fr.children, + "pid": ev.get("pid"), + "tid": ev.get("tid"), + "req_id": req_id, + "node_id": node_id, + } + if parent: + parent.child += dur_us + parent.children.append(completed) + else: + self.roots_by_req[req_id or ""].append(completed) + else: + # TODO :Process other events + pass + + +class TraceAggregator: + def __init__(self) -> None: + self._req: Dict[str, RunAggregator] = {} + self._lock = threading.Lock() + + def enqueue(self, batch: Dict[str, Any]) -> None: + run_id = batch.get("run_id") + node_id = batch.get("node_id") + if not run_id or not node_id: + return + events = batch.get("events") or [] + batch_seq = int(batch.get("batch_seq") or 0) + with self._lock: + agg = self._req.setdefault(run_id, RunAggregator()) + last = agg.last_batch_seq.get(node_id) + if (last is not None) and (batch_seq != last + 1): + agg.drops += abs(batch_seq - (last + 1)) + agg.last_batch_seq[node_id] = batch_seq + for ev in events: + try: + agg.ingest_event(node_id, ev) + except Exception: + continue + + def annotate(self, run_id: str, *, mapping: Optional[Dict[str, str]] = None, repeats: int = 0) -> List[Dict[str, Any]]: + with self._lock: + agg = self._req.get(run_id) + if not agg: + return [] + if not mapping: + rows = [ + {"name": k, "self_ms": v, "total_ms": v, "count": repeats or agg.counts_by_name.get(k, 0), "max_ms": None} + for k, v in agg.sums_by_name.items() + ] + else: + sums: Dict[str, float] = {} + counts: Dict[str, int] = {} + for raw, val in agg.sums_by_name.items(): + disp = mapping.get(raw, raw) + sums[disp] = sums.get(disp, 0.0) + val + counts[disp] = counts.get(disp, 0) + agg.counts_by_name.get(raw, 0) + rows = [ + {"name": k, "self_ms": v, "total_ms": v, "count": repeats or counts.get(k, 0), "max_ms": None} + for k, v in sums.items() + ] + rows.sort(key=lambda r: r["self_ms"], reverse=True) + return rows + + def roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: + with self._lock: + agg = self._req.get(run_id) + if not agg: + return [] + return list(agg.roots_by_req.get(req_id or "", [])) diff --git a/src/dnet/ring/api/models.py b/src/dnet/ring/api/models.py index 7199142a..fc84162f 100644 --- a/src/dnet/ring/api/models.py +++ b/src/dnet/ring/api/models.py @@ -8,6 +8,7 @@ from ..common import LayerAssignment +from dnet.perf.trace import _Frame class RoleMapping(BaseModel): """Role mapping for chat formatting.""" @@ -403,3 +404,30 @@ class UnloadModelResponse(BaseModel): message: Optional[str] = Field( default=None, description="Overall status or error message" ) + +# Tracer ingest + +class TraceEvent(BaseModel): + type: Literal["B", "E", "I"] = Field(..., description="Event type/phase") + name: str = Field(..., description="Span/mark name") + ts: int = Field(..., description="Timestamp in microseconds") + args: Dict[str, Any] = Field(default_factory=dict) + req_id: Optional[str] = None + pid: Optional[int] = None + tid: Optional[int] = None + +class TraceIngestBatch(BaseModel): + run_id: str = Field(..., description="Bench run identifier") + node_id: str = Field(..., description="Shard/service identity") + batch_seq: int = Field(..., description="Monotonic batch sequence per node") + events: List[TraceEvent] = Field(default_factory=list) + dropped: Optional[int] = Field(default=0, description="Events dropped on sender") + max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") + last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") + schema_version: int = Field(default=1) + +class TraceIngestResponse(BaseModel): + ok: bool = True + accepted: int = 0 + batch_seq: Optional[int] = None + message: Optional[str] = None diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index a9793f9d..ce355f12 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -76,6 +76,8 @@ ShardLoadModelRequest, ShardLoadModelResponse, ShardProfileResponse, + TraceIngestBatch, + TraceIngestResponse, ) from ..data_types import StopCondition from .servicer import ShardApiServicer @@ -368,6 +370,34 @@ async def completions(req: CompletionRequestModel): # type: ignore ) return await self._handle_text_completion(req) + # Ingest trace buffers and forward to REPL + @self.app.post("/trace/ingest") + async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: ignore + logger.debug(f"Received trace buffer.") + try: + if self._trace_ingest_cb is not None: + self._trace_ingest_cb(batch.model_dump()) + return TraceIngestResponse(ok=True, accepted=len(batch.events), batch_seq=batch.batch_seq) + + try: + run_dir = Path("logs/trace/ingest") / batch.run_id + run_dir.mkdir(parents=True, exist_ok=True) + fpath = run_dir / f"{batch.node_id}.jsonl" + with fpath.open("a", encoding="utf-8") as f: + f.write(batch.model_dump_json() + "\n") + except Exception: + logger.warning(f"Unable to write trace ingest buffer to temp file {fpath}") + return TraceIngestResponse( + ok=True, + accepted=len(batch.events), + batch_seq=batch.batch_seq, + message="no aggregator; appended" + ) + except Exception as e: + logger.warning(f"Unable to ingest trace buffer: {e}") + return TraceIngestResponse(ok=False, accepted=0, message=str(e)) + + async def _handle_prepare_topology( self, req: PrepareTopologyRequest ) -> TopologyInfo: @@ -1581,3 +1611,7 @@ async def shutdown(self) -> None: logger.warning("Discovery service was not running") logger.info("API server shutdown complete") + + # REPL helper to install a trace ingestion callback + def set_trace_ingest_callback(self, cb: Optional[Callable[[Dict[str, Any]], None]]) -> None: + self._trace_ingest_cb = cb diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index b9226e3b..d25705fe 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -90,26 +90,26 @@ def _process_activation(self, activation_msg: ActivationMessage): return # Prepare input activation - with self.tracer.frame("compute.thread", "activations.process"): - if activation_msg.dtype == "tokens": - # tokens were staged as int32 in the pool; embed locally on start shard - input_size = int(np.prod(activation_msg.shape)) - tok_view = input_buffer[:input_size].reshape(activation_msg.shape) - # Convert robustly to MLX int32 and embed (batch=1) + with self.tracer.frame("compute.thread", "activations.process") as f: + if activation_msg.dtype == "tokens": # embed locally on start shard + f.event("embed_tokens") + numel = int(np.prod(activation_msg.shape)) + tok_view = input_buffer[:numel].reshape(activation_msg.shape) toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) x = self.model.embed(toks[None]) if x.dtype != self._wire_mx_dtype: x = x.astype(self._wire_mx_dtype) - else: - # Prepare input activation using MLX view operations only - input_size = int(np.prod(activation_msg.shape)) - x = input_buffer[:input_size].reshape(activation_msg.shape) - # Ensure expected dtype without re-materializing when not needed - try: + + else: # Prepare input activation using MLX view operations only + f.set("activation_dtype", activation_msg.dtype) + numel = int(np.prod(activation_msg.shape)) + x = input_buffer[:numel].reshape(activation_msg.shape) + + try: # Ensure expected dtype if str(x.dtype) != activation_msg.dtype: x = x.astype(mlx_dtype_map[activation_msg.dtype]) except Exception: - pass + logger.warning(f"Unable to update activation dtype") # Compute windows until boundary (stay local as long as possible) current_layer = activation_msg.layer_id + 1 @@ -118,7 +118,8 @@ def _process_activation(self, activation_msg: ActivationMessage): processed = 0 did_early_swap = False - with self.tracer.frame("compute.thread", "weights.prepare"): + with self.tracer.frame("compute.thread", "weights.prepare") as f: + # Determine contiguous local window starting at current_layer window_layers: List[int] = [] _tmp_layer = current_layer @@ -127,89 +128,89 @@ def _process_activation(self, activation_msg: ActivationMessage): _tmp_layer += 1 processed += 1 - if self._mode == "offload" and window_layers: - prep = self._prepared_by_nonce.get(activation_msg.nonce) - if prep is not None: - layers, fut = prep - if layers == window_layers and fut is not None: - try: - fut.result(timeout=30) - except Exception: - pass + if self._mode == "offload" and window_layers: + prep = self._prepared_by_nonce.get(activation_msg.nonce) + if prep is not None: + layers, fut = prep + if layers == window_layers and fut is not None: + try: + fut.result(timeout=30) + except Exception: + pass - # In sliding_fit with a single resident window, proactively evict only the - # non-needed head from the current resident set before loading new weights. - # This prevents LRU from evicting the useful tail during materialization. - if ( - self._mode == "sliding_fit" - and int(self._resident_windows) <= 1 - and window_layers - ): - try: - resident = [] + # In sliding_fit with a single resident window, proactively evict only the + # non-needed head from the current resident set before loading new weights. + # This prevents LRU from evicting the useful tail during materialization. + if ( + self._mode == "sliding_fit" + and int(self._resident_windows) <= 1 + and window_layers + ): try: - resident = self.weight_cache.get_resident_layers() # type: ignore[union-attr] - except Exception: resident = [] - evicted_cnt = self._delta_swap_eviction( - window_layers, resident, activation_msg, early=True - ) - if evicted_cnt > 0: - did_early_swap = True - except Exception: - pass + try: + resident = self.weight_cache.get_resident_layers() # type: ignore[union-attr] + except Exception: + resident = [] + evicted_cnt = self._delta_swap_eviction( + window_layers, resident, activation_msg, early=True + ) + if evicted_cnt > 0: + did_early_swap = True + except Exception: + pass - # Ensure weights for the window are resident and bind only if arrays changed - # if model fits and we've already bound these layers, skip the scan entirely. - fast_fit = ( - self._mode == "fit" - and len(self._assigned_sorted) <= self.window_size - ) - skip_scan = fast_fit and all( - (wl in self._bound_versions) for wl in window_layers - ) - to_bind: Dict[str, mx.array] = {} - if not skip_scan: - t_w_ready = time.perf_counter() - for wl in window_layers: - weights = self.weight_cache.get_weight(wl) - if weights is None: - logger.error("Failed to load weights for layer %s", wl) - self.input_pool.release(activation_msg.pool_id) - return + # Ensure weights for the window are resident and bind only if arrays changed + # if model fits and we've already bound these layers, skip the scan entirely. + fast_fit = ( + self._mode == "fit" + and len(self._assigned_sorted) <= self.window_size + ) + skip_scan = fast_fit and all( + (wl in self._bound_versions) for wl in window_layers + ) + to_bind: Dict[str, mx.array] = {} + if not skip_scan: + t_w_ready = time.perf_counter() + for wl in window_layers: + weights = self.weight_cache.get_weight(wl) + if weights is None: + logger.error("Failed to load weights for layer %s", wl) + self.input_pool.release(activation_msg.pool_id) + return + try: + # Use identity of first array as a cheap version/fingerprint + first_arr = next(iter(weights.values())) + version = id(first_arr) + except StopIteration: + version = -1 + if self._bound_versions.get(wl) != version: + for k, v in weights.items(): + to_bind[k] = v + self._bound_versions[wl] = version + if self._profile: + t_w_ms = (time.perf_counter() - t_w_ready) * 1000.0 + # Only log when non-trivial or binding happened to reduce overhead/noise + if to_bind or t_w_ms > 0.5: + logger.info( + "[PROFILE][WAIT-WEIGHTS] node=%s nonce=%s layers=%s ms=%.3f", + self.node_id, + activation_msg.nonce, + window_layers, + t_w_ms, + ) + + # Opportunistically schedule prefetch for the next window to overlap with compute try: - # Use identity of first array as a cheap version/fingerprint - first_arr = next(iter(weights.values())) - version = id(first_arr) - except StopIteration: - version = -1 - if self._bound_versions.get(wl) != version: - for k, v in weights.items(): - to_bind[k] = v - self._bound_versions[wl] = version - if self._profile: - t_w_ms = (time.perf_counter() - t_w_ready) * 1000.0 - # Only log when non-trivial or binding happened to reduce overhead/noise - if to_bind or t_w_ms > 0.5: - logger.info( - "[PROFILE][WAIT-WEIGHTS] node=%s nonce=%s layers=%s ms=%.3f", - self.node_id, - activation_msg.nonce, - window_layers, - t_w_ms, + next_win_pre = self._next_local_layers( + (window_layers[-1] if window_layers else (activation_msg.layer_id)), + self.window_size, ) - - # Opportunistically schedule prefetch for the next window to overlap with compute - try: - next_win_pre = self._next_local_layers( - (window_layers[-1] if window_layers else (activation_msg.layer_id)), - self.window_size, - ) - for nl in next_win_pre: - self._prefetch_to_ram(nl) - self._enqueue_weight_prefetch(nl) - except Exception: - pass + for nl in next_win_pre: + self._prefetch_to_ram(nl) + self._enqueue_weight_prefetch(nl) + except Exception: + pass # Execute the window with self.tracer.frame("compute.thread", "execute"): diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index cbd9610e..5e339abd 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -90,3 +90,33 @@ class HealthResponse(BaseModel): grpc_port: int = Field(..., description="gRPC server port") http_port: int = Field(..., description="HTTP server port") instance: Optional[str] = Field(default=None, description="Shard name") + + +# Tracer ingest + +class TraceEvent(BaseModel): + type: Literal["B", "E", "I"] = Field(..., description="Event type/phase") + name: str = Field(..., description="Span/mark name") + ts: int = Field(..., description="Timestamp in microseconds") + args: Dict[str, Any] = Field(default_factory=dict) + req_id: Optional[str] = None + pid: Optional[int] = None + tid: Optional[int] = None + + +class TraceIngestBatch(BaseModel): + run_id: str = Field(..., description="Bench run identifier") + node_id: str = Field(..., description="Shard/service identity") + batch_seq: int = Field(..., description="Monotonic batch sequence per node") + events: List[TraceEvent] = Field(default_factory=list) + dropped: Optional[int] = Field(default=0, description="Events dropped on sender") + max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") + last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") + schema_version: int = Field(default=1) + + +class TraceIngestResponse(BaseModel): + ok: bool = True + accepted: int = 0 + batch_seq: Optional[int] = None + message: Optional[str] = None diff --git a/src/dnet/ring/shard/startup.py b/src/dnet/ring/shard/startup.py new file mode 100644 index 00000000..ccfb9f44 --- /dev/null +++ b/src/dnet/ring/shard/startup.py @@ -0,0 +1,563 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, Dict, List, Mapping +import threading +from socket import gethostname +from secrets import token_hex + +import mlx.core as mx +from fastapi import Request +from fastapi.responses import JSONResponse +from grpc import aio as aio_grpc + +from hypercorn import Config +import hypercorn.asyncio as aio_hypercorn +from dnet_p2p.thunderbolt import ThunderboltConnection +from dnet_p2p import ( + DnetDeviceProperties, + discover_thunderbolt_connection, +) + +from ...protos.dnet_ring_pb2_grpc import add_DnetRingServiceServicer_to_server +from .servicer import ShardServicer +from ...utils.logger import logger +from ...utils.serialization import tensor_to_bytes +from ...utils.latency import ( + DeviceLatencyResult, + LatencyMeasurement, + LatencyResults, + calculate_median_latency_seconds, +) +from .models import ( + HealthResponse, + ShardLoadModelRequest, + ShardLoadModelResponse, + ShardProfileRequest, + ShardProfileResponse, + ShardUnloadModelResponse, +) +from ...protos import dnet_ring_pb2 + + +class StartupMixin: + async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()): + self.running = True + try: # Capture the main event loop for cross-thread scheduling + self._loop = asyncio.get_running_loop() + except Exception: + self._loop = None + await self._start_grpc_server() + await self._start_http_server(shutdown_trigger) + await asyncio.sleep(0.2) + + self.background_tasks = [ + asyncio.create_task(self._ingress_worker()), + asyncio.create_task(self._prefetch_worker()), + asyncio.create_task(self._send_worker()), + ] + # Start idle sweeper to close silent streams + try: + if getattr(self, "_streaming_enabled", False) and hasattr( + self, "_stream_sweeper" + ): + self.background_tasks.append( + asyncio.create_task(self._stream_sweeper()) + ) + except Exception: + pass + + self.compute_thread = threading.Thread(target=self._compute_worker, daemon=True) + self.compute_thread.start() + + self._start_discovery() + logger.info( + "Shard node %s started on gRPC port %s HTTP port %s", + self.node_id, + self.grpc_port, + self.http_port, + ) + + def _start_discovery(self) -> None: + """Start mDNS discovery service.""" + hostname = gethostname() + # TODO: optionally take shard name from CLI + instance = f"shard-{token_hex(4)}-{hostname}" + self.discovery.create_instance( + instance, + hostname, + "0.0.0.0", # Binds to all addresses + self.http_port, # HTTP port + self.grpc_port, # gRPC port + is_manager=False, # Shard is never a manager + ) + self.discovery.start() + logger.info( + "Discovery service started for shard node %s with name %s", + self.node_id, + self.discovery.fullname(), + ) + + async def _start_grpc_server(self) -> None: + """Start gRPC server.""" + self.server = aio_grpc.server() + + # Add the ring servicer; shard acts as client for ShardApiService (to API) + servicer = ShardServicer(self) # type: ignore # FIXME: !!! + add_DnetRingServiceServicer_to_server(servicer, self.server) + + listen_addr = f"[::]:{self.grpc_port}" + self.server.add_insecure_port(listen_addr) + await self.server.start() + logger.info( + "Shard node %s gRPC server started on %s", self.node_id, listen_addr + ) + try: + await asyncio.get_running_loop().run_in_executor( + self.executor, self._warmup_serialization + ) + logger.info("Warmup serialization completed") + except Exception as e: + logger.warning("Warmup serialization failed: %s", e) + + def _warmup_serialization(self): + try: + dummy = mx.random.normal((1024, 1024), dtype=mx.float32) + dummy16 = dummy.astype(self._wire_mx_dtype) + _ = tensor_to_bytes(dummy16) + except Exception: + pass + + def _warmup_shard(self): + logger.info( + "[WARMUP] Starting shard warmup with window size %s", self.window_size + ) + batch_size, seq_len = 1, 1 + hidden_size = self.model_metadata.model_config.get("hidden_size", 2560) + x = mx.zeros((batch_size, seq_len, hidden_size), dtype=mx.bfloat16) + start_time = time.perf_counter() + try: + default_n = max(1, int(getattr(self, "_resident_windows", 1))) + except Exception: + default_n = 1 + try: + max_windows = max( + 1, + int( + getattr(self, "config", None).warmup_windows + if getattr(self, "config", None) + else default_n + ), + ) + except Exception: + max_windows = default_n + windows: list[list[int]] = [] + for window_start in range(0, len(self._assigned_sorted), self.window_size): + window_end = min( + window_start + self.window_size, len(self._assigned_sorted) + ) + windows.append(self._assigned_sorted[window_start:window_end]) + for wi, window_layers in enumerate(windows[:max_windows]): + weights_to_bind = {} + for layer_id in window_layers: + weights = self.weight_cache.get_weight(layer_id) + if weights: + for k, v in weights.items(): + weights_to_bind[k] = v + if weights_to_bind: + self.model.load_weights(list(weights_to_bind.items()), strict=False) + try: + for layer_id in window_layers: + x = self.model.apply_single_layer(layer_id, x, cache=None) + _s = mx.sum(x) + mx.eval(_s) + except Exception: + pass + try: + for lid in window_layers: + self.weight_cache.decrease_reference(lid) + except Exception: + pass + if not self._warmup_keep_flag: + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(window_layers) # type: ignore[attr-defined] + except Exception: + pass + try: + self.weight_cache.evict_layers(window_layers) + except Exception: + pass + total_time = (time.perf_counter() - start_time) * 1000 + self._warmup_completed = True + logger.info( + "[WARMUP] Shard warmup completed in %.2fms; windows=%s kept=%s", + total_time, + min(len(windows), max_windows), + int(self._warmup_keep_flag), + ) + + async def _start_http_server(self, shutdown_trigger: Any) -> None: + """Start HTTP server. + + Args: + shutdown_trigger: Shutdown trigger function + """ + await self._setup_routes() + + # Start HTTP server in background + config = Config.from_mapping( + bind=f"0.0.0.0:{self.http_port}", + log_level="info", + log_config=None, + use_reloader=False, + h2c=False, + ) + + # Start the server as a background task + self.http_server = asyncio.create_task( + aio_hypercorn.serve(self.app, config, shutdown_trigger=shutdown_trigger) # type: ignore + ) + logger.info( + "Shard node %s HTTP server started on port %s", self.node_id, self.http_port + ) + + async def _setup_routes(self) -> None: + """Setup HTTP routes.""" + + @self.app.get("/health") + async def health() -> HealthResponse: + try: + instance = self.discovery.instance_name() + except Exception: + instance = None + return HealthResponse( + status="ok", + node_id=self.node_id, + running=self.running, + model_loaded=self._check_model_loaded(), + model_path=self.model_path, + assigned_layers=self.assigned_layers, + queue_size=self.activation_recv_queue.qsize(), + grpc_port=self.grpc_port, + http_port=self.http_port, + instance=instance, + ) + + @self.app.post("/profile") + async def profile(req: ShardProfileRequest) -> ShardProfileResponse: + logger.info("Received /profile request") + try: + # Measure latencies + latency_results = await self._measure_latency_to_devices( + req.devices, req.thunderbolts, req.payload_sizes + ) + + # Profile device using dperf + device_profile = await self._profile_device( + req.repo_id, req.max_batch_exp + ) + + # Overwrite `t_comm` with median latency (subprocess returns a dict) + median_latency = calculate_median_latency_seconds(latency_results) + if median_latency is not None: + device_profile["t_comm"] = float(median_latency) + logger.info( + f"Set t_comm to median latency: {device_profile['t_comm']:.6f}s" + ) + else: + logger.warning( + "No valid latency measurements, keeping default t_comm" + ) + + # Return the dict payload directly + return ShardProfileResponse( + profile=device_profile, + latency=latency_results, + ) + except Exception as e: + logger.error(f"Error in /profile endpoint: {e}") + raise + + @self.app.post("/load_model") + async def load_model_endpoint( + req: ShardLoadModelRequest, + ) -> ShardLoadModelResponse: + """Load model with specified layers.""" + try: + logger.info( + f"HTTP /load_model: model={req.model_path}, layers={req.layers}, " + f"next_node={req.next_node or 'none'}, window_size={req.window_size}, " + f"total_layers={req.total_layers}, api_callback={req.api_callback_address or 'none'}" + ) + result = await self.load_model(req) + return result + + except Exception as e: + logger.error(f"Error in /load_model endpoint: {e}") + return ShardLoadModelResponse( + success=False, + message=f"Error: {str(e)}", + layers_loaded=[], + load_time_ms=0.0, + ) + + @self.app.post("/unload_model") + async def unload_model_endpoint() -> ShardUnloadModelResponse: + """Unload current model.""" + try: + logger.info("HTTP /unload_model") + result = await self.unload_model() + return result + + except Exception as e: + logger.error(f"Error in /unload_model endpoint: {e}") + return ShardUnloadModelResponse( + success=False, + message=f"Error: {str(e)}", + ) + + @self.app.post("/warm") + async def warm(request: Request) -> JSONResponse: + try: + body = await request.json() + start = int(body.get("start", -1)) + window = int(body.get("window", self.window_size)) + if start < 0: + return JSONResponse( + status_code=400, content={"error": "missing/invalid start"} + ) + start_idx = 0 + for i, lyr in enumerate(self._assigned_sorted): + if lyr >= start: + start_idx = i + break + else: + return JSONResponse(content={"prefetched": []}) + window_layers = self._assigned_sorted[ + start_idx : start_idx + max(1, window) + ] + for wl in window_layers: + self._prefetch_to_ram(wl) + self._enqueue_weight_prefetch(wl) + return JSONResponse(content={"prefetched": window_layers}) + except Exception as e: + logger.error("/warm failed: %s", e) + return JSONResponse(status_code=500, content={"error": str(e)}) + + async def _profile_device(self, repo_id: str, max_batch_exp: int) -> dict: + """Profile device using dperf in a subprocess and return a dict. + + Args: + repo_id: Hugging Face repository ID + max_batch_exp: Maximum batch size exponent (2^max_batch_exp) + + Returns: + Device profile information as a plain dict + """ + from ...utils.profile_subproc import profile_device_via_subprocess + + profile_dict = profile_device_via_subprocess( + repo_id, max_batch_exp=max_batch_exp, debug=0 + ) + logger.info("Device profiling completed for node %s", self.node_id) + return profile_dict + + async def _connect_next_node(self) -> bool: + """Connect to next node in ring. + + Returns: + True if connected or no next node, False on failure + """ + if not self.next_node: + logger.info(f"Shard node {self.node_id} is the final shard (no next node)") + return True + + if self.next_node_channel: + logger.debug(f"Shard node {self.node_id} already connected to next node.") + return True + + try: + # use thunderbolt here if available + this_properties = self.discovery.get_own_properties() + thunderbolt_conn = discover_thunderbolt_connection( + this_properties, + self.next_node, + ) + next_ip = ( + thunderbolt_conn.ip_addr + if thunderbolt_conn + else self.next_node.local_ip + ) + address = f"{next_ip}:{self.next_node.shard_port}" + logger.info( + f"Shard node {this_properties.instance} connecting to next node {self.next_node.instance} at {address}" + ) + + self.next_node_channel = aio_grpc.insecure_channel(address) + from ...protos.dnet_ring_pb2_grpc import DnetRingServiceStub + + self.next_node_stub = DnetRingServiceStub(self.next_node_channel) + return True + except Exception as e: + logger.warning( + f"Shard node {self.node_id} failed to connect to next node {address}: {e}" + ) + self.next_node_channel = None + self.next_node_stub = None + return False + + async def _reconnect_next_node(self) -> bool: + try: + if self.next_node_channel: + await self.next_node_channel.close() + except Exception: + pass + self.next_node_channel = None + self.next_node_stub = None + return await self._connect_next_node() + + async def _health_check(self): + try: + health_request = dnet_ring_pb2.HealthRequest(requester_id=str(self.node_id)) + response = await self.next_node_stub.HealthCheck(health_request) # type: ignore + logger.info( + "Shard node %s successfully pinged: %s, healthy: %s", + self.node_id, + response.node_id, + response.healthy, + ) + return True + except Exception as e: + logger.warning( + "Shard node %s failed to ping next node %s: %s", + self.node_id, + self.next_node_address, + e, + ) + return False + + async def _measure_latency_to_devices( + self, + devices: Mapping[str, DnetDeviceProperties], + thunderbolts: Mapping[str, ThunderboltConnection], + payload_sizes: List[int], + ) -> LatencyResults: + """Measure latency to all devices except self. + + Args: + devices: Device information mapping + thunderbolts: Thunderbolt connection information + payload_sizes: List of payload sizes to test + + Returns: + Latency measurement results + """ + latency_results_dict: Dict[str, DeviceLatencyResult] = {} + + for service_name, device_info in devices.items(): + # Skip measuring latency to ourselves + if service_name.startswith(self.discovery.instance_name()): + logger.debug("Skipping latency measurement to self: %s", service_name) + continue + + # Skip measuring latency to API (manager) devices + if device_info.is_manager: + logger.debug( + "Skipping latency measurement to manager/API: %s", service_name + ) + continue + + try: + shard_port = device_info.shard_port + + # Check for Thunderbolt connection + if service_name in thunderbolts: + tb_data = thunderbolts[service_name] + service_ip = tb_data.ip_addr + logger.info( + "Using Thunderbolt for %s at %s, connected to instance %s", + service_name, + service_ip, + tb_data.instance, + ) + else: + # No Thunderbolt, use WiFi + service_ip = device_info.local_ip + + if not shard_port or not service_ip: + logger.warning( + "No shard_port or local_ip for device %s", service_name + ) + continue + + # Connect to target shard's gRPC server + target_address = f"{service_ip}:{shard_port}" + channel = aio_grpc.insecure_channel(target_address) + from ...protos.dnet_ring_pb2_grpc import DnetRingServiceStub + + stub = DnetRingServiceStub(channel) + + # Measure latency for each payload size + latency_measurements: List[LatencyMeasurement] = [] + for payload_size in payload_sizes: + # Create dummy payload + dummy_data = b"x" * payload_size + + start_time = time.perf_counter() + timestamp_ms = int(time.time() * 1000) + + request = dnet_ring_pb2.LatencyMeasureRequest( + requester_id=str(self.node_id), + payload_size=payload_size, + dummy_data=dummy_data, + timestamp=timestamp_ms, + ) + + response = await stub.MeasureLatency(request) # type: ignore + end_time = time.perf_counter() + + if response.success: + latency_ms = (end_time - start_time) * 1000 + latency_measurements.append( + LatencyMeasurement( + payload_size=payload_size, + latency_ms=round(latency_ms, 2), + success=True, + error=None, + ) + ) + else: + latency_measurements.append( + LatencyMeasurement( + payload_size=payload_size, + success=False, + error=response.message, + latency_ms=0, + ) + ) + + # Store results + result = DeviceLatencyResult( + target_node_id=response.node_id if response.success else None, + measurements=latency_measurements, + success=True, + error=None, + ) + latency_results_dict[service_name] = result + + # Close channel + await channel.close() + + except Exception as e: + logger.error("Error measuring latency to %s: %s", service_name, e) + result = DeviceLatencyResult( + target_node_id=None, + success=False, + error=str(e), + measurements=[], + ) + latency_results_dict[service_name] = result + + return LatencyResults(results=latency_results_dict) From 01ff68734606431a12dc060479550ea75da693d6 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 03:16:43 -0700 Subject: [PATCH 007/172] Receive correct data format, dump to temp log file without REPL registered callback --- src/dnet/perf/trace.py | 27 ++++++++++++++++++++------- src/dnet/ring/api/models.py | 13 ++++++------- src/dnet/ring/api/node.py | 7 ++++--- src/dnet/ring/shard/models.py | 17 +++++++---------- src/dnet/ring/shard/node.py | 8 +++++--- src/dnet/ring/shard/startup.py | 19 ++++++++----------- 6 files changed, 50 insertions(+), 41 deletions(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 92119d79..78a92a65 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -31,7 +31,7 @@ class TraceConfig: record_pid_tid: bool = True aggregate: bool = False aggregate_url: Optional[str] = None - agg_max_events: int = 1000 + agg_max_events: int = 300 class _NoopFrame: def __enter__(self): @@ -107,17 +107,21 @@ def stop_aggregator(self, *, flush: bool = True, timeout: float = 5.0) -> None: self._agg_thread = None def _agg_exec(self) -> None: - url = self.config.aggregate_url or "" - assert url != "" + assert self.config.aggregate_url != "" + url = "http://" + self.config.aggregate_url + "/trace/ingest" client = httpx.Client(timeout=5.0) try: - while self._agg_enabled and not self._agg_q.empty(): + logger.debug(f"Aggregation worker thread {self._agg_enabled}, {self._agg_q.empty()}") + while self._agg_enabled or not self._agg_q.empty(): try: batch = self._agg_q.get(timeout=0.2) except queue.Empty: continue + logger.info(f"Sending trace buffer to API : {url}") try: - client.post(url, json=batch) + res = client.post(url, json=batch) + if res.status_code != 200: + logger.error(f"Aggregator POST failed {res.status_code}: {res.text}") except Exception: logger.warning(f"Unable to POST trace aggregation data to {url}") finally: @@ -128,6 +132,14 @@ def _agg_exec(self) -> None: except Exception: logger.warining("Unable to close httpx client.") + # We don't have the API addr at init time + def update_api_addr(self, addr): + self.config.aggregate_url = addr + logger.debug(f"Updated API Address: {self.config.aggregate_url}") + + def update_confi(self, config): + pass + def start(self, *, reset: bool = True) -> None: self._active = bool(self.config.enabled) if not self._active: @@ -201,8 +213,9 @@ def _emit(self, ev: Dict[str, Any]) -> None: if self._agg_enabled: if len(self._events) < self._agg_max_events: return - batch = { "run_id": self._agg_run_id, - "node_id": self._agg_node_id or self.config.node_id, + logger.debug(f"Queuing tracer frame batch of {len(self._events)}") + batch = { "run_id": (self._req_id or "NONE"), + "node_id": (self.config.node_id or "UNKNOWN_NODE"), "events": list(self._events)} try: self._agg_q.put_nowait(batch) diff --git a/src/dnet/ring/api/models.py b/src/dnet/ring/api/models.py index fc84162f..c9f4b45b 100644 --- a/src/dnet/ring/api/models.py +++ b/src/dnet/ring/api/models.py @@ -408,9 +408,9 @@ class UnloadModelResponse(BaseModel): # Tracer ingest class TraceEvent(BaseModel): - type: Literal["B", "E", "I"] = Field(..., description="Event type/phase") + type: str = Field(..., description="Event type/phase") name: str = Field(..., description="Span/mark name") - ts: int = Field(..., description="Timestamp in microseconds") + ts: float = Field(..., description="Timestamp in microseconds") args: Dict[str, Any] = Field(default_factory=dict) req_id: Optional[str] = None pid: Optional[int] = None @@ -419,12 +419,11 @@ class TraceEvent(BaseModel): class TraceIngestBatch(BaseModel): run_id: str = Field(..., description="Bench run identifier") node_id: str = Field(..., description="Shard/service identity") - batch_seq: int = Field(..., description="Monotonic batch sequence per node") events: List[TraceEvent] = Field(default_factory=list) - dropped: Optional[int] = Field(default=0, description="Events dropped on sender") - max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") - last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") - schema_version: int = Field(default=1) + #dropped: Optional[int] = Field(default=0, description="Events dropped on sender") + #max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") + #last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") + #schema_version: int = Field(default=1) class TraceIngestResponse(BaseModel): ok: bool = True diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index ce355f12..2ef371d5 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -373,14 +373,15 @@ async def completions(req: CompletionRequestModel): # type: ignore # Ingest trace buffers and forward to REPL @self.app.post("/trace/ingest") async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: ignore - logger.debug(f"Received trace buffer.") try: if self._trace_ingest_cb is not None: + logger.debug("Forwarding trace batch to REPL.") self._trace_ingest_cb(batch.model_dump()) - return TraceIngestResponse(ok=True, accepted=len(batch.events), batch_seq=batch.batch_seq) + return TraceIngestResponse(ok=True, accepted=len(batch.events)) try: run_dir = Path("logs/trace/ingest") / batch.run_id + logger.debug(f"callback not available. Dumping trace buffer to file {run_dir}") run_dir.mkdir(parents=True, exist_ok=True) fpath = run_dir / f"{batch.node_id}.jsonl" with fpath.open("a", encoding="utf-8") as f: @@ -390,7 +391,6 @@ async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: return TraceIngestResponse( ok=True, accepted=len(batch.events), - batch_seq=batch.batch_seq, message="no aggregator; appended" ) except Exception as e: @@ -1057,6 +1057,7 @@ async def _collect_shard_profiles( payload_sizes=payload_sizes, max_batch_exp=max_batch_exp, devices=shards, + api_address=f"{this_device.local_ip}:{this_device.server_port}" ).model_dump(), timeout=1000.0, ) diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index 5e339abd..e8b46c68 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -53,6 +53,7 @@ class ShardUnloadModelResponse(BaseModel): class ShardProfileRequest(BaseModel): """Request to profile device and measure latencies.""" + api_address: Optional[str] = Field( ..., description="API Address" ) devices: Dict[str, DnetDeviceProperties] = Field( ..., description="Device information mapping" ) @@ -95,28 +96,24 @@ class HealthResponse(BaseModel): # Tracer ingest class TraceEvent(BaseModel): - type: Literal["B", "E", "I"] = Field(..., description="Event type/phase") + type: str = Field(..., description="Event type/phase") name: str = Field(..., description="Span/mark name") - ts: int = Field(..., description="Timestamp in microseconds") + ts: float = Field(..., description="Timestamp in microseconds") args: Dict[str, Any] = Field(default_factory=dict) req_id: Optional[str] = None pid: Optional[int] = None tid: Optional[int] = None - class TraceIngestBatch(BaseModel): run_id: str = Field(..., description="Bench run identifier") node_id: str = Field(..., description="Shard/service identity") - batch_seq: int = Field(..., description="Monotonic batch sequence per node") events: List[TraceEvent] = Field(default_factory=list) - dropped: Optional[int] = Field(default=0, description="Events dropped on sender") - max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") - last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") - schema_version: int = Field(default=1) - + #dropped: Optional[int] = Field(default=0, description="Events dropped on sender") + #max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") + #last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") + #schema_version: int = Field(default=1) class TraceIngestResponse(BaseModel): ok: bool = True accepted: int = 0 - batch_seq: Optional[int] = None message: Optional[str] = None diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index e4b73a78..2f9d6a7a 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -62,7 +62,7 @@ from .comms import CommsMixin from ..weight_cache import WeightCache -from dnet.perf.trace import TraceConfig, Tracer +from dnet.perf import TraceConfig, Tracer class RingShardNode(ComputeMixin, PrefetchMixin, CommsMixin): @@ -205,12 +205,14 @@ def __init__( # Debug tracing cfg = TraceConfig( file="./trace.json", - streaming=True, + streaming=False, include_prefixes = ("src/dnet/"), include_c_calls = False, budget = 10000, enabled = True, record_pid_tid = True, + aggregate=False, + aggregate_url=None, # FIXME: This is set when we get a /profile req ) self.tracer = Tracer(cfg) self.tracer.start() @@ -233,7 +235,7 @@ def __init__( ) async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse: - """Load model with specified layers. """ + """Load model with specified layers""" try: # Check if already loaded with same configuration if ( self.model is not None diff --git a/src/dnet/ring/shard/startup.py b/src/dnet/ring/shard/startup.py index ccfb9f44..7aaf8e73 100644 --- a/src/dnet/ring/shard/startup.py +++ b/src/dnet/ring/shard/startup.py @@ -249,15 +249,14 @@ async def health() -> HealthResponse: async def profile(req: ShardProfileRequest) -> ShardProfileResponse: logger.info("Received /profile request") try: - # Measure latencies - latency_results = await self._measure_latency_to_devices( - req.devices, req.thunderbolts, req.payload_sizes - ) + # Since this is the first request we get from API grab the address and store it + # TODO: Have a handshake request before this one where we share addresses and state + self.api_address = req.api_address + self.tracer.update_api_addr(self.api_address) + self.tracer.start_aggregator() - # Profile device using dperf - device_profile = await self._profile_device( - req.repo_id, req.max_batch_exp - ) + latency_results = await self._measure_latency_to_devices( req.devices, req.thunderbolts, req.payload_sizes) + device_profile = await self._profile_device( req.repo_id, req.max_batch_exp) # Overwrite `t_comm` with median latency (subprocess returns a dict) median_latency = calculate_median_latency_seconds(latency_results) @@ -267,9 +266,7 @@ async def profile(req: ShardProfileRequest) -> ShardProfileResponse: f"Set t_comm to median latency: {device_profile['t_comm']:.6f}s" ) else: - logger.warning( - "No valid latency measurements, keeping default t_comm" - ) + logger.warning( "No valid latency measurements, keeping default t_comm") # Return the dict payload directly return ShardProfileResponse( From 5ae6ed2e3655cdd4ea7d4061acf01cd63f8ddfd5 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 04:08:11 -0700 Subject: [PATCH 008/172] register repl callback --- src/dnet/ring/api/node.py | 1 + src/repl.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 2ef371d5..5bdf4416 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -1615,4 +1615,5 @@ async def shutdown(self) -> None: # REPL helper to install a trace ingestion callback def set_trace_ingest_callback(self, cb: Optional[Callable[[Dict[str, Any]], None]]) -> None: + logger.debug(f"Registered tracer ingest callback.") self._trace_ingest_cb = cb diff --git a/src/repl.py b/src/repl.py index 0b95a19f..f6483c05 100644 --- a/src/repl.py +++ b/src/repl.py @@ -120,6 +120,7 @@ def do_api(self, cmd: List[str]) -> None: http_port or self.state.api_http_port, grpc_port or self.state.api_grpc_port ) + self.api_call("set_trace_ingest_callback", self.__trace_cb, timeout=2.0) elif cmd[1] == "stop": self.stop_api() elif cmd[1] == "status": @@ -331,6 +332,11 @@ def handle_start_worker(self): # ===== Handle API server + # Tracer frames ingest callback + def __trace_cb(self, data): + dprint(str(data)) + pass + async def _api_main(self) -> None: # main thread loop self._api_loop = asyncio.get_running_loop() self._api_shutdown_e = asyncio.Event() @@ -374,7 +380,7 @@ def start_api(self, http_port: int=8080, grpc_port: int=50500, timeout=10): if not self._api_ready.wait(timeout): raise RuntimeError("API Server Timeout.") if self._api_exc is not None: - raise RuntimeError(f"API Server failed to start: {e}") + raise RuntimeError(f"API Server failed to start") def stop_api(self, timeout: float = 5.0) -> None: if not self._api_thread: return @@ -459,6 +465,7 @@ def _print_nodes_table(self, rows: List[Any]) -> None: "head": "head" if r["head"] else "no", } sys.stdout.write(" ".join(str(vals[h]).ljust(w[h]) for h in headers)) + sys.stdout.write("\n") sys.stdout.write("\n\n") sys.stdout.flush() From b145c881207b285100a3f15d9eb43d8c90909f19 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 14:09:09 -0700 Subject: [PATCH 009/172] add llama3 model script --- src/dnet/ring/model/__init__.py | 5 + src/dnet/ring/model/llama3.py | 181 ++++++++++++++++++++++++++++++++ 2 files changed, 186 insertions(+) create mode 100644 src/dnet/ring/model/llama3.py diff --git a/src/dnet/ring/model/__init__.py b/src/dnet/ring/model/__init__.py index 4f432a7e..57ccdc32 100644 --- a/src/dnet/ring/model/__init__.py +++ b/src/dnet/ring/model/__init__.py @@ -8,6 +8,11 @@ from .llama import LlamaRingModel from .gpt_oss import GptOssRingModel from .qwen3 import Qwen3RingModel +from .llama3 import Llama3RingModel +from .llama4 import Llama4RingModel +from .gpt_oss import GptOssRingModel +from .glm import GLMRingModel +from .glm4 import GLM4RingModel def get_ring_model( diff --git a/src/dnet/ring/model/llama3.py b/src/dnet/ring/model/llama3.py new file mode 100644 index 00000000..0583ac9c --- /dev/null +++ b/src/dnet/ring/model/llama3.py @@ -0,0 +1,181 @@ +from typing import Any, Dict, List, Optional, Tuple + +import mlx.nn as nn +import mlx.core as mx +from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.llama import ModelArgs, TransformerBlock + +from .base import BaseRingModel + +import logging +logger = logging.getLogger(__name__) + + +class Llama3RingModel(BaseRingModel): + model_type = "llama" + + def __init__( + self, + model_config: Any, + assigned_layers: Optional[List[int]] = [], + is_api_layer: bool = False + ): + super().__init__() + + if is_api_layer and assigned_layers: + raise RuntimeError(f"API Service doesn't handle layers") + + self.config = ModelArgs.from_dict(model_config) + self.is_api_layer = is_api_layer + + self._converted_to_quantized = False + self.runtime_cache: Optional[List[Any]] = None + + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + self.norm = nn.RMSNorm(self.config.hidden_size, self.config.rms_norm_eps) + self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) + + self.layers: List[nn.Module] = [] + self.abs_to_local: Dict[int, int] = {} + + for i, l in enumerate(sorted(assigned_layers or [])): + self.layers.append(TransformerBlock(self.config)) + self.abs_to_local[l] = i + logger.debug(f"Created {len(self.layers)} Transformer layers") + logger.debug(f"abs_to_local mapping: {self.abs_to_local}") + + @property + def decoding_layers(self): + return self.layers + + @property + def head_dim(self) -> Tuple[int, int]: + return self.config.head_dim + + @property + def n_kv_heads(self) -> int: + return self.config.num_key_value_heads + + @property + def num_layers(self) -> int: + return len(self.layers) + + def set_runtime_cache(self, cache: Optional[List[Any]]) -> None: + self._runtime_cache = cache + + def class_predicate(p, m): + return hasattr(m, "to_quantized") + + def embed(self, x: mx.array): + return self.embed_tokens(x) if self.is_api_layer else x + + def normalize(self, x: mx.array): + return self.norm(x) if self.is_api_layer else x + + def lm_project(self, x: mx.array): + return self.lm_head(x) if self.is_api_layer else x + + def quantize_layers(self): + self.quantization = None + if hasattr(self.config, "quantization"): + self.quantization = getattr(self.config, "quantization") + elif hasattr(self.config, "quantization_config"): + self.quantization = getattr(self.config, "quantization_config") + + if self.quantization is not None: + bits = int(self.quantization.get("bits", 8)) + group = int(self.quantization.get("group_size", 64)) + if self.is_api_layer: + try: + from mlx.nn.layers.quantized import QuantizedEmbedding + self.embed_tokens = QuantizedEmbedding(self.config.vocab_size, + self.config.hidden_size, + group_size=group, bits=bits) + + logger.debug(f"API Service initialized to QuantizedEmbedding:" + f"{self.config.vocab_size}, hidden={self.config.hidden_size}" + f"group_size={group}, bits={bits}") + except Exception as e: + logger.warning(f"Unable to initialize QuantizedEmbedding: {e}") + + else: + try: + nn.quantize(self, bits=bits, group_size=group, class_predicate=Llama3RingModel.class_predicate) + logger.debug(f"Quantized the model: bits={bits}, group_size={group}") + self._converted_to_quantized = True + except: + self._converted_to_quantized = False + + def forward( + self, + x: mx.array, + cache: Optional[List[Any]] = None + ): + mask = create_attention_mask(x, cache) + if cache is None: + cache = [None] * len(self.layers) + + for i, l in enumerate(self.layers): + x = l(x, mask, cache[i] if i < len(cache) else None) + + return x + + # TODO: Original implementation is slidin window. Bench to see if it's faster or just do sparse + def apply_single_layer( + self, + layer_idx: int, + x: mx.array, + cache: Optional[List[Any]] = None + ): + if layer_idx not in self.abs_to_local: + raise RuntimeError(f"Attempted execution of foreign layer {layer_idx}") + mask = create_attention_mask(x, cache) + local_idx = self.abs_to_local[layer_idx] + logger.debug(f"apply_single_layer: layer:{layer_idx}, local_idx:{local_idx}, input_shape:{x.shape}") + + layer = self.layers[local_idx] + ret = self.layers[local_idx](x, mask, cache[local_idx] if local_idx < len(cache) else None) + logger.debug(f"Executed layer:{layer_idx} with output shape: {ret.shape}") + return ret + + def load_weights(self, weights, strict=False): + weight_keys = [k for k, _ in weights] + has_scales = any(".scales" in k for k in weight_keys) + has_biases = any(".biases" in k for k in weight_keys) + + if has_scales and has_biases: + if not self._converted_to_quantized: + self.quantize_layers() + + shard_weights = {} + for k, v in weights: + if k.startswith("model.layers.") or k.startswith("layers."): + p = k.split(".") + idx_pos = 2 if p[0] == "model" else 1 + try: + idx = int(p[idx_pos]) + except Exception as e: + logger.warning(f"Unable to read weight positions: {e}") + continue + if idx not in self.abs_to_local: + continue + local_idx = self.abs_to_local[idx] + p[idx_pos] = str(local_idx) + if p[0] == "model": + p = p[1:] + new_key = ".".join(p) + logger.debug(f"Mapping weight {k} -> {new_key}") + shard_weights[new_key] = v + + elif self.is_api_layer: + if (k.startswith("embed_tokens") or k.startswith("lm_head") or k.startswith("norm")): + shard_weights[k] = v + + if shard_weights: + try: + super().load_weights(list(shard_weights.items()), strict=strict) + logger.debug(f"Loaded {len(shard_weights.keys())} weights into model") + except Exception as e: + logger.error(f"Failed to load weights: {e}") + logger.error(f"Weight keys: {list(shard_weights.keys())}") + raise From 6e6cfd35fe0eb9e59450323c4804e739f1c45eb8 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 14:10:32 -0700 Subject: [PATCH 010/172] comment out unavailable models --- src/dnet/ring/model/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dnet/ring/model/__init__.py b/src/dnet/ring/model/__init__.py index 57ccdc32..93d31dd8 100644 --- a/src/dnet/ring/model/__init__.py +++ b/src/dnet/ring/model/__init__.py @@ -9,10 +9,10 @@ from .gpt_oss import GptOssRingModel from .qwen3 import Qwen3RingModel from .llama3 import Llama3RingModel -from .llama4 import Llama4RingModel -from .gpt_oss import GptOssRingModel -from .glm import GLMRingModel -from .glm4 import GLM4RingModel +#from .llama4 import Llama4RingModel +#from .gpt_oss import GptOssRingModel +#from .glm import GLMRingModel +#from .glm4 import GLM4RingModel def get_ring_model( From e4312e29d9281755bbf1d31ec2cd7af3b03bc741 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 15:53:07 -0700 Subject: [PATCH 011/172] embed correctly --- src/dnet/ring/model/llama3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dnet/ring/model/llama3.py b/src/dnet/ring/model/llama3.py index 0583ac9c..cdec5669 100644 --- a/src/dnet/ring/model/llama3.py +++ b/src/dnet/ring/model/llama3.py @@ -67,7 +67,7 @@ def class_predicate(p, m): return hasattr(m, "to_quantized") def embed(self, x: mx.array): - return self.embed_tokens(x) if self.is_api_layer else x + return self.embed_tokens(x) def normalize(self, x: mx.array): return self.norm(x) if self.is_api_layer else x From e7a8ee8644e0749d80ecb68f2772073f2b5ead03 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 16:09:09 -0700 Subject: [PATCH 012/172] don't filter weights based on is_api_layer --- src/dnet/ring/model/llama3.py | 45 ++++++++++++++++------------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/src/dnet/ring/model/llama3.py b/src/dnet/ring/model/llama3.py index cdec5669..231f02b5 100644 --- a/src/dnet/ring/model/llama3.py +++ b/src/dnet/ring/model/llama3.py @@ -70,10 +70,10 @@ def embed(self, x: mx.array): return self.embed_tokens(x) def normalize(self, x: mx.array): - return self.norm(x) if self.is_api_layer else x + return self.norm(x) def lm_project(self, x: mx.array): - return self.lm_head(x) if self.is_api_layer else x + return self.lm_head(x) def quantize_layers(self): self.quantization = None @@ -85,26 +85,24 @@ def quantize_layers(self): if self.quantization is not None: bits = int(self.quantization.get("bits", 8)) group = int(self.quantization.get("group_size", 64)) - if self.is_api_layer: - try: - from mlx.nn.layers.quantized import QuantizedEmbedding - self.embed_tokens = QuantizedEmbedding(self.config.vocab_size, - self.config.hidden_size, - group_size=group, bits=bits) - - logger.debug(f"API Service initialized to QuantizedEmbedding:" - f"{self.config.vocab_size}, hidden={self.config.hidden_size}" - f"group_size={group}, bits={bits}") - except Exception as e: - logger.warning(f"Unable to initialize QuantizedEmbedding: {e}") + try: + from mlx.nn.layers.quantized import QuantizedEmbedding + self.embed_tokens = QuantizedEmbedding(self.config.vocab_size, + self.config.hidden_size, + group_size=group, bits=bits) + + logger.debug(f"API Service initialized to QuantizedEmbedding:" + f"{self.config.vocab_size}, hidden={self.config.hidden_size}" + f"group_size={group}, bits={bits}") + except Exception as e: + logger.warning(f"Unable to initialize QuantizedEmbedding: {e}") - else: - try: - nn.quantize(self, bits=bits, group_size=group, class_predicate=Llama3RingModel.class_predicate) - logger.debug(f"Quantized the model: bits={bits}, group_size={group}") - self._converted_to_quantized = True - except: - self._converted_to_quantized = False + try: + nn.quantize(self, bits=bits, group_size=group, class_predicate=Llama3RingModel.class_predicate) + logger.debug(f"Quantized the model: bits={bits}, group_size={group}") + self._converted_to_quantized = True + except: + self._converted_to_quantized = False def forward( self, @@ -167,9 +165,8 @@ def load_weights(self, weights, strict=False): logger.debug(f"Mapping weight {k} -> {new_key}") shard_weights[new_key] = v - elif self.is_api_layer: - if (k.startswith("embed_tokens") or k.startswith("lm_head") or k.startswith("norm")): - shard_weights[k] = v + elif (k.startswith("embed_tokens") or k.startswith("lm_head") or k.startswith("norm")): + shard_weights[k] = v if shard_weights: try: From 9103284af36c7365493b5715ec7bdafe6052f213 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 03:23:41 -0700 Subject: [PATCH 013/172] tracer config request, aggregate and separate api logger for less repl noise --- src/dnet/perf/trace.py | 9 +- src/dnet/perf/utils/__init__.py | 3 +- src/dnet/perf/utils/aggregator.py | 1 + src/dnet/ring/api/api_logging.py | 51 +++++++++ src/dnet/ring/api/node.py | 58 ++++++++++- src/dnet/ring/api/servicer.py | 3 +- src/dnet/ring/api/utils.py | 3 +- src/dnet/ring/shard/models.py | 20 +++- src/dnet/ring/shard/startup.py | 37 +++++-- src/repl.py | 167 ++++++++++++++++++++++++++---- 10 files changed, 310 insertions(+), 42 deletions(-) create mode 100644 src/dnet/ring/api/api_logging.py diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 78a92a65..e4a92cfb 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -108,7 +108,6 @@ def stop_aggregator(self, *, flush: bool = True, timeout: float = 5.0) -> None: def _agg_exec(self) -> None: assert self.config.aggregate_url != "" - url = "http://" + self.config.aggregate_url + "/trace/ingest" client = httpx.Client(timeout=5.0) try: logger.debug(f"Aggregation worker thread {self._agg_enabled}, {self._agg_q.empty()}") @@ -117,13 +116,13 @@ def _agg_exec(self) -> None: batch = self._agg_q.get(timeout=0.2) except queue.Empty: continue - logger.info(f"Sending trace buffer to API : {url}") + logger.info(f"Sending trace buffer to API : {self.config.aggregate_url}") try: - res = client.post(url, json=batch) + res = client.post(self.config.aggregate_url, json=batch) if res.status_code != 200: logger.error(f"Aggregator POST failed {res.status_code}: {res.text}") - except Exception: - logger.warning(f"Unable to POST trace aggregation data to {url}") + except Exception as e: + logger.warning(f"Unable to POST trace aggregation data to {self.config.aggregate_url}: {e}") finally: self._agg_q.task_done() finally: diff --git a/src/dnet/perf/utils/__init__.py b/src/dnet/perf/utils/__init__.py index 9c310fd5..7228627d 100644 --- a/src/dnet/perf/utils/__init__.py +++ b/src/dnet/perf/utils/__init__.py @@ -1,2 +1 @@ - -from aggregator import TraceAggregator +from .aggregator import TraceAggregator diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index f9bbdef6..2736a25b 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -104,6 +104,7 @@ def __init__(self) -> None: def enqueue(self, batch: Dict[str, Any]) -> None: run_id = batch.get("run_id") node_id = batch.get("node_id") + logger.debug(f"Enquing trace buffer from {run_id}, {node_id}") if not run_id or not node_id: return events = batch.get("events") or [] diff --git a/src/dnet/ring/api/api_logging.py b/src/dnet/ring/api/api_logging.py new file mode 100644 index 00000000..d90c526c --- /dev/null +++ b/src/dnet/ring/api/api_logging.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import os +import logging +from logging.handlers import RotatingFileHandler +from pathlib import Path + + +_CONFIGURED_FLAG = "_dnet_api_logger_configured" + + +def get_api_logger() -> logging.Logger: + """Return a process‑local logger for the API server. + + - Does not propagate to the root logger (so it won't spam the REPL TTY). + - Writes to logs/api.log with rotation. + - Level is controlled by DNET_API_LOG (default: INFO). + """ + log = logging.getLogger("dnet.api") + if getattr(log, _CONFIGURED_FLAG, False): + return log + + # Level from env, fallback INFO + level_name = (os.getenv("DNET_API_LOG", "INFO") or "INFO").strip().upper() + level = getattr(logging, level_name, logging.INFO) + log.setLevel(level) + + # Do not bubble to root (console) + log.propagate = False + + # Ensure logs directory + try: + Path("logs").mkdir(parents=True, exist_ok=True) + except Exception: + pass + + # Attach a rotating file handler + try: + fh = RotatingFileHandler("logs/api.log", maxBytes=10_000_000, backupCount=5) + fmt = logging.Formatter( + "%(asctime)s %(levelname)s [%(threadName)s] %(name)s: %(message)s" + ) + fh.setFormatter(fmt) + log.addHandler(fh) + except Exception: + # As a last resort, attach a NullHandler to avoid 'No handler' warnings + log.addHandler(logging.NullHandler()) + + setattr(log, _CONFIGURED_FLAG, True) + return log + diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 5bdf4416..08bc1cbd 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -39,6 +39,7 @@ add_ShardApiServiceServicer_to_server, ) +from .api_logging import get_api_logger from ...utils.logger import logger from ...utils.banner import print_startup_banner from ...utils.model import ( @@ -78,6 +79,8 @@ ShardProfileResponse, TraceIngestBatch, TraceIngestResponse, + TraceConfigRequest, + TraceConfigResponse, ) from ..data_types import StopCondition from .servicer import ShardApiServicer @@ -101,6 +104,9 @@ async def azip(*async_iterables): break +logger = get_api_logger() + + class RingApiNode: """API node for distributed inference ring with dynamic topology.""" @@ -155,6 +161,24 @@ async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()) -> None: shutdown_trigger: Shutdown trigger function """ self.running = True + # Reduce third‑party library noise in this process (keeps REPL TTY clean) + try: + import logging as _logging + for name in ( + "grpc", + "grpc._cython", + "asyncio", + "hpack", + "h2", + "hypercorn", + "hypercorn.error", + "hypercorn.access", + ): + lg = _logging.getLogger(name) + lg.setLevel(_logging.CRITICAL) + lg.propagate = False + except Exception: + pass await self._start_grpc_server() await self._start_discovery() @@ -203,7 +227,7 @@ async def _start_http_server(self, shutdown_trigger: Any) -> None: config = Config.from_mapping( bind=f"0.0.0.0:{self.http_port}", - log_level="info", + log_level="error", # keep HTTP server quiet on console log_config=None, use_reloader=False, h2c=True, @@ -397,6 +421,38 @@ async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: logger.warning(f"Unable to ingest trace buffer: {e}") return TraceIngestResponse(ok=False, accepted=0, message=str(e)) + async def _forward_trace_config(self, cfg: Any) -> bool: + shards = self._get_shards_from_discovery() + this = self.discovery.get_own_properties() + api_endpoint = f"http://{this.local_ip}:{this.server_port}/trace/ingest" + payload = TraceConfigRequest( + file=cfg.file, + streaming=cfg.streaming, + include_prefixes=cfg.include_prefixes, + include_c_calls=cfg.include_c_calls, + budget=cfg.budget, + enabled=cfg.enabled, + node_id=None, + record_pid_tid=cfg.record_pid_tid, + aggregate=cfg.aggregate, + aggregate_url=api_endpoint, + agg_max_events=cfg.agg_max_events + ) + + async with httpx.AsyncClient(timeout=5.0) as client: + for name, props in shards.items(): + url = f"http://{props.local_ip}:{props.server_port}/trace" + logger.debug(f"Forwarding trace config to {url}") + payload.node_id = name + try: + res = await client.post(url, json=dict(payload)) + if res.status_code != 200: + logger.warning(f"Failed to POST tracer config to node {name}.") + except Exception as e: + logger.warning(f"Failed to POST tracer config: {e}") + return False + return True + async def _handle_prepare_topology( self, req: PrepareTopologyRequest diff --git a/src/dnet/ring/api/servicer.py b/src/dnet/ring/api/servicer.py index 31b53d35..a1024822 100644 --- a/src/dnet/ring/api/servicer.py +++ b/src/dnet/ring/api/servicer.py @@ -6,7 +6,8 @@ from ...protos import shard_api_comm_pb2 as pb2 from ...protos import shard_api_comm_pb2_grpc as pb2_grpc -from ...utils.logger import logger +from .api_logging import get_api_logger +logger = get_api_logger() if TYPE_CHECKING: pass diff --git a/src/dnet/ring/api/utils.py b/src/dnet/ring/api/utils.py index ad0fee36..536e16fb 100644 --- a/src/dnet/ring/api/utils.py +++ b/src/dnet/ring/api/utils.py @@ -16,7 +16,8 @@ from .models import ChatParams - +from .api_logging import get_api_logger +logger = get_api_logger() def create_generate_step_for_ring_with_grpc( stub: DnetRingServiceStub, diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index e8b46c68..f54d7954 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -1,6 +1,6 @@ """Shard models for dnet ring topology endpoints.""" -from typing import Any, Dict, List, Optional, Literal +from typing import Any, Dict, List, Optional, Literal, Tuple from pydantic import BaseModel, Field from dnet_p2p import DnetDeviceProperties, ThunderboltConnection @@ -93,8 +93,24 @@ class HealthResponse(BaseModel): instance: Optional[str] = Field(default=None, description="Shard name") -# Tracer ingest +# Tracer +class TraceConfigRequest(BaseModel): + file: str = Field(..., description="File for trace streaming") + streaming: bool = Field(..., description="Toggle for trace streaming to file") + include_prefixes: List[str] = Field(default=("src/dnet/"), description="") + include_c_calls: bool = Field(default=False, description="") + budget: int = Field(default=0, description="Max amount of traces in memory") + enabled: bool = Field(default=False, description="Start or stop tracing") + node_id: Optional[str] = Field(default="NONE", description="") + record_pid_tid: bool = Field(default=True, descriptino="") + aggregate: bool = Field(default=True, description="") + aggregate_url: Optional[str] = Field(default=None, description="") + agg_max_events: int = Field(default=300, description="Threshold for sending frames to API") + +class TraceConfigResponse(BaseModel): + ok: bool = True + class TraceEvent(BaseModel): type: str = Field(..., description="Event type/phase") name: str = Field(..., description="Span/mark name") diff --git a/src/dnet/ring/shard/startup.py b/src/dnet/ring/shard/startup.py index 7aaf8e73..6790be64 100644 --- a/src/dnet/ring/shard/startup.py +++ b/src/dnet/ring/shard/startup.py @@ -20,6 +20,8 @@ discover_thunderbolt_connection, ) +from dnet.perf.trace import TraceConfig + from ...protos.dnet_ring_pb2_grpc import add_DnetRingServiceServicer_to_server from .servicer import ShardServicer from ...utils.logger import logger @@ -37,6 +39,8 @@ ShardProfileRequest, ShardProfileResponse, ShardUnloadModelResponse, + TraceConfigRequest, + TraceConfigResponse, ) from ...protos import dnet_ring_pb2 @@ -247,14 +251,7 @@ async def health() -> HealthResponse: @self.app.post("/profile") async def profile(req: ShardProfileRequest) -> ShardProfileResponse: - logger.info("Received /profile request") try: - # Since this is the first request we get from API grab the address and store it - # TODO: Have a handshake request before this one where we share addresses and state - self.api_address = req.api_address - self.tracer.update_api_addr(self.api_address) - self.tracer.start_aggregator() - latency_results = await self._measure_latency_to_devices( req.devices, req.thunderbolts, req.payload_sizes) device_profile = await self._profile_device( req.repo_id, req.max_batch_exp) @@ -277,6 +274,32 @@ async def profile(req: ShardProfileRequest) -> ShardProfileResponse: logger.error(f"Error in /profile endpoint: {e}") raise + @self.app.post("/trace") + async def setup_trace(req: TraceConfigRequest) -> TraceConfigResponse: + try: + cfg = TraceConfig( + file=req.file, + streaming=req.streaming, + include_prefixes=req.include_prefixes, + include_c_calls=req.include_c_calls, + budget=req.budget, + enabled=req.enabled, + node_id=req.node_id, + record_pid_tid=req.record_pid_tid, + aggregate=req.aggregate, + aggregate_url=req.aggregate_url, + agg_max_events=req.agg_max_events, + ) + self.tracer.config = cfg + logger.info("Updated tracer config.") + self.api_address = cfg.aggregate_url + self.tracer.start_aggregator() + logger.debug(cfg) + return TraceConfigResponse(ok=True) + except Exception as e: + logger.error(f"Unable to setup tracing on shard: {e}") + return TraceConfigResponse(ok=False) + @self.app.post("/load_model") async def load_model_endpoint( req: ShardLoadModelRequest, diff --git a/src/repl.py b/src/repl.py index f6483c05..592424a6 100644 --- a/src/repl.py +++ b/src/repl.py @@ -1,11 +1,13 @@ import os import sys +import logging import cmd +import time import argparse import subprocess from dataclasses import dataclass -from typing import Optional, List, Any +from typing import Optional, List, Any, Dict import asyncio import inspect @@ -15,7 +17,8 @@ from dnet.ring.api import RingApiNode from dnet.ring.shard import RingShardNode from dnet.utils.network import NodeAddress -from dnet.utils.logger import logger +#from dnet.utils.logger import logger +from dnet.ring.api.api_logging import get_api_logger from dnet.utils.model import ( ModelMetadata, get_model_metadata, @@ -23,6 +26,12 @@ get_safetensor_details, ) +logger = get_api_logger() + +from dnet.perf.trace import TraceConfig, Tracer +from dnet.perf.utils import TraceAggregator +#from dnet.perf.bench import + # Handle restricted repos from importlib import import_module import huggingface_hub as hb @@ -57,22 +66,39 @@ class REPL(cmd.Cmd): def __init__(self, model="NULL", nodes=1): assert nodes >= 1 and nodes < 10, "Invalid number of local nodes. Must be 0 < num < 10." - super().__init__() + + # State self.state = REPLState() self.state.model = model self.state.running_port += 2 self.state.num_local_nodes = nodes + # API Thread self._node: Optional[RingApiNode] = None self._api_thread: Optional[threading.Thread] = None self._api_ready = threading.Event() self._api_running = threading.Event() + self._api_searching = threading.Event() # Track mDNS searching self._api_loop: Optional[asyncio.AbstractEventLoop] = None self._api_shutdown_e: Optional[asyncio.Event] = None self._api_exc: Optional[BaseException] = None - self._api_searching = threading.Event() # Track mDNS searching + # Tracing + self._trace_cfg = TraceConfig( + enabled=False, + streaming=False, + budget=3000, + aggregate=True, + agg_max_events=50, + ) + + self._tracing = threading.Event() + self._tracer = None + self._trace_file = f"trace-{time.strftime("%Y%m%d-%H%M%S")}" + self._trace_cursor = 0 # keep track of registered events in buffer + self._trace_agg = TraceAggregator() + def loop(self): # Main tty loop sys.stdout.write(self.WELCOME) @@ -82,7 +108,7 @@ def loop(self): # Main tty loop if cmd == "": self.print_state() - elif cmd in [".exit", "exit", "quit", "q"]: + elif cmd in [".exit", "exit", "quit"]: self.handle_terminate_signal() elif cmd in [".help", "help", "h"]: self.print_help() @@ -96,10 +122,13 @@ def loop(self): # Main tty loop elif cmd.startswith("nodes"): self.print_mdns_nodes() continue + elif cmd.startswith(("trace", ".trace")): + self.do_trace(cmd.split(" ")) + continue elif cmd.startswith(("topo", ".topo")): self.do_topo(cmd.split(" ")) continue - elif cmd.startswith((".model", "model", "m")): + elif cmd.startswith((".model", "model", "m ")): cmd.split(" ") path = self._handle_model_pull(cmd[1]) if path: @@ -121,6 +150,7 @@ def do_api(self, cmd: List[str]) -> None: grpc_port or self.state.api_grpc_port ) self.api_call("set_trace_ingest_callback", self.__trace_cb, timeout=2.0) + elif cmd[1] == "stop": self.stop_api() elif cmd[1] == "status": @@ -182,12 +212,12 @@ def _print_hf(cmd, desc, examples=[""]): ["Examples > model meta-llama/Meta-Llama-3-8B"]) _print_hf("nodes list ", "List mDNS discovered nodes.") _print_hf("log [LEVEL]", "Set the logging level.") - dprint("\033[1m\n API Server Control:\n\033[0m") + dprint("\033[1m\n Controlling the API Server:\n\033[0m") _print_hf("api start [http_port=8080] [grpc_port=50500]", "Start the API server in a separate thread. Use provided ports if given.") _print_hf("api stop ", "Signal clean shutdown of the API server.") _print_hf("api status ", "Prints the status of the API server.") _print_hf("api log ", "Print latest logs to the current terminal.") - dprint("\033[1m\n Building a topology:\n\033[0m") + dprint("\033[1m\n Topology construction:\n\033[0m") _print_hf("search ", "Returns the current state of mDNS search.") _print_hf("search [on/off] ", "Toggle mDNS search across the local network.") _print_hf("nodes list ", "List all nodes in the current topology (including local ones).") @@ -195,15 +225,19 @@ def _print_hf(cmd, desc, examples=[""]): _print_hf("nodes ", "List mDNS discovered nodes.") _print_hf("topo [AUTO/SETUP]", "Toggle between automatic and manual topology creation.") _print_hf("topo add [NODE]", "Add [NODE] to the topology.") - _print_hf("topo remove [NODE]", "Add [NODE] to the topology.") - sys.stdout.write("\033[1m\n Building a schedule:\n\033[0m") - _print_hf("sched create", "Automatic search for best schedule given the active topology and the loaded model.") - _print_hf("sched assign [LAYER] [NODE]", "Assign the layer with index [LAYER] to [NODE].", - ["Example > sched assign 10 benny_234"]) - _print_hf("schedule assign [START-END] [NODE]", "Assign the layer range between [START] and [END] to [NODE].", - ["Example > sched assign 0-12 benny_234"]) - sys.stdout.write("\033[1m\n Benchmarking and profiling:\n\033[0m") - _print_hf("profile [REPO]", "Estimate the total FLOPS of the model from [REPO]") + _print_hf("topo remove [NODE]", "Remove [NODE] from the topology.") + sys.stdout.write("\033[1m\n Scheduling:\n\033[0m") + _print_hf("sched auto ", "Automatic search for best schedule given the active topology and the loaded model.") + _print_hf("sched assign [INDEX] [NODE]", "Assign the layer range between [START] and [END] to [NODE].", + ["Example: > sched assign 10 benny_234", + " > sched assign 0-12 benny_234"]) + sys.stdout.write("\033[1m\n Benchmarking, Tracing and Profiling:\n\033[0m") + _print_hf("trace [ON|OFF][PATH][SYSTEM] ", "Trace [SYSTEM] and output to file at [PATH].") + _print_hf("trace status ", "See status of the trace, eg. number of frames captured") + _print_hf("trace focus [SUBSYSTEM] ", "Focus the trace on [SUBSYSTEM]. Do 'trace focus' for a list of available subsystems.") + _print_hf("trace stream [ON|OFF] ", "Stream the trace spans to current terminal.") + _print_hf("trace set [BUDGET] ", "Set the maximum amount of recoded events.") + _print_hf("profile [REPO] ", "Estimate the total FLOPS of the model from [REPO]") _print_hf("bench [REPO]", "Benchmark the system using the model from [REPO]") _print_hf("bench [KERNEL]", "Behcnmark the system using base kernel [KERNEL]") _print_hf("bench [NODE]", "Behcnmark the network latency between the current system and [NODE]") @@ -332,11 +366,6 @@ def handle_start_worker(self): # ===== Handle API server - # Tracer frames ingest callback - def __trace_cb(self, data): - dprint(str(data)) - pass - async def _api_main(self) -> None: # main thread loop self._api_loop = asyncio.get_running_loop() self._api_shutdown_e = asyncio.Event() @@ -380,7 +409,33 @@ def start_api(self, http_port: int=8080, grpc_port: int=50500, timeout=10): if not self._api_ready.wait(timeout): raise RuntimeError("API Server Timeout.") if self._api_exc is not None: - raise RuntimeError(f"API Server failed to start") + raise RuntimeError(f"API Server failed to start: {self._api_exc}") + # Register REPL aggregator callback on the API node + try: + self.api_call("set_trace_ingest_callback", self._trace_agg.enqueue, timeout=5) + except Exception: + pass + + # Silence API server logs on the REPL console: drop records emitted from the API thread + try: + class _DropApiOnConsole(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + # Only drop records coming from the API thread so other threads keep logging + tname = getattr(record, "threadName", "") or "" + return tname != "api_server" + + root = logging.getLogger() + for h in list(root.handlers): + if isinstance(h, logging.StreamHandler) and getattr(h, "stream", None) in (sys.stdout, sys.stderr): + if not any(isinstance(f, _DropApiOnConsole) for f in getattr(h, "filters", [])): + h.addFilter(_DropApiOnConsole()) + + # Also quiet Hypercorn logs explicitly (HTTP server used by API) + logging.getLogger("hypercorn").setLevel(logging.CRITICAL) + logging.getLogger("hypercorn.error").setLevel(logging.CRITICAL) + logging.getLogger("hypercorn.access").setLevel(logging.CRITICAL) + except Exception: + pass def stop_api(self, timeout: float = 5.0) -> None: if not self._api_thread: return @@ -441,6 +496,72 @@ async def _await_then_set(): self._api_loop.call_soon_threadsafe(runner) return f.result(timeout) + # ------- Trace aggregation helpers + + def do_trace(self, cmd): + if len(cmd) < 2: + dprint(f"Tracing is currently {"ON" if self._trace_cfg.enabled else "OFF"}\n") + elif cmd[1] in ("on", "ON"): + self._trace_cfg.enabled = True + if self._api_running: + self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards + dprint("Tracing is now ON\n") + elif cmd[1] in ("off", "OFF"): + self._trace_cfg.enabled = False + if self._api_running: + self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards + dprint("Tracing is not OFF\n") + elif cmd[1] == "focus": + #self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards + dprint("Subsystems not yet implemented.\n") + elif cmd[1] == "stream": + if len(cmd) == 2: + dprint(f"Trace is {"streaming to file: "+str(self._trace_cfg.file) if self._trace_cfg.streaming else "not streaming."}\n") + elif cmd[2] == "on": + self._trace_cfg.streaming = True + dprint(f"Streaming trace frames to {self._trace_cfg.file}\n") + elif cmd[2] == "off": + self._trace_cfg.streaming = False + dprint("Trace streaming is OFF.\n") + elif cmd[1] == "set": + if len(cmd) == 2: + dprint("Use: trace set [BUDGET], eg. 2000\n") + else: + dprint("Not implemented yet\n") + # FIXME: Implement + elif cmd[1] == "status": + dprint(f"Frames: {len(self._trace_agg._req)}\n") + + elif cmd[1] == "annotate": + self.print_trace_annotate("NONE") + + # Trace callback registered with API Thread + def __trace_cb(self, data): + self._trace_agg.enqueue(data) + + def __print_tr(self, symbol, ms, counts): + sym = " " + symbol.ljust(40, ' ') + pms = f"{ms:.10}".ljust(10, ' ') + cns = f"{counts}".ljust(4, ' ') + sys.stdout.write(f"{sym} {pms} {cns}\n") + + def print_trace_annotate( + self, + run_id: str = "run", + mapping: Optional[Dict[str, str]] = None, + repeats: int = 0, + ) -> List[Dict[str, Any]]: + names = " "*17 + "symbol" + " "*21 + "ms" + " "*4 + "counts" + dots = " " + "."*41 + " " + "."*10 + " " + "."*4 + dprint(f"{names}\n{dots}\n\n") + sums = self._trace_agg._req[run_id].sums_by_name + cnts = self._trace_agg._req[run_id].counts_by_name + for n, d in sums.items(): + self.__print_tr(n, d, cnts[n]) + + def get_trace_roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: + return self._trace_agg.roots(run_id, req_id) + def _print_nodes_table(self, rows: List[Any]) -> None: headers = ["name", "role", "addr", "http", "grpc", "status", "head"] limits = {"name": 36, "addr": 15} From 132f1cddfc7d28f808b10edd5225bf4b19759346 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 03:24:48 -0700 Subject: [PATCH 014/172] mlx bug in lm_head, transposes weights for no reason, manually compute matmul --- src/dnet/ring/model/llama3.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/src/dnet/ring/model/llama3.py b/src/dnet/ring/model/llama3.py index 231f02b5..81626d51 100644 --- a/src/dnet/ring/model/llama3.py +++ b/src/dnet/ring/model/llama3.py @@ -33,7 +33,9 @@ def __init__( self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) self.norm = nn.RMSNorm(self.config.hidden_size, self.config.rms_norm_eps) - self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) + + if not self.config.tie_word_embeddings: + self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) self.layers: List[nn.Module] = [] self.abs_to_local: Dict[int, int] = {} @@ -41,8 +43,9 @@ def __init__( for i, l in enumerate(sorted(assigned_layers or [])): self.layers.append(TransformerBlock(self.config)) self.abs_to_local[l] = i + logger.debug(f"Created {len(self.layers)} Transformer layers") - logger.debug(f"abs_to_local mapping: {self.abs_to_local}") + #logger.debug(f"abs_to_local mapping: {self.abs_to_local}") @property def decoding_layers(self): @@ -72,8 +75,14 @@ def embed(self, x: mx.array): def normalize(self, x: mx.array): return self.norm(x) + # FIXME: Weird MLX bug, lm_head weights are transposed internally for no reason def lm_project(self, x: mx.array): - return self.lm_head(x) + if self.config.tie_word_embeddings: + return self.embed_tokens.as_linear(x) + try: + return self.lm_head(x) + except Exception as e: + return mx.matmul(x, self.lm_head.weight) def quantize_layers(self): self.quantization = None @@ -127,7 +136,19 @@ def apply_single_layer( ): if layer_idx not in self.abs_to_local: raise RuntimeError(f"Attempted execution of foreign layer {layer_idx}") - mask = create_attention_mask(x, cache) + + mask = None + sqlen = int(x.shape[1]) + if sqlen > 1: + cached = getattr(self, "_cached_mask_len", None) + cached_mask = getattr(self, "_cached_mask", None) + if cached is None or cached != sqlen or not cached_mask: + mask = create_attention_mask(x, cache) + self._cached_mask = mask + self._cached_mask_len = sqlen + else: + mask = cached_mask + local_idx = self.abs_to_local[layer_idx] logger.debug(f"apply_single_layer: layer:{layer_idx}, local_idx:{local_idx}, input_shape:{x.shape}") @@ -165,7 +186,9 @@ def load_weights(self, weights, strict=False): logger.debug(f"Mapping weight {k} -> {new_key}") shard_weights[new_key] = v - elif (k.startswith("embed_tokens") or k.startswith("lm_head") or k.startswith("norm")): + elif k.startswith("lm_head"): + shard_weights[k] = v + elif (k.startswith("embed_tokens") or k.startswith("norm")): shard_weights[k] = v if shard_weights: From 6f9612bb5695dc5a8139e216846c8077a10b81e7 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 03:32:59 -0700 Subject: [PATCH 015/172] wrap in trace frames --- src/dnet/ring/shard/comms.py | 395 ++++++++++++++++++----------------- 1 file changed, 199 insertions(+), 196 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index 9ca00406..8390457d 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -227,76 +227,79 @@ async def _send_activation(self, activation_msg: ActivationMessage): return try: if activation_msg.is_final: - try: - if self._mode == "offload" and self.window_size > 0: - first_window = self._assigned_sorted[: self.window_size] - if first_window: - loop = asyncio.get_running_loop() - fut = loop.run_in_executor( - self.executor, - self._prepare_window_blocking, - list(first_window), + with self.tracer.frame("grpc", "send_activation.final") as f: + try: + if self._mode == "offload" and self.window_size > 0: + first_window = self._assigned_sorted[: self.window_size] + if first_window: + loop = asyncio.get_running_loop() + fut = loop.run_in_executor( + self.executor, + self._prepare_window_blocking, + list(first_window), + ) + self._prepared_by_nonce[activation_msg.nonce] = ( + list(first_window), + fut, + ) + except Exception: + pass + cb = activation_msg.callback_url or "" + parsed = urlparse(cb) if cb else None + t_rpc = time.perf_counter() + if parsed and parsed.scheme == "grpc": + addr = parsed.netloc + if not addr: + logger.error("Invalid gRPC callback URL for token: %s", cb) + return + # Ensure API channel/stub + if (self.api_channel is None) or (addr != self.api_address): + # close existing channel if any + try: + if self.api_channel is not None: + await self.api_channel.close() + except Exception: + pass + + self.api_address = addr + self.api_channel = aio_grpc.insecure_channel( + addr, options=GRPC_AIO_OPTIONS ) - self._prepared_by_nonce[activation_msg.nonce] = ( - list(first_window), - fut, + self.api_stub = shard_api_comm_pb2_grpc.ShardApiServiceStub( + self.api_channel ) - except Exception: - pass - cb = activation_msg.callback_url or "" - parsed = urlparse(cb) if cb else None - t_rpc = time.perf_counter() - if parsed and parsed.scheme == "grpc": - addr = parsed.netloc - if not addr: - logger.error("Invalid gRPC callback URL for token: %s", cb) - return - # Ensure API channel/stub - if (self.api_channel is None) or (addr != self.api_address): - # close existing channel if any + f.event("reset_api") + try: - if self.api_channel is not None: - await self.api_channel.close() - except Exception: - pass - - self.api_address = addr - self.api_channel = aio_grpc.insecure_channel( - addr, options=GRPC_AIO_OPTIONS - ) - self.api_stub = shard_api_comm_pb2_grpc.ShardApiServiceStub( - self.api_channel - ) - try: - req = shard_api_comm_pb2.TokenRequest( - nonce=activation_msg.nonce, - token_id=int(getattr(activation_msg, "token_id", -1)), - timestamp=utc_epoch_now(), - ) - resp = await self.api_stub.SendToken(req) # type: ignore - rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 - if not resp.success: - logger.error( - "API SendToken failed for %s: %s", - activation_msg.nonce, - resp.message, - ) - if self._profile: - logger.info( - "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", - self.node_id, - activation_msg.nonce, - int(getattr(activation_msg, "token_id", -1)), - rpc_ms, + req = shard_api_comm_pb2.TokenRequest( + nonce=activation_msg.nonce, + token_id=int(getattr(activation_msg, "token_id", -1)), + timestamp=utc_epoch_now(), ) - except Exception as e: - logger.exception("Error sending token via gRPC: %s", e) - else: - logger.error( - "No valid gRPC callback for token; cannot deliver nonce=%s", - activation_msg.nonce, - ) - return + resp = await self.api_stub.SendToken(req) # type: ignore + rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 + if not resp.success: + logger.error( + "API SendToken failed for %s: %s", + activation_msg.nonce, + resp.message, + ) + if self._profile: + logger.info( + "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", + self.node_id, + activation_msg.nonce, + int(getattr(activation_msg, "token_id", -1)), + rpc_ms, + ) + except Exception as e: + logger.exception("Error sending token via gRPC: %s", e) + else: + logger.error( + "No valid gRPC callback for token; cannot deliver nonce=%s", + activation_msg.nonce, + ) + return used_pool = False @@ -309,9 +312,6 @@ async def _send_activation(self, activation_msg: ActivationMessage): "Failed to get output buffer %s", activation_msg.pool_id ) return - data_size = int(np.prod(activation_msg.shape)) - shaped = output_buffer[:data_size].reshape(activation_msg.shape) - used_pool = True if self._profile: logger.info( @@ -346,156 +346,159 @@ async def _send_activation(self, activation_msg: ActivationMessage): activation_msg.dtype = self._wire_dtype_str t_cast = time.perf_counter() - if isinstance(shaped, np.ndarray): - data = shaped.tobytes(order="C") - else: - data = tensor_to_bytes(shaped) + with self.tracer.frame("grpc", "send_activations.cast_to_dtype") as f: + + if isinstance(shaped, np.ndarray): # Cast to target dtype + if shaped.dtype != wire_np_dtype: + shaped = shaped.astype(wire_np_dtype, copy=False) + f.event("ndarray.cast") + data = shaped.tobytes(order="C") + else: + if str(shaped.dtype) != self._wire_dtype_str: # MLX array + shaped = shaped.astype(self._wire_mx_dtype) + f.event("mxarray.cast") + data = tensor_to_bytes(shaped) - ser_ms = (time.perf_counter() - t_ser) * 1000.0 - cast_ms = (t_cast - t_ser) * 1000.0 + activation_msg.dtype = self._wire_dtype_str nxt = activation_msg.layer_id + 1 - if (nxt < self.model_metadata.num_layers) and ( - nxt not in self._assigned_set - ): + if (nxt < self.model_metadata.num_layers) and (nxt not in self._assigned_set): if self.next_node_stub: - request = activation_msg.to_proto(data) - request.timestamp = utc_epoch_now() - if self._mode == "offload" and self.window_size > 0: - next_window = self._next_local_layers( - activation_msg.layer_id, self.window_size - ) - loop = asyncio.get_running_loop() - if next_window: - fut = loop.run_in_executor( - self.executor, - self._prepare_window_blocking, - list(next_window), - ) - self._prepared_by_nonce[activation_msg.nonce] = ( - list(next_window), - fut, + with self.tracer.frame("grpc", "send_activation.next") as f: + request = activation_msg.to_proto(data) + request.timestamp = utc_epoch_now() + if self._mode == "offload" and self.window_size > 0: + next_window = self._next_local_layers( + activation_msg.layer_id, self.window_size ) - else: - first_window = self._assigned_sorted[: self.window_size] - if first_window: + loop = asyncio.get_running_loop() + if next_window: fut = loop.run_in_executor( self.executor, self._prepare_window_blocking, - list(first_window), + list(next_window), ) self._prepared_by_nonce[activation_msg.nonce] = ( - list(first_window), + list(next_window), fut, ) - stream_used = False - ctx = await self._ensure_stream(activation_msg.nonce) - if ( - ctx - and ctx.open - and not ctx.disabled - and hasattr(dnet_ring_pb2, "ActivationFrame") - ): - try: - ctx.last_seq += 1 - frame = dnet_ring_pb2.ActivationFrame( - request=request, - seq=ctx.last_seq, - end_of_request=False, - ) - await ctx.queue.put(frame) - ctx.last_activity_t = asyncio.get_running_loop().time() - stream_used = True - if self._profile: - logger.info( - "[PROFILE][STREAM-ENQ] nonce=%s seq=%s q=%s", - activation_msg.nonce, - ctx.last_seq, - ctx.queue.qsize(), + else: + first_window = self._assigned_sorted[: self.window_size] + if first_window: + fut = loop.run_in_executor( + self.executor, + self._prepare_window_blocking, + list(first_window), + ) + self._prepared_by_nonce[activation_msg.nonce] = ( + list(first_window), + fut, + ) + stream_used = False + ctx = await self._ensure_stream(activation_msg.nonce) + if ( + ctx + and ctx.open + and not ctx.disabled + and hasattr(dnet_ring_pb2, "ActivationFrame") + ): + try: + ctx.last_seq += 1 + frame = dnet_ring_pb2.ActivationFrame( + request=request, + seq=ctx.last_seq, + end_of_request=False, ) - except Exception as e: - logger.warning( - "[STREAM] enqueue failed; fallback to unary: %s", e + await ctx.queue.put(frame) + ctx.last_activity_t = asyncio.get_running_loop().time() + stream_used = True + if self._profile: + logger.info( + "[PROFILE][STREAM-ENQ] nonce=%s seq=%s q=%s", + activation_msg.nonce, + ctx.last_seq, + ctx.queue.qsize(), + ) + except Exception as e: + logger.warning( + "[STREAM] enqueue failed; fallback to unary: %s", e + ) + ctx.disabled = True + + if not stream_used: + # In fit mode, avoid long unary stalls: use short deadline and min retries + # Streaming should be the norm; unary is a quick safety valve only. + ring_timeout = 3.0 if self._mode == "fit" else 30.0 + ring_retries = ( + 1 + if self._mode == "fit" + else max(1, int(self._send_retries)) ) - ctx.disabled = True - - if not stream_used: - # In fit mode, avoid long unary stalls: use short deadline and min retries - # Streaming should be the norm; unary is a quick safety valve only. - ring_timeout = 3.0 if self._mode == "fit" else 30.0 - ring_retries = ( - 1 - if self._mode == "fit" - else max(1, int(self._send_retries)) - ) - # Emit a clear fallback log with reason/context - if self._profile: - if ctx is None: - reason = "no_stream_ctx" - elif not ctx.open: - reason = "stream_closed" - elif ctx.disabled: - reason = "stream_disabled" + # Emit a clear fallback log with reason/context + if self._profile: + if ctx is None: + reason = "no_stream_ctx" + elif not ctx.open: + reason = "stream_closed" + elif ctx.disabled: + reason = "stream_disabled" + else: + reason = "enqueue_failed" + logger.warning( + "[STREAM->UNARY] node=%s nonce=%s reason=%s mode=%s timeout_s=%.1f retries=%d", + self.node_id, + activation_msg.nonce, + reason, + self._mode, + ring_timeout, + ring_retries, + ) + t0 = time.perf_counter() + last_exc: Optional[Exception] = None + for attempt in range(1, ring_retries + 1): + try: + # FIXME: use response here? + _ = await self.next_node_stub.SendActivation( + request, timeout=ring_timeout + ) # type: ignore + break + except grpc.aio.AioRpcError as e: # type: ignore + last_exc = e + code = e.code() + if code in { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.CANCELLED, + grpc.StatusCode.DEADLINE_EXCEEDED, + }: + logger.warning( + "SendActivation attempt %s/%s failed (%s); reconnecting...", + attempt, + ring_retries, + code.name, + ) + await self._reconnect_next_node() + await asyncio.sleep(min(0.25 * attempt, 1.0)) + continue + raise else: - reason = "enqueue_failed" - logger.warning( - "[STREAM->UNARY] node=%s nonce=%s reason=%s mode=%s timeout_s=%.1f retries=%d", + raise last_exc # type: ignore + rpc_ms = (time.perf_counter() - t0) * 1000.0 + logger.info( + "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", self.node_id, activation_msg.nonce, - reason, - self._mode, - ring_timeout, - ring_retries, + activation_msg.layer_id + 1, + (len(data) / 1024), + ser_ms, + rpc_ms, + cast_ms, ) - t0 = time.perf_counter() - last_exc: Optional[Exception] = None - for attempt in range(1, ring_retries + 1): - try: - # FIXME: use response here? - _ = await self.next_node_stub.SendActivation( - request, timeout=ring_timeout - ) # type: ignore - break - except grpc.aio.AioRpcError as e: # type: ignore - last_exc = e - code = e.code() - if code in { - grpc.StatusCode.UNAVAILABLE, - grpc.StatusCode.CANCELLED, - grpc.StatusCode.DEADLINE_EXCEEDED, - }: - logger.warning( - "SendActivation attempt %s/%s failed (%s); reconnecting...", - attempt, - ring_retries, - code.name, - ) - await self._reconnect_next_node() - await asyncio.sleep(min(0.25 * attempt, 1.0)) - continue - raise - else: - raise last_exc # type: ignore - rpc_ms = (time.perf_counter() - t0) * 1000.0 - logger.info( - "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", - self.node_id, - activation_msg.nonce, - activation_msg.layer_id + 1, - (len(data) / 1024), - ser_ms, - rpc_ms, - cast_ms, - ) else: - logger.error( - "Cannot forward activation - no next node configured; end shard should sample inline." - ) + logger.error("Cannot forward activation - no next node configured; end shard should sample inline.") else: logger.error( "Final activation reached send path unexpectedly; sampling should occur on end shard." ) - # Clear scheduling at request end # Sequential offload: prefetch state is unused From 1b54b8415d22d425469a863f284fa7ead6297d7b Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 11:51:07 -0700 Subject: [PATCH 016/172] trace ingress worker --- src/dnet/ring/shard/node.py | 328 +++++++++++++++++------------------- 1 file changed, 156 insertions(+), 172 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 2f9d6a7a..e21b7ddb 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -618,6 +618,7 @@ async def reset_cache(self) -> None: except Exception as e: logger.error("Node %s: Error resetting cache: %s", self.node_id, e) + async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): """Receive activation from previous node and queue for local compute or forward.""" if self.input_pool is None: @@ -630,142 +631,45 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): t_recv = time.perf_counter() await self._connect_next_node() - try: - activation = request.activation - target_layer = activation.layer_id + 1 + with self.tracer.frame("grpc.receive", "connect_next_node"): + await self._connect_next_node() + with self.tracer.frame("grpc.receive", "process_activation") as f: try: - payload_bytes = len(activation.data) - except Exception: - payload_bytes = -1 - transport_ms = float(utc_epoch_now() - request.timestamp) - logger.info( - "[PROFILE][RX] node=%s nonce=%s target_layer=%s transport_ms=%.1f payload_kb=%.1f", - self.node_id, - request.nonce, - target_layer, - transport_ms, - (payload_bytes / 1024.0), - ) + activation = request.activation + target_layer = activation.layer_id + 1 # Detect new sequence per node: initialize per-nonce KV if request.nonce != self._active_nonce: self._active_nonce = request.nonce try: - self._get_or_make_kv(request.nonce) + payload_bytes = len(activation.data) except Exception: - pass + payload_bytes = -1 + f.event("process_payload") if target_layer in self._assigned_set: # Allocate input pool and copy payload (with optional decompression) t_alloc = time.perf_counter() if "|" in activation.dtype: - try: - deq = decompress_tensor_from_protobuf_data( - tensor_data=activation.data, - shape=list(activation.shape), - dtype_with_metadata=activation.dtype, - ) - except Exception as e: - logger.error( - "Decompression failed for nonce %s: %s", request.nonce, e - ) - return - - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=deq.dtype, - shape=cast(tuple[int, ...], tuple(deq.shape)), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - flat = deq.reshape(-1) - buffer[: flat.size] = flat - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", - self.node_id, - request.nonce, - alloc_copy_ms, - ) - # Update activation message with true dtype/shape - new_dtype_str = str(deq.dtype) - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = new_dtype_str - activation_msg.shape = tuple(deq.shape) - else: - # Special token stream support: dtype='tokens' carries int32 token IDs - if activation.dtype == "tokens": + with self.tracer.frame("grpc.receive", "decompress") as fr: try: - tokens = np.frombuffer( - request.activation.data, dtype=np.int32 + deq = decompress_tensor_from_protobuf_data( + tensor_data=activation.data, + shape=list(activation.shape), + dtype_with_metadata=activation.dtype, ) - shp = (int(len(tokens)),) except Exception as e: logger.error( - "Failed to parse tokens for nonce %s: %s", - request.nonce, - e, - ) - return - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mx.int32, - shape=cast(tuple[int, ...], shp), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - buffer[: len(tokens)] = tokens - if self._profile: - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (tokens)", - self.node_id, - request.nonce, - alloc_copy_ms, - ) - activation_msg = ActivationMessage.from_proto(request, pool_id) - # Ensure dtype reflects token payload for compute path - activation_msg.dtype = "tokens" - activation_msg.shape = shp - else: - # Safety: byte length must match shape*dtype - try: - expected = ( - int(np.prod(activation.shape)) - * np.dtype(dtype_map[activation.dtype]).itemsize - ) - actual = len(request.activation.data) - except Exception: - expected = -1 - actual = -1 - if expected != actual: - logger.error( - "Payload size mismatch for nonce=%s: expected=%d actual=%d dtype=%s shape=%s", - request.nonce, - expected, - actual, - activation.dtype, - activation.shape, + "Decompression failed for nonce %s: %s", request.nonce, e ) return + with self.tracer.frame("grpc.receive", "alloc.buffer") as fr: pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, - dtype=mlx_dtype_map[activation.dtype], - shape=cast(tuple[int, ...], activation.shape), + dtype=deq.dtype, + shape=cast(tuple[int, ...], tuple(deq.shape)), ) if pool_id is None: logger.warning( @@ -775,22 +679,103 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): return buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: - data = request.activation.data - input_data = np.frombuffer( - data, dtype=dtype_map[activation.dtype] - ) - buffer[: len(input_data)] = input_data + flat = deq.reshape(-1) + buffer[: flat.size] = flat alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f", + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", self.node_id, request.nonce, alloc_copy_ms, ) - activation_msg = ActivationMessage.from_proto(request, pool_id) - if self._profile: - activation_msg.recv_perf_t = t_recv + # Update activation message with true dtype/shape + new_dtype_str = str(deq.dtype) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + else: + # Special token stream support: dtype='tokens' carries int32 token IDs + if activation.dtype == "tokens": + with self.tracer.frame("grpc.receive", "token_stream") as fr: + try: + tokens = np.frombuffer( + request.activation.data, dtype=np.int32 + ) + shp = (int(len(tokens)),) + except Exception as e: + logger.error( + "Failed to parse tokens for nonce %s: %s", + request.nonce, + e, + ) + return + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mx.int32, + shape=cast(tuple[int, ...], shp), + ) + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + buffer[: len(tokens)] = tokens + if self._profile: + alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 + logger.info( + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (tokens)", + self.node_id, + request.nonce, + alloc_copy_ms, + ) + activation_msg = ActivationMessage.from_proto(request, pool_id) + # Ensure dtype reflects token payload for compute path + activation_msg.dtype = "tokens" + activation_msg.shape = shp + else: + with self.tracer.frame("grpc.receive", "default") as fr: + # Safety: byte length must match shape*dtype + try: + expected = ( + int(np.prod(activation.shape)) + * np.dtype(dtype_map[activation.dtype]).itemsize + ) + actual = len(request.activation.data) + except Exception: + pass + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mlx_dtype_map[activation.dtype], + shape=cast(tuple[int, ...], activation.shape), + ) + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + data = request.activation.data + input_data = np.frombuffer( + data, dtype=dtype_map[activation.dtype] + ) + buffer[: len(input_data)] = input_data + alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 + logger.info( + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f", + self.node_id, + request.nonce, + alloc_copy_ms, + ) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) # Queue for processing — non-blocking back-off loop (cancellable) if self._profile: @@ -820,8 +805,6 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): ) await self._forward_activation(request) - except Exception as e: - logger.exception("Error receiving activation: %s", e) async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: """ @@ -845,64 +828,67 @@ async def _ingress_worker(self): finally enqueues for compute or forwards to the next shard. """ while self.running: - try: - req = await self.ingress_q.get() - except asyncio.CancelledError: - break - try: - t_recv = time.perf_counter() - await self._connect_next_node() - - activation = req.activation - target_layer = activation.layer_id + 1 + with self.tracer.frame("grpc", "ingress") as f: + with self.tracer.frame("grpc.ingress", "get"): + try: + req = await self.ingress_q.get() + except asyncio.CancelledError: + break try: - payload_bytes = len(activation.data) - except Exception: - payload_bytes = -1 - transport_ms = float(utc_epoch_now() - req.timestamp) - logger.info( - "[PROFILE][RX] node=%s nonce=%s target_layer=%s transport_ms=%.1f payload_kb=%.1f", - self.node_id, - req.nonce, - target_layer, - transport_ms, - (payload_bytes / 1024.0), - ) + with self.tracer.frame("grpc.ingress", "connect_next_node"): + await self._connect_next_node() + + activation = req.activation + target_layer = activation.layer_id + 1 # Detect new sequence per node: initialize per-nonce KV if req.nonce != self._active_nonce: self._active_nonce = req.nonce try: - self._get_or_make_kv(req.nonce) + payload_bytes = len(activation.data) except Exception: - pass + logger.error(f"Unable to read length of data for {req.nonce}") + payload_bytes = -1 + + f.set("nonce", req.nonce) + f.set("target", target_layer) + f.set("payload_bytes", payload_bytes) + f.event("received") if target_layer in self._assigned_set: # Heavy prep in executor (alloc/copy/decompress) - loop = asyncio.get_running_loop() - try: - activation_msg = await loop.run_in_executor( - self.executor, - self._prepare_activation_message_blocking, - req, - ) - except Exception as e: - logger.error( - "Activation prepare failed for nonce %s: %s", req.nonce, e - ) - continue - if activation_msg is None: - continue - if self._profile: - activation_msg.recv_perf_t = t_recv + with self.tracer.frame("grpc.ingress", "prepare"): + loop = asyncio.get_running_loop() + try: + activation_msg = await loop.run_in_executor( + self.executor, + self._prepare_activation_message_blocking, + req, + ) + except Exception as e: + logger.error("Activation prepare failed for nonce %s: %s", req.nonce, e) + continue + if activation_msg is None: + continue + if self._profile: + activation_msg.recv_perf_t = t_recv # Enqueue for compute (cancellable back-off) - while self.running: - try: - self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", + with self.tracer.frame("grpc.ingress", "queue") as fr: + while self.running: + try: + self.activation_recv_queue.put_nowait(activation_msg) + logger.debug( + "Queued activation for processing: nonce %s", + activation_msg.nonce, + ) + break + except Full: + await asyncio.sleep(0) + else: + logger.error( + "Failed to queue activation %s (node stopping)", activation_msg.nonce, ) break @@ -928,8 +914,6 @@ async def _ingress_worker(self): ) await self._forward_activation(req) - except Exception as e: - logger.error("Ingress worker error: %s", e) def _get_or_make_kv(self, nonce: str) -> list: """Return a per-nonce KV cache list for this shard's local layers.""" From f96dd28bc17b91bd63053d8fd19b7e77dcb95de1 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 11:51:44 -0700 Subject: [PATCH 017/172] trace token request stall --- src/dnet/ring/shard/comms.py | 45 +++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index 8390457d..dc197637 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -270,28 +270,29 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) f.event("reset_api") - try: - req = shard_api_comm_pb2.TokenRequest( - nonce=activation_msg.nonce, - token_id=int(getattr(activation_msg, "token_id", -1)), - timestamp=utc_epoch_now(), - ) - resp = await self.api_stub.SendToken(req) # type: ignore - rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 - if not resp.success: - logger.error( - "API SendToken failed for %s: %s", - activation_msg.nonce, - resp.message, - ) - if self._profile: - logger.info( - "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", - self.node_id, - activation_msg.nonce, - int(getattr(activation_msg, "token_id", -1)), - rpc_ms, + with self.tracer.frame("grpc", "token_request") as fr: + try: + req = shard_api_comm_pb2.TokenRequest( + nonce=activation_msg.nonce, + token_id=int(getattr(activation_msg, "token_id", -1)), + timestamp=utc_epoch_now(), ) + resp = await self.api_stub.SendToken(req) # type: ignore + rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 + if not resp.success: + logger.error( + "API SendToken failed for %s: %s", + activation_msg.nonce, + resp.message, + ) + if self._profile: + logger.info( + "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", + self.node_id, + activation_msg.nonce, + int(getattr(activation_msg, "token_id", -1)), + rpc_ms, + ) except Exception as e: logger.exception("Error sending token via gRPC: %s", e) else: @@ -495,6 +496,8 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) else: logger.error("Cannot forward activation - no next node configured; end shard should sample inline.") + + # Final layer not annotated with 'is_final' else: logger.error( "Final activation reached send path unexpectedly; sampling should occur on end shard." From a3125c3fd2fd048493ad4479d388bf9692022e00 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 11:52:10 -0700 Subject: [PATCH 018/172] stop printing state on empty prompt --- src/repl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/repl.py b/src/repl.py index 592424a6..862d7136 100644 --- a/src/repl.py +++ b/src/repl.py @@ -107,7 +107,8 @@ def loop(self): # Main tty loop cmd = sys.stdin.readline().strip() if cmd == "": - self.print_state() + #self.print_state() + continue elif cmd in [".exit", "exit", "quit"]: self.handle_terminate_signal() elif cmd in [".help", "help", "h"]: From 80c23ef534386f4212db3a74512d686653805453 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 12:53:09 -0700 Subject: [PATCH 019/172] trace startup --- src/dnet/ring/shard/startup.py | 44 ++++++++++++++-------------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/src/dnet/ring/shard/startup.py b/src/dnet/ring/shard/startup.py index 6790be64..f3b38fbb 100644 --- a/src/dnet/ring/shard/startup.py +++ b/src/dnet/ring/shard/startup.py @@ -48,34 +48,26 @@ class StartupMixin: async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()): self.running = True - try: # Capture the main event loop for cross-thread scheduling - self._loop = asyncio.get_running_loop() - except Exception: - self._loop = None - await self._start_grpc_server() - await self._start_http_server(shutdown_trigger) - await asyncio.sleep(0.2) - - self.background_tasks = [ - asyncio.create_task(self._ingress_worker()), - asyncio.create_task(self._prefetch_worker()), - asyncio.create_task(self._send_worker()), - ] - # Start idle sweeper to close silent streams - try: - if getattr(self, "_streaming_enabled", False) and hasattr( - self, "_stream_sweeper" - ): - self.background_tasks.append( - asyncio.create_task(self._stream_sweeper()) - ) - except Exception: - pass - self.compute_thread = threading.Thread(target=self._compute_worker, daemon=True) - self.compute_thread.start() + with self.tracer.frame("startup", "workers"): + self.background_tasks = [ + asyncio.create_task(self._ingress_worker()), + asyncio.create_task(self._prefetch_worker()), + asyncio.create_task(self._send_worker()) ] + + try: # Start idle sweeper to close silent streams + if getattr(self, "_streaming_enabled", False) and hasattr(self, "_stream_sweeper"): + self.background_tasks.append( asyncio.create_task(self._stream_sweeper())) + except Exception: + pass + + with self.tracer.frame("startup", "compute"): + self.compute_thread = threading.Thread(target=self._compute_worker, daemon=True) + self.compute_thread.start() + + with self.tracer.frame("startup", "discovery"): + self._start_discovery() - self._start_discovery() logger.info( "Shard node %s started on gRPC port %s HTTP port %s", self.node_id, From f706fde2ea9402507f1d25f34009d5441ef83624 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 12:53:43 -0700 Subject: [PATCH 020/172] trace prepare_activation --- src/dnet/ring/shard/node.py | 257 ++++++++++++++++++------------------ 1 file changed, 131 insertions(+), 126 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index e21b7ddb..a9aeba76 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -199,8 +199,6 @@ def __init__( self._sync_per_layer = obs.sync_per_layer self._sync_every_n = obs.sync_every_n self._prof = make_profiler(self._profile) - if self._profile: - logger.info("[PROFILE] enabled on shard node %s", self.node_id) # Debug tracing cfg = TraceConfig( @@ -264,6 +262,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse # Load model metadata with self.tracer.frame("memory", "model.load_metadata"): self.model_metadata = get_model_metadata(req.model_path) + self.assigned_layers = req.layers self._assigned_sorted = sorted(self.assigned_layers) self._assigned_set = set(self._assigned_sorted) @@ -441,29 +440,30 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse self.total_layers = req.total_layers self.api_callback_address = req.api_callback_address - if self.next_node: - with self.tracer.frame("network", "connect.next_node"): + with self.tracer.frame("network", "connect.next_node"): + if self.next_node: await self._connect_next_node() else: logger.warning("Node %s: No next node configured", self.node_id) # Warmup: compile hot path and stabilize allocators before first request - if req.warmup and self._mode == "fit": - loop = asyncio.get_running_loop() - try: - await loop.run_in_executor(self.executor, self._warmup_shard) - except Exception: - # Fall back to direct call if executor is unavailable - self._warmup_shard() - elif req.warmup and self._mode != "fit": - # Offload/sliding-fit: perform a small, offload-safe warmup for the first window - loop = asyncio.get_running_loop() - try: - await loop.run_in_executor( - self.executor, self._warmup_shard_offload - ) - except Exception: - self._warmup_shard_offload() + with self.tracer.frame("memory", "warmup"): + if req.warmup and self._mode == "fit": + loop = asyncio.get_running_loop() + try: + await loop.run_in_executor(self.executor, self._warmup_shard) + except Exception: + # Fall back to direct call if executor is unavailable + self._warmup_shard() + elif req.warmup and self._mode != "fit": + # Offload/sliding-fit: perform a small, offload-safe warmup for the first window + loop = asyncio.get_running_loop() + try: + await loop.run_in_executor( + self.executor, self._warmup_shard_offload + ) + except Exception: + self._warmup_shard_offload() if self._mode == "offload" and not ( self._warmup_completed and self._warmup_keep_flag @@ -484,10 +484,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse if m % self.window_size != 0: logger.warning( "Window size %s does not divide local layer count %s. Rounds per token will vary; consider setting k*w = %s.", - self.window_size, - m, - m, - ) + self.window_size, m, m) else: k = m // self.window_size logger.info( @@ -952,13 +949,14 @@ def _clear_kv(self, nonce: str) -> None: except Exception: pass + def _prepare_activation_message_blocking( self, request: dnet_ring_pb2.ActivationRequest ) -> Optional[ActivationMessage]: + """Blocking heavy prep: allocate pool buffer, copy/decompress payload, build ActivationMessage. + Returns None on failure. """ - Returns None on failure. - """ if self.input_pool is None: logger.error( "Node %s: Cannot prepare activation - input pool not initialized", @@ -968,113 +966,117 @@ def _prepare_activation_message_blocking( try: activation = request.activation - if "|" in activation.dtype: - # Compressed path: decompress to MLX array and copy to pool - try: - deq = decompress_tensor_from_protobuf_data( - tensor_data=activation.data, - shape=list(activation.shape), - dtype_with_metadata=activation.dtype, - ) - except Exception as e: - logger.error( - "Decompression failed for nonce %s: %s", request.nonce, e - ) - return None + if "|" in activation.dtype: # Compressed path: decompress to MLX array and copy to pool - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=deq.dtype, - shape=cast(tuple[int, ...], tuple(deq.shape)), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return None - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - flat = deq.reshape(-1) - buffer[: flat.size] = flat - # Update activation message with true dtype/shape - new_dtype_str = str(deq.dtype) - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = new_dtype_str - activation_msg.shape = tuple(deq.shape) - return activation_msg - elif activation.dtype == "tokens": - # Tokens path: parse int32 token IDs and stage them - try: - tokens = np.frombuffer(activation.data, dtype=np.int32) - shp = (int(len(tokens)),) - except Exception as e: - logger.error( - "Failed to parse tokens for nonce %s: %s", request.nonce, e - ) - return None - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mx.int32, - shape=cast(tuple[int, ...], shp), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return None - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - buffer[: len(tokens)] = tokens - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = "tokens" - activation_msg.shape = shp - return activation_msg - else: - # Dense path: validate size and copy raw bytes view into pool buffer - try: - expected = ( - int(np.prod(activation.shape)) - * np.dtype(dtype_map[activation.dtype]).itemsize + with self.tracer.frame("grpc.ingress.prepare_activation", "decompress") as f: + try: + deq = decompress_tensor_from_protobuf_data( + tensor_data=activation.data, + shape=list(activation.shape), + dtype_with_metadata=activation.dtype, + ) + except Exception as e: + logger.error( + "Decompression failed for nonce %s: %s", request.nonce, e + ) + return None + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=deq.dtype, + shape=cast(tuple[int, ...], tuple(deq.shape)), ) - actual = len(activation.data) - except Exception: - expected = -1 - actual = -1 - if expected != actual: - logger.error( - "Payload size mismatch for nonce=%s: expected=%d actual=%d dtype=%s shape=%s", - request.nonce, - expected, - actual, - activation.dtype, - activation.shape, + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return None + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + flat = deq.reshape(-1) + buffer[: flat.size] = flat + # Update activation message with true dtype/shape + new_dtype_str = str(deq.dtype) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + return activation_msg + + elif activation.dtype == "tokens": # Tokens path: parse int32 token IDs and stage them + with self.tracer.frame("grpc.ingress.prepare_activation", "tokens") as f: + try: + tokens = np.frombuffer(activation.data, dtype=np.int32) + shp = (int(len(tokens)),) + except Exception as e: + logger.error( + "Failed to parse tokens for nonce %s: %s", request.nonce, e + ) + return None + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mx.int32, + shape=cast(tuple[int, ...], shp), ) - return None + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return None + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + buffer[: len(tokens)] = tokens + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = "tokens" + activation_msg.shape = shp + return activation_msg - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mlx_dtype_map[activation.dtype], - shape=cast(tuple[int, ...], activation.shape), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, + else: # Dense path: validate size and copy raw bytes view into pool buffer + with self.tracer.frame("grpc.ingress.prepare_activation", "default") as f: + try: + expected = ( + int(np.prod(activation.shape)) + * np.dtype(dtype_map[activation.dtype]).itemsize + ) + actual = len(activation.data) + except Exception: + expected = -1 + actual = -1 + if expected != actual: + logger.error( + "Payload size mismatch for nonce=%s: expected=%d actual=%d dtype=%s shape=%s", + request.nonce, + expected, + actual, + activation.dtype, + activation.shape, + ) + return None + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mlx_dtype_map[activation.dtype], + shape=cast(tuple[int, ...], activation.shape), ) - return None - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - data = request.activation.data - input_data = np.frombuffer(data, dtype=dtype_map[activation.dtype]) - buffer[: len(input_data)] = input_data - activation_msg = ActivationMessage.from_proto(request, pool_id) - return activation_msg + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return None + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + data = request.activation.data + input_data = np.frombuffer(data, dtype=dtype_map[activation.dtype]) + buffer[: len(input_data)] = input_data + activation_msg = ActivationMessage.from_proto(request, pool_id) + return activation_msg except Exception as e: logger.error("Activation prep error: %s", e) return None + def _next_local_layers(self, after_layer: int, count: int) -> List[int]: """Get next local layers after specified layer. @@ -1091,22 +1093,25 @@ def _next_local_layers(self, after_layer: int, count: int) -> List[int]: i = _bisect_left(s, after_layer + 1) return s[i : i + count] + def _compute_worker(self) -> None: """Compute thread worker.""" while self.running: try: # Get activation from queue (blocks until available) - activation_msg = self.activation_recv_queue.get(timeout=1.0) + with self.tracer.frame("compute", "dequeue"): + activation_msg = self.activation_recv_queue.get(timeout=1.0) # Process the activation with self.tracer.frame("compute", "forward"): - self._process_activation(activation_msg) + self._process_activation(activation_msg) except Empty: continue except Exception as e: logger.error("Compute worker error: %s", e) + async def shutdown(self) -> None: """Shutdown the node.""" self.running = False From b6643195e86807b3b3d05a9053f97f644ee1ffff Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 12:54:21 -0700 Subject: [PATCH 021/172] don't set _prefetch_pause --- src/dnet/ring/shard/compute.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index d25705fe..41111154 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -92,6 +92,7 @@ def _process_activation(self, activation_msg: ActivationMessage): # Prepare input activation with self.tracer.frame("compute.thread", "activations.process") as f: if activation_msg.dtype == "tokens": # embed locally on start shard + logger.debug(f"Embedding tokens.") f.event("embed_tokens") numel = int(np.prod(activation_msg.shape)) tok_view = input_buffer[:numel].reshape(activation_msg.shape) @@ -395,6 +396,7 @@ def _process_activation(self, activation_msg: ActivationMessage): except Exception as e: logger.error("End-shard sampling failed: %s", e) return + output_msg = ActivationMessage( nonce=activation_msg.nonce, layer_id=last_layer, From ba7e3aa4c7077b401a544296d09db1b21e966673 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 14:36:57 -0700 Subject: [PATCH 022/172] compute mean, p50, p99, etc. per trace symbol and print --- src/dnet/perf/trace.py | 7 +- src/dnet/perf/utils/aggregator.py | 154 ++++++++++++++++++++---------- src/repl.py | 40 +++++--- 3 files changed, 135 insertions(+), 66 deletions(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index e4a92cfb..d2ed43b6 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -131,14 +131,10 @@ def _agg_exec(self) -> None: except Exception: logger.warining("Unable to close httpx client.") - # We don't have the API addr at init time def update_api_addr(self, addr): self.config.aggregate_url = addr logger.debug(f"Updated API Address: {self.config.aggregate_url}") - def update_confi(self, config): - pass - def start(self, *, reset: bool = True) -> None: self._active = bool(self.config.enabled) if not self._active: @@ -214,8 +210,9 @@ def _emit(self, ev: Dict[str, Any]) -> None: if len(self._events) < self._agg_max_events: return logger.debug(f"Queuing tracer frame batch of {len(self._events)}") batch = { "run_id": (self._req_id or "NONE"), - "node_id": (self.config.node_id or "UNKNOWN_NODE"), + "node_id": (self.config.node_id or "NODE"), "events": list(self._events)} + logger.debug(batch) try: self._agg_q.put_nowait(batch) except queue.Full: diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index 2736a25b..010e6922 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -11,11 +11,24 @@ Key = Tuple[str, Optional[int], Optional[int], str] # (node_id, pid, tid, req_id) @dataclass -class _OpenFrame: +class _ActiveSpan: + """Per-instance active span used for self-time accounting on a call stack.""" name: str t0: int child: int = 0 - children: List[Dict[str, Any]] = field(default_factory=list) + + +@dataclass +class _SymbolAgg: + """Aggregated statistics for a single trace symbol (name).""" + total_ms: float = 0.0 + count: int = 0 + durations: deque = field(default_factory=lambda: deque(maxlen=10000)) + + def add(self, self_ms: float) -> None: + self.total_ms += float(self_ms) + self.count += 1 + self.durations.append(float(self_ms)) # Sort known frames and compute averages by key @dataclass @@ -24,73 +37,74 @@ class RunAggregator: counts_by_name: Dict[str, int] = field(default_factory=dict) last_batch_seq: Dict[str, int] = field(default_factory=dict) - stacks: Dict[Key, List[_OpenFrame]] = field(default_factory=dict) + stacks: Dict[Key, List[_ActiveSpan]] = field(default_factory=dict) drops: int = 0 - roots_by_req: DefaultDict[str, List[Dict[str, Any]]] = field(default_factory=lambda: defaultdict(list)) + # Aggregated stats per symbol (primary source of truth) + symbols: Dict[str, _SymbolAgg] = field(default_factory=dict) + # Back-compat mirrors for existing readers (e.g., legacy REPL code) + durations_by_name: Dict[str, deque] = field(default_factory=dict) def _key(self, node_id: str, pid: Optional[int], tid: Optional[int], req_id: Optional[str]) -> Key: return (node_id, pid, tid, req_id or "") - def _push(self, key: Key, f: _OpenFrame) -> None: + def _push(self, key: Key, f: _ActiveSpan) -> None: self.stacks.setdefault(key, []).append(f) - def _pop(self, key: Key) -> Optional[_OpenFrame]: + def _pop(self, key: Key) -> Optional[_ActiveSpan]: st = self.stacks.get(key) if not st: return None return st.pop() - def _peek(self, key: Key) -> Optional[_OpenFrame]: + def _peek(self, key: Key) -> Optional[_ActiveSpan]: st = self.stacks.get(key) return st[-1] if st else None - def _acc_annotate(self, name: str, self_ms: float) -> None: - self.sums_by_name[name] = self.sums_by_name.get(name, 0.0) + self_ms - self.counts_by_name[name] = self.counts_by_name.get(name, 0) + 1 + def _accumulate(self, name: str, self_ms: float) -> None: + sym = self.symbols.get(name) + if sym is None: + sym = _SymbolAgg() + self.symbols[name] = sym + sym.add(self_ms) + + # FIXME: Remove + self.sums_by_name[name] = sym.total_ms + self.counts_by_name[name] = sym.count + dq = self.durations_by_name.get(name) + if dq is None: + dq = deque(maxlen=10000) + self.durations_by_name[name] = dq + dq.append(float(self_ms)) def ingest_event(self, node_id: str, ev: Dict[str, Any]) -> None: if not ev.get("name"): logger.error(f"Received trace frame without name {ev.get('ts')}") return - # Normalize ts to microseconds (accept float seconds or int microseconds) + # Normalize timestamp to microseconds ts_raw = ev.get("ts") - ts_us = 0 + ts = 0 try: if isinstance(ts_raw, float): - ts_us = int(ts_raw * 1_000_000) + ts = int(ts_raw * 1_000_000) elif isinstance(ts_raw, int): - ts_us = ts_raw + ts = ts_raw else: - ts_us = int(ts_raw or 0) + ts = int(ts_raw or 0) except Exception: - ts_us = 0 + ts = 0 req_id = ev.get("req_id") or "" key = self._key(node_id, ev.get("pid"), ev.get("tid"), req_id) if ev.get("type") == "B": - self._push(key, _OpenFrame(name=ev.get("name"), t0=ts_us)) + self._push(key, _ActiveSpan(name=ev.get("name"), t0=ts)) elif ev.get("type") == "E": fr = self._pop(key) if not fr: return - dur_us = max(0, ts_us - fr.t0) + dur_us = max(0, ts - fr.t0) self_us = max(0, dur_us - fr.child) self_ms = self_us / 1000.0 - self._acc_annotate(fr.name, self_ms) + self._accumulate(fr.name, self_ms) parent = self._peek(key) - completed = { - "name": fr.name, - "ts": fr.t0, - "dur_ms": dur_us / 1000.0, - "self_ms": self_ms, - "children": fr.children, - "pid": ev.get("pid"), - "tid": ev.get("tid"), - "req_id": req_id, - "node_id": node_id, - } if parent: parent.child += dur_us - parent.children.append(completed) - else: - self.roots_by_req[req_id or ""].append(completed) else: # TODO :Process other events pass @@ -126,28 +140,68 @@ def annotate(self, run_id: str, *, mapping: Optional[Dict[str, str]] = None, rep agg = self._req.get(run_id) if not agg: return [] + def _stats(xs: List[float]) -> Dict[str, float]: + if not xs: + return {"mean": 0.0, "p50": 0.0, "p90": 0.0, "p99": 0.0, "min": 0.0, "max": 0.0} + n = len(xs) + srt = sorted(xs) + def q(p: float) -> float: + if n == 1: + return srt[0] + k = int(round(p * (n - 1))) + k = max(0, min(k, n - 1)) + return srt[k] + return { + "mean": (sum(xs) / n), + "p50": q(0.5), + "p90": q(0.9), + "p99": q(0.99), + "min": srt[0], + "max": srt[-1], + } + + rows: List[Dict[str, Any]] = [] if not mapping: - rows = [ - {"name": k, "self_ms": v, "total_ms": v, "count": repeats or agg.counts_by_name.get(k, 0), "max_ms": None} - for k, v in agg.sums_by_name.items() - ] + for name, sym in agg.symbols.items(): + total = sym.total_ms + samples = list(sym.durations) + st = _stats(samples) + rows.append({ + "name": name, + "total": total, + "max": st["max"], + "mean": st["mean"], + "p50": st["p50"], + "p90": st["p90"], + "p99": st["p99"], + "samples": len(samples), + }) else: sums: Dict[str, float] = {} counts: Dict[str, int] = {} - for raw, val in agg.sums_by_name.items(): + dists: Dict[str, List[float]] = {} + for raw, sym in agg.symbols.items(): disp = mapping.get(raw, raw) - sums[disp] = sums.get(disp, 0.0) + val - counts[disp] = counts.get(disp, 0) + agg.counts_by_name.get(raw, 0) - rows = [ - {"name": k, "self_ms": v, "total_ms": v, "count": repeats or counts.get(k, 0), "max_ms": None} - for k, v in sums.items() - ] - rows.sort(key=lambda r: r["self_ms"], reverse=True) + sums[disp] = sums.get(disp, 0.0) + sym.total_ms + counts[disp] = counts.get(disp, 0) + sym.count + if sym.durations: + dists.setdefault(disp, []).extend(sym.durations) + for name, total in sums.items(): + samples = dists.get(name, []) + st = _stats(samples) + rows.append({ + "name": name, + "total": total, + "max": st["max"], + "mean": st["mean"], + "p50": st["p50"], + "p90": st["p90"], + "p99": st["p99"], + "samples": len(samples), + }) + rows.sort(key=lambda r: r["total"], reverse=True) return rows def roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: - with self._lock: - agg = self._req.get(run_id) - if not agg: - return [] - return list(agg.roots_by_req.get(req_id or "", [])) + # Call-tree storage is disabled to reduce memory; keep API for compatibility. + return [] diff --git a/src/repl.py b/src/repl.py index 862d7136..c2f081fb 100644 --- a/src/repl.py +++ b/src/repl.py @@ -540,7 +540,7 @@ def do_trace(self, cmd): def __trace_cb(self, data): self._trace_agg.enqueue(data) - def __print_tr(self, symbol, ms, counts): + def __print_tr(self, row): sym = " " + symbol.ljust(40, ' ') pms = f"{ms:.10}".ljust(10, ' ') cns = f"{counts}".ljust(4, ' ') @@ -552,16 +552,34 @@ def print_trace_annotate( mapping: Optional[Dict[str, str]] = None, repeats: int = 0, ) -> List[Dict[str, Any]]: - names = " "*17 + "symbol" + " "*21 + "ms" + " "*4 + "counts" - dots = " " + "."*41 + " " + "."*10 + " " + "."*4 - dprint(f"{names}\n{dots}\n\n") - sums = self._trace_agg._req[run_id].sums_by_name - cnts = self._trace_agg._req[run_id].counts_by_name - for n, d in sums.items(): - self.__print_tr(n, d, cnts[n]) - - def get_trace_roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: - return self._trace_agg.roots(run_id, req_id) + + rows = self._trace_agg.annotate(run_id) + headers = ["name", "total","max","mean","p50","p90","p99","samples"] + limits = {"name": 50,} + w = {h: max(len(h), min(limits.get(h, 8), max(len(str(r[h])) for r in rows))) for h in headers} + w["name"] = max(w["name"], 35) + + line = " ".join(h.ljust(w[h]) for h in headers); sys.stdout.write("\n") + sys.stdout.write(line + "\n") + sys.stdout.write(" ".join("."*w[h] for h in headers)); sys.stdout.write("\n") + for r in rows: + name = str(r["name"]) + if len(name) > w["name"]: name = name[:w["name"]-1] + "..." + vals = { + "name": r["name"], + "total": r["total"], + "max": r["max"], + "mean": r["mean"], + "p50": r["p50"], + "p90": r["p90"], + "p99": r["p99"], + "samples": r["samples"], + } + sys.stdout.write(" " + str(vals[headers[0]]).ljust(w[headers[0]])) + sys.stdout.write(" ".join(f"{vals[h]:8.2f}".rjust(w[h]) for h in headers[1:])) + sys.stdout.write("\n") + sys.stdout.write("\n\n") + sys.stdout.flush() def _print_nodes_table(self, rows: List[Any]) -> None: headers = ["name", "role", "addr", "http", "grpc", "status", "head"] From fd13c7269dd2c79624453a340c178b9a97ce2138 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 22:35:29 -0700 Subject: [PATCH 023/172] append async symbols with 'wait' --- src/dnet/ring/shard/compute.py | 3 ++- src/dnet/ring/shard/node.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 41111154..794fc682 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -377,7 +377,8 @@ def _process_activation(self, activation_msg: ActivationMessage): except Exception: pass - with self.tracer.frame("compute.thread", "mdns.send"): + # Create and enqueue output message: either forward activations or finalize on end role + with self.tracer.frame("compute.thread", "grpc.send"): nxt = last_layer + 1 if nxt >= self.model_metadata.num_layers: # End of model try: diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index a9aeba76..5f796d67 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -811,12 +811,13 @@ async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: while self.running: try: self.ingress_q.put_nowait(request) + logger.debug(f"[ENQUE] Enqueued activation request") return except asyncio.QueueFull: await asyncio.sleep(0) - # If we reached here, node is stopping; drop admission silently return + async def _ingress_worker(self): """Drains ingress queue and processes frames with heavy work offloaded. @@ -826,9 +827,10 @@ async def _ingress_worker(self): """ while self.running: with self.tracer.frame("grpc", "ingress") as f: - with self.tracer.frame("grpc.ingress", "get"): + with self.tracer.frame("grpc.ingress", "get.wait"): try: req = await self.ingress_q.get() + logger.debug(f"[DEQUE]Dequeued activation for processing {req}") except asyncio.CancelledError: break @@ -1099,7 +1101,7 @@ def _compute_worker(self) -> None: while self.running: try: # Get activation from queue (blocks until available) - with self.tracer.frame("compute", "dequeue"): + with self.tracer.frame("compute", "deque.wait"): activation_msg = self.activation_recv_queue.get(timeout=1.0) # Process the activation From 1086f87d2669b817572c5aedd2c34673a74fb211 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 22:36:12 -0700 Subject: [PATCH 024/172] min bench wrapper --- src/dnet/perf/bench.py | 145 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 src/dnet/perf/bench.py diff --git a/src/dnet/perf/bench.py b/src/dnet/perf/bench.py new file mode 100644 index 00000000..ccbd040d --- /dev/null +++ b/src/dnet/perf/bench.py @@ -0,0 +1,145 @@ + +from __future__ import annotations + +import json +import os +import statistics +import time +from dataclasses import dataclass, field +from typing import Any, Dict, Iterable, List, Optional + +from dnet.perf.trace import Tracer + + +def _percentile(xs: List[float], q: float) -> float: + if not xs: + return 0.0 + ys = sorted(xs) + k = int(round(q * (len(ys) - 1))) + k = max(0, min(k, len(ys) - 1)) + return ys[k] + + +def collect_stats(times_ms: List[float], *, bytes_total: float = 0.0, tokens_total: float = 0.0) -> Dict[str, Any]: + if not times_ms: + return { + "mean": 0.0, + "std": 0.0, + "min": 0.0, + "p50": 0.0, + "p90": 0.0, + "p99": 0.0, + "max": 0.0, + "samples": 0, + "mb_s": 0.0, + "tok_s": 0.0, + } + total_ms = sum(times_ms) + mean = total_ms / len(times_ms) + std = statistics.pstdev(times_ms) if len(times_ms) > 1 else 0.0 + total_s = max(total_ms / 1000.0, 1e-12) + return { + "mean": mean, + "std": std, + "min": min(times_ms), + "p50": _percentile(times_ms, 0.5), + "p90": _percentile(times_ms, 0.9), + "p99": _percentile(times_ms, 0.99), + "max": max(times_ms), + "samples": len(times_ms), + "mb_per_s": (bytes_total / 1_000_000.0) / total_s if bytes_total else 0.0, + "tokens_per_s": (tokens_total / total_s) if tokens_total else 0.0, + } + + +def _ensure_dir(path: str) -> None: + d = os.path.dirname(path) or "." + os.makedirs(d, exist_ok=True) + + +@dataclass +class BenchCounters: + values: Dict[str, float] = field(default_factory=dict) + + def add_time(self, key: str, dt_ms: float) -> None: + self.values[key] = self.values.get(key, 0.0) + float(dt_ms) + + def add_bytes(self, *, direction: str, n: int) -> None: + k = "bytes_in" if direction == "in" else "bytes_out" + self.values[k] = self.values.get(k, 0.0) + float(n) + + def inc(self, key: str, delta: float = 1.0) -> None: + self.values[key] = self.values.get(key, 0.0) + float(delta) + + def snapshot(self, *, run_id: str, node: str, role: str = "shard") -> Dict[str, Any]: + snap = { + "run_id": run_id, + "node": node, + "role": role, + "counters": dict(self.values), + } + return snap + + +class TimedSpan: + __slots__ = ("_tracer", "_name", "_attrs", "_t0", "_frame", "_counters", "_counter_key") + + def __init__( + self, + tracer: Optional[Tracer], + name: str, + counters: Optional[BenchCounters] = None, + counter_key: Optional[str] = None, + attrs: Optional[Dict[str, Any]] = None, + ) -> None: + self._tracer = tracer + self._name = name + self._attrs = attrs or {} + self._t0 = 0.0 + self._frame = None + self._counters = counters + self._counter_key = counter_key + + def __enter__(self): + self._t0 = time.perf_counter() + if self._tracer is not None: + self._frame = self._tracer.frame("bench", self._name, self._attrs) + self._frame.__enter__() + return self + + def __exit__(self, ex_type, ex, tb) -> bool: + dt_ms = (time.perf_counter() - self._t0) * 1000.0 + if self._frame is not None: + try: + self._frame.__exit__(ex_type, ex, tb) + except Exception: + pass + if self._counters is not None and self._counter_key: + self._counters.add_time(self._counter_key, dt_ms) + return False + + +def aggregate_annotate( + snapshots: Iterable[Dict[str, Any]], + *, + mapping: Optional[Dict[str, str]] = None, + repeats: int = 0, +) -> List[Dict[str, Any]]: + + sums: Dict[str, float] = {} + for snap in snapshots: + ctr = snap.get("counters") if isinstance(snap, dict) else None + if not isinstance(ctr, dict): + continue + for k, v in ctr.items(): + name = mapping.get(k, k) if mapping else k + try: + sums[name] = sums.get(name, 0.0) + float(v) + except Exception: + continue + + rows = [ {"name": name, "self_ms": val, "total_ms": val, "count": repeats or 0, "max_ms": None} + for name, val in sums.items() if val > 0.0] + rows.sort(key=lambda r: r["self_ms"], reverse=True) + return rows + From 003b9f361c75a2584dffbbbff099352bfb6d8f1d Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 22:47:57 -0700 Subject: [PATCH 025/172] various --- src/dnet/perf/bench.py | 1 - src/dnet/perf/trace.py | 1 - src/dnet/ring/shard/startup.py | 1 - src/dnet/utils/logger.py | 2 +- 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/dnet/perf/bench.py b/src/dnet/perf/bench.py index ccbd040d..0cdbadd2 100644 --- a/src/dnet/perf/bench.py +++ b/src/dnet/perf/bench.py @@ -19,7 +19,6 @@ def _percentile(xs: List[float], q: float) -> float: k = max(0, min(k, len(ys) - 1)) return ys[k] - def collect_stats(times_ms: List[float], *, bytes_total: float = 0.0, tokens_total: float = 0.0) -> Dict[str, Any]: if not times_ms: return { diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index d2ed43b6..7b5d79c1 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -212,7 +212,6 @@ def _emit(self, ev: Dict[str, Any]) -> None: batch = { "run_id": (self._req_id or "NONE"), "node_id": (self.config.node_id or "NODE"), "events": list(self._events)} - logger.debug(batch) try: self._agg_q.put_nowait(batch) except queue.Full: diff --git a/src/dnet/ring/shard/startup.py b/src/dnet/ring/shard/startup.py index f3b38fbb..20ba93ec 100644 --- a/src/dnet/ring/shard/startup.py +++ b/src/dnet/ring/shard/startup.py @@ -286,7 +286,6 @@ async def setup_trace(req: TraceConfigRequest) -> TraceConfigResponse: logger.info("Updated tracer config.") self.api_address = cfg.aggregate_url self.tracer.start_aggregator() - logger.debug(cfg) return TraceConfigResponse(ok=True) except Exception as e: logger.error(f"Unable to setup tracing on shard: {e}") diff --git a/src/dnet/utils/logger.py b/src/dnet/utils/logger.py index 19f09c46..bac462f9 100644 --- a/src/dnet/utils/logger.py +++ b/src/dnet/utils/logger.py @@ -14,7 +14,7 @@ def get_logger() -> logging.Logger: Returns: Configured logger instance """ - logLevelEnv = os.getenv("DNET_LOG", "INFO").strip().upper() + logLevelEnv = os.getenv("DNET_LOG", "DEBUG").strip().upper() logLevel = logging.INFO # default if logLevelEnv in logging._nameToLevel: logLevel = logging._nameToLevel[logLevelEnv] From 71835fbc5e52e5b12f7bc242c0c55641f1514c0f Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 22 Oct 2025 01:13:59 -0700 Subject: [PATCH 026/172] fix indent and other rebase errors --- src/dnet/ring/shard/comms.py | 87 +++++++---- src/dnet/ring/shard/node.py | 294 +++++++++++++++++++++-------------- 2 files changed, 237 insertions(+), 144 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index dc197637..db5817e5 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -293,8 +293,8 @@ async def _send_activation(self, activation_msg: ActivationMessage): int(getattr(activation_msg, "token_id", -1)), rpc_ms, ) - except Exception as e: - logger.exception("Error sending token via gRPC: %s", e) + except Exception as e: + logger.exception("Error sending token via gRPC: %s", e) else: logger.error( "No valid gRPC callback for token; cannot deliver nonce=%s", @@ -322,7 +322,6 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) t_ser = time.perf_counter() t_cast = t_ser - _len_bytes = int(getattr(shaped, "nbytes", 0)) _do_compress = bool( self._compress and _len_bytes >= self._compress_min_bytes @@ -360,7 +359,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): f.event("mxarray.cast") data = tensor_to_bytes(shaped) - activation_msg.dtype = self._wire_dtype_str + activation_msg.dtype = self._wire_dtype_str nxt = activation_msg.layer_id + 1 if (nxt < self.model_metadata.num_layers) and (nxt not in self._assigned_set): @@ -426,6 +425,38 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) ctx.disabled = True + # Prefer streaming if enabled/available; fallback to unary + stream_used = False + ctx = await self._ensure_stream(activation_msg.nonce) + if ( + ctx + and ctx.open + and not ctx.disabled + and hasattr(dnet_ring_pb2, "ActivationFrame") + ): + try: + ctx.last_seq += 1 + frame = dnet_ring_pb2.ActivationFrame( + request=request, + seq=ctx.last_seq, + end_of_request=False, + ) + await ctx.queue.put(frame) + ctx.last_activity_t = asyncio.get_running_loop().time() + stream_used = True + if self._profile: + logger.info( + "[PROFILE][STREAM-ENQ] nonce=%s seq=%s q=%s", + activation_msg.nonce, + ctx.last_seq, + ctx.queue.qsize(), + ) + except Exception as e: + logger.warning( + "[STREAM] enqueue failed; fallback to unary: %s", e + ) + ctx.disabled = True + if not stream_used: # In fit mode, avoid long unary stalls: use short deadline and min retries # Streaming should be the norm; unary is a quick safety valve only. @@ -443,17 +474,17 @@ async def _send_activation(self, activation_msg: ActivationMessage): reason = "stream_closed" elif ctx.disabled: reason = "stream_disabled" - else: - reason = "enqueue_failed" - logger.warning( - "[STREAM->UNARY] node=%s nonce=%s reason=%s mode=%s timeout_s=%.1f retries=%d", - self.node_id, - activation_msg.nonce, - reason, - self._mode, - ring_timeout, - ring_retries, - ) + else: + reason = "enqueue_failed" + logger.warning( + "[STREAM->UNARY] node=%s nonce=%s reason=%s mode=%s timeout_s=%.1f retries=%d", + self.node_id, + activation_msg.nonce, + reason, + self._mode, + ring_timeout, + ring_retries, + ) t0 = time.perf_counter() last_exc: Optional[Exception] = None for attempt in range(1, ring_retries + 1): @@ -481,19 +512,19 @@ async def _send_activation(self, activation_msg: ActivationMessage): await asyncio.sleep(min(0.25 * attempt, 1.0)) continue raise - else: - raise last_exc # type: ignore - rpc_ms = (time.perf_counter() - t0) * 1000.0 - logger.info( - "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", - self.node_id, - activation_msg.nonce, - activation_msg.layer_id + 1, - (len(data) / 1024), - ser_ms, - rpc_ms, - cast_ms, - ) + else: + raise last_exc # type: ignore + rpc_ms = (time.perf_counter() - t0) * 1000.0 + logger.info( + "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", + self.node_id, + activation_msg.nonce, + activation_msg.layer_id + 1, + (len(data) / 1024), + ser_ms, + rpc_ms, + cast_ms, + ) else: logger.error("Cannot forward activation - no next node configured; end shard should sample inline.") diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 5f796d67..743850b3 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -443,8 +443,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse with self.tracer.frame("network", "connect.next_node"): if self.next_node: await self._connect_next_node() - else: - logger.warning("Node %s: No next node configured", self.node_id) + else: + logger.warning("Node %s: No next node configured", self.node_id) # Warmup: compile hot path and stabilize allocators before first request with self.tracer.frame("memory", "warmup"): @@ -696,59 +696,22 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): if activation.dtype == "tokens": with self.tracer.frame("grpc.receive", "token_stream") as fr: try: - tokens = np.frombuffer( - request.activation.data, dtype=np.int32 + deq = decompress_tensor_from_protobuf_data( + tensor_data=activation.data, + shape=list(activation.shape), + dtype_with_metadata=activation.dtype, ) - shp = (int(len(tokens)),) except Exception as e: logger.error( - "Failed to parse tokens for nonce %s: %s", - request.nonce, - e, - ) - return - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mx.int32, - shape=cast(tuple[int, ...], shp), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, + "Decompression failed for nonce %s: %s", request.nonce, e ) return - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - buffer[: len(tokens)] = tokens - if self._profile: - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (tokens)", - self.node_id, - request.nonce, - alloc_copy_ms, - ) - activation_msg = ActivationMessage.from_proto(request, pool_id) - # Ensure dtype reflects token payload for compute path - activation_msg.dtype = "tokens" - activation_msg.shape = shp - else: - with self.tracer.frame("grpc.receive", "default") as fr: - # Safety: byte length must match shape*dtype - try: - expected = ( - int(np.prod(activation.shape)) - * np.dtype(dtype_map[activation.dtype]).itemsize - ) - actual = len(request.activation.data) - except Exception: - pass + with self.tracer.frame("grpc.receive", "alloc.buffer") as fr: pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, - dtype=mlx_dtype_map[activation.dtype], - shape=cast(tuple[int, ...], activation.shape), + dtype=deq.dtype, + shape=cast(tuple[int, ...], tuple(deq.shape)), ) if pool_id is None: logger.warning( @@ -758,49 +721,135 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): return buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: - data = request.activation.data - input_data = np.frombuffer( - data, dtype=dtype_map[activation.dtype] - ) - buffer[: len(input_data)] = input_data + flat = deq.reshape(-1) + buffer[: flat.size] = flat alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f", + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", self.node_id, request.nonce, alloc_copy_ms, ) - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = new_dtype_str - activation_msg.shape = tuple(deq.shape) - - # Queue for processing — non-blocking back-off loop (cancellable) - if self._profile: - activation_msg.enq_perf_t = time.perf_counter() - while self.running: - try: - self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", + + # Update activation message with true dtype/shape + new_dtype_str = str(deq.dtype) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + else: + # Special token stream support: dtype='tokens' carries int32 token IDs + if activation.dtype == "tokens": + with self.tracer.frame("grpc.receive", "token_stream") as fr: + try: + tokens = np.frombuffer( + request.activation.data, dtype=np.int32 + ) + shp = (int(len(tokens)),) + except Exception as e: + logger.error( + "Failed to parse tokens for nonce %s: %s", + request.nonce, + e, + ) + return + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mx.int32, + shape=cast(tuple[int, ...], shp), + ) + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + buffer[: len(tokens)] = tokens + if self._profile: + alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 + logger.info( + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (tokens)", + self.node_id, + request.nonce, + alloc_copy_ms, + ) + activation_msg = ActivationMessage.from_proto(request, pool_id) + # Ensure dtype reflects token payload for compute path + activation_msg.dtype = "tokens" + activation_msg.shape = shp + else: + with self.tracer.frame("grpc.receive", "default") as fr: + # Safety: byte length must match shape*dtype + try: + expected = ( + int(np.prod(activation.shape)) + * np.dtype(dtype_map[activation.dtype]).itemsize + ) + actual = len(request.activation.data) + except Exception: + pass + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mlx_dtype_map[activation.dtype], + shape=cast(tuple[int, ...], activation.shape), + ) + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + data = request.activation.data + input_data = np.frombuffer( + data, dtype=dtype_map[activation.dtype] + ) + buffer[: len(input_data)] = input_data + alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 + logger.info( + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f", + self.node_id, + request.nonce, + alloc_copy_ms, + ) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + + # Queue for processing — non-blocking back-off loop (cancellable) + if self._profile: + activation_msg.enq_perf_t = time.perf_counter() + while self.running: + try: + self.activation_recv_queue.put_nowait(activation_msg) + logger.debug( + "Queued activation for processing: nonce %s", + activation_msg.nonce, + ) + break + except Full: + await asyncio.sleep(0) + else: + logger.error( + "Failed to queue activation %s (node stopping)", activation_msg.nonce, ) - break - except Full: - await asyncio.sleep(0) + self.input_pool.release(pool_id) else: - logger.error( - "Failed to queue activation %s (node stopping)", - activation_msg.nonce, + # Forward to next node (not our layer) + logger.debug( + "Forwarding activation (layer %s) to next node, nonce: %s", + target_layer, + request.nonce, ) - self.input_pool.release(pool_id) - else: - # Forward to next node (not our layer) - logger.debug( - "Forwarding activation (layer %s) to next node, nonce: %s", - target_layer, - request.nonce, - ) - await self._forward_activation(request) + await self._forward_activation(request) + + except Exception as e: + logger.exception("Error receiving activation: %s", e) + async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: @@ -850,10 +899,10 @@ async def _ingress_worker(self): logger.error(f"Unable to read length of data for {req.nonce}") payload_bytes = -1 - f.set("nonce", req.nonce) - f.set("target", target_layer) - f.set("payload_bytes", payload_bytes) - f.event("received") + f.set("nonce", req.nonce) + f.set("target", target_layer) + f.set("payload_bytes", payload_bytes) + f.event("received") if target_layer in self._assigned_set: # Heavy prep in executor (alloc/copy/decompress) @@ -877,41 +926,54 @@ async def _ingress_worker(self): with self.tracer.frame("grpc.ingress", "queue") as fr: while self.running: try: - self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", + activation_msg = await loop.run_in_executor( + self.executor, + self._prepare_activation_message_blocking, + req, + ) + except Exception as e: + logger.error("Activation prepare failed for nonce %s: %s", req.nonce, e) + continue + if activation_msg is None: + continue + if self._profile: + activation_msg.recv_perf_t = t_recv + + # Enqueue for compute (cancellable back-off) + with self.tracer.frame("grpc.ingress", "queue") as fr: + while self.running: + try: + self.activation_recv_queue.put_nowait(activation_msg) + logger.debug( + "Queued activation for processing: nonce %s", + activation_msg.nonce, + ) + break + except Full: + await asyncio.sleep(0) + else: + logger.error( + "Failed to queue activation %s (node stopping)", activation_msg.nonce, ) - break - except Full: - await asyncio.sleep(0) - else: - logger.error( - "Failed to queue activation %s (node stopping)", - activation_msg.nonce, - ) - break - except Full: - await asyncio.sleep(0) + try: + if self.input_pool: + # FIXME: !!! + self.input_pool.release(activation_msg.pool_id) + except Exception: + pass else: - logger.error( - "Failed to queue activation %s (node stopping)", - activation_msg.nonce, - ) - try: - if self.input_pool: - # FIXME: !!! - self.input_pool.release(activation_msg.pool_id) - except Exception: - pass - else: - # Forward to next node (not our layer) - logger.debug( - "Forwarding activation (layer %s) to next node, nonce: %s", - target_layer, - req.nonce, - ) - await self._forward_activation(req) + # Forward to next node (not our layer) + logger.debug( + "Forwarding activation (layer %s) to next node, nonce: %s", + target_layer, + req.nonce, + ) + await self._forward_activation(req) + + except Exceptio as e: + logger.error("Ingress worker error: %s", e) + def _get_or_make_kv(self, nonce: str) -> list: From c04c69bb7de062739e67b46741ab7ae349f81d64 Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 22 Oct 2025 01:36:29 -0700 Subject: [PATCH 027/172] add accidentally removed code --- src/dnet/ring/shard/node.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 743850b3..e6fafe49 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -494,14 +494,6 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse k, ) - load_time_ms = (time.perf_counter() - start_time) * 1000.0 - logger.info( - "Node %s: Successfully loaded model %s with layers %s in %.2fms", - self.node_id, - req.model_path, - req.layers, - load_time_ms, - ) return ShardLoadModelResponse( success=True, From 3a38948a4f469012c079a43debe8fa50ac9044c9 Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 22 Oct 2025 02:48:51 -0700 Subject: [PATCH 028/172] fix indent for _send_activation last token --- src/dnet/ring/shard/comms.py | 26 ++++++++------------------ src/dnet/ring/shard/node.py | 7 ++----- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index db5817e5..bcc826c3 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -226,6 +226,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) return try: + logger.debug(f"Sending activation") if activation_msg.is_final: with self.tracer.frame("grpc", "send_activation.final") as f: try: @@ -339,10 +340,11 @@ async def _send_activation(self, activation_msg: ActivationMessage): if shaped.dtype != wire_np_dtype: # FIXME: numpy vs mx array here shaped = shaped.astype(wire_np_dtype, copy=False) - else: - # MLX array -> cast to desired wire dtype + + else: # MLX array -> cast to desired wire dtype if str(shaped.dtype) != self._wire_dtype_str: shaped = shaped.astype(self._wire_mx_dtype) + activation_msg.dtype = self._wire_dtype_str t_cast = time.perf_counter() @@ -364,6 +366,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): nxt = activation_msg.layer_id + 1 if (nxt < self.model_metadata.num_layers) and (nxt not in self._assigned_set): if self.next_node_stub: + with self.tracer.frame("grpc", "send_activation.next") as f: request = activation_msg.to_proto(data) request.timestamp = utc_epoch_now() @@ -428,12 +431,8 @@ async def _send_activation(self, activation_msg: ActivationMessage): # Prefer streaming if enabled/available; fallback to unary stream_used = False ctx = await self._ensure_stream(activation_msg.nonce) - if ( - ctx - and ctx.open - and not ctx.disabled - and hasattr(dnet_ring_pb2, "ActivationFrame") - ): + if (ctx and ctx.open and not ctx.disabled and hasattr(dnet_ring_pb2, "ActivationFrame")): + logger.debug(f"Sending activation with stream") try: ctx.last_seq += 1 frame = dnet_ring_pb2.ActivationFrame( @@ -444,17 +443,8 @@ async def _send_activation(self, activation_msg: ActivationMessage): await ctx.queue.put(frame) ctx.last_activity_t = asyncio.get_running_loop().time() stream_used = True - if self._profile: - logger.info( - "[PROFILE][STREAM-ENQ] nonce=%s seq=%s q=%s", - activation_msg.nonce, - ctx.last_seq, - ctx.queue.qsize(), - ) except Exception as e: - logger.warning( - "[STREAM] enqueue failed; fallback to unary: %s", e - ) + logger.warning("[STREAM] enqueue failed; fallback to unary: %s", e) ctx.disabled = True if not stream_used: diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index e6fafe49..b9436158 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -845,10 +845,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: - """ - Lightweight admission for streaming: - enqueue protobuf frame to ingress queue, then return. - """ + """enqueue protobuf frame to ingress queue""" while self.running: try: self.ingress_q.put_nowait(request) @@ -871,7 +868,7 @@ async def _ingress_worker(self): with self.tracer.frame("grpc.ingress", "get.wait"): try: req = await self.ingress_q.get() - logger.debug(f"[DEQUE]Dequeued activation for processing {req}") + logger.debug(f"[DEQUE]Dequeued activation for processing") except asyncio.CancelledError: break From 7e342aed696bbb541c4802dd3fc813c4bdc703e7 Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 22 Oct 2025 03:34:37 -0700 Subject: [PATCH 029/172] remove old startup file and add tracer endpoints in node.py:_setup_routes --- src/dnet/ring/shard/node.py | 32 +- src/dnet/ring/shard/startup.py | 574 --------------------------------- 2 files changed, 31 insertions(+), 575 deletions(-) delete mode 100644 src/dnet/ring/shard/startup.py diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index b9436158..3ed35f1e 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -32,6 +32,8 @@ ShardProfileRequest, ShardProfileResponse, ShardUnloadModelResponse, + TraceConfigRequest, + TraceConfigResponse, ) from ..model.base import BaseRingModel @@ -210,7 +212,7 @@ def __init__( enabled = True, record_pid_tid = True, aggregate=False, - aggregate_url=None, # FIXME: This is set when we get a /profile req + aggregate_url=None, ) self.tracer = Tracer(cfg) self.tracer.start() @@ -1533,6 +1535,34 @@ async def profile(req: ShardProfileRequest) -> ShardProfileResponse: logger.error(f"Error in /profile endpoint: {e}") raise + + @self.app.post("/trace") + async def setup_trace(req: TraceConfigRequest) -> TraceConfigResponse: + logger.debug("Updating trace config") + try: + cfg = TraceConfig( + file=req.file, + streaming=req.streaming, + include_prefixes=req.include_prefixes, + include_c_calls=req.include_c_calls, + budget=req.budget, + enabled=req.enabled, + node_id=req.node_id, + record_pid_tid=req.record_pid_tid, + aggregate=req.aggregate, + aggregate_url=req.aggregate_url, + agg_max_events=req.agg_max_events, + ) + self.tracer.config = cfg + logger.info("Updated tracer config.") + self.api_address = cfg.aggregate_url + self.tracer.start_aggregator() + return TraceConfigResponse(ok=True) + except Exception as e: + logger.error(f"Unable to setup tracing on shard: {e}") + return TraceConfigResponse(ok=False) + + @self.app.post("/load_model") async def load_model_endpoint( req: ShardLoadModelRequest, diff --git a/src/dnet/ring/shard/startup.py b/src/dnet/ring/shard/startup.py deleted file mode 100644 index 20ba93ec..00000000 --- a/src/dnet/ring/shard/startup.py +++ /dev/null @@ -1,574 +0,0 @@ -from __future__ import annotations - -import asyncio -import time -from typing import Any, Dict, List, Mapping -import threading -from socket import gethostname -from secrets import token_hex - -import mlx.core as mx -from fastapi import Request -from fastapi.responses import JSONResponse -from grpc import aio as aio_grpc - -from hypercorn import Config -import hypercorn.asyncio as aio_hypercorn -from dnet_p2p.thunderbolt import ThunderboltConnection -from dnet_p2p import ( - DnetDeviceProperties, - discover_thunderbolt_connection, -) - -from dnet.perf.trace import TraceConfig - -from ...protos.dnet_ring_pb2_grpc import add_DnetRingServiceServicer_to_server -from .servicer import ShardServicer -from ...utils.logger import logger -from ...utils.serialization import tensor_to_bytes -from ...utils.latency import ( - DeviceLatencyResult, - LatencyMeasurement, - LatencyResults, - calculate_median_latency_seconds, -) -from .models import ( - HealthResponse, - ShardLoadModelRequest, - ShardLoadModelResponse, - ShardProfileRequest, - ShardProfileResponse, - ShardUnloadModelResponse, - TraceConfigRequest, - TraceConfigResponse, -) -from ...protos import dnet_ring_pb2 - - -class StartupMixin: - async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()): - self.running = True - - with self.tracer.frame("startup", "workers"): - self.background_tasks = [ - asyncio.create_task(self._ingress_worker()), - asyncio.create_task(self._prefetch_worker()), - asyncio.create_task(self._send_worker()) ] - - try: # Start idle sweeper to close silent streams - if getattr(self, "_streaming_enabled", False) and hasattr(self, "_stream_sweeper"): - self.background_tasks.append( asyncio.create_task(self._stream_sweeper())) - except Exception: - pass - - with self.tracer.frame("startup", "compute"): - self.compute_thread = threading.Thread(target=self._compute_worker, daemon=True) - self.compute_thread.start() - - with self.tracer.frame("startup", "discovery"): - self._start_discovery() - - logger.info( - "Shard node %s started on gRPC port %s HTTP port %s", - self.node_id, - self.grpc_port, - self.http_port, - ) - - def _start_discovery(self) -> None: - """Start mDNS discovery service.""" - hostname = gethostname() - # TODO: optionally take shard name from CLI - instance = f"shard-{token_hex(4)}-{hostname}" - self.discovery.create_instance( - instance, - hostname, - "0.0.0.0", # Binds to all addresses - self.http_port, # HTTP port - self.grpc_port, # gRPC port - is_manager=False, # Shard is never a manager - ) - self.discovery.start() - logger.info( - "Discovery service started for shard node %s with name %s", - self.node_id, - self.discovery.fullname(), - ) - - async def _start_grpc_server(self) -> None: - """Start gRPC server.""" - self.server = aio_grpc.server() - - # Add the ring servicer; shard acts as client for ShardApiService (to API) - servicer = ShardServicer(self) # type: ignore # FIXME: !!! - add_DnetRingServiceServicer_to_server(servicer, self.server) - - listen_addr = f"[::]:{self.grpc_port}" - self.server.add_insecure_port(listen_addr) - await self.server.start() - logger.info( - "Shard node %s gRPC server started on %s", self.node_id, listen_addr - ) - try: - await asyncio.get_running_loop().run_in_executor( - self.executor, self._warmup_serialization - ) - logger.info("Warmup serialization completed") - except Exception as e: - logger.warning("Warmup serialization failed: %s", e) - - def _warmup_serialization(self): - try: - dummy = mx.random.normal((1024, 1024), dtype=mx.float32) - dummy16 = dummy.astype(self._wire_mx_dtype) - _ = tensor_to_bytes(dummy16) - except Exception: - pass - - def _warmup_shard(self): - logger.info( - "[WARMUP] Starting shard warmup with window size %s", self.window_size - ) - batch_size, seq_len = 1, 1 - hidden_size = self.model_metadata.model_config.get("hidden_size", 2560) - x = mx.zeros((batch_size, seq_len, hidden_size), dtype=mx.bfloat16) - start_time = time.perf_counter() - try: - default_n = max(1, int(getattr(self, "_resident_windows", 1))) - except Exception: - default_n = 1 - try: - max_windows = max( - 1, - int( - getattr(self, "config", None).warmup_windows - if getattr(self, "config", None) - else default_n - ), - ) - except Exception: - max_windows = default_n - windows: list[list[int]] = [] - for window_start in range(0, len(self._assigned_sorted), self.window_size): - window_end = min( - window_start + self.window_size, len(self._assigned_sorted) - ) - windows.append(self._assigned_sorted[window_start:window_end]) - for wi, window_layers in enumerate(windows[:max_windows]): - weights_to_bind = {} - for layer_id in window_layers: - weights = self.weight_cache.get_weight(layer_id) - if weights: - for k, v in weights.items(): - weights_to_bind[k] = v - if weights_to_bind: - self.model.load_weights(list(weights_to_bind.items()), strict=False) - try: - for layer_id in window_layers: - x = self.model.apply_single_layer(layer_id, x, cache=None) - _s = mx.sum(x) - mx.eval(_s) - except Exception: - pass - try: - for lid in window_layers: - self.weight_cache.decrease_reference(lid) - except Exception: - pass - if not self._warmup_keep_flag: - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(window_layers) # type: ignore[attr-defined] - except Exception: - pass - try: - self.weight_cache.evict_layers(window_layers) - except Exception: - pass - total_time = (time.perf_counter() - start_time) * 1000 - self._warmup_completed = True - logger.info( - "[WARMUP] Shard warmup completed in %.2fms; windows=%s kept=%s", - total_time, - min(len(windows), max_windows), - int(self._warmup_keep_flag), - ) - - async def _start_http_server(self, shutdown_trigger: Any) -> None: - """Start HTTP server. - - Args: - shutdown_trigger: Shutdown trigger function - """ - await self._setup_routes() - - # Start HTTP server in background - config = Config.from_mapping( - bind=f"0.0.0.0:{self.http_port}", - log_level="info", - log_config=None, - use_reloader=False, - h2c=False, - ) - - # Start the server as a background task - self.http_server = asyncio.create_task( - aio_hypercorn.serve(self.app, config, shutdown_trigger=shutdown_trigger) # type: ignore - ) - logger.info( - "Shard node %s HTTP server started on port %s", self.node_id, self.http_port - ) - - async def _setup_routes(self) -> None: - """Setup HTTP routes.""" - - @self.app.get("/health") - async def health() -> HealthResponse: - try: - instance = self.discovery.instance_name() - except Exception: - instance = None - return HealthResponse( - status="ok", - node_id=self.node_id, - running=self.running, - model_loaded=self._check_model_loaded(), - model_path=self.model_path, - assigned_layers=self.assigned_layers, - queue_size=self.activation_recv_queue.qsize(), - grpc_port=self.grpc_port, - http_port=self.http_port, - instance=instance, - ) - - @self.app.post("/profile") - async def profile(req: ShardProfileRequest) -> ShardProfileResponse: - try: - latency_results = await self._measure_latency_to_devices( req.devices, req.thunderbolts, req.payload_sizes) - device_profile = await self._profile_device( req.repo_id, req.max_batch_exp) - - # Overwrite `t_comm` with median latency (subprocess returns a dict) - median_latency = calculate_median_latency_seconds(latency_results) - if median_latency is not None: - device_profile["t_comm"] = float(median_latency) - logger.info( - f"Set t_comm to median latency: {device_profile['t_comm']:.6f}s" - ) - else: - logger.warning( "No valid latency measurements, keeping default t_comm") - - # Return the dict payload directly - return ShardProfileResponse( - profile=device_profile, - latency=latency_results, - ) - except Exception as e: - logger.error(f"Error in /profile endpoint: {e}") - raise - - @self.app.post("/trace") - async def setup_trace(req: TraceConfigRequest) -> TraceConfigResponse: - try: - cfg = TraceConfig( - file=req.file, - streaming=req.streaming, - include_prefixes=req.include_prefixes, - include_c_calls=req.include_c_calls, - budget=req.budget, - enabled=req.enabled, - node_id=req.node_id, - record_pid_tid=req.record_pid_tid, - aggregate=req.aggregate, - aggregate_url=req.aggregate_url, - agg_max_events=req.agg_max_events, - ) - self.tracer.config = cfg - logger.info("Updated tracer config.") - self.api_address = cfg.aggregate_url - self.tracer.start_aggregator() - return TraceConfigResponse(ok=True) - except Exception as e: - logger.error(f"Unable to setup tracing on shard: {e}") - return TraceConfigResponse(ok=False) - - @self.app.post("/load_model") - async def load_model_endpoint( - req: ShardLoadModelRequest, - ) -> ShardLoadModelResponse: - """Load model with specified layers.""" - try: - logger.info( - f"HTTP /load_model: model={req.model_path}, layers={req.layers}, " - f"next_node={req.next_node or 'none'}, window_size={req.window_size}, " - f"total_layers={req.total_layers}, api_callback={req.api_callback_address or 'none'}" - ) - result = await self.load_model(req) - return result - - except Exception as e: - logger.error(f"Error in /load_model endpoint: {e}") - return ShardLoadModelResponse( - success=False, - message=f"Error: {str(e)}", - layers_loaded=[], - load_time_ms=0.0, - ) - - @self.app.post("/unload_model") - async def unload_model_endpoint() -> ShardUnloadModelResponse: - """Unload current model.""" - try: - logger.info("HTTP /unload_model") - result = await self.unload_model() - return result - - except Exception as e: - logger.error(f"Error in /unload_model endpoint: {e}") - return ShardUnloadModelResponse( - success=False, - message=f"Error: {str(e)}", - ) - - @self.app.post("/warm") - async def warm(request: Request) -> JSONResponse: - try: - body = await request.json() - start = int(body.get("start", -1)) - window = int(body.get("window", self.window_size)) - if start < 0: - return JSONResponse( - status_code=400, content={"error": "missing/invalid start"} - ) - start_idx = 0 - for i, lyr in enumerate(self._assigned_sorted): - if lyr >= start: - start_idx = i - break - else: - return JSONResponse(content={"prefetched": []}) - window_layers = self._assigned_sorted[ - start_idx : start_idx + max(1, window) - ] - for wl in window_layers: - self._prefetch_to_ram(wl) - self._enqueue_weight_prefetch(wl) - return JSONResponse(content={"prefetched": window_layers}) - except Exception as e: - logger.error("/warm failed: %s", e) - return JSONResponse(status_code=500, content={"error": str(e)}) - - async def _profile_device(self, repo_id: str, max_batch_exp: int) -> dict: - """Profile device using dperf in a subprocess and return a dict. - - Args: - repo_id: Hugging Face repository ID - max_batch_exp: Maximum batch size exponent (2^max_batch_exp) - - Returns: - Device profile information as a plain dict - """ - from ...utils.profile_subproc import profile_device_via_subprocess - - profile_dict = profile_device_via_subprocess( - repo_id, max_batch_exp=max_batch_exp, debug=0 - ) - logger.info("Device profiling completed for node %s", self.node_id) - return profile_dict - - async def _connect_next_node(self) -> bool: - """Connect to next node in ring. - - Returns: - True if connected or no next node, False on failure - """ - if not self.next_node: - logger.info(f"Shard node {self.node_id} is the final shard (no next node)") - return True - - if self.next_node_channel: - logger.debug(f"Shard node {self.node_id} already connected to next node.") - return True - - try: - # use thunderbolt here if available - this_properties = self.discovery.get_own_properties() - thunderbolt_conn = discover_thunderbolt_connection( - this_properties, - self.next_node, - ) - next_ip = ( - thunderbolt_conn.ip_addr - if thunderbolt_conn - else self.next_node.local_ip - ) - address = f"{next_ip}:{self.next_node.shard_port}" - logger.info( - f"Shard node {this_properties.instance} connecting to next node {self.next_node.instance} at {address}" - ) - - self.next_node_channel = aio_grpc.insecure_channel(address) - from ...protos.dnet_ring_pb2_grpc import DnetRingServiceStub - - self.next_node_stub = DnetRingServiceStub(self.next_node_channel) - return True - except Exception as e: - logger.warning( - f"Shard node {self.node_id} failed to connect to next node {address}: {e}" - ) - self.next_node_channel = None - self.next_node_stub = None - return False - - async def _reconnect_next_node(self) -> bool: - try: - if self.next_node_channel: - await self.next_node_channel.close() - except Exception: - pass - self.next_node_channel = None - self.next_node_stub = None - return await self._connect_next_node() - - async def _health_check(self): - try: - health_request = dnet_ring_pb2.HealthRequest(requester_id=str(self.node_id)) - response = await self.next_node_stub.HealthCheck(health_request) # type: ignore - logger.info( - "Shard node %s successfully pinged: %s, healthy: %s", - self.node_id, - response.node_id, - response.healthy, - ) - return True - except Exception as e: - logger.warning( - "Shard node %s failed to ping next node %s: %s", - self.node_id, - self.next_node_address, - e, - ) - return False - - async def _measure_latency_to_devices( - self, - devices: Mapping[str, DnetDeviceProperties], - thunderbolts: Mapping[str, ThunderboltConnection], - payload_sizes: List[int], - ) -> LatencyResults: - """Measure latency to all devices except self. - - Args: - devices: Device information mapping - thunderbolts: Thunderbolt connection information - payload_sizes: List of payload sizes to test - - Returns: - Latency measurement results - """ - latency_results_dict: Dict[str, DeviceLatencyResult] = {} - - for service_name, device_info in devices.items(): - # Skip measuring latency to ourselves - if service_name.startswith(self.discovery.instance_name()): - logger.debug("Skipping latency measurement to self: %s", service_name) - continue - - # Skip measuring latency to API (manager) devices - if device_info.is_manager: - logger.debug( - "Skipping latency measurement to manager/API: %s", service_name - ) - continue - - try: - shard_port = device_info.shard_port - - # Check for Thunderbolt connection - if service_name in thunderbolts: - tb_data = thunderbolts[service_name] - service_ip = tb_data.ip_addr - logger.info( - "Using Thunderbolt for %s at %s, connected to instance %s", - service_name, - service_ip, - tb_data.instance, - ) - else: - # No Thunderbolt, use WiFi - service_ip = device_info.local_ip - - if not shard_port or not service_ip: - logger.warning( - "No shard_port or local_ip for device %s", service_name - ) - continue - - # Connect to target shard's gRPC server - target_address = f"{service_ip}:{shard_port}" - channel = aio_grpc.insecure_channel(target_address) - from ...protos.dnet_ring_pb2_grpc import DnetRingServiceStub - - stub = DnetRingServiceStub(channel) - - # Measure latency for each payload size - latency_measurements: List[LatencyMeasurement] = [] - for payload_size in payload_sizes: - # Create dummy payload - dummy_data = b"x" * payload_size - - start_time = time.perf_counter() - timestamp_ms = int(time.time() * 1000) - - request = dnet_ring_pb2.LatencyMeasureRequest( - requester_id=str(self.node_id), - payload_size=payload_size, - dummy_data=dummy_data, - timestamp=timestamp_ms, - ) - - response = await stub.MeasureLatency(request) # type: ignore - end_time = time.perf_counter() - - if response.success: - latency_ms = (end_time - start_time) * 1000 - latency_measurements.append( - LatencyMeasurement( - payload_size=payload_size, - latency_ms=round(latency_ms, 2), - success=True, - error=None, - ) - ) - else: - latency_measurements.append( - LatencyMeasurement( - payload_size=payload_size, - success=False, - error=response.message, - latency_ms=0, - ) - ) - - # Store results - result = DeviceLatencyResult( - target_node_id=response.node_id if response.success else None, - measurements=latency_measurements, - success=True, - error=None, - ) - latency_results_dict[service_name] = result - - # Close channel - await channel.close() - - except Exception as e: - logger.error("Error measuring latency to %s: %s", service_name, e) - result = DeviceLatencyResult( - target_node_id=None, - success=False, - error=str(e), - measurements=[], - ) - latency_results_dict[service_name] = result - - return LatencyResults(results=latency_results_dict) From 44d929d9fe35fa5e130b6a30f2905571b11fa7fe Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 22 Oct 2025 22:52:27 -0700 Subject: [PATCH 030/172] runtime stats high-level frames --- src/dnet/ring/shard/node.py | 74 +++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 3ed35f1e..d604087d 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -375,6 +375,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse resident_windows=self._resident_windows, use_mxload_fastpath=self.config.mxload_fastpath, prefetch_mode=self.config.prefetch_mode, + tracer=self.tracer, ) # Load the model @@ -1158,7 +1159,7 @@ def _compute_worker(self) -> None: activation_msg = self.activation_recv_queue.get(timeout=1.0) # Process the activation - with self.tracer.frame("compute", "forward"): + with self.tracer.frame("compute", "forward"): # NOTE: Symbol hardcoded for runtime stats self._process_activation(activation_msg) except Empty: @@ -1301,9 +1302,7 @@ def _warmup_serialization(self): pass def _warmup_shard(self): - logger.info( - "[WARMUP] Starting shard warmup with window size %s", self.window_size - ) + logger.info("[WARMUP] Starting shard warmup with window size %s", self.window_size) if not self.model or not self.model_metadata or not self.weight_cache: logger.warning("[WARMUP] No model loaded; skipping warmup") return @@ -1325,10 +1324,9 @@ def _warmup_shard(self): max_windows = max(1, self.config.warmup_windows) windows: list[list[int]] = [] for window_start in range(0, len(self._assigned_sorted), self.window_size): - window_end = min( - window_start + self.window_size, len(self._assigned_sorted) - ) + window_end = min(window_start + self.window_size, len(self._assigned_sorted)) windows.append(self._assigned_sorted[window_start:window_end]) + for wi, window_layers in enumerate(windows[:max_windows]): weights_to_bind = {} for layer_id in window_layers: @@ -1336,6 +1334,7 @@ def _warmup_shard(self): if weights: for k, v in weights.items(): weights_to_bind[k] = v + if weights_to_bind: # Serialize MLX parameter binding with self._mlx_lock: @@ -1349,6 +1348,7 @@ def _warmup_shard(self): mx.eval(_s) except Exception: pass + try: for lid in window_layers: self.weight_cache.decrease_reference(lid) @@ -1575,7 +1575,8 @@ async def load_model_endpoint( f"total_layers={req.total_layers}, kv_bits={req.kv_bits or 'default'}, " f"api_callback={req.api_callback_address or 'none'}" ) - result = await self.load_model(req) + with self.tracer.frame("memory", "model.load"): # NOTE: Symbol hardcoded for runtime stats + result = await self.load_model(req) return result except Exception as e: @@ -1592,7 +1593,8 @@ async def unload_model_endpoint() -> ShardUnloadModelResponse: """Unload current model.""" try: logger.info("HTTP /unload_model") - result = await self.unload_model() + with self.tracer.frame("memory", "model.unload"): # NOTE: Symbol hardcoded for runtime stats + result = await self.unload_model() return result except Exception as e: @@ -1606,29 +1608,30 @@ async def unload_model_endpoint() -> ShardUnloadModelResponse: # FIXME: add pydantic type here async def warm(request: Request) -> JSONResponse: try: - body = await request.json() - start = int(body.get("start", -1)) - window = int(body.get("window", self.window_size)) - if start < 0: - return JSONResponse( - status_code=400, content={"error": "missing/invalid start"} - ) - start_idx = 0 - for i, lyr in enumerate(self._assigned_sorted): - if lyr >= start: - start_idx = i - break - else: - return JSONResponse(content={"prefetched": []}) - window_layers = self._assigned_sorted[ - start_idx : start_idx + max(1, window) - ] - for wl in window_layers: - # Prefetch disabled in fit mode; allow only when non-fit and enabled - if self._mode != "fit" and self.config.prefetch_mode != "off": - self._prefetch_to_ram(wl) - self._enqueue_weight_prefetch(wl) - return JSONResponse(content={"prefetched": window_layers}) + with self.tracer.frame("memory", "model.warm"): # NOTE: Symbol hardcoded for runtime stats + body = await request.json() + start = int(body.get("start", -1)) + window = int(body.get("window", self.window_size)) + if start < 0: + return JSONResponse( + status_code=400, content={"error": "missing/invalid start"} + ) + start_idx = 0 + for i, lyr in enumerate(self._assigned_sorted): + if lyr >= start: + start_idx = i + break + else: + return JSONResponse(content={"prefetched": []}) + window_layers = self._assigned_sorted[ + start_idx : start_idx + max(1, window) + ] + for wl in window_layers: + # Prefetch disabled in fit mode; allow only when non-fit and enabled + if self._mode != "fit" and self.config.prefetch_mode != "off": + self._prefetch_to_ram(wl) + self._enqueue_weight_prefetch(wl) + return JSONResponse(content={"prefetched": window_layers}) except Exception as e: logger.error("/warm failed: %s", e) return JSONResponse(status_code=500, content={"error": str(e)}) @@ -1670,9 +1673,10 @@ async def _profile_device(self, repo_id: str, max_batch_exp: int) -> dict: Returns: Device profile information as a plain dict """ - profile_dict = profile_device_via_subprocess( - repo_id, max_batch_exp=max_batch_exp, debug=0 - ) + with self.tracer.frame("startup", "profile.device"): # NOTE: Symbol hardcoded for runtime stats + profile_dict = profile_device_via_subprocess( + repo_id, max_batch_exp=max_batch_exp, debug=0 + ) logger.info("Device profiling completed for node %s", self.node_id) return profile_dict From 12bd2271ae2977a418a52b3f872d25b27a19e03f Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 22 Oct 2025 22:52:54 -0700 Subject: [PATCH 031/172] runtime stats aggregator --- src/dnet/perf/utils/aggregator.py | 64 +++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index 010e6922..34ba7fa8 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -7,6 +7,7 @@ from collections import defaultdict, deque from dnet.utils.logger import logger +from dnet.ring.common import LayerAssignment, TopologyInformation Key = Tuple[str, Optional[int], Optional[int], str] # (node_id, pid, tid, req_id) @@ -205,3 +206,66 @@ def q(p: float) -> float: def roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: # Call-tree storage is disabled to reduce memory; keep API for compatibility. return [] + + +# Runtime statistics +# Use a RunAggregator to get raw frames per request, then transform into _RuntimeStats + +# Track a single request, use multiple for a full benchmark +@dataclass +class _RuntimeStats: + model: str # Model name + tokenizer: str # Tokenizer name + run_id: str # ID of request serviced (for later mapping) + ttft: Dict[str, float] # Time to first token, map: p50 : 0.0 (ms) + itl: Dict[str, float] # Inter-token latency, mapL p50 : 0.0 (ms) + requests: int # Number of requests serviced + failed: int # Number of failed requests + prompt_tokens: int # Number of prompt tokens per request (req_id: #) + generated_tokens: int # Number of generated tokens per request (req_id: #) + + latencys: Dict[List[str, str, str], int] # Map of latencys: [node0, node1, p50]: 0.0 + latency_per_layer: Dict[int, float] # Map of {layer: 0.0} + latency_per_shard: Dict[str, float] # Map of {shard: 0.0} + total_latency: int # Total runtime of requests + throughput: float # aaa + + topo: TopologyInfo = None # Topology information for this request (keep here since it might change) + assignment: LayerAssignment = None # Map of layer to shard IDs + startup_t: float # Time to start shard (ms) + layer_assignment_t: float # Time to layer assignment (ms) + + +# NOTE: Hardcodes some high-level trace frame symbols +def to_runstats(agg: RunAggregator): + pass + +# Process stats + handle per-request data +class StatsAggregator: + def __init__(self) -> None: + self._req: Dict[str, _RuntimeStats] = {} # Map req_id : RuntimeStats obj + self._lock = threading.Lock() + + # Ingest raw data from tracer + def add(self, run: _RuntimeStats) -> bool: + run_id = run.get("run_id") + + # Return data for total, per req, worker or model (maybe add per layer too?) + def stats( + self, + req_id: Optional[str], + worker: Optional[str], + model: Optional[str] + ): + + if req_id: + pass + + elif worker: + pass + + elif model: + pass + + else: # Return stats of all counters + From 69d3fe82f0f9567b6361a193622ae7d9d0279f1c Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 00:24:40 -0700 Subject: [PATCH 032/172] track per-nonce in-flight and in-wait times and append to ingress trace frame --- src/dnet/ring/shard/node.py | 38 +++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index d604087d..9a5e142d 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -217,6 +217,10 @@ def __init__( self.tracer = Tracer(cfg) self.tracer.start() + # Get in-flight and in-wait time per request + self._rx_ingress_t: Dict[str, float] = {} # Mapping of nonce -> perf_counter() + self._rx_inflight_t: Dict[str, float] = {} # Track per-request inflight + # Per-nonce KV caches (concurrent requests) self._kv_by_nonce: Dict[str, list] = {} self._kv_last_seen: Dict[str, float] = {} @@ -851,6 +855,10 @@ async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: """enqueue protobuf frame to ingress queue""" while self.running: try: + rx_t = time.perf_counter() + self._rx_ingress_t[request.nonce] = rx_t + self._rx_inflight_t[request.nonce] = rx_t - request.timestamp + self.ingress_q.put_nowait(request) logger.debug(f"[ENQUE] Enqueued activation request") return @@ -867,17 +875,22 @@ async def _ingress_worker(self): finally enqueues for compute or forwards to the next shard. """ while self.running: - with self.tracer.frame("grpc", "ingress") as f: - with self.tracer.frame("grpc.ingress", "get.wait"): - try: - req = await self.ingress_q.get() - logger.debug(f"[DEQUE]Dequeued activation for processing") - except asyncio.CancelledError: - break + with self.tracer.frame("network.ingress", "wait"): # NOTE: bad counter + try: + req = await self.ingress_q.get() + logger.debug(f"[DEQUE]Dequeued activation for processing") + except asyncio.CancelledError: + break + + # Trace processing of request, in-flight and in-wait times + with self.tracer.frame("network.ingress", "process") as f: + f.set("inflight", self._rx_inflight_t[req.nonce]) + f.set("inwait", time.perf_counter() - self._rx_ingress_t[req.nonce]) + f.set("nonce", req.nonce) try: - with self.tracer.frame("grpc.ingress", "connect_next_node"): - await self._connect_next_node() + #with self.tracer.frame("grpc.ingress", "connect_next_node"): + await self._connect_next_node() activation = req.activation target_layer = activation.layer_id + 1 @@ -1024,7 +1037,7 @@ def _prepare_activation_message_blocking( activation = request.activation if "|" in activation.dtype: # Compressed path: decompress to MLX array and copy to pool - with self.tracer.frame("grpc.ingress.prepare_activation", "decompress") as f: + with self.tracer.frame("network.ingress.prepare_activation", "decompress") as f: try: deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, @@ -1060,7 +1073,7 @@ def _prepare_activation_message_blocking( return activation_msg elif activation.dtype == "tokens": # Tokens path: parse int32 token IDs and stage them - with self.tracer.frame("grpc.ingress.prepare_activation", "tokens") as f: + with self.tracer.frame("network.ingress.prepare_activation", "tokens") as f: try: tokens = np.frombuffer(activation.data, dtype=np.int32) shp = (int(len(tokens)),) @@ -1089,7 +1102,7 @@ def _prepare_activation_message_blocking( return activation_msg else: # Dense path: validate size and copy raw bytes view into pool buffer - with self.tracer.frame("grpc.ingress.prepare_activation", "default") as f: + with self.tracer.frame("network.ingress.prepare_activation", "default") as f: try: expected = ( int(np.prod(activation.shape)) @@ -1575,6 +1588,7 @@ async def load_model_endpoint( f"total_layers={req.total_layers}, kv_bits={req.kv_bits or 'default'}, " f"api_callback={req.api_callback_address or 'none'}" ) + self.tracer.mark("model", {"model": req.model_path, "ts": time.perf_counter()}) # Record model name with self.tracer.frame("memory", "model.load"): # NOTE: Symbol hardcoded for runtime stats result = await self.load_model(req) return result From f4d957c605276f0db0338296a88799a24f1b1958 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 00:30:44 -0700 Subject: [PATCH 033/172] stop tracking bytes and target, change 'grpc' to 'network' for cleaner frame tagging --- src/dnet/ring/shard/node.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 9a5e142d..f2933e21 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -850,7 +850,6 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): logger.exception("Error receiving activation: %s", e) - async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: """enqueue protobuf frame to ingress queue""" while self.running: @@ -883,13 +882,13 @@ async def _ingress_worker(self): break # Trace processing of request, in-flight and in-wait times - with self.tracer.frame("network.ingress", "process") as f: + with self.tracer.frame("network", "ingress") as f: f.set("inflight", self._rx_inflight_t[req.nonce]) f.set("inwait", time.perf_counter() - self._rx_ingress_t[req.nonce]) f.set("nonce", req.nonce) try: - #with self.tracer.frame("grpc.ingress", "connect_next_node"): + #with self.tracer.frame("network.ingress", "connect_next_node"): await self._connect_next_node() activation = req.activation @@ -945,7 +944,7 @@ async def _ingress_worker(self): activation_msg.recv_perf_t = t_recv # Enqueue for compute (cancellable back-off) - with self.tracer.frame("grpc.ingress", "queue") as fr: + with self.tracer.frame("network.ingress", "enque") as fr: while self.running: try: self.activation_recv_queue.put_nowait(activation_msg) @@ -957,18 +956,15 @@ async def _ingress_worker(self): except Full: await asyncio.sleep(0) else: - logger.error( - "Failed to queue activation %s (node stopping)", - activation_msg.nonce, - ) + logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce,) try: if self.input_pool: # FIXME: !!! self.input_pool.release(activation_msg.pool_id) except Exception: pass - else: - # Forward to next node (not our layer) + + else: # Forward to next node (not our layer) logger.debug( "Forwarding activation (layer %s) to next node, nonce: %s", target_layer, From 2ec80fc3357cc518f216ca4d13b8d53b7b8fc86d Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 00:31:22 -0700 Subject: [PATCH 034/172] remove profiling logs --- src/dnet/ring/shard/compute.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 794fc682..fce05d21 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -65,16 +65,13 @@ def _delta_swap_eviction( return len(evicted) def _process_activation(self, activation_msg: ActivationMessage): - if ( - not self.model + if ( not self.model or not self.model_metadata or not self.weight_cache or not self.input_pool or not self.output_pool ): - logger.error( - "Node %s: Cannot process activation - model not loaded", self.node_id - ) + logger.error("Node %s: Cannot process activation - model not loaded", self.node_id) return try: @@ -172,7 +169,6 @@ def _process_activation(self, activation_msg: ActivationMessage): ) to_bind: Dict[str, mx.array] = {} if not skip_scan: - t_w_ready = time.perf_counter() for wl in window_layers: weights = self.weight_cache.get_weight(wl) if weights is None: @@ -267,7 +263,8 @@ def _process_activation(self, activation_msg: ActivationMessage): """ for lid in window_layers: - self.weight_cache.decrease_reference(lid) + #self.weight_cache.decrease_reference(lid) + pass with self.tracer.frame("compute.thread", "execute.evict_and_unload"): try: From 4b903819b3a7a15e9f661ea1ba15c93ce947348e Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 01:06:22 -0700 Subject: [PATCH 035/172] aggregate per-nonce --- src/dnet/perf/utils/aggregator.py | 87 ++++++++++++++++++++++++------- 1 file changed, 68 insertions(+), 19 deletions(-) diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index 34ba7fa8..b9cd78c5 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -7,7 +7,8 @@ from collections import defaultdict, deque from dnet.utils.logger import logger -from dnet.ring.common import LayerAssignment, TopologyInformation +from dnet.ring import LayerAssignment, TopologyInfo +from dnet.perf import _Frame Key = Tuple[str, Optional[int], Optional[int], str] # (node_id, pid, tid, req_id) @@ -119,22 +120,13 @@ def __init__(self) -> None: def enqueue(self, batch: Dict[str, Any]) -> None: run_id = batch.get("run_id") node_id = batch.get("node_id") - logger.debug(f"Enquing trace buffer from {run_id}, {node_id}") - if not run_id or not node_id: - return events = batch.get("events") or [] - batch_seq = int(batch.get("batch_seq") or 0) + logger.debug(f"Enquing trace buffer from {run_id}, {node_id}") + if not run_id or not node_id: return # Drop batch with self._lock: agg = self._req.setdefault(run_id, RunAggregator()) - last = agg.last_batch_seq.get(node_id) - if (last is not None) and (batch_seq != last + 1): - agg.drops += abs(batch_seq - (last + 1)) - agg.last_batch_seq[node_id] = batch_seq for ev in events: - try: - agg.ingest_event(node_id, ev) - except Exception: - continue + agg.ingest_event(node_id, ev) def annotate(self, run_id: str, *, mapping: Optional[Dict[str, str]] = None, repeats: int = 0) -> List[Dict[str, Any]]: with self._lock: @@ -216,7 +208,8 @@ def roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: class _RuntimeStats: model: str # Model name tokenizer: str # Tokenizer name - run_id: str # ID of request serviced (for later mapping) + run_id: str # ID of session (for later mapping) + nonce: List[str] # List of serviced requests ttft: Dict[str, float] # Time to first token, map: p50 : 0.0 (ms) itl: Dict[str, float] # Inter-token latency, mapL p50 : 0.0 (ms) requests: int # Number of requests serviced @@ -229,11 +222,11 @@ class _RuntimeStats: latency_per_shard: Dict[str, float] # Map of {shard: 0.0} total_latency: int # Total runtime of requests throughput: float # aaa + startup_t: float # Time to start shard (ms) + layer_assignment_t: float # Time to layer assignment (ms) topo: TopologyInfo = None # Topology information for this request (keep here since it might change) assignment: LayerAssignment = None # Map of layer to shard IDs - startup_t: float # Time to start shard (ms) - layer_assignment_t: float # Time to layer assignment (ms) # NOTE: Hardcodes some high-level trace frame symbols @@ -243,12 +236,67 @@ def to_runstats(agg: RunAggregator): # Process stats + handle per-request data class StatsAggregator: def __init__(self) -> None: - self._req: Dict[str, _RuntimeStats] = {} # Map req_id : RuntimeStats obj self._lock = threading.Lock() + self._max_resident_rq = 50 # per node FIXME: modify from repl + self._workers: Dict[str, Dict[str, Dict[str, _Frame]]] = {} # Store frames per nonce, per node_id + + self._nonces = [] # Tracked nonces (either in-flight or done) + self._nonce_round_finish: Dict[str, bool] = {} # Track in-flight rounds + self._nonce_prefill: Dict[str, bool] = {} # Track if this round is prefill + self._running_stats: Dict[str, _RuntimeStats] = {} # Unfinished stat frames + self._stats: Dict[str, _RuntimeStats] = {} # Finished stat frames + # Ingest raw data from tracer - def add(self, run: _RuntimeStats) -> bool: - run_id = run.get("run_id") + def add(self, data: Dict[str, Any]) -> None: + run_id = data.run_id + node_id = data.node_id + events = data.events or [] + name = data.name + if not run_id or not node_id: return # Drop the batch + + with self._lock: + for i, ev in enumerate(events): + nonce = ev.attrs["nonce"] or f"ERR_{i}" + + if node_id not in self._workers: + self._workers[node_id] = {} + + if nonce not in self._workers[node_id]: + self._workers[node_id][nonce] = {} + + if name not in self._workers[node_id][nonce]: + self._workers[node_id][nonce][name] = [ev, ] + continue + + if len(self._workers[node_id]) >= self._max_resident_req: # remove oldest entry + del self._workers[self._nonces[0]] + del self._nonces[0] + + self._workers[node_id][name].append(ev) + self._nonces.push(nonce) + + # Construct RuntimeStats + assert "model" in self._frames, "No model found in trace data." + + rt_stat = self._req.setdefault(run_id, _RuntimeStats) + #rt_stat.model = self._workers[0]["model"][-1].attrs["model"] + rt_stat.tokenizer = + rt_stat.run_id = run_id + rt_stat.ttft = {} + + for n in self._nonces: # accumulate new data for each nonce + for shard in self._workers: + + if "final" in self._workers[node_id][nonce] and not self._nonce_round_finish[nonce]: + self._nonce_round_finish[nonce] = True + if not self._nonce_prefill[nonce]: # This is prefill, append ttft + + + acc_ttt = 0 # accumulated time to token + acc_ttt += shard["network.ingress"][-1] + inflight = shard['network.ingress'][] + # Return data for total, per req, worker or model (maybe add per layer too?) def stats( @@ -268,4 +316,5 @@ def stats( pass else: # Return stats of all counters + pass From 7b3efd9b02bce80e5f02290cdc3f1129c812dacc Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 02:08:31 -0700 Subject: [PATCH 036/172] construct new request on embedding event --- src/dnet/perf/utils/aggregator.py | 81 +++++++++++++++++++------------ 1 file changed, 49 insertions(+), 32 deletions(-) diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index b9cd78c5..de1b508f 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -8,7 +8,6 @@ from dnet.utils.logger import logger from dnet.ring import LayerAssignment, TopologyInfo -from dnet.perf import _Frame Key = Tuple[str, Optional[int], Optional[int], str] # (node_id, pid, tid, req_id) @@ -210,8 +209,8 @@ class _RuntimeStats: tokenizer: str # Tokenizer name run_id: str # ID of session (for later mapping) nonce: List[str] # List of serviced requests - ttft: Dict[str, float] # Time to first token, map: p50 : 0.0 (ms) - itl: Dict[str, float] # Inter-token latency, mapL p50 : 0.0 (ms) + ttft: float # Time to first token + itl: float # Inter-token latency requests: int # Number of requests serviced failed: int # Number of failed requests prompt_tokens: int # Number of prompt tokens per request (req_id: #) @@ -229,23 +228,24 @@ class _RuntimeStats: assignment: LayerAssignment = None # Map of layer to shard IDs -# NOTE: Hardcodes some high-level trace frame symbols -def to_runstats(agg: RunAggregator): - pass - # Process stats + handle per-request data +# NOTE: Hardcodes some high-level trace frame symbols +# TODO: Use a bitmap to track the stages for each req and prevent double-count class StatsAggregator: def __init__(self) -> None: self._lock = threading.Lock() - self._max_resident_rq = 50 # per node FIXME: modify from repl - self._workers: Dict[str, Dict[str, Dict[str, _Frame]]] = {} # Store frames per nonce, per node_id + self._max_inflight_rq = 20 # per node FIXME: modify from repl + + # Frames are kept in here while in-flight, then remove the frame objects and append to _stats + self._workers: Dict[str, Dict[str, Dict[str, Any]]] = {} # Store frames per nonce, per node_id self._nonces = [] # Tracked nonces (either in-flight or done) self._nonce_round_finish: Dict[str, bool] = {} # Track in-flight rounds self._nonce_prefill: Dict[str, bool] = {} # Track if this round is prefill self._running_stats: Dict[str, _RuntimeStats] = {} # Unfinished stat frames self._stats: Dict[str, _RuntimeStats] = {} # Finished stat frames + self._open_frames: Dict[str, Dict[str, Any]] # We got 'B' event but not 'E' (per nonce) # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: @@ -256,8 +256,13 @@ def add(self, data: Dict[str, Any]) -> None: if not run_id or not node_id: return # Drop the batch with self._lock: + + # Ensure we register workers and nodes for i, ev in enumerate(events): - nonce = ev.attrs["nonce"] or f"ERR_{i}" + if "nonce" not in ev.attrs: ev.attrs["nonce"] = f"N_{i}" + nonce = ev.attrs["nonce"] + + new_frames.append(ev) if node_id not in self._workers: self._workers[node_id] = {} @@ -265,33 +270,45 @@ def add(self, data: Dict[str, Any]) -> None: if nonce not in self._workers[node_id]: self._workers[node_id][nonce] = {} - if name not in self._workers[node_id][nonce]: - self._workers[node_id][nonce][name] = [ev, ] - continue - if len(self._workers[node_id]) >= self._max_resident_req: # remove oldest entry - del self._workers[self._nonces[0]] - del self._nonces[0] + del self._workers[self._nonces[0]] + del self._nonces[0] - self._workers[node_id][name].append(ev) self._nonces.push(nonce) - # Construct RuntimeStats - assert "model" in self._frames, "No model found in trace data." - - rt_stat = self._req.setdefault(run_id, _RuntimeStats) - #rt_stat.model = self._workers[0]["model"][-1].attrs["model"] - rt_stat.tokenizer = - rt_stat.run_id = run_id - rt_stat.ttft = {} - - for n in self._nonces: # accumulate new data for each nonce - for shard in self._workers: - + # Update in-flight events or register new ones + for e in new_events: + nonce = e.attrs["nonce"] + assert nonce is not None, "" + + if not node_id and nonce: return # Drop invalid frames + stats = self._running_stats[nonce] + + # Register new request + if e.name == "compute.embedding": + #assert "model" in self._frames, "No model found in trace data." + rt_stat = self._running_stats.setdefault(run_id, _RuntimeStats) + #rt_stat.model = self._workers[0]["model"][-1].attrs["model"] + #rt_stat.tokenizer = + rt_stat.run_id = run_id + rt_stat.nonce = nonce + rt_stat.ttft = {} + + if e.name == "network.ingress": + if e.type == "B": self._open_frames[nonce][e.name] = e + n_rt = e.attrs["inflight"] + e.attrs["inwait"] + n_rt += self._open_frames[nonce][e.name].t0 + if self._nonce_prefill[nonce]: + stats.ttft += n_rt + continue + stats.itl += n_rt + + if f.name == "compute.forward": + + # Request is finished, construct _RuntimeStats and remove from memory if "final" in self._workers[node_id][nonce] and not self._nonce_round_finish[nonce]: - self._nonce_round_finish[nonce] = True - if not self._nonce_prefill[nonce]: # This is prefill, append ttft - + self._nonce_round_finish[nonce] = True + if not self._nonce_prefill[nonce]: # This is prefill, append ttft acc_ttt = 0 # accumulated time to token acc_ttt += shard["network.ingress"][-1] From df9b2e6781363dc51abb41127f39ea97a36a5a49 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 11:50:58 -0700 Subject: [PATCH 037/172] handle frame with custom cost function --- src/dnet/perf/utils/aggregator.py | 67 ++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 23 deletions(-) diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index de1b508f..d1419cee 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -210,13 +210,11 @@ class _RuntimeStats: run_id: str # ID of session (for later mapping) nonce: List[str] # List of serviced requests ttft: float # Time to first token - itl: float # Inter-token latency - requests: int # Number of requests serviced - failed: int # Number of failed requests + itl: List[float] # Inter-token latency per round prompt_tokens: int # Number of prompt tokens per request (req_id: #) generated_tokens: int # Number of generated tokens per request (req_id: #) - latencys: Dict[List[str, str, str], int] # Map of latencys: [node0, node1, p50]: 0.0 + latencies: Dict[List[str, str, str], int] # Map of inter-node latencies: [node0, node1, p50]: 0.0 latency_per_layer: Dict[int, float] # Map of {layer: 0.0} latency_per_shard: Dict[str, float] # Map of {shard: 0.0} total_latency: int # Total runtime of requests @@ -246,6 +244,7 @@ def __init__(self) -> None: self._running_stats: Dict[str, _RuntimeStats] = {} # Unfinished stat frames self._stats: Dict[str, _RuntimeStats] = {} # Finished stat frames self._open_frames: Dict[str, Dict[str, Any]] # We got 'B' event but not 'E' (per nonce) + self._model_per_run: Dict[str, str] = {} # Track model per run_id # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: @@ -277,33 +276,43 @@ def add(self, data: Dict[str, Any]) -> None: self._nonces.push(nonce) # Update in-flight events or register new ones - for e in new_events: + for e in events: nonce = e.attrs["nonce"] assert nonce is not None, "" if not node_id and nonce: return # Drop invalid frames - stats = self._running_stats[nonce] - # Register new request - if e.name == "compute.embedding": - #assert "model" in self._frames, "No model found in trace data." - rt_stat = self._running_stats.setdefault(run_id, _RuntimeStats) - #rt_stat.model = self._workers[0]["model"][-1].attrs["model"] - #rt_stat.tokenizer = - rt_stat.run_id = run_id - rt_stat.nonce = nonce - rt_stat.ttft = {} + if e.name == "embedding": # Register new request + rt_stat = self._running_stats.setdefault(run_id, _RuntimeStats( + model="", + tokenizer="", + run_id=run_id, + nonce=nonce, + ttft=0.0, + itl=[0.0], + generated_tokens=0, + prompt_tokens=e.attrs["prompt_tokens"], + latencies={}, + latency_per_layer={}, + latency_per_shard={}, + total_latency=0.0, + assignment=None, + topo=None, + )) + + # FIXME: We might receive other frames then "embed" from shards + # so we need to handle the creation of this better + stats = self._running_stats[nonce] if e.name == "network.ingress": - if e.type == "B": self._open_frames[nonce][e.name] = e - n_rt = e.attrs["inflight"] + e.attrs["inwait"] - n_rt += self._open_frames[nonce][e.name].t0 - if self._nonce_prefill[nonce]: - stats.ttft += n_rt - continue - stats.itl += n_rt + _cost: lambda e: e.attrs["inflight"] + e.attrs["inwait"] + e.attrs["ms"] + self._handle_frame(e, stats, _cost) - if f.name == "compute.forward": + if e.name == "compute.forward": + _cost = lambda e: e.attrs["ms"] + self._handle_frame(e, stats, _cost) + + if e.name == "" # Request is finished, construct _RuntimeStats and remove from memory if "final" in self._workers[node_id][nonce] and not self._nonce_round_finish[nonce]: @@ -314,6 +323,18 @@ def add(self, data: Dict[str, Any]) -> None: acc_ttt += shard["network.ingress"][-1] inflight = shard['network.ingress'][] + # Handle cost aggregation of frames + def _handle_frame(e: Any, stats: _RuntimeStats, _cost_fnc: Any): + if e.type == 'B': + self._open_frames[nonce][e.name] = e + return + elif e.type == 'E': + n_rt = _cost_fnc(e) # Custom cost function for each farme + if self._nonce_prefill[nonce]: + stats.ttft += n_rt + else: + stats.itl[-1] += n_rt + del self._open_frames[nonce][e.name] # Return data for total, per req, worker or model (maybe add per layer too?) def stats( From 3d5b4f580232f92b34406ea8dd6e65f8d1fb166b Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 14:08:10 -0700 Subject: [PATCH 038/172] update canonical traces for stats, rename ingress to rx and egress to tx --- src/dnet/ring/shard/comms.py | 52 ++++------ src/dnet/ring/shard/compute.py | 14 ++- src/dnet/ring/shard/node.py | 168 +++++++++++---------------------- 3 files changed, 89 insertions(+), 145 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index bcc826c3..b36f8250 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -172,18 +172,21 @@ async def _send_worker(self): ): try: activation_msg = await self.activation_computed_queue.get() - if activation_msg.tx_enq_perf_t and self._profile: - q_wait_ms = ( - time.perf_counter() - activation_msg.tx_enq_perf_t - ) * 1000.0 - logger.info( - "[PROFILE][QUEUE-TX] node=%s nonce=%s wait_ms=%.3f size=%s", - self.node_id, - activation_msg.nonce, - q_wait_ms, - self.activation_computed_queue.qsize(), - ) - await self._send_activation(activation_msg) + with self.tracer.frame("network", "tx") as f: + if activation_msg.tx_enq_perf_t and self._profile: + f.set("inwait", time.perf_counter() - self._rx_enque_t) + f.set("nonce", activation_msg.nonce) + q_wait_ms = ( + time.perf_counter() - activation_msg.tx_enq_perf_t + ) * 1000.0 + logger.info( + "[PROFILE][QUEUE-TX] node=%s nonce=%s wait_ms=%.3f size=%s", + self.node_id, + activation_msg.nonce, + q_wait_ms, + self.activation_computed_queue.qsize(), + ) + await self._send_activation(activation_msg) except asyncio.CancelledError: break except Exception as e: @@ -247,28 +250,22 @@ async def _send_activation(self, activation_msg: ActivationMessage): pass cb = activation_msg.callback_url or "" parsed = urlparse(cb) if cb else None - t_rpc = time.perf_counter() if parsed and parsed.scheme == "grpc": addr = parsed.netloc if not addr: logger.error("Invalid gRPC callback URL for token: %s", cb) return - # Ensure API channel/stub - if (self.api_channel is None) or (addr != self.api_address): - # close existing channel if any - try: + + if (self.api_channel is None) or (addr != self.api_address): # Ensure API channel/stub + try: # close existing channel if any if self.api_channel is not None: await self.api_channel.close() except Exception: pass self.api_address = addr - self.api_channel = aio_grpc.insecure_channel( - addr, options=GRPC_AIO_OPTIONS - ) - self.api_stub = shard_api_comm_pb2_grpc.ShardApiServiceStub( - self.api_channel - ) + self.api_channel = aio_grpc.insecure_channel( addr, options=GRPC_AIO_OPTIONS) + self.api_stub = shard_api_comm_pb2_grpc.ShardApiServiceStub( self.api_channel) f.event("reset_api") with self.tracer.frame("grpc", "token_request") as fr: @@ -279,21 +276,12 @@ async def _send_activation(self, activation_msg: ActivationMessage): timestamp=utc_epoch_now(), ) resp = await self.api_stub.SendToken(req) # type: ignore - rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 if not resp.success: logger.error( "API SendToken failed for %s: %s", activation_msg.nonce, resp.message, ) - if self._profile: - logger.info( - "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", - self.node_id, - activation_msg.nonce, - int(getattr(activation_msg, "token_id", -1)), - rpc_ms, - ) except Exception as e: logger.exception("Error sending token via gRPC: %s", e) else: diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index fce05d21..01825bac 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -65,7 +65,7 @@ def _delta_swap_eviction( return len(evicted) def _process_activation(self, activation_msg: ActivationMessage): - if ( not self.model + if (not self.model or not self.model_metadata or not self.weight_cache or not self.input_pool @@ -88,13 +88,21 @@ def _process_activation(self, activation_msg: ActivationMessage): # Prepare input activation with self.tracer.frame("compute.thread", "activations.process") as f: + f.set("nonce", activation_msg.nonce) if activation_msg.dtype == "tokens": # embed locally on start shard logger.debug(f"Embedding tokens.") - f.event("embed_tokens") numel = int(np.prod(activation_msg.shape)) tok_view = input_buffer[:numel].reshape(activation_msg.shape) toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) + x = self.model.embed(toks[None]) + + # NOTE: Used to track start of request in perf stats + self.tracer.mark("embedding", { + "nonce": actication_msg.nonce, + "prompt_tokens": toks.size, + }) + if x.dtype != self._wire_mx_dtype: x = x.astype(self._wire_mx_dtype) @@ -382,6 +390,8 @@ def _process_activation(self, activation_msg: ActivationMessage): with self._mlx_lock: y = self.model.normalize(x_cast) y = self.model.lm_project(y) + self.tracer.mark("lm_head", {"nonce": actication_msg.nonce}) # NOTE: canonical stats end + # Greedy sample last position if y.ndim == 3: logits_2d = y[:, -1, :] diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index f2933e21..fbfa52e0 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -107,7 +107,7 @@ def __init__( self._assigned_sorted = sorted(self.assigned_layers or []) self._assigned_set = set(self._assigned_sorted) - # Topology (configured later) + # Topology self.next_node: Optional[DnetDeviceProperties] = None self.total_layers: int = 0 # Total layers in model self.api_callback_address: Optional[str] = None @@ -218,8 +218,10 @@ def __init__( self.tracer.start() # Get in-flight and in-wait time per request - self._rx_ingress_t: Dict[str, float] = {} # Mapping of nonce -> perf_counter() - self._rx_inflight_t: Dict[str, float] = {} # Track per-request inflight + self._rx_ingress_t: Dict[str, float] = {} # Timestamp we enqued the request + self._rx_inflight_t: Dict[str, float] = {} # Per-request inflight time + self._ex_enque_t: Dict[str, float] = {} # req is queued for execution + self._tx_enque_t: Dict[str, float] = {} # req is queued for sendoff # Per-nonce KV caches (concurrent requests) self._kv_by_nonce: Dict[str, list] = {} @@ -597,13 +599,11 @@ async def unload_model(self) -> ShardUnloadModelResponse: async def reset_cache(self) -> None: """Reset LLM KV cache.""" if not self.model: - logger.warning( - "Node %s: Cannot reset cache - no model loaded", self.node_id - ) + logger.warning("Node %s: Cannot reset cache - no model loaded", self.node_id) return - try: - with self.tracer.frame("memory", "cache.reset"): + with self.tracer.frame("memory", "cache.reset"): + try: self.cache = make_cache( self.model, # type: ignore[arg-type] kv_mode=self.config.kv_cache.mode, @@ -611,26 +611,19 @@ async def reset_cache(self) -> None: kv_group=self.config.kv_cache.group_size, ) logger.info("Node %s: Cache reset successfully", self.node_id) - except Exception as e: - logger.error("Node %s: Error resetting cache: %s", self.node_id, e) + # FIXME This seems to still be dead code async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): """Receive activation from previous node and queue for local compute or forward.""" if self.input_pool is None: - logger.error( - "Node %s: Cannot receive activation - input pool not initialized", - self.node_id, - ) + logger.error("Node %s: Cannot receive activation - input pool not initialized", self.node_id) return - t_recv = time.perf_counter() - await self._connect_next_node() - - with self.tracer.frame("grpc.receive", "connect_next_node"): + with self.tracer.frame("network.rx", "connect_next_node"): await self._connect_next_node() - with self.tracer.frame("grpc.receive", "process_activation") as f: + with self.tracer.frame("network.rx", "process_activation") as f: try: activation = request.activation target_layer = activation.layer_id + 1 @@ -698,87 +691,62 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, shape=list(activation.shape), - dtype_with_metadata=activation.dtype, - ) + dtype_with_metadata=activation.dtype) except Exception as e: - logger.error( - "Decompression failed for nonce %s: %s", request.nonce, e - ) + logger.error("Decompression failed for nonce %s: %s", request.nonce, e) return - with self.tracer.frame("grpc.receive", "alloc.buffer") as fr: + with self.tracer.frame("network.rx", "alloc.buffer") as fr: pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, dtype=deq.dtype, - shape=cast(tuple[int, ...], tuple(deq.shape)), - ) + shape=cast(tuple[int, ...], tuple(deq.shape))) + if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) + logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) return + buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: flat = deq.reshape(-1) buffer[: flat.size] = flat - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", - self.node_id, - request.nonce, - alloc_copy_ms, - ) # Update activation message with true dtype/shape new_dtype_str = str(deq.dtype) activation_msg = ActivationMessage.from_proto(request, pool_id) activation_msg.dtype = new_dtype_str activation_msg.shape = tuple(deq.shape) - else: - # Special token stream support: dtype='tokens' carries int32 token IDs + + else: # Special token stream support: dtype='tokens' carries int32 token IDs if activation.dtype == "tokens": - with self.tracer.frame("grpc.receive", "token_stream") as fr: + with self.tracer.frame("network.rx", "token_stream") as fr: try: - tokens = np.frombuffer( - request.activation.data, dtype=np.int32 - ) - shp = (int(len(tokens)),) + tokens = np.frombuffer(request.activation.data, dtype=np.int32) + shp = (int(len(tokens)), ) except Exception as e: - logger.error( - "Failed to parse tokens for nonce %s: %s", - request.nonce, - e, - ) + logger.error("Failed to parse tokens for nonce %s: %s", request.nonce, e,) return + pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, dtype=mx.int32, - shape=cast(tuple[int, ...], shp), - ) + shape=cast(tuple[int, ...], shp)) + if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) + logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) return + buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: buffer[: len(tokens)] = tokens - if self._profile: - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (tokens)", - self.node_id, - request.nonce, - alloc_copy_ms, - ) activation_msg = ActivationMessage.from_proto(request, pool_id) + # Ensure dtype reflects token payload for compute path activation_msg.dtype = "tokens" activation_msg.shape = shp + else: - with self.tracer.frame("grpc.receive", "default") as fr: + with self.tracer.frame("network.ex", "default") as fr: # Safety: byte length must match shape*dtype try: expected = ( @@ -792,58 +760,37 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, dtype=mlx_dtype_map[activation.dtype], - shape=cast(tuple[int, ...], activation.shape), - ) + shape=cast(tuple[int, ...], activation.shape)) + if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) + logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) return + buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: data = request.activation.data - input_data = np.frombuffer( - data, dtype=dtype_map[activation.dtype] - ) + input_data = np.frombuffer(data, dtype=dtype_map[activation.dtype]) buffer[: len(input_data)] = input_data - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f", - self.node_id, - request.nonce, - alloc_copy_ms, - ) + activation_msg = ActivationMessage.from_proto(request, pool_id) activation_msg.dtype = new_dtype_str activation_msg.shape = tuple(deq.shape) # Queue for processing — non-blocking back-off loop (cancellable) - if self._profile: - activation_msg.enq_perf_t = time.perf_counter() while self.running: try: self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", - activation_msg.nonce, - ) + self._ex_enque_t[activation_msg.nonce] = time.perf_counter() + logger.debug("Queued activation for processing: nonce %s", activation_msg.nonce) break except Full: await asyncio.sleep(0) else: - logger.error( - "Failed to queue activation %s (node stopping)", - activation_msg.nonce, - ) + logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce) self.input_pool.release(pool_id) - else: - # Forward to next node (not our layer) - logger.debug( - "Forwarding activation (layer %s) to next node, nonce: %s", - target_layer, - request.nonce, - ) + + else: # Forward to next node (not our layer) + logger.debug("Forwarding activation (layer %s) to next node, nonce: %s", target_layer, request.nonce) await self._forward_activation(request) except Exception as e: @@ -871,10 +818,10 @@ async def _ingress_worker(self): Admission (servicer) is lightweight; this worker performs per-frame processing, offloading alloc/copy/decompress to the threadpool, and - finally enqueues for compute or forwards to the next shard. - """ + finally enqueues for compute or forwards to the next shard. """ + while self.running: - with self.tracer.frame("network.ingress", "wait"): # NOTE: bad counter + with self.tracer.frame("network.rx", "wait"): # NOTE: bad counter try: req = await self.ingress_q.get() logger.debug(f"[DEQUE]Dequeued activation for processing") @@ -882,9 +829,9 @@ async def _ingress_worker(self): break # Trace processing of request, in-flight and in-wait times - with self.tracer.frame("network", "ingress") as f: - f.set("inflight", self._rx_inflight_t[req.nonce]) + with self.tracer.frame("network", "rx") as f: f.set("inwait", time.perf_counter() - self._rx_ingress_t[req.nonce]) + f.set("inflight", self._rx_inflight_t[req.nonce]) f.set("nonce", req.nonce) try: @@ -940,11 +887,9 @@ async def _ingress_worker(self): continue if activation_msg is None: continue - if self._profile: - activation_msg.recv_perf_t = t_recv - # Enqueue for compute (cancellable back-off) - with self.tracer.frame("network.ingress", "enque") as fr: + # Enqueue for compute + with self.tracer.frame("network.rx", "enque") as fr: while self.running: try: self.activation_recv_queue.put_nowait(activation_msg) @@ -1033,7 +978,7 @@ def _prepare_activation_message_blocking( activation = request.activation if "|" in activation.dtype: # Compressed path: decompress to MLX array and copy to pool - with self.tracer.frame("network.ingress.prepare_activation", "decompress") as f: + with self.tracer.frame("network.rx.prepare_activation", "decompress") as f: try: deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, @@ -1069,7 +1014,7 @@ def _prepare_activation_message_blocking( return activation_msg elif activation.dtype == "tokens": # Tokens path: parse int32 token IDs and stage them - with self.tracer.frame("network.ingress.prepare_activation", "tokens") as f: + with self.tracer.frame("network.rx.prepare_activation", "tokens") as f: try: tokens = np.frombuffer(activation.data, dtype=np.int32) shp = (int(len(tokens)),) @@ -1098,7 +1043,7 @@ def _prepare_activation_message_blocking( return activation_msg else: # Dense path: validate size and copy raw bytes view into pool buffer - with self.tracer.frame("network.ingress.prepare_activation", "default") as f: + with self.tracer.frame("network.rx.prepare_activation", "default") as f: try: expected = ( int(np.prod(activation.shape)) @@ -1168,7 +1113,8 @@ def _compute_worker(self) -> None: activation_msg = self.activation_recv_queue.get(timeout=1.0) # Process the activation - with self.tracer.frame("compute", "forward"): # NOTE: Symbol hardcoded for runtime stats + with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("inwait", time.perf_counter() - self._ex_enque_t) self._process_activation(activation_msg) except Empty: From 38edeae61f189d6851855c0b15f1f43381a9db2a Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 14:08:45 -0700 Subject: [PATCH 039/172] filter canonical frames --- src/dnet/perf/utils/aggregator.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index d1419cee..6807d47e 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -304,20 +304,27 @@ def add(self, data: Dict[str, Any]) -> None: # so we need to handle the creation of this better stats = self._running_stats[nonce] - if e.name == "network.ingress": - _cost: lambda e: e.attrs["inflight"] + e.attrs["inwait"] + e.attrs["ms"] + if e.name == "network.rx": + # Time in transport, ingress queue and ingress_worker + _cost = lambda e: e.attrs["inflight"] + e.attrs["inwait"] + e.attrs["ms"] self._handle_frame(e, stats, _cost) + #TODO: change shard in metadata if e.name == "compute.forward": - _cost = lambda e: e.attrs["ms"] + _cost = lambda e: e.attrs["inwait"] + e.attrs["ms"] # compute queue + execution self._handle_frame(e, stats, _cost) - if e.name == "" + if e.name == "network.tx.send": + _cost = lambda e: e.attrs["inwait"] + e.attrs["ms"] # tx queue + sendoff + self._handle_frame(e, stats, _cost) - # Request is finished, construct _RuntimeStats and remove from memory - if "final" in self._workers[node_id][nonce] and not self._nonce_round_finish[nonce]: + if e.name = "lm_head" and not self._nonce.round_finish[nonce]: # Finish request self._nonce_round_finish[nonce] = True - if not self._nonce_prefill[nonce]: # This is prefill, append ttft + + # TODO: Remove frame and stsats from working and append + st_obj = self._running_stats[nonce] + del self._running_stats[nonce] + self._stats.append(st_obj) acc_ttt = 0 # accumulated time to token acc_ttt += shard["network.ingress"][-1] From fdbf05bf30270a46a2a31e7a1fbafb6501b1af84 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 21:27:02 -0700 Subject: [PATCH 040/172] move to on-request vars to avoid race conditions --- src/dnet/protos/dnet_ring.proto | 6 ++++-- src/dnet/ring/data_types.py | 5 +++++ src/dnet/ring/shard/compute.py | 5 ++--- src/dnet/ring/shard/node.py | 20 ++++++++------------ 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/dnet/protos/dnet_ring.proto b/src/dnet/protos/dnet_ring.proto index 0b46c5be..5452a559 100644 --- a/src/dnet/protos/dnet_ring.proto +++ b/src/dnet/protos/dnet_ring.proto @@ -32,8 +32,10 @@ message ActivationRequest { string nonce = 1; Activation activation = 2; int64 timestamp = 3; - string node_origin = 4; - string callback_url = 5; + float rx_enq_t = 4; + float rx_inflight_t = 5; + string node_origin = 6; + string callback_url = 7; } // Response message for activation sending diff --git a/src/dnet/ring/data_types.py b/src/dnet/ring/data_types.py index 5da3e65a..db9438e5 100644 --- a/src/dnet/ring/data_types.py +++ b/src/dnet/ring/data_types.py @@ -25,7 +25,12 @@ class ActivationMessage: recv_perf_t: float = 0.0 enq_perf_t: float = 0.0 # TX queue enqueue time (perf_counter seconds) + rx_enq_perf_t: float = 0.0 tx_enq_perf_t: float = 0.0 + rx_ingress_t: float = 0.0 + rx_inflight_t: float = 0.0 + ex_enq_t: float = 0.0 + tx_enq_t: float = 0.0 # Final token path (end-shard sampling) is_final: bool = False token_id: int = -1 diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 01825bac..d51343a3 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -94,12 +94,11 @@ def _process_activation(self, activation_msg: ActivationMessage): numel = int(np.prod(activation_msg.shape)) tok_view = input_buffer[:numel].reshape(activation_msg.shape) toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) - x = self.model.embed(toks[None]) # NOTE: Used to track start of request in perf stats self.tracer.mark("embedding", { - "nonce": actication_msg.nonce, + "nonce": activation_msg.nonce, "prompt_tokens": toks.size, }) @@ -390,7 +389,7 @@ def _process_activation(self, activation_msg: ActivationMessage): with self._mlx_lock: y = self.model.normalize(x_cast) y = self.model.lm_project(y) - self.tracer.mark("lm_head", {"nonce": actication_msg.nonce}) # NOTE: canonical stats end + #self.tracer.mark("lm_head", {"nonce": actication_msg.nonce}) # NOTE: canonical stats end # Greedy sample last position if y.ndim == 3: diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index fbfa52e0..d52bebc3 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -217,12 +217,6 @@ def __init__( self.tracer = Tracer(cfg) self.tracer.start() - # Get in-flight and in-wait time per request - self._rx_ingress_t: Dict[str, float] = {} # Timestamp we enqued the request - self._rx_inflight_t: Dict[str, float] = {} # Per-request inflight time - self._ex_enque_t: Dict[str, float] = {} # req is queued for execution - self._tx_enque_t: Dict[str, float] = {} # req is queued for sendoff - # Per-nonce KV caches (concurrent requests) self._kv_by_nonce: Dict[str, list] = {} self._kv_last_seen: Dict[str, float] = {} @@ -780,7 +774,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): while self.running: try: self.activation_recv_queue.put_nowait(activation_msg) - self._ex_enque_t[activation_msg.nonce] = time.perf_counter() + activatino_msg.ex_enq_t = time.perf_counter() logger.debug("Queued activation for processing: nonce %s", activation_msg.nonce) break except Full: @@ -802,8 +796,8 @@ async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: while self.running: try: rx_t = time.perf_counter() - self._rx_ingress_t[request.nonce] = rx_t - self._rx_inflight_t[request.nonce] = rx_t - request.timestamp + request.rx_enq_t = rx_t + request.rx_inflight_t = rx_t - request.timestamp self.ingress_q.put_nowait(request) logger.debug(f"[ENQUE] Enqueued activation request") @@ -830,8 +824,8 @@ async def _ingress_worker(self): # Trace processing of request, in-flight and in-wait times with self.tracer.frame("network", "rx") as f: - f.set("inwait", time.perf_counter() - self._rx_ingress_t[req.nonce]) - f.set("inflight", self._rx_inflight_t[req.nonce]) + f.set("inwait", time.perf_counter() - req.rx_enq_t) + f.set("inflight", req.rx_inflight_t) f.set("nonce", req.nonce) try: @@ -1114,7 +1108,9 @@ def _compute_worker(self) -> None: # Process the activation with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats - f.set("inwait", time.perf_counter() - self._ex_enque_t) + f.set("inwait", time.perf_counter() - activation_msg.ex_enq_t) + if (self.model_metadata.num_layers - 1) in self.assigned_layers: + f.set("lm_head", True) self._process_activation(activation_msg) except Empty: From 82113568cc406b576b6e18186796483d62bab302 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 21:27:25 -0700 Subject: [PATCH 041/172] auto topo and load model --- src/repl.py | 146 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 108 insertions(+), 38 deletions(-) diff --git a/src/repl.py b/src/repl.py index c2f081fb..f9cb363b 100644 --- a/src/repl.py +++ b/src/repl.py @@ -29,8 +29,17 @@ logger = get_api_logger() from dnet.perf.trace import TraceConfig, Tracer -from dnet.perf.utils import TraceAggregator +from dnet.perf.utils import TraceAggregator, StatsAggregator #from dnet.perf.bench import +from dnet.ring.common import TopologyInfo + +from dnet.ring.api.models import ( + PrepareTopologyManualRequest, + PrepareTopologyRequest, + PrepareTopologyResponse, + APILoadModelRequest, + APILoadModelResponse, +) # Handle restricted repos from importlib import import_module @@ -58,6 +67,7 @@ class REPLState: api_http_port: int = 8080 api_grpc_port: int = 50500 window_size = 2 # Number of layers per node per visit (also number resident in cache) + topo: TopologyInfo = None class REPL(cmd.Cmd): @@ -92,13 +102,17 @@ def __init__(self, model="NULL", nodes=1): aggregate=True, agg_max_events=50, ) - self._tracing = threading.Event() self._tracer = None self._trace_file = f"trace-{time.strftime("%Y%m%d-%H%M%S")}" self._trace_cursor = 0 # keep track of registered events in buffer self._trace_agg = TraceAggregator() + # Runtime stats (ingests data from tracer) + self._stats_agg = StatsAggregator() + self._stats = threading.Event() + self._stats.set() # Trace runtime information by default + def loop(self): # Main tty loop sys.stdout.write(self.WELCOME) @@ -123,9 +137,15 @@ def loop(self): # Main tty loop elif cmd.startswith("nodes"): self.print_mdns_nodes() continue + elif cmd.startswith("load"): + self.load_model() + continue elif cmd.startswith(("trace", ".trace")): self.do_trace(cmd.split(" ")) continue + elif cmd.startswith(("perf", ".perf")): + self.do_perf(cmd.split(" ")) + continue elif cmd.startswith(("topo", ".topo")): self.do_topo(cmd.split(" ")) continue @@ -187,8 +207,10 @@ def do_topo(self, cmd: List[str]) -> None: dprint("Invalid topology command. Type 'help' for a list of valid commands.\n") return if cmd[1] == "search": + self.print_mdns_nodes() pass - elif cmd[1] == "auto": + elif cmd[1] == "auto" or cmd[1] == "build": + self.prepare_topo() pass elif cmd[1] == "setup": pass @@ -238,7 +260,8 @@ def _print_hf(cmd, desc, examples=[""]): _print_hf("trace focus [SUBSYSTEM] ", "Focus the trace on [SUBSYSTEM]. Do 'trace focus' for a list of available subsystems.") _print_hf("trace stream [ON|OFF] ", "Stream the trace spans to current terminal.") _print_hf("trace set [BUDGET] ", "Set the maximum amount of recoded events.") - _print_hf("profile [REPO] ", "Estimate the total FLOPS of the model from [REPO]") + _print_hf("perf ", "Prints the current state of runtime performance tracking.") + _print_hf("perf stat [REQ_ID | WORKER_ID | MODEL] ", "Prints the runtime statistics of target system.") _print_hf("bench [REPO]", "Benchmark the system using the model from [REPO]") _print_hf("bench [KERNEL]", "Behcnmark the system using base kernel [KERNEL]") _print_hf("bench [NODE]", "Behcnmark the network latency between the current system and [NODE]") @@ -502,43 +525,65 @@ async def _await_then_set(): def do_trace(self, cmd): if len(cmd) < 2: dprint(f"Tracing is currently {"ON" if self._trace_cfg.enabled else "OFF"}\n") - elif cmd[1] in ("on", "ON"): - self._trace_cfg.enabled = True - if self._api_running: - self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards - dprint("Tracing is now ON\n") - elif cmd[1] in ("off", "OFF"): - self._trace_cfg.enabled = False - if self._api_running: - self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards - dprint("Tracing is not OFF\n") - elif cmd[1] == "focus": - #self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards - dprint("Subsystems not yet implemented.\n") - elif cmd[1] == "stream": - if len(cmd) == 2: - dprint(f"Trace is {"streaming to file: "+str(self._trace_cfg.file) if self._trace_cfg.streaming else "not streaming."}\n") - elif cmd[2] == "on": - self._trace_cfg.streaming = True - dprint(f"Streaming trace frames to {self._trace_cfg.file}\n") - elif cmd[2] == "off": - self._trace_cfg.streaming = False - dprint("Trace streaming is OFF.\n") - elif cmd[1] == "set": - if len(cmd) == 2: - dprint("Use: trace set [BUDGET], eg. 2000\n") - else: - dprint("Not implemented yet\n") - # FIXME: Implement - elif cmd[1] == "status": - dprint(f"Frames: {len(self._trace_agg._req)}\n") - - elif cmd[1] == "annotate": - self.print_trace_annotate("NONE") + return + + match cmd[1]: + case s if s in ["on", "ON"]: + self._trace_cfg.enabled = True + dprint("Tracing is now ON\n") + + case s if s in ["off", "OFF"]: + self._trace_cfg.enabled = False + dprint("Tracing is now OFF\n") + + case s if s == "focus": + dprint("Subsystems not yet implemented.\n") + + case s if s == "stream": + if len(cmd) == 2: + dprint(f"Trace is {"streaming to file: "+str(self._trace_cfg.file) if self._trace_cfg.streaming else "not streaming."}\n") + elif cmd[2] == "on": + self._trace_cfg.streaming = True + dprint(f"Streaming trace frames to {self._trace_cfg.file}\n") + elif cmd[2] == "off": + self._trace_cfg.streaming = False + dprint("Trace streaming is OFF.\n") + + case s if s == "set": + if len(cmd) == 2: + dprint("Use: trace set [BUDGET], eg. 2000\n") + else: + dprint("Not implemented yet\n") + + case s if s == "annotate": + self.print_trace_annotate("NONE") + + case _: + dprint("Unknown trace command. Type 'help' for a list of available commands.\n") + + if self._api_running.is_set() and self._api_ready.is_set(): + self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards + + # Performance trackers + def do_perf(self, cmd): + if len(cmd) < 2 or cmd[1] == "stat": + dprint("Runtime performance metrics are ON by default.\n") + dprint("Turn tracking off with 'perf off'. Do 'perf stat' for statistics on previous requests or 'help' for more commands.\n\n") + return + + match cmd[1]: + case s if s in "...": + pass + case _: + pass # Trace callback registered with API Thread + # This forwards the tracer frames back to the REPL for printing def __trace_cb(self, data): - self._trace_agg.enqueue(data) + if self._tracing.is_set(): + self._trace_agg.enqueue(data) + if self._stats.is_set(): + self._stats_agg.add(data) def __print_tr(self, row): sym = " " + symbol.ljust(40, ' ') @@ -638,6 +683,31 @@ def print_mdns_nodes(self) -> None: except Exception as e: dprint(f"Could not list nodes: {e}\n") + def print_topo(self, topo): + line = "="*20+" Topology " + "="*20 + sys.stdout.write(f"{line}\nModel: {topo.model}\nLayers: {topo.num_layers}\n") + sys.stdout.write(f"Devices: {topo.devices}\n\n") + # TODO: Better print here + + def prepare_topo(self): + req = PrepareTopologyRequest(model="Qwen/Qwen3-4B-MLX-4bit") + try: + topo = self.api_call("_handle_prepare_topology", req, timeout=30) + except Exception as e: + dprint(f"Unable to create topology: {e}\n\n") + return + self.state.topo = topo + self.print_topo(topo) + + def load_model(self): + req = APILoadModelRequest(model="Qwen/Qwen3-4B-MLX-4bit") + try: + res = self.api_call("_handle_load_model", req, timeout=30) + except Exception as e: + dprint(f"Failed to load model: {e}\n\n") + return + + # ===== Handle shutdown def handle_shutdown(self): From 7dafdb86532367e538e5f1210df0e12ce99575f5 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 21:27:50 -0700 Subject: [PATCH 042/172] wrap in trace frames --- src/dnet/ring/weight_cache.py | 128 +++++++++++++++++++++++----------- 1 file changed, 89 insertions(+), 39 deletions(-) diff --git a/src/dnet/ring/weight_cache.py b/src/dnet/ring/weight_cache.py index f7c9ae93..fcb5f53a 100644 --- a/src/dnet/ring/weight_cache.py +++ b/src/dnet/ring/weight_cache.py @@ -26,19 +26,19 @@ def __init__( model_metadata: ModelMetadata, window_size: Optional[int] = None, prefetch_threads: int = 2, + tracer=None, *, resident_windows: int = 2, use_mxload_fastpath: bool = False, prefetch_mode: str = "off", ): self.assigned_layers = assigned_layers - # Resident budget: enforce up to N windows resident - resident_windows = max(1, int(resident_windows)) + resident_windows = max(1, int(resident_windows)) # Resident budget if window_size is not None and window_size > 0: self.max_weights = min( - len(self.assigned_layers), max(1, resident_windows * int(window_size)) - ) + len(self.assigned_layers), + max(1, resident_windows * int(window_size))) else: self.max_weights = len(self.assigned_layers) self.cache = {} # layer_id -> (data, access_time) @@ -51,14 +51,18 @@ def __init__( prefetch_mode=prefetch_mode, ) self.lock = threading.Lock() + + if not tracer: + logger.error("Invalid tracer object passed to WeightCache.") + self.tracer = tracer + # Track in-flight materializations so compute can wait on prefetch self.loading_futures: Dict[int, Future] = {} self.prefetch_futures: Dict[int, Future] = {} logger.info("WeightCache resident budget: max_weights=%d", self.max_weights) - def get_weight( - self, layer_id: int, *, inc_ref: bool = True - ) -> Optional[Dict[str, mx.array]]: + + def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[str, mx.array]]: """Get weight from cache""" # First, fast path under lock for cache hit or in-flight load with self.lock: @@ -72,41 +76,85 @@ def get_weight( ) return data - # If a load is in-flight, wait on it outside the lock - inflight = self.loading_futures.get(layer_id) - if inflight is None: - # Prepare eviction decision now to avoid overfilling once loaded - need_evict = len(self.cache) >= self.max_weights - if need_evict: - # Evict under lock, then proceed to load - self._evict_lru() - # Install a new future marker so others wait - fut = Future() - self.loading_futures[layer_id] = fut - inflight = fut - creator = True - else: - creator = False - - if creator: - # Perform the blocking load without holding the cache lock - try: - t0 = time.perf_counter() - data = self.layer_manager.load_layer_to_gpu(layer_id) - dt_ms = (time.perf_counter() - t0) * 1000.0 - # Estimate bytes by summing tensor sizes for the layer + with self.tracer.frame("weights.cache", "search") as f: + with self.lock: + if layer_id in self.cache: + data, _ = self.cache[layer_id] + self.cache[layer_id] = (data, time.time()) # refresh LRU timestamp + if inc_ref: + self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) + logger.debug("Cache hit for layer %s, ref=%d inc=%d", + layer_id, self.reference_counts.get(layer_id, 0), int(inc_ref)) + return data + + inflight = self.loading_futures.get(layer_id) # If a load is in-flight, wait on it outside the lock + if inflight is None: + need_evict = len(self.cache) >= self.max_weights + if need_evict: # Prepare eviction decision now to avoid overfilling once loaded + self._evict_lru() # Evict under lock, then proceed to load + fut = Future() # Install a new future marker so others wait + self.loading_futures[layer_id] = fut + inflight = fut + creator = True + else: + creator = False + + if creator: # Perform the blocking load without holding the cache lock + with self.tracer.frame("weights.cache", "load") as f: try: - winfo = self.layer_manager.weight_info.get(layer_id, {}) - total_bytes = sum(w.size_bytes for w in winfo.values()) - except Exception: - total_bytes = 0 - # Commit to cache under lock - with self.lock: + data = self.layer_manager.load_layer_to_gpu(layer_id) + f.event("load") + + try: # Estimate bytes by summing tensor sizes for the layer + winfo = self.layer_manager.weight_info.get(layer_id, {}) + total_bytes = sum(w.size_bytes for w in winfo.values()) + f.set("bytes", total_bytes) + except Exception: + total_bytes = 0 + + with self.lock: # Commit to cache under lock + self.cache[layer_id] = (data, time.time()) + if inc_ref: + self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) + else: + self.reference_counts.setdefault(layer_id, 0) + + try: # Resolve future and clear from tracking + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_result(True) + except Exception: + self.loading_futures.pop(layer_id, None) + return data + + except Exception as e: + with self.lock: # Signal failure to any waiters + try: + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_exception(e) + except Exception: + self.loading_futures.pop(layer_id, None) + logger.exception("Failed to load weight %s: %s", layer_id, e) + return None + + else: # Not the creator: wait for the in-flight load to complete + with self.tracer.frame("weights.cache", "wait") as f: + try: + inflight.result() # block until the creator completes + except Exception as e: + logger.error("Wait for layer %s load failed: %s", layer_id, e) + return None + + with self.lock: # Return from cache + data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] + if data is None: + logger.error("Wait for layer %s load failed: data not in cache", layer_id) + return None + self.cache[layer_id] = (data, time.time()) if inc_ref: - self.reference_counts[layer_id] = ( - self.reference_counts.get(layer_id, 0) + 1 - ) + self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) else: self.reference_counts.setdefault(layer_id, 0) # Resolve future and clear from tracking @@ -184,6 +232,7 @@ def prefetch_to_ram(self, layer_id: int): except Exception: return None + def cancel_all_prefetch(self): """Cancel any in-flight prefetch tasks and clear tracking.""" with self.lock: @@ -195,6 +244,7 @@ def cancel_all_prefetch(self): pass self.prefetch_futures.clear() + def _evict_lru(self): """Evict least recently used weight with zero references""" candidates = [ From 0b3e80f35bd74fc4958729e425bfff8c511985e9 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 21:28:49 -0700 Subject: [PATCH 043/172] change file name --- src/dnet/perf/utils/__init__.py | 2 +- .../utils/{aggregator.py => aggregators.py} | 85 +++++++++---------- 2 files changed, 42 insertions(+), 45 deletions(-) rename src/dnet/perf/utils/{aggregator.py => aggregators.py} (82%) diff --git a/src/dnet/perf/utils/__init__.py b/src/dnet/perf/utils/__init__.py index 7228627d..0ee2f5e1 100644 --- a/src/dnet/perf/utils/__init__.py +++ b/src/dnet/perf/utils/__init__.py @@ -1 +1 @@ -from .aggregator import TraceAggregator +from .aggregators import TraceAggregator, StatsAggregator diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregators.py similarity index 82% rename from src/dnet/perf/utils/aggregator.py rename to src/dnet/perf/utils/aggregators.py index 6807d47e..78d5bbd6 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregators.py @@ -200,21 +200,21 @@ def roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: # Runtime statistics -# Use a RunAggregator to get raw frames per request, then transform into _RuntimeStats +# Use a RunAggregator to get raw frames per request, then transform into ReqStats # Track a single request, use multiple for a full benchmark @dataclass -class _RuntimeStats: +class ReqStats: model: str # Model name tokenizer: str # Tokenizer name run_id: str # ID of session (for later mapping) - nonce: List[str] # List of serviced requests + nonce: str # List of serviced requests ttft: float # Time to first token itl: List[float] # Inter-token latency per round prompt_tokens: int # Number of prompt tokens per request (req_id: #) generated_tokens: int # Number of generated tokens per request (req_id: #) - latencies: Dict[List[str, str, str], int] # Map of inter-node latencies: [node0, node1, p50]: 0.0 + latencies: List[List[str, str, str, int]] # List of inter-node latencies: [node0, node1, p50, 0.0] latency_per_layer: Dict[int, float] # Map of {layer: 0.0} latency_per_shard: Dict[str, float] # Map of {shard: 0.0} total_latency: int # Total runtime of requests @@ -233,47 +233,47 @@ class StatsAggregator: def __init__(self) -> None: self._lock = threading.Lock() - self._max_inflight_rq = 20 # per node FIXME: modify from repl + self._max_inflight_req = 20 # per node FIXME: modify from repl # Frames are kept in here while in-flight, then remove the frame objects and append to _stats - self._workers: Dict[str, Dict[str, Dict[str, Any]]] = {} # Store frames per nonce, per node_id + self._frames: Dict[str, Dict[str, Dict[str, Any]]] = {} # Store frames per nonce, per node_id self._nonces = [] # Tracked nonces (either in-flight or done) self._nonce_round_finish: Dict[str, bool] = {} # Track in-flight rounds self._nonce_prefill: Dict[str, bool] = {} # Track if this round is prefill - self._running_stats: Dict[str, _RuntimeStats] = {} # Unfinished stat frames - self._stats: Dict[str, _RuntimeStats] = {} # Finished stat frames - self._open_frames: Dict[str, Dict[str, Any]] # We got 'B' event but not 'E' (per nonce) + self._running_stats: Dict[str, ReqStats] = {} # Unfinished stat frames + self._stats: Dict[str, ReqStats] = {} # Finished stat frames + self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per nonce) self._model_per_run: Dict[str, str] = {} # Track model per run_id # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: - run_id = data.run_id - node_id = data.node_id - events = data.events or [] - name = data.name + run_id = data["run_id"] + node_id = data["node_id"] + events = data["events"] or [] + name = data["name"] if not run_id or not node_id: return # Drop the batch with self._lock: # Ensure we register workers and nodes for i, ev in enumerate(events): - if "nonce" not in ev.attrs: ev.attrs["nonce"] = f"N_{i}" - nonce = ev.attrs["nonce"] + if "nonce" not in ev["attrs"]: ev["attrs"]["nonce"] = f"N_{i}" + nonce = ev["attrs"]["nonce"] new_frames.append(ev) - if node_id not in self._workers: - self._workers[node_id] = {} + if node_id not in self._frames: + self._frames[node_id] = {} - if nonce not in self._workers[node_id]: - self._workers[node_id][nonce] = {} + if nonce not in self._frames[node_id]: + self._frames[node_id][nonce] = {} - if len(self._workers[node_id]) >= self._max_resident_req: # remove oldest entry - del self._workers[self._nonces[0]] + if len(self._frames[node_id]) >= self._max_resident_req: # remove oldest entry + del self._frames[self._nonces[0]] del self._nonces[0] - self._nonces.push(nonce) + self._nonces.append(nonce) # Update in-flight events or register new ones for e in events: @@ -282,8 +282,8 @@ def add(self, data: Dict[str, Any]) -> None: if not node_id and nonce: return # Drop invalid frames - if e.name == "embedding": # Register new request - rt_stat = self._running_stats.setdefault(run_id, _RuntimeStats( + if e["name"] == "embedding": # Register new request + rt_stat = self._running_stats.setdefault(nonce, ReqStats( model="", tokenizer="", run_id=run_id, @@ -304,34 +304,31 @@ def add(self, data: Dict[str, Any]) -> None: # so we need to handle the creation of this better stats = self._running_stats[nonce] - if e.name == "network.rx": + if e["name"] == "network.rx": # Time in transport, ingress queue and ingress_worker - _cost = lambda e: e.attrs["inflight"] + e.attrs["inwait"] + e.attrs["ms"] - self._handle_frame(e, stats, _cost) + _cost = lambda e: e["attrs"]["inflight"] + e["attrs"]["inwait"] + e["attrs"]["ms"] + self._handle_frame(e, nonce, stats, _cost) #TODO: change shard in metadata - if e.name == "compute.forward": - _cost = lambda e: e.attrs["inwait"] + e.attrs["ms"] # compute queue + execution - self._handle_frame(e, stats, _cost) + if e["name"] == "compute.forward": + _cost = lambda e: e["attrs"]["inwait"] + e.attrs["ms"] # compute queue + execution + self._handle_frame(e, nonce, stats, _cost) - if e.name == "network.tx.send": - _cost = lambda e: e.attrs["inwait"] + e.attrs["ms"] # tx queue + sendoff - self._handle_frame(e, stats, _cost) - - if e.name = "lm_head" and not self._nonce.round_finish[nonce]: # Finish request + # Finish request + if "lm_head" in e.attrs and not self._nonce_round_finish[nonce]: self._nonce_round_finish[nonce] = True - - # TODO: Remove frame and stsats from working and append st_obj = self._running_stats[nonce] + self._stats[nonce] = st_obj del self._running_stats[nonce] - self._stats.append(st_obj) - - acc_ttt = 0 # accumulated time to token - acc_ttt += shard["network.ingress"][-1] - inflight = shard['network.ingress'][] + #del self._frames[node_id][nonce] + # TODO: Handle latency of transfer back to API + + if e["name"] == "network.tx.send": + _cost = lambda e: e["attrs"]["inwait"] + e["attrs"]["ms"] # tx queue + sendoff + self._handle_frame(e, nonce, stats, _cost) # Handle cost aggregation of frames - def _handle_frame(e: Any, stats: _RuntimeStats, _cost_fnc: Any): + def _handle_frame(self, e: Any, nonce, stats: ReqStats, _cost_fnc: Any): if e.type == 'B': self._open_frames[nonce][e.name] = e return @@ -360,6 +357,6 @@ def stats( elif model: pass - else: # Return stats of all counters + else: # Sort per model, per request (node info only when requested) pass From 44efaeeedac428e361bde148d38df02cd89c756f Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 21:48:57 -0700 Subject: [PATCH 044/172] various small stuff --- src/dnet/perf/trace.py | 6 +++++- src/dnet/ring/__init__.py | 2 ++ src/dnet/ring/api/api_logging.py | 13 +++---------- src/dnet/ring/api/node.py | 5 +++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 7b5d79c1..65accca1 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -224,8 +224,12 @@ def frame(self, scope: str, name: str, attrs: Optional[Dict[str, Any]] = None): return _NoopFrame() return _Frame(self, f"{scope}.{name}", attrs) + # Same as normal frame but signals that this trace is a cannon event (required for runtime stats) + def canonical(self, scope: str, name: str, attrs: Optional[Dict[str, Any]] = None): + return self.frame(scope, name, attrs) + # Mark an event outside of a frame - def mark(self, name: str, **attrs: Any) -> None: + def mark(self, name: str, attrs: Any) -> None: self._emit({"type": "I", "name": name, "args": attrs}) # Helpers diff --git a/src/dnet/ring/__init__.py b/src/dnet/ring/__init__.py index e69de29b..adb40d20 100644 --- a/src/dnet/ring/__init__.py +++ b/src/dnet/ring/__init__.py @@ -0,0 +1,2 @@ + +from .common import LayerAssignment, TopologyInfo diff --git a/src/dnet/ring/api/api_logging.py b/src/dnet/ring/api/api_logging.py index d90c526c..48999dac 100644 --- a/src/dnet/ring/api/api_logging.py +++ b/src/dnet/ring/api/api_logging.py @@ -5,17 +5,9 @@ from logging.handlers import RotatingFileHandler from pathlib import Path - _CONFIGURED_FLAG = "_dnet_api_logger_configured" - def get_api_logger() -> logging.Logger: - """Return a process‑local logger for the API server. - - - Does not propagate to the root logger (so it won't spam the REPL TTY). - - Writes to logs/api.log with rotation. - - Level is controlled by DNET_API_LOG (default: INFO). - """ log = logging.getLogger("dnet.api") if getattr(log, _CONFIGURED_FLAG, False): return log @@ -23,7 +15,8 @@ def get_api_logger() -> logging.Logger: # Level from env, fallback INFO level_name = (os.getenv("DNET_API_LOG", "INFO") or "INFO").strip().upper() level = getattr(logging, level_name, logging.INFO) - log.setLevel(level) + #log.setLevel(level) + log.setLevel(logging.DEBUG) # Do not bubble to root (console) log.propagate = False @@ -36,7 +29,7 @@ def get_api_logger() -> logging.Logger: # Attach a rotating file handler try: - fh = RotatingFileHandler("logs/api.log", maxBytes=10_000_000, backupCount=5) + fh = RotatingFileHandler("logs/api.log", maxBytes=10000000, backupCount=5) fmt = logging.Formatter( "%(asctime)s %(levelname)s [%(threadName)s] %(name)s: %(message)s" ) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 08bc1cbd..e36dbab2 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -422,6 +422,7 @@ async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: return TraceIngestResponse(ok=False, accepted=0, message=str(e)) async def _forward_trace_config(self, cfg: Any) -> bool: + logger.debug("Forwarding Trace config") shards = self._get_shards_from_discovery() this = self.discovery.get_own_properties() api_endpoint = f"http://{this.local_ip}:{this.server_port}/trace/ingest" @@ -447,9 +448,9 @@ async def _forward_trace_config(self, cfg: Any) -> bool: try: res = await client.post(url, json=dict(payload)) if res.status_code != 200: - logger.warning(f"Failed to POST tracer config to node {name}.") + logger.error(f"Failed to POST tracer config to {url}.: {res.text}") except Exception as e: - logger.warning(f"Failed to POST tracer config: {e}") + logger.error(f"Failed to POST tracer config: {e}") return False return True From 57a2fddffcc2453e6fb15eca98635ef42e3abf1d Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 00:04:41 -0700 Subject: [PATCH 045/172] add tracer to api and send frames back to repl. emit special frames when api starts and ends a chat request --- src/dnet/perf/trace.py | 12 +++---- src/dnet/perf/utils/aggregators.py | 51 ++++++++++++++++-------------- src/dnet/ring/api/node.py | 23 ++++++++++++++ 3 files changed, 57 insertions(+), 29 deletions(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 65accca1..48f6e49b 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -229,7 +229,7 @@ def canonical(self, scope: str, name: str, attrs: Optional[Dict[str, Any]] = Non return self.frame(scope, name, attrs) # Mark an event outside of a frame - def mark(self, name: str, attrs: Any) -> None: + def mark(self, name: str, attrs: Any = {}) -> None: self._emit({"type": "I", "name": name, "args": attrs}) # Helpers @@ -250,7 +250,7 @@ def profile_block(self, outfile: Optional[str] = None, sort: str = "cumtime", li with open(outfile, "w", encoding="utf-8") as f: f.write(out) else: - self._emit({"type": "PROFILE", "name": "cprofile", "args": {"sort": sort, "limit": limit, "report": out}}) + self._emit({"type": "PROFILE", "name": "cprofile", "attrs": {"sort": sort, "limit": limit, "report": out}}) @contextmanager def callgraph( @@ -284,13 +284,13 @@ def prof(frame, event, arg): key = f"{filename}:{code.co_firstlineno}:{name}" if event == "call": stack.append((key, time.perf_counter())) - self._emit({"type": "B", "name": f"py.{name}", "args": {"file": filename, "line": code.co_firstlineno}}) + self._emit({"type": "B", "name": f"py.{name}", "attrs": {"file": filename, "line": code.co_firstlineno}}) emitted += 1 else: if stack and stack[-1][0] == key: _, t0 = stack.pop() dt_ms = (time.perf_counter() - t0) * 1000.0 - self._emit({"type": "E", "name": f"py.{name}", "args": {"ms": round(dt_ms, 3)}}) + self._emit({"type": "E", "name": f"py.{name}", "attrs": {"ms": round(dt_ms, 3)}}) emitted += 1 elif inc_c and event in ("c_call", "c_return"): func = getattr(arg, "__name__", None) @@ -298,10 +298,10 @@ def prof(frame, event, arg): if not func: return if event == "c_call": - self._emit({"type": "B", "name": f"c.{mod}.{func}", "args": {}}) + self._emit({"type": "B", "name": f"c.{mod}.{func}", "attrs": {}}) emitted += 1 else: - self._emit({"type": "E", "name": f"c.{mod}.{func}", "args": {}}) + self._emit({"type": "E", "name": f"c.{mod}.{func}", "attrs": {}}) emitted += 1 prev = sys.getprofile() diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index 78d5bbd6..fad29a7d 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -238,30 +238,27 @@ def __init__(self) -> None: # Frames are kept in here while in-flight, then remove the frame objects and append to _stats self._frames: Dict[str, Dict[str, Dict[str, Any]]] = {} # Store frames per nonce, per node_id - self._nonces = [] # Tracked nonces (either in-flight or done) + self._nonces: List[str] = [] # Tracked nonces (either in-flight or done) self._nonce_round_finish: Dict[str, bool] = {} # Track in-flight rounds self._nonce_prefill: Dict[str, bool] = {} # Track if this round is prefill - self._running_stats: Dict[str, ReqStats] = {} # Unfinished stat frames - self._stats: Dict[str, ReqStats] = {} # Finished stat frames + self._running_stats: Dict[str, ReqStats] = {} # Unfinished stat frames + self._stats: Dict[str, ReqStats] = {} # Finished stat frames self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per nonce) self._model_per_run: Dict[str, str] = {} # Track model per run_id # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: - run_id = data["run_id"] - node_id = data["node_id"] + run_id = data["run_id"] or "NONE" + node_id = data["node_id"] or "NONE" events = data["events"] or [] - name = data["name"] if not run_id or not node_id: return # Drop the batch with self._lock: # Ensure we register workers and nodes for i, ev in enumerate(events): - if "nonce" not in ev["attrs"]: ev["attrs"]["nonce"] = f"N_{i}" - nonce = ev["attrs"]["nonce"] - - new_frames.append(ev) + if "nonce" not in ev["args"]: ev["args"]["nonce"] = f"N_" + nonce = ev["args"]["nonce"] if node_id not in self._frames: self._frames[node_id] = {} @@ -269,21 +266,24 @@ def add(self, data: Dict[str, Any]) -> None: if nonce not in self._frames[node_id]: self._frames[node_id][nonce] = {} - if len(self._frames[node_id]) >= self._max_resident_req: # remove oldest entry + if len(self._frames[node_id]) >= self._max_inflight_req: # remove oldest entry del self._frames[self._nonces[0]] del self._nonces[0] - - self._nonces.append(nonce) + if nonce not in self._nonces: + self._nonces.append(nonce) # Update in-flight events or register new ones for e in events: - nonce = e.attrs["nonce"] + nonce = e["args"]["nonce"] assert nonce is not None, "" - if not node_id and nonce: return # Drop invalid frames + if not node_id or not nonce: return # Drop invalid frames - if e["name"] == "embedding": # Register new request - rt_stat = self._running_stats.setdefault(nonce, ReqStats( + if e["name"] == "chat.request.end": + print(e) + if e["name"] == "chat.request.start": + print(e) + self._running_stats[nonce] = ReqStats( model="", tokenizer="", run_id=run_id, @@ -291,31 +291,36 @@ def add(self, data: Dict[str, Any]) -> None: ttft=0.0, itl=[0.0], generated_tokens=0, - prompt_tokens=e.attrs["prompt_tokens"], + prompt_tokens=e["args"]["prompt_tokens"], latencies={}, latency_per_layer={}, latency_per_shard={}, total_latency=0.0, assignment=None, topo=None, - )) + ) + if e["name"] == "embedding": # Register new request + pass # FIXME: We might receive other frames then "embed" from shards # so we need to handle the creation of this better + if nonce not in self._running_stats: + continue + stats = self._running_stats[nonce] if e["name"] == "network.rx": # Time in transport, ingress queue and ingress_worker - _cost = lambda e: e["attrs"]["inflight"] + e["attrs"]["inwait"] + e["attrs"]["ms"] + _cost = lambda e: e["args"]["inflight"] + e["args"]["inwait"] + e["args"]["ms"] self._handle_frame(e, nonce, stats, _cost) #TODO: change shard in metadata if e["name"] == "compute.forward": - _cost = lambda e: e["attrs"]["inwait"] + e.attrs["ms"] # compute queue + execution + _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # compute queue + execution self._handle_frame(e, nonce, stats, _cost) # Finish request - if "lm_head" in e.attrs and not self._nonce_round_finish[nonce]: + if "lm_head" in e["args"] and not self._nonce_round_finish[nonce]: self._nonce_round_finish[nonce] = True st_obj = self._running_stats[nonce] self._stats[nonce] = st_obj @@ -324,7 +329,7 @@ def add(self, data: Dict[str, Any]) -> None: # TODO: Handle latency of transfer back to API if e["name"] == "network.tx.send": - _cost = lambda e: e["attrs"]["inwait"] + e["attrs"]["ms"] # tx queue + sendoff + _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # tx queue + sendoff self._handle_frame(e, nonce, stats, _cost) # Handle cost aggregation of frames diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index e36dbab2..60206b39 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -86,6 +86,7 @@ from .servicer import ShardApiServicer from ..common import TopologyInfo, LayerAssignment +from dnet.perf import Tracer, TraceConfig async def arange(count: int): """Async range generator.""" @@ -148,12 +149,27 @@ def __init__( except Exception: pass + cfg = TraceConfig( + file="./trace.json", + streaming=False, + include_prefixes = ("src/dnet/"), + include_c_calls = False, + budget = 10000, + enabled = True, + record_pid_tid = True, + aggregate=False, + aggregate_url=None, + ) + self.tracer = Tracer(cfg) + self.tracer.start() + logger.info( "API node initialized on HTTP port %s, gRPC port %s", self.http_port, self.grpc_port, ) + async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()) -> None: """Start the API node. @@ -401,6 +417,11 @@ async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: if self._trace_ingest_cb is not None: logger.debug("Forwarding trace batch to REPL.") self._trace_ingest_cb(batch.model_dump()) + + _t_batch = { "run_id": "NONE", "node_id": "API", "events": list(self.tracer._events) } + #self.tracer._events.clear() + self._trace_ingest_cb(_t_batch) # FIXME: Move this + return TraceIngestResponse(ok=True, accepted=len(batch.events)) try: @@ -1197,6 +1218,7 @@ async def _handle_chat_completion(self, req: ChatRequestModel) -> ChatResponseMo Returns: Chat response """ + self.tracer.mark("chat.request.start") stop_id_sequences: List[List[int]] = [ self.tokenizer.encode(stop_word, add_special_tokens=False) # type: ignore for stop_word in req.stop # type: ignore @@ -1327,6 +1349,7 @@ async def _handle_completion( ) # Build optional metrics + self.tracer.mark("chat.request.end") metrics = None if profile_enabled: t_end = time.perf_counter() From 933c3d8cb012a6b4babc16d8d6aff485ee21b195 Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 00:57:28 -0700 Subject: [PATCH 046/172] track prompt tokens --- src/dnet/ring/api/node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 60206b39..4d7bdad5 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -419,7 +419,7 @@ async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: self._trace_ingest_cb(batch.model_dump()) _t_batch = { "run_id": "NONE", "node_id": "API", "events": list(self.tracer._events) } - #self.tracer._events.clear() + self.tracer._events.clear() self._trace_ingest_cb(_t_batch) # FIXME: Move this return TraceIngestResponse(ok=True, accepted=len(batch.events)) @@ -1218,7 +1218,7 @@ async def _handle_chat_completion(self, req: ChatRequestModel) -> ChatResponseMo Returns: Chat response """ - self.tracer.mark("chat.request.start") + self.tracer.mark("chat.request.start", {"prompt_tokens": len(req.messages[0].content)}) stop_id_sequences: List[List[int]] = [ self.tokenizer.encode(stop_word, add_special_tokens=False) # type: ignore for stop_word in req.stop # type: ignore From 739522bfd9938ebee17c034c8a8ef5252c55b4b5 Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 01:25:49 -0700 Subject: [PATCH 047/172] move trace frame and add correct nonce and other metadata --- src/dnet/ring/api/node.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 4d7bdad5..81e1f198 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -1218,7 +1218,6 @@ async def _handle_chat_completion(self, req: ChatRequestModel) -> ChatResponseMo Returns: Chat response """ - self.tracer.mark("chat.request.start", {"prompt_tokens": len(req.messages[0].content)}) stop_id_sequences: List[List[int]] = [ self.tokenizer.encode(stop_word, add_special_tokens=False) # type: ignore for stop_word in req.stop # type: ignore @@ -1292,6 +1291,13 @@ async def _handle_completion( t_start = time.perf_counter() t_first_token = None nonce = f"chatcmpl-{uuid.uuid4()}" + + self.tracer.mark("chat.request.start", { + "tokenizer": None, + "prompt_tokens": prompt.size, + "nonce": nonce, + }) + detokenizer = self.tokenizer.detokenizer # type: ignore detokenizer.reset() tokens: List[int] = [] @@ -1348,8 +1354,12 @@ async def _handle_completion( else detokenizer.text[: -len(stop_sequence_suffix)] ) + self.tracer.mark("chat.request.end", { + "generated_tokens": len(tokens), + "nonce": nonce, + }) + # Build optional metrics - self.tracer.mark("chat.request.end") metrics = None if profile_enabled: t_end = time.perf_counter() From abe027510b61eaeb4afb4a8e59f7518c169019d5 Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 03:26:35 -0700 Subject: [PATCH 048/172] track nonce on all frames --- src/dnet/ring/shard/node.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index d52bebc3..00808f70 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -614,10 +614,12 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): logger.error("Node %s: Cannot receive activation - input pool not initialized", self.node_id) return - with self.tracer.frame("network.rx", "connect_next_node"): + with self.tracer.frame("network.rx", "connect_next_node") as f: + f.set("nonce", request.nonce) await self._connect_next_node() with self.tracer.frame("network.rx", "process_activation") as f: + f.set("nonce", request.nonce) try: activation = request.activation target_layer = activation.layer_id + 1 @@ -691,6 +693,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): return with self.tracer.frame("network.rx", "alloc.buffer") as fr: + fr.set("nonce", request.nonce) pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, dtype=deq.dtype, @@ -714,6 +717,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): else: # Special token stream support: dtype='tokens' carries int32 token IDs if activation.dtype == "tokens": with self.tracer.frame("network.rx", "token_stream") as fr: + fr.set("nonce", request.nonce) try: tokens = np.frombuffer(request.activation.data, dtype=np.int32) shp = (int(len(tokens)), ) @@ -741,6 +745,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): else: with self.tracer.frame("network.ex", "default") as fr: + fr.set("nonce", request.nonce) # Safety: byte length must match shape*dtype try: expected = ( @@ -851,7 +856,8 @@ async def _ingress_worker(self): if target_layer in self._assigned_set: # Heavy prep in executor (alloc/copy/decompress) - with self.tracer.frame("grpc.ingress", "prepare"): + with self.tracer.frame("grpc.ingress", "prepare") as fr: + fr.set("nonce", req.nonce) loop = asyncio.get_running_loop() try: activation_msg = await loop.run_in_executor( @@ -884,6 +890,7 @@ async def _ingress_worker(self): # Enqueue for compute with self.tracer.frame("network.rx", "enque") as fr: + fr.set("nonce", req.nonce) while self.running: try: self.activation_recv_queue.put_nowait(activation_msg) @@ -973,6 +980,7 @@ def _prepare_activation_message_blocking( if "|" in activation.dtype: # Compressed path: decompress to MLX array and copy to pool with self.tracer.frame("network.rx.prepare_activation", "decompress") as f: + f.set("nonce", request.nonce) try: deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, @@ -1009,6 +1017,7 @@ def _prepare_activation_message_blocking( elif activation.dtype == "tokens": # Tokens path: parse int32 token IDs and stage them with self.tracer.frame("network.rx.prepare_activation", "tokens") as f: + f.set("nonce", request.nonce) try: tokens = np.frombuffer(activation.data, dtype=np.int32) shp = (int(len(tokens)),) @@ -1038,6 +1047,7 @@ def _prepare_activation_message_blocking( else: # Dense path: validate size and copy raw bytes view into pool buffer with self.tracer.frame("network.rx.prepare_activation", "default") as f: + f.set("nonce", request.nonce) try: expected = ( int(np.prod(activation.shape)) @@ -1108,6 +1118,7 @@ def _compute_worker(self) -> None: # Process the activation with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("nonce", activation_msg.nonce) f.set("inwait", time.perf_counter() - activation_msg.ex_enq_t) if (self.model_metadata.num_layers - 1) in self.assigned_layers: f.set("lm_head", True) @@ -1527,7 +1538,7 @@ async def load_model_endpoint( f"api_callback={req.api_callback_address or 'none'}" ) self.tracer.mark("model", {"model": req.model_path, "ts": time.perf_counter()}) # Record model name - with self.tracer.frame("memory", "model.load"): # NOTE: Symbol hardcoded for runtime stats + with self.tracer.frame("memory", "model.load") as f: # NOTE: Symbol hardcoded for runtime stats result = await self.load_model(req) return result @@ -1545,7 +1556,7 @@ async def unload_model_endpoint() -> ShardUnloadModelResponse: """Unload current model.""" try: logger.info("HTTP /unload_model") - with self.tracer.frame("memory", "model.unload"): # NOTE: Symbol hardcoded for runtime stats + with self.tracer.frame("memory", "model.unload") as f: # NOTE: Symbol hardcoded for runtime stats result = await self.unload_model() return result @@ -1560,7 +1571,9 @@ async def unload_model_endpoint() -> ShardUnloadModelResponse: # FIXME: add pydantic type here async def warm(request: Request) -> JSONResponse: try: - with self.tracer.frame("memory", "model.warm"): # NOTE: Symbol hardcoded for runtime stats + # FIXME: Append warmup config? Or something to distinguish + with self.tracer.frame("memory", "model.warm") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("nonce", request.nonce) body = await request.json() start = int(body.get("start", -1)) window = int(body.get("window", self.window_size)) @@ -1625,7 +1638,7 @@ async def _profile_device(self, repo_id: str, max_batch_exp: int) -> dict: Returns: Device profile information as a plain dict """ - with self.tracer.frame("startup", "profile.device"): # NOTE: Symbol hardcoded for runtime stats + with self.tracer.frame("startup", "profile.device") as f: # NOTE: Symbol hardcoded for runtime stats profile_dict = profile_device_via_subprocess( repo_id, max_batch_exp=max_batch_exp, debug=0 ) From 7d75a4e74818e41b25d3124aee49f615b3b7bc24 Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 14:02:50 -0700 Subject: [PATCH 049/172] better track in-wait time and default to 0 for single shard --- src/dnet/protos/dnet_ring.proto | 7 ++++--- src/dnet/protos/shard_api_comm.proto | 3 ++- src/dnet/ring/data_types.py | 1 + src/dnet/ring/shard/comms.py | 5 ++++- src/dnet/ring/shard/node.py | 10 ++++++++-- src/repl.py | 19 ++++++++++++++----- 6 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/dnet/protos/dnet_ring.proto b/src/dnet/protos/dnet_ring.proto index 5452a559..8009601f 100644 --- a/src/dnet/protos/dnet_ring.proto +++ b/src/dnet/protos/dnet_ring.proto @@ -33,9 +33,10 @@ message ActivationRequest { Activation activation = 2; int64 timestamp = 3; float rx_enq_t = 4; - float rx_inflight_t = 5; - string node_origin = 6; - string callback_url = 7; + float tx_enq_prev_t = 5; + float rx_inflight_t = 6; + string node_origin = 7; + string callback_url = 8; } // Response message for activation sending diff --git a/src/dnet/protos/shard_api_comm.proto b/src/dnet/protos/shard_api_comm.proto index ec80cf18..6540f65f 100644 --- a/src/dnet/protos/shard_api_comm.proto +++ b/src/dnet/protos/shard_api_comm.proto @@ -35,6 +35,7 @@ message TokenRequest { string nonce = 1; int32 token_id = 2; int64 timestamp = 3; + float tx_enq_prev_t = 4; } // Response for token reception @@ -49,4 +50,4 @@ message RingError { string failed_node = 2; string error_code = 3; string error = 4; -} \ No newline at end of file +} diff --git a/src/dnet/ring/data_types.py b/src/dnet/ring/data_types.py index db9438e5..ee6f81e6 100644 --- a/src/dnet/ring/data_types.py +++ b/src/dnet/ring/data_types.py @@ -27,6 +27,7 @@ class ActivationMessage: # TX queue enqueue time (perf_counter seconds) rx_enq_perf_t: float = 0.0 tx_enq_perf_t: float = 0.0 + tx_enq_prev_t: float = 0.0 rx_ingress_t: float = 0.0 rx_inflight_t: float = 0.0 ex_enq_t: float = 0.0 diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index b36f8250..a8eff0ac 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -274,6 +274,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): nonce=activation_msg.nonce, token_id=int(getattr(activation_msg, "token_id", -1)), timestamp=utc_epoch_now(), + tx_enq_prev_t=time.perf_counter(), ) resp = await self.api_stub.SendToken(req) # type: ignore if not resp.success: @@ -355,7 +356,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): if (nxt < self.model_metadata.num_layers) and (nxt not in self._assigned_set): if self.next_node_stub: - with self.tracer.frame("grpc", "send_activation.next") as f: + with self.tracer.frame("network", "send_activation.next") as f: request = activation_msg.to_proto(data) request.timestamp = utc_epoch_now() if self._mode == "offload" and self.window_size > 0: @@ -416,6 +417,8 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) ctx.disabled = True + request.tx_enq_prev_t = time.perf_counter() + # Prefer streaming if enabled/available; fallback to unary stream_used = False ctx = await self._ensure_stream(activation_msg.nonce) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 00808f70..409ea241 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -802,7 +802,8 @@ async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: try: rx_t = time.perf_counter() request.rx_enq_t = rx_t - request.rx_inflight_t = rx_t - request.timestamp + request.rx_inflight_t = 0.0 if request.tx_enq_prev_t == 0.0 else rx_t - request_enq_prev_t + logger.error(f"rx_t {rx_t} --- tx_enq {request.tx_enq_prev_t}") self.ingress_q.put_nowait(request) logger.debug(f"[ENQUE] Enqueued activation request") @@ -1119,9 +1120,14 @@ def _compute_worker(self) -> None: # Process the activation with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats f.set("nonce", activation_msg.nonce) - f.set("inwait", time.perf_counter() - activation_msg.ex_enq_t) + if activation_msg.ex_enq_t == 0.0: # FIXME float comparison + f.set("inwait", 0.0) + else: + f.set("inwait", time.perf_counter() - activation_msg.ex_enq_t) + if (self.model_metadata.num_layers - 1) in self.assigned_layers: f.set("lm_head", True) + self._process_activation(activation_msg) except Empty: diff --git a/src/repl.py b/src/repl.py index f9cb363b..dc629123 100644 --- a/src/repl.py +++ b/src/repl.py @@ -536,6 +536,9 @@ def do_trace(self, cmd): self._trace_cfg.enabled = False dprint("Tracing is now OFF\n") + case s if s in ["status"]: + dprint(f"Captured {len(self._trace_agg._req)} frames.\n") + case s if s == "focus": dprint("Subsystems not yet implemented.\n") @@ -572,7 +575,10 @@ def do_perf(self, cmd): return match cmd[1]: - case s if s in "...": + case s if s in "stats": + print(f"{self._stats_agg._nonces}") + print(f"{self._stats_agg._running_stats}") + print(f"{self._stats_agg._stats}") pass case _: pass @@ -580,10 +586,13 @@ def do_perf(self, cmd): # Trace callback registered with API Thread # This forwards the tracer frames back to the REPL for printing def __trace_cb(self, data): - if self._tracing.is_set(): - self._trace_agg.enqueue(data) - if self._stats.is_set(): - self._stats_agg.add(data) + try: + if self._tracing.is_set(): + self._trace_agg.enqueue(data) + if self._stats.is_set(): + self._stats_agg.add(data) + except Exception as e: + print(f"Unable to ingest trace buffer into REPL: {e}") def __print_tr(self, row): sym = " " + symbol.ljust(40, ' ') From 7f8980065c8a10b5b243b69ac4e96080a6a9056d Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 15:55:13 -0700 Subject: [PATCH 050/172] started filtering counters and printing --- src/dnet/perf/utils/aggregators.py | 145 ++++++++++++++++++++++------- 1 file changed, 109 insertions(+), 36 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index fad29a7d..6a0653d0 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -1,6 +1,7 @@ from __future__ import annotations +import sys import threading from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple, Optional, DefaultDict @@ -213,6 +214,7 @@ class ReqStats: itl: List[float] # Inter-token latency per round prompt_tokens: int # Number of prompt tokens per request (req_id: #) generated_tokens: int # Number of generated tokens per request (req_id: #) + total_tokens: int # Total number of tokens processed latencies: List[List[str, str, str, int]] # List of inter-node latencies: [node0, node1, p50, 0.0] latency_per_layer: Dict[int, float] # Map of {layer: 0.0} @@ -248,10 +250,13 @@ def __init__(self) -> None: # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: - run_id = data["run_id"] or "NONE" - node_id = data["node_id"] or "NONE" + run_id = data["run_id"] + node_id = data["node_id"] events = data["events"] or [] - if not run_id or not node_id: return # Drop the batch + + if not run_id or not node_id: + print("Dropped batch") + return # Drop the batch with self._lock: @@ -269,6 +274,7 @@ def add(self, data: Dict[str, Any]) -> None: if len(self._frames[node_id]) >= self._max_inflight_req: # remove oldest entry del self._frames[self._nonces[0]] del self._nonces[0] + if nonce not in self._nonces: self._nonces.append(nonce) @@ -277,82 +283,120 @@ def add(self, data: Dict[str, Any]) -> None: nonce = e["args"]["nonce"] assert nonce is not None, "" + if not node_id or not nonce: return # Drop invalid frames - if e["name"] == "chat.request.end": - print(e) if e["name"] == "chat.request.start": - print(e) + print(e["args"]) + self._open_frames[nonce] = {} + self._nonce_prefill[nonce] = True self._running_stats[nonce] = ReqStats( - model="", - tokenizer="", + model=e["args"]["model"], + tokenizer=e["args"]["tokenizer"], run_id=run_id, nonce=nonce, - ttft=0.0, + ttft=e["args"]["t0"], # set to initial timestamp then compute itl=[0.0], generated_tokens=0, prompt_tokens=e["args"]["prompt_tokens"], + total_tokens=e["args"]["prompt_tokens"], latencies={}, latency_per_layer={}, latency_per_shard={}, total_latency=0.0, assignment=None, topo=None, + layer_assignment_t=None, + throughput=0.0, + startup_t=0.0, ) + + if e["name"] == "embedding": # Register new request pass # FIXME: We might receive other frames then "embed" from shards # so we need to handle the creation of this better - if nonce not in self._running_stats: - continue + if nonce not in self._running_stats: + continue stats = self._running_stats[nonce] - if e["name"] == "network.rx": - # Time in transport, ingress queue and ingress_worker + if e["name"] == "network.rx": # Time in transport, ingress queue and ingress_worker + print(f"\n{e["name"]}\n{e["args"]["inflight"]}\n{e["args"]["inwait"]}\n{e["args"]["ms"]}") _cost = lambda e: e["args"]["inflight"] + e["args"]["inwait"] + e["args"]["ms"] self._handle_frame(e, nonce, stats, _cost) #TODO: change shard in metadata if e["name"] == "compute.forward": + print(f"\n{e["name"]}\n{e["args"]["inwait"]}\n{e["args"]["ms"]}") _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # compute queue + execution self._handle_frame(e, nonce, stats, _cost) + self._nonce_round_finish[nonce] = False + + # End a cycle on compute done (inter-node queue wait is computed in next) + if self._nonce_prefill[nonce]: + stats.ttft = e["args"]["t0"] - stats.ttft + else: + stats.itl[-1] = e["args"]["t0"] - stats.itl[-1] + stats.itl.append(e["args"]["t0"]) + + if e["name"] == "chat.request.end": + if self._nonce_round_finish[nonce]: + self._nonce_round_finish[nonce] = True + pass + self._nonce_round_finish[nonce] = True + st_obj = self._running_stats[nonce] + st_obj.generated_tokens = e["args"]["generated_tokens"] + st_obj.total_tokens += e["args"]["generated_tokens"] + self._stats[nonce] = st_obj + del self._running_stats[nonce] + #del self._frames[node_id][nonce] + # TODO: Handle latency of transfer back to API - # Finish request - if "lm_head" in e["args"] and not self._nonce_round_finish[nonce]: - self._nonce_round_finish[nonce] = True - st_obj = self._running_stats[nonce] - self._stats[nonce] = st_obj - del self._running_stats[nonce] - #del self._frames[node_id][nonce] - # TODO: Handle latency of transfer back to API - if e["name"] == "network.tx.send": _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # tx queue + sendoff self._handle_frame(e, nonce, stats, _cost) # Handle cost aggregation of frames def _handle_frame(self, e: Any, nonce, stats: ReqStats, _cost_fnc: Any): - if e.type == 'B': - self._open_frames[nonce][e.name] = e - return - elif e.type == 'E': - n_rt = _cost_fnc(e) # Custom cost function for each farme - if self._nonce_prefill[nonce]: - stats.ttft += n_rt - else: - stats.itl[-1] += n_rt - del self._open_frames[nonce][e.name] + try: + if e["type"] == 'B': + self._open_frames[nonce][e["name"]] = e + return + elif e["type"] == 'E': + n_rt = _cost_fnc(e) # Custom cost function for each farme + if self._nonce_prefill[nonce]: + stats.ttft += n_rt + else: + stats.itl[-1] += n_rt + if e["name"] in self._open_frames[nonce]: + del self._open_frames[nonce][e["name"]] + except Exception as ex: + print(f"{ex}") # Return data for total, per req, worker or model (maybe add per layer too?) def stats( self, - req_id: Optional[str], - worker: Optional[str], - model: Optional[str] + req_id: Optional[str] = None, + worker: Optional[str] = None, + model: Optional[str] = None ): + fields = [ # 0 is native, 1 compound + (0, "prompt_tokens", ""), + (0, "generated_tokens", ""), + (0, "total_tokens", ""), + (0, -1, ""), # special for empty line + (0, "ttft", "ms"), + (1, "tokens_per_second", "ms"), + (1, "inter_token_latency", "ms"), + (0, "throughput", "ms"), + (0, -1, ""), + (1, "workers", "ms"), + (1, "estimated_compute", "GFLOPs") + ] + if req_id: pass @@ -363,5 +407,34 @@ def stats( pass else: # Sort per model, per request (node info only when requested) - pass + if len(self._stats) < 1: + print("No tracked stats in memory. Track a request first.\n") + return + stats = self._stats[list(self._stats.keys())[-1]] + sys.stdout.write(f"\n Performance counters stats for model '{stats.model}':\n\n") + for tag, n, unit in fields: + if tag == 0: # Native counter + if n == -1: + sys.stdout.write("\n") + continue + nr = getattr(stats, n) + if isinstance(nr, int): + nr_str = f"{nr}" + elif isinstance(nr, float): + nr_str = f"{nr:15.5}" + elif isinstance(nr, str): + if len(nr) > 20: + nr_str = nr[:15] + "..." + else: + nr_str = nr + elif tag == 1: + match n: + case "tokens_per_second": + case "inter_token_latency": + case _: + pass + sys.stdout.write(f"{nr_str.rjust(20)} {unit.ljust(4)}\t{n}\n") + sys.stdout.write("\n\n") + return + From 74aacdf8fea1e7ee348115e91ae0a0363f810d36 Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 23:20:26 -0700 Subject: [PATCH 051/172] basic counters working, ttft, tps, itl, token_count --- src/dnet/perf/utils/aggregators.py | 238 ++++++++++++++++------------- src/dnet/ring/api/node.py | 8 +- src/repl.py | 11 +- 3 files changed, 143 insertions(+), 114 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index 6a0653d0..51b4494d 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -3,13 +3,17 @@ import sys import threading +import statistics from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple, Optional, DefaultDict from collections import defaultdict, deque -from dnet.utils.logger import logger +#from dnet.utils.logger import logger +from dnet.ring.api.api_logging import get_api_logger from dnet.ring import LayerAssignment, TopologyInfo +logger = get_api_logger() + Key = Tuple[str, Optional[int], Optional[int], str] # (node_id, pid, tid, req_id) @dataclass @@ -206,23 +210,31 @@ def roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: # Track a single request, use multiple for a full benchmark @dataclass class ReqStats: - model: str # Model name - tokenizer: str # Tokenizer name - run_id: str # ID of session (for later mapping) - nonce: str # List of serviced requests - ttft: float # Time to first token - itl: List[float] # Inter-token latency per round - prompt_tokens: int # Number of prompt tokens per request (req_id: #) - generated_tokens: int # Number of generated tokens per request (req_id: #) - total_tokens: int # Total number of tokens processed - - latencies: List[List[str, str, str, int]] # List of inter-node latencies: [node0, node1, p50, 0.0] - latency_per_layer: Dict[int, float] # Map of {layer: 0.0} - latency_per_shard: Dict[str, float] # Map of {shard: 0.0} - total_latency: int # Total runtime of requests - throughput: float # aaa - startup_t: float # Time to start shard (ms) - layer_assignment_t: float # Time to layer assignment (ms) + model: str = "" # Model name + tokenizer: str = "" # Tokenizer name + run_id: str = "" # ID of session (for later mapping) + nonce: str = "" # List of serviced requests + ttft: float = 0.0 # Time to first token + itl: List[float] = None # Inter-token latency per round + prompt_tokens: int = -1 # Number of prompt tokens per request (req_id: #) + generated_tokens: int = -1 # Number of generated tokens per request (req_id: #) + total_tokens: int = -1 # Total number of tokens processed + + latencies: List[List[str, str, str, int]] = None # List of inter-node latencies: [node0, node1, p50, 0.0] + latency_per_layer: Dict[int, float] = None # Map of {layer: 0.0} + latency_per_shard: Dict[str, float] = None # Map of {shard: 0.0} + total_latency: int = -1 # Total runtime of requests + startup_t: float = 0.0 # Time to start shard (ms) + layer_assignment_t: float = 0.0 # Time to layer assignment (ms) + + # Per-worker data + compute_per_worker: Dict[str, float] = None + inwait_per_worker: Dict[str, float] = None + inflight_per_worker: Dict[str, float] = None + + # Network info + tx_bytes_per_node: Dict[str, int] = None # Volume of trafic per node + rx_bytes_per_node: Dict[str, int] = None topo: TopologyInfo = None # Topology information for this request (keep here since it might change) assignment: LayerAssignment = None # Map of layer to shard IDs @@ -259,7 +271,6 @@ def add(self, data: Dict[str, Any]) -> None: return # Drop the batch with self._lock: - # Ensure we register workers and nodes for i, ev in enumerate(events): if "nonce" not in ev["args"]: ev["args"]["nonce"] = f"N_" @@ -283,11 +294,9 @@ def add(self, data: Dict[str, Any]) -> None: nonce = e["args"]["nonce"] assert nonce is not None, "" - if not node_id or not nonce: return # Drop invalid frames if e["name"] == "chat.request.start": - print(e["args"]) self._open_frames[nonce] = {} self._nonce_prefill[nonce] = True self._running_stats[nonce] = ReqStats( @@ -295,23 +304,19 @@ def add(self, data: Dict[str, Any]) -> None: tokenizer=e["args"]["tokenizer"], run_id=run_id, nonce=nonce, - ttft=e["args"]["t0"], # set to initial timestamp then compute - itl=[0.0], - generated_tokens=0, + ttft= e["args"]["t0"], + itl=[ e["args"]["t0"], ], prompt_tokens=e["args"]["prompt_tokens"], total_tokens=e["args"]["prompt_tokens"], latencies={}, latency_per_layer={}, latency_per_shard={}, - total_latency=0.0, assignment=None, - topo=None, - layer_assignment_t=None, - throughput=0.0, - startup_t=0.0, + compute_per_worker={}, + inwait_per_worker={}, + inflight_per_worker={}, ) - if e["name"] == "embedding": # Register new request pass @@ -323,55 +328,45 @@ def add(self, data: Dict[str, Any]) -> None: stats = self._running_stats[nonce] if e["name"] == "network.rx": # Time in transport, ingress queue and ingress_worker - print(f"\n{e["name"]}\n{e["args"]["inflight"]}\n{e["args"]["inwait"]}\n{e["args"]["ms"]}") _cost = lambda e: e["args"]["inflight"] + e["args"]["inwait"] + e["args"]["ms"] - self._handle_frame(e, nonce, stats, _cost) #TODO: change shard in metadata if e["name"] == "compute.forward": - print(f"\n{e["name"]}\n{e["args"]["inwait"]}\n{e["args"]["ms"]}") - _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # compute queue + execution - self._handle_frame(e, nonce, stats, _cost) - self._nonce_round_finish[nonce] = False - - # End a cycle on compute done (inter-node queue wait is computed in next) - if self._nonce_prefill[nonce]: - stats.ttft = e["args"]["t0"] - stats.ttft - else: - stats.itl[-1] = e["args"]["t0"] - stats.itl[-1] - stats.itl.append(e["args"]["t0"]) - - if e["name"] == "chat.request.end": - if self._nonce_round_finish[nonce]: - self._nonce_round_finish[nonce] = True - pass - self._nonce_round_finish[nonce] = True - st_obj = self._running_stats[nonce] - st_obj.generated_tokens = e["args"]["generated_tokens"] - st_obj.total_tokens += e["args"]["generated_tokens"] - self._stats[nonce] = st_obj - del self._running_stats[nonce] - #del self._frames[node_id][nonce] - # TODO: Handle latency of transfer back to API - - if e["name"] == "network.tx.send": - _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # tx queue + sendoff - self._handle_frame(e, nonce, stats, _cost) + try: + _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # compute queue + execution + self._handle_round(e, nonce, stats, _cost) + except Exception as e: + print(f"{e}") + + try: + if e["name"] == "chat.request.end": + st_obj = self._running_stats[nonce] + st_obj.generated_tokens = e["args"]["generated_tokens"] + st_obj.total_tokens += e["args"]["generated_tokens"] + print("Adding to stats") + self._stats[nonce] = st_obj + del self._running_stats[nonce] + #del self._frames[node_id][nonce] + # TODO: Handle latency of transfer back to API + + + if e["name"] == "network.tx.send": + _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # tx queue + sendoff + + except Exception as e: + print(f"{e}") # Handle cost aggregation of frames - def _handle_frame(self, e: Any, nonce, stats: ReqStats, _cost_fnc: Any): + def _handle_round(self, e: Any, nonce, stats: ReqStats, _cost_fnc: Any): try: - if e["type"] == 'B': - self._open_frames[nonce][e["name"]] = e - return - elif e["type"] == 'E': - n_rt = _cost_fnc(e) # Custom cost function for each farme - if self._nonce_prefill[nonce]: - stats.ttft += n_rt - else: - stats.itl[-1] += n_rt - if e["name"] in self._open_frames[nonce]: - del self._open_frames[nonce][e["name"]] + logger.error(f"TTFT: {e["args"]["t0"]} - {stats.ttft}") + if self._nonce_prefill[nonce]: + stats.ttft = (e["args"]["t0"] - stats.ttft) * 1000.0 + self._nonce_prefill[nonce] = False + else: + if e["args"]["t0"] > 0.0: + stats.itl[-1] = (e["args"]["t0"] - stats.itl[-1]) + stats.itl.append(e["args"]["t0"]) except Exception as ex: print(f"{ex}") @@ -383,6 +378,7 @@ def stats( model: Optional[str] = None ): + # FIXME: Allow manual selection of counters (and push to tracer) fields = [ # 0 is native, 1 compound (0, "prompt_tokens", ""), (0, "generated_tokens", ""), @@ -391,50 +387,78 @@ def stats( (0, "ttft", "ms"), (1, "tokens_per_second", "ms"), (1, "inter_token_latency", "ms"), - (0, "throughput", "ms"), (0, -1, ""), - (1, "workers", "ms"), - (1, "estimated_compute", "GFLOPs") + (1, "estimated_compute", "GFLOPs"), + (1, "compute_time_per_worker", "ms"), + (1, "inwait_time_per_worker", "ms"), + (1, "inflight_time_per_worker", "ms"), + (0, -1, ""), + (1, "network_latency", "ms"), ] - if req_id: - pass - - elif worker: - pass - - elif model: - pass + # FIXME: Allow filtering by these + if req_id: pass + elif worker: pass + elif model: pass else: # Sort per model, per request (node info only when requested) if len(self._stats) < 1: print("No tracked stats in memory. Track a request first.\n") return stats = self._stats[list(self._stats.keys())[-1]] - sys.stdout.write(f"\n Performance counters stats for model '{stats.model}':\n\n") - for tag, n, unit in fields: - if tag == 0: # Native counter - if n == -1: - sys.stdout.write("\n") - continue - nr = getattr(stats, n) - if isinstance(nr, int): - nr_str = f"{nr}" - elif isinstance(nr, float): - nr_str = f"{nr:15.5}" - elif isinstance(nr, str): - if len(nr) > 20: - nr_str = nr[:15] + "..." - else: - nr_str = nr - elif tag == 1: - match n: - case "tokens_per_second": - case "inter_token_latency": - case _: - pass - sys.stdout.write(f"{nr_str.rjust(20)} {unit.ljust(4)}\t{n}\n") - sys.stdout.write("\n\n") + #sys.stdout.write(f"\n Loaded model '{stats.model}'.\n") + sys.stdout.write(f"Performance stats for request '{stats.nonce}':\n\n") + try: + for tag, n, unit in fields: + if tag == 0: # Native counter + if n == -1: + sys.stdout.write("\n") + continue + nr = getattr(stats, n) + if isinstance(nr, int): + nr_str = f"{nr}" + elif isinstance(nr, float): + nr_str = f"{nr:.2f}" + elif isinstance(nr, str): + if len(nr) > 20: + nr_str = nr[:15] + "..." + else: + nr_str = nr + sys.stdout.write(f"{nr_str.rjust(20)} {unit.ljust(5)}\t{n}\n") + + # Compound trackers + elif tag == 1: + match n: + case "tokens_per_second": + tps = [ 1 / rt for rt in stats.itl ] + #median = statistics.median(tps) + mean = sum(tps) / len(tps) + sys.stdout.write(f"{mean:.2f}".rjust(20)+" tok/s".rjust(5)+" \ttokens_per_second") + sys.stdout.write(f"\t# {statistics.median(stats.itl):.3f} s/tok\n") + + case "inter_token_latency": + itl = stats.itl[:-1] # FIXME: last element is super big + median = statistics.median(itl) + p90 = statistics.quantiles(itl, n=100)[89] + p99 = statistics.quantiles(itl, n=100)[98] + sys.stdout.write(f"{median:.4f}".rjust(20) + " ms".ljust(5) + "\tmean_inter_token_latency (ITL)\n") + sys.stdout.write(" "*35 + f"{p90:.3f} (p90), {p99:.3f} (p99)\n") + sys.stdout.write(" "*35 +f"{min(itl):.3f} (min), {max(itl):.3f} (max)\n") + + case "estimated_compute": + sys.stdout.write(f"UNKNOWN":rjust(20)+" GFLOPs".ljust(5)+"\testimated_flops\n") + + case "compute_time_per_worker": + pass + + case _: + pass + + except Exception as e: + logger.error(f"{e}") + + # Per-node information + sys.stdout.write("\n") return diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 81e1f198..2ec5bcce 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -419,8 +419,8 @@ async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: self._trace_ingest_cb(batch.model_dump()) _t_batch = { "run_id": "NONE", "node_id": "API", "events": list(self.tracer._events) } - self.tracer._events.clear() self._trace_ingest_cb(_t_batch) # FIXME: Move this + self.tracer._events.clear() return TraceIngestResponse(ok=True, accepted=len(batch.events)) @@ -1293,9 +1293,12 @@ async def _handle_completion( nonce = f"chatcmpl-{uuid.uuid4()}" self.tracer.mark("chat.request.start", { - "tokenizer": None, + "tokenizer": "", + "model": req.model, + "temperature": req.temperature, "prompt_tokens": prompt.size, "nonce": nonce, + "t0": time.perf_counter(), }) detokenizer = self.tokenizer.detokenizer # type: ignore @@ -1357,6 +1360,7 @@ async def _handle_completion( self.tracer.mark("chat.request.end", { "generated_tokens": len(tokens), "nonce": nonce, + "t0": time.perf_counter(), }) # Build optional metrics diff --git a/src/repl.py b/src/repl.py index dc629123..f0e9fb6d 100644 --- a/src/repl.py +++ b/src/repl.py @@ -1,4 +1,5 @@ +import io import os import sys import logging @@ -6,6 +7,7 @@ import time import argparse import subprocess +import contextlib from dataclasses import dataclass from typing import Optional, List, Any, Dict @@ -146,7 +148,7 @@ def loop(self): # Main tty loop elif cmd.startswith(("perf", ".perf")): self.do_perf(cmd.split(" ")) continue - elif cmd.startswith(("topo", ".topo")): + elif cmd.startswith(("topo", ".topo", "t ")): self.do_topo(cmd.split(" ")) continue elif cmd.startswith((".model", "model", "m ")): @@ -209,7 +211,7 @@ def do_topo(self, cmd: List[str]) -> None: if cmd[1] == "search": self.print_mdns_nodes() pass - elif cmd[1] == "auto" or cmd[1] == "build": + elif cmd[1] in ("auto", "build", "b"): self.prepare_topo() pass elif cmd[1] == "setup": @@ -391,6 +393,7 @@ def handle_start_worker(self): # ===== Handle API server async def _api_main(self) -> None: # main thread loop + logging.disable(logging.CRITICAL) self._api_loop = asyncio.get_running_loop() self._api_shutdown_e = asyncio.Event() self._node = RingApiNode( @@ -576,9 +579,7 @@ def do_perf(self, cmd): match cmd[1]: case s if s in "stats": - print(f"{self._stats_agg._nonces}") - print(f"{self._stats_agg._running_stats}") - print(f"{self._stats_agg._stats}") + self._stats_agg.stats() pass case _: pass From d9946196282ade0c7601415460a402998925a0f2 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sat, 25 Oct 2025 00:51:11 -0700 Subject: [PATCH 052/172] track node_id for every frame --- src/dnet/ring/shard/comms.py | 6 +++++- src/dnet/ring/shard/compute.py | 20 +++++++++++++------ src/dnet/ring/shard/node.py | 36 +++++++++++++++++++++++++++++----- 3 files changed, 50 insertions(+), 12 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index a8eff0ac..b827bcf0 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -174,8 +174,9 @@ async def _send_worker(self): activation_msg = await self.activation_computed_queue.get() with self.tracer.frame("network", "tx") as f: if activation_msg.tx_enq_perf_t and self._profile: - f.set("inwait", time.perf_counter() - self._rx_enque_t) + f.set("inwait", time.perf_counter() - activation_msg.tx_enq_t) f.set("nonce", activation_msg.nonce) + f.set("node", self._instance_name) q_wait_ms = ( time.perf_counter() - activation_msg.tx_enq_perf_t ) * 1000.0 @@ -232,6 +233,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): logger.debug(f"Sending activation") if activation_msg.is_final: with self.tracer.frame("grpc", "send_activation.final") as f: + f.set("node", self._instance_name) try: if self._mode == "offload" and self.window_size > 0: first_window = self._assigned_sorted[: self.window_size] @@ -269,6 +271,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): f.event("reset_api") with self.tracer.frame("grpc", "token_request") as fr: + fr.set("node", self._instance_name) try: req = shard_api_comm_pb2.TokenRequest( nonce=activation_msg.nonce, @@ -357,6 +360,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): if self.next_node_stub: with self.tracer.frame("network", "send_activation.next") as f: + f.set("node", self._instance_name) request = activation_msg.to_proto(data) request.timestamp = utc_epoch_now() if self._mode == "offload" and self.window_size > 0: diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index d51343a3..50744b55 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -76,11 +76,13 @@ def _process_activation(self, activation_msg: ActivationMessage): try: # per-nonce kvcache for concurrent requests - with self.tracer.frame("compute.thread", "kvcache.init"): + with self.tracer.frame("compute.thread", "kvcache.init") as f: + f.set("node", self._instance_name) kv = self._get_or_make_kv(activation_msg.nonce) # Get input activation from pool - with self.tracer.frame("compute.thread", "activations.load"): + with self.tracer.frame("compute.thread", "activations.load") as f: + f.set("node", self._instance_name) input_buffer = self.input_pool.get_buffer(activation_msg.pool_id) if input_buffer is None: logger.error("Failed to get input buffer %s", activation_msg.pool_id) @@ -89,6 +91,7 @@ def _process_activation(self, activation_msg: ActivationMessage): # Prepare input activation with self.tracer.frame("compute.thread", "activations.process") as f: f.set("nonce", activation_msg.nonce) + f.set("node", self._instance_name) if activation_msg.dtype == "tokens": # embed locally on start shard logger.debug(f"Embedding tokens.") numel = int(np.prod(activation_msg.shape)) @@ -124,6 +127,7 @@ def _process_activation(self, activation_msg: ActivationMessage): did_early_swap = False with self.tracer.frame("compute.thread", "weights.prepare") as f: + f.set("node", self._instance_name) # Determine contiguous local window starting at current_layer window_layers: List[int] = [] @@ -217,7 +221,8 @@ def _process_activation(self, activation_msg: ActivationMessage): pass # Execute the window - with self.tracer.frame("compute.thread", "execute"): + with self.tracer.frame("compute.thread", "execute") as f: + f.set("node", self._instance_name) self._beyond_cursor = window_layers[-1] if window_layers else (activation_msg.layer_id) try: # Prevent prefetch touching during encode/compute to minimize UMA pressure @@ -273,7 +278,8 @@ def _process_activation(self, activation_msg: ActivationMessage): #self.weight_cache.decrease_reference(lid) pass - with self.tracer.frame("compute.thread", "execute.evict_and_unload"): + with self.tracer.frame("compute.thread", "execute.evict_and_unload") as f: + f.set("node", self._instance_name) try: # Sliding-fit delta swap: maintain a single resident set by evicting # only what's needed to fit the next window into the budget. Prefer @@ -368,7 +374,8 @@ def _process_activation(self, activation_msg: ActivationMessage): continue # Boundary reached — directly pass tensor to TX to avoid extra copy/sync - with self.tracer.frame("compute.thread", "execute.enqueue_prefetch"): + with self.tracer.frame("compute.thread", "execute.enqueue_prefetch") as f: + f.set("node", self._instance_name) x_cast = x if x.dtype == self._wire_mx_dtype else x.astype(self._wire_mx_dtype) try: self._compute_busy.clear() @@ -382,7 +389,8 @@ def _process_activation(self, activation_msg: ActivationMessage): pass # Create and enqueue output message: either forward activations or finalize on end role - with self.tracer.frame("compute.thread", "grpc.send"): + with self.tracer.frame("compute.thread", "grpc.send") as f: + f.set("node", self._instance_name) nxt = last_layer + 1 if nxt >= self.model_metadata.num_layers: # End of model try: diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 409ea241..8878f00b 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -192,6 +192,7 @@ def __init__( # Discovery self.discovery = AsyncDnetP2P("lib/dnet-p2p/lib") + self._instance_name = "" # Background tasks self.background_tasks: List[asyncio.Task] = [] @@ -258,7 +259,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse and (self.model_path != req.model_path or self.assigned_layers != req.layers) ): logger.info("Node %s: Unloading current model to load new configuration", self.node_id) - with self.tracer.frame("memory", "model.unload"): + with self.tracer.frame("memory", "model.unload") as f: + f.set("node", self._instance_name) await self.unload_model() # Load model metadata @@ -366,7 +368,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # Initialize weight cache - with self.tracer.frame("memory", "weight_cache.init"): + with self.tracer.frame("memory", "weight_cache.init") as f: + f.set("node", self._instance_name) self.weight_cache = WeightCache( self.assigned_layers, self.model_metadata, @@ -379,7 +382,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # Load the model - with self.tracer.frame("memory", "model.load"): + with self.tracer.frame("memory", "model.load") as f: + f.set("node", self._instance_name) self.model = get_ring_model( self.model_metadata.model_type, self.model_metadata.model_config, @@ -403,7 +407,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse logger.warning("[QUANT] apply failed: %s", e) self.model.eval() - with self.tracer.frame("memory", "make_cache"): + with self.tracer.frame("memory", "make_cache") as f: + f.set("node", self._instance_name) self.cache = make_cache( self.model, kv_mode=self.config.kv_cache.mode, @@ -443,7 +448,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse self.total_layers = req.total_layers self.api_callback_address = req.api_callback_address - with self.tracer.frame("network", "connect.next_node"): + with self.tracer.frame("network", "connect.next_node") as f: + f.set("node", self._instance_name) if self.next_node: await self._connect_next_node() else: @@ -616,6 +622,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): with self.tracer.frame("network.rx", "connect_next_node") as f: f.set("nonce", request.nonce) + f.set("node", self._instance_name) await self._connect_next_node() with self.tracer.frame("network.rx", "process_activation") as f: @@ -638,6 +645,8 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): t_alloc = time.perf_counter() if "|" in activation.dtype: with self.tracer.frame("grpc.receive", "decompress") as fr: + fr.set("nonce", request.nonce) + fr.set("node", self._instance_name) try: deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, @@ -694,6 +703,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): with self.tracer.frame("network.rx", "alloc.buffer") as fr: fr.set("nonce", request.nonce) + fr.set("node", self._instance_name) pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, dtype=deq.dtype, @@ -718,6 +728,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): if activation.dtype == "tokens": with self.tracer.frame("network.rx", "token_stream") as fr: fr.set("nonce", request.nonce) + fr.set("node", self._instance_name) try: tokens = np.frombuffer(request.activation.data, dtype=np.int32) shp = (int(len(tokens)), ) @@ -745,6 +756,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): else: with self.tracer.frame("network.ex", "default") as fr: + fr.set("node", self._instance_name) fr.set("nonce", request.nonce) # Safety: byte length must match shape*dtype try: @@ -821,6 +833,7 @@ async def _ingress_worker(self): finally enqueues for compute or forwards to the next shard. """ while self.running: + logger.error(f"NODE_ID {self.node_id}") with self.tracer.frame("network.rx", "wait"): # NOTE: bad counter try: req = await self.ingress_q.get() @@ -832,6 +845,7 @@ async def _ingress_worker(self): with self.tracer.frame("network", "rx") as f: f.set("inwait", time.perf_counter() - req.rx_enq_t) f.set("inflight", req.rx_inflight_t) + f.set("node", self._instance_name) f.set("nonce", req.nonce) try: @@ -858,6 +872,7 @@ async def _ingress_worker(self): if target_layer in self._assigned_set: # Heavy prep in executor (alloc/copy/decompress) with self.tracer.frame("grpc.ingress", "prepare") as fr: + fr.set("node", self._instance_name) fr.set("nonce", req.nonce) loop = asyncio.get_running_loop() try: @@ -891,6 +906,7 @@ async def _ingress_worker(self): # Enqueue for compute with self.tracer.frame("network.rx", "enque") as fr: + fr.set("node", self._instance_name) fr.set("nonce", req.nonce) while self.running: try: @@ -981,6 +997,7 @@ def _prepare_activation_message_blocking( if "|" in activation.dtype: # Compressed path: decompress to MLX array and copy to pool with self.tracer.frame("network.rx.prepare_activation", "decompress") as f: + f.set("node", self._instance_name) f.set("nonce", request.nonce) try: deq = decompress_tensor_from_protobuf_data( @@ -1018,6 +1035,7 @@ def _prepare_activation_message_blocking( elif activation.dtype == "tokens": # Tokens path: parse int32 token IDs and stage them with self.tracer.frame("network.rx.prepare_activation", "tokens") as f: + f.set("node", self._instance_name) f.set("nonce", request.nonce) try: tokens = np.frombuffer(activation.data, dtype=np.int32) @@ -1048,6 +1066,7 @@ def _prepare_activation_message_blocking( else: # Dense path: validate size and copy raw bytes view into pool buffer with self.tracer.frame("network.rx.prepare_activation", "default") as f: + f.set("node", self._instance_name) f.set("nonce", request.nonce) try: expected = ( @@ -1120,6 +1139,7 @@ def _compute_worker(self) -> None: # Process the activation with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats f.set("nonce", activation_msg.nonce) + f.set("node", self._instance_name) if activation_msg.ex_enq_t == 0.0: # FIXME float comparison f.set("inwait", 0.0) else: @@ -1129,6 +1149,7 @@ def _compute_worker(self) -> None: f.set("lm_head", True) self._process_activation(activation_msg) + f.set("t0", time.perf_counter()) except Empty: continue @@ -1226,6 +1247,7 @@ async def _start_discovery(self) -> None: hostname = gethostname() # TODO: optionally take shard name from CLI instance = f"shard-{token_hex(4)}-{hostname}" + self._instance_name = instance self.discovery.create_instance( instance, self.http_port, @@ -1545,6 +1567,7 @@ async def load_model_endpoint( ) self.tracer.mark("model", {"model": req.model_path, "ts": time.perf_counter()}) # Record model name with self.tracer.frame("memory", "model.load") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("node", self._instance_name) result = await self.load_model(req) return result @@ -1563,6 +1586,7 @@ async def unload_model_endpoint() -> ShardUnloadModelResponse: try: logger.info("HTTP /unload_model") with self.tracer.frame("memory", "model.unload") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("node", self._instance_name) result = await self.unload_model() return result @@ -1579,6 +1603,7 @@ async def warm(request: Request) -> JSONResponse: try: # FIXME: Append warmup config? Or something to distinguish with self.tracer.frame("memory", "model.warm") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("node", self._instance_name) f.set("nonce", request.nonce) body = await request.json() start = int(body.get("start", -1)) @@ -1645,6 +1670,7 @@ async def _profile_device(self, repo_id: str, max_batch_exp: int) -> dict: Device profile information as a plain dict """ with self.tracer.frame("startup", "profile.device") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("node", self._instance_name) profile_dict = profile_device_via_subprocess( repo_id, max_batch_exp=max_batch_exp, debug=0 ) From cbc9ab1bd0ca7ab293e5782209a88cb60b89c099 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sat, 25 Oct 2025 01:23:35 -0700 Subject: [PATCH 053/172] fix trace annotate --- src/repl.py | 68 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/src/repl.py b/src/repl.py index f0e9fb6d..231cecb1 100644 --- a/src/repl.py +++ b/src/repl.py @@ -393,7 +393,7 @@ def handle_start_worker(self): # ===== Handle API server async def _api_main(self) -> None: # main thread loop - logging.disable(logging.CRITICAL) + #logging.disable(logging.CRITICAL) self._api_loop = asyncio.get_running_loop() self._api_shutdown_e = asyncio.Event() self._node = RingApiNode( @@ -532,10 +532,12 @@ def do_trace(self, cmd): match cmd[1]: case s if s in ["on", "ON"]: + self._tracing.set() self._trace_cfg.enabled = True dprint("Tracing is now ON\n") case s if s in ["off", "OFF"]: + self._tracing.clear() self._trace_cfg.enabled = False dprint("Tracing is now OFF\n") @@ -562,7 +564,11 @@ def do_trace(self, cmd): dprint("Not implemented yet\n") case s if s == "annotate": - self.print_trace_annotate("NONE") + if len(self._trace_agg._req) < 1: + dprint("No trace frames captured. Is tracing enabled?\n") + return + last = list(self._trace_agg._req.keys())[-1] + self.print_trace_annotate(last) case _: dprint("Unknown trace command. Type 'help' for a list of available commands.\n") @@ -608,33 +614,37 @@ def print_trace_annotate( repeats: int = 0, ) -> List[Dict[str, Any]]: - rows = self._trace_agg.annotate(run_id) - headers = ["name", "total","max","mean","p50","p90","p99","samples"] - limits = {"name": 50,} - w = {h: max(len(h), min(limits.get(h, 8), max(len(str(r[h])) for r in rows))) for h in headers} - w["name"] = max(w["name"], 35) - - line = " ".join(h.ljust(w[h]) for h in headers); sys.stdout.write("\n") - sys.stdout.write(line + "\n") - sys.stdout.write(" ".join("."*w[h] for h in headers)); sys.stdout.write("\n") - for r in rows: - name = str(r["name"]) - if len(name) > w["name"]: name = name[:w["name"]-1] + "..." - vals = { - "name": r["name"], - "total": r["total"], - "max": r["max"], - "mean": r["mean"], - "p50": r["p50"], - "p90": r["p90"], - "p99": r["p99"], - "samples": r["samples"], - } - sys.stdout.write(" " + str(vals[headers[0]]).ljust(w[headers[0]])) - sys.stdout.write(" ".join(f"{vals[h]:8.2f}".rjust(w[h]) for h in headers[1:])) - sys.stdout.write("\n") - sys.stdout.write("\n\n") - sys.stdout.flush() + try: + rows = self._trace_agg.annotate(run_id) + logger.debug(f"rows") + headers = ["name", "total","max","mean","p50","p90","p99","samples"] + limits = {"name": 50,} + w = {h: max(len(h), min(limits.get(h, 8), max(len(str(r[h])) for r in rows))) for h in headers} + w["name"] = max(w["name"], 35) + + line = " ".join(h.ljust(w[h]) for h in headers); sys.stdout.write("\n") + sys.stdout.write(line + "\n") + sys.stdout.write(" ".join("."*w[h] for h in headers)); sys.stdout.write("\n") + for r in rows: + name = str(r["name"]) + if len(name) > w["name"]: name = name[:w["name"]-1] + "..." + vals = { + "name": r["name"], + "total": r["total"], + "max": r["max"], + "mean": r["mean"], + "p50": r["p50"], + "p90": r["p90"], + "p99": r["p99"], + "samples": r["samples"], + } + sys.stdout.write(" " + str(vals[headers[0]]).ljust(w[headers[0]])) + sys.stdout.write(" ".join(f"{vals[h]:8.2f}".rjust(w[h]) for h in headers[1:])) + sys.stdout.write("\n") + sys.stdout.write("\n\n") + sys.stdout.flush() + except Exception as e: + logger.error(f"{e}") def _print_nodes_table(self, rows: List[Any]) -> None: headers = ["name", "role", "addr", "http", "grpc", "status", "head"] From 30bd452ac6370aaa41598f8bb87684c59e252113 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sat, 25 Oct 2025 01:24:23 -0700 Subject: [PATCH 054/172] aggreagate frame symbols into sub-groups for focusing and compute total time per subsection --- src/dnet/perf/utils/aggregators.py | 42 +++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index 51b4494d..50d7ddd5 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -122,6 +122,7 @@ def __init__(self) -> None: self._lock = threading.Lock() def enqueue(self, batch: Dict[str, Any]) -> None: + try: run_id = batch.get("run_id") node_id = batch.get("node_id") events = batch.get("events") or [] @@ -131,6 +132,8 @@ def enqueue(self, batch: Dict[str, Any]) -> None: agg = self._req.setdefault(run_id, RunAggregator()) for ev in events: agg.ingest_event(node_id, ev) + except Exception as e: + logger.error(f"Trace aggregator enque error: {e}") def annotate(self, run_id: str, *, mapping: Optional[Dict[str, str]] = None, repeats: int = 0) -> List[Dict[str, Any]]: with self._lock: @@ -260,6 +263,43 @@ def __init__(self) -> None: self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per nonce) self._model_per_run: Dict[str, str] = {} # Track model per run_id + # Maps of frames to higher-level sub-systems + self._compute_set = [ + "compute.forward", + "compute.thread.kvcache.init", + "compute.thread.weights.prepare", + "compute.thread.activations.process", + "compute.thread.activations.load", + "compute.thread.execute", + "compute.thread.execute.enqueue_prefetch", + "compute.thread.execute.evict_and_unload", + "compute.thread.cleanup", + "compute.thread.mdns.send", + ] + + self._network_set = [ + "network.tx", + "network.token_request", + "network.rx.prepare", + "network.rx.prepare_activation.tokens", + "network.rx.enque", + "network.send_activation.final", + "network.rx", + "network.connect.next_node", + "network.rx.prefetch", + ] + + self._memory_set = [ + "memory.model.load", + "memory.model.load_metadata", + "memory.warmup", + "memory.weight_cache.init", + "memory.prefetch", + "memory.memory_pools.init", + "memory.cache.reset", + "memory.make_cache", + ] + # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: run_id = data["run_id"] @@ -446,7 +486,7 @@ def stats( sys.stdout.write(" "*35 +f"{min(itl):.3f} (min), {max(itl):.3f} (max)\n") case "estimated_compute": - sys.stdout.write(f"UNKNOWN":rjust(20)+" GFLOPs".ljust(5)+"\testimated_flops\n") + sys.stdout.write(f"UNKNOWN".rjust(20)+" GFLOPs".ljust(5)+"\testimated_flops\n") case "compute_time_per_worker": pass From 8471fb1c0a03d04935e427473610b89a9b07fd30 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sat, 25 Oct 2025 02:57:30 -0700 Subject: [PATCH 055/172] typo in activation timestamp --- src/dnet/ring/shard/node.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 8878f00b..8c21aedb 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -814,8 +814,7 @@ async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: try: rx_t = time.perf_counter() request.rx_enq_t = rx_t - request.rx_inflight_t = 0.0 if request.tx_enq_prev_t == 0.0 else rx_t - request_enq_prev_t - logger.error(f"rx_t {rx_t} --- tx_enq {request.tx_enq_prev_t}") + request.rx_inflight_t = 0.0 if request.tx_enq_prev_t == 0.0 else rx_t - request.tx_enq_prev_t self.ingress_q.put_nowait(request) logger.debug(f"[ENQUE] Enqueued activation request") @@ -833,7 +832,6 @@ async def _ingress_worker(self): finally enqueues for compute or forwards to the next shard. """ while self.running: - logger.error(f"NODE_ID {self.node_id}") with self.tracer.frame("network.rx", "wait"): # NOTE: bad counter try: req = await self.ingress_q.get() From 81bac7d05e34aaa2fd472cb391e2d9e5f156f06e Mon Sep 17 00:00:00 2001 From: Octavian Date: Sat, 25 Oct 2025 03:51:27 -0700 Subject: [PATCH 056/172] per-node info and restructured counting --- src/dnet/perf/utils/aggregators.py | 80 ++++++++++++++++++------------ 1 file changed, 49 insertions(+), 31 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index 50d7ddd5..b5753589 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -232,8 +232,8 @@ class ReqStats: # Per-worker data compute_per_worker: Dict[str, float] = None - inwait_per_worker: Dict[str, float] = None - inflight_per_worker: Dict[str, float] = None + network_per_worker: Dict[str, float] = None + memory_per_worker: Dict[str, float] = None # Network info tx_bytes_per_node: Dict[str, int] = None # Volume of trafic per node @@ -263,6 +263,8 @@ def __init__(self) -> None: self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per nonce) self._model_per_run: Dict[str, str] = {} # Track model per run_id + self.nodes = [] # Keep track of active nodes + # Maps of frames to higher-level sub-systems self._compute_set = [ "compute.forward", @@ -353,13 +355,10 @@ def add(self, data: Dict[str, Any]) -> None: latency_per_shard={}, assignment=None, compute_per_worker={}, - inwait_per_worker={}, - inflight_per_worker={}, + network_per_worker={}, + memory_per_worker={}, ) - if e["name"] == "embedding": # Register new request - pass - # FIXME: We might receive other frames then "embed" from shards # so we need to handle the creation of this better if nonce not in self._running_stats: @@ -367,6 +366,31 @@ def add(self, data: Dict[str, Any]) -> None: stats = self._running_stats[nonce] + if "node" not in e["args"]: + if e["name"] == "chat.request.end": + print(f"{e}") + st_obj = self._running_stats[nonce] + st_obj.generated_tokens = e["args"]["generated_tokens"] + st_obj.total_tokens += e["args"]["generated_tokens"] + print("Adding to stats") + self._stats[nonce] = st_obj + del self._running_stats[nonce] + #del self._frames[node_id][nonce] + # TODO: Handle latency of transfer back to API + + else: + continue # Drop frames without "node" + + node_id = e["args"]["node"] + if node_id not in self.nodes: + self.nodes.append(node_id) + stats.compute_per_worker[node_id] = 0.0 + stats.network_per_worker[node_id] = 0.0 + stats.memory_per_worker[node_id] = 0.0 + + if e["name"] in self._memory_set: + stats.memory_per_worker[node_id] += e["args"]["ms"] + if e["name"] == "network.rx": # Time in transport, ingress queue and ingress_worker _cost = lambda e: e["args"]["inflight"] + e["args"]["inwait"] + e["args"]["ms"] #TODO: change shard in metadata @@ -378,29 +402,20 @@ def add(self, data: Dict[str, Any]) -> None: except Exception as e: print(f"{e}") - try: - if e["name"] == "chat.request.end": - st_obj = self._running_stats[nonce] - st_obj.generated_tokens = e["args"]["generated_tokens"] - st_obj.total_tokens += e["args"]["generated_tokens"] - print("Adding to stats") - self._stats[nonce] = st_obj - del self._running_stats[nonce] - #del self._frames[node_id][nonce] - # TODO: Handle latency of transfer back to API - + if e["name"] in self._compute_set: # Aggregate for compute total + stats.compute_per_worker[node_id] += e["args"]["ms"] - if e["name"] == "network.tx.send": - _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # tx queue + sendoff + if e["name"] in self._network_set: + stats.network_per_worker[node_id] += e["args"]["ms"] - except Exception as e: - print(f"{e}") + if e["name"] in self._memory_set: + stats.memory_per_worker[node_id] += e["args"]["ms"] # Handle cost aggregation of frames def _handle_round(self, e: Any, nonce, stats: ReqStats, _cost_fnc: Any): try: - logger.error(f"TTFT: {e["args"]["t0"]} - {stats.ttft}") if self._nonce_prefill[nonce]: + logger.error(f"TTFT: {stats.ttft}") stats.ttft = (e["args"]["t0"] - stats.ttft) * 1000.0 self._nonce_prefill[nonce] = False else: @@ -429,11 +444,6 @@ def stats( (1, "inter_token_latency", "ms"), (0, -1, ""), (1, "estimated_compute", "GFLOPs"), - (1, "compute_time_per_worker", "ms"), - (1, "inwait_time_per_worker", "ms"), - (1, "inflight_time_per_worker", "ms"), - (0, -1, ""), - (1, "network_latency", "ms"), ] # FIXME: Allow filtering by these @@ -477,6 +487,7 @@ def stats( sys.stdout.write(f"\t# {statistics.median(stats.itl):.3f} s/tok\n") case "inter_token_latency": + assert len(stats.itl) > 1, "Not enough trace frames" itl = stats.itl[:-1] # FIXME: last element is super big median = statistics.median(itl) p90 = statistics.quantiles(itl, n=100)[89] @@ -488,12 +499,19 @@ def stats( case "estimated_compute": sys.stdout.write(f"UNKNOWN".rjust(20)+" GFLOPs".ljust(5)+"\testimated_flops\n") - case "compute_time_per_worker": - pass - case _: pass + for i, n in enumerate(self.nodes): + comp = stats.compute_per_worker[n] + net = stats.network_per_worker[n] + mem = stats.memory_per_worker[n] + total = comp + net + mem + sys.stdout.write(f"\n node{i} [{n}]\n") + sys.stdout.write(f"{comp:.2f}".rjust(20)+"ms".ljust(5)+f"\tcompute_time # [{(comp/total)*100:0.2f}%]\n") + sys.stdout.write(f"{net:.2f}".rjust(20)+"ms".ljust(5)+f"\tnetwork_time # [{(net/total)*100:0.2f}%]\n") + sys.stdout.write(f"{mem:.2f}".rjust(20)+"ms".ljust(5)+f"\tmemory_time # [{(mem/total)*100:0.2f}%]\n") + except Exception as e: logger.error(f"{e}") From 70f37cd8db5bdf0b64c294e70b3a7a7dddab9dce Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 03:59:25 -0700 Subject: [PATCH 057/172] reformat frames --- src/dnet/ring/api/node.py | 8 +++---- src/dnet/ring/shard/comms.py | 5 ++++- src/dnet/ring/shard/compute.py | 28 ++++++++++++++++++++++- src/dnet/ring/shard/node.py | 41 +++++++++++++++++----------------- src/dnet/ring/weight_cache.py | 6 ++--- 5 files changed, 58 insertions(+), 30 deletions(-) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 2ec5bcce..34cf8ae7 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -1292,12 +1292,12 @@ async def _handle_completion( t_first_token = None nonce = f"chatcmpl-{uuid.uuid4()}" - self.tracer.mark("chat.request.start", { + self.tracer.mark("request.start", { "tokenizer": "", "model": req.model, "temperature": req.temperature, "prompt_tokens": prompt.size, - "nonce": nonce, + "req_id": nonce, "t0": time.perf_counter(), }) @@ -1357,9 +1357,9 @@ async def _handle_completion( else detokenizer.text[: -len(stop_sequence_suffix)] ) - self.tracer.mark("chat.request.end", { + self.tracer.mark("request.end", { "generated_tokens": len(tokens), - "nonce": nonce, + "req_id": nonce, "t0": time.perf_counter(), }) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index b827bcf0..bca9bdc6 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -175,7 +175,7 @@ async def _send_worker(self): with self.tracer.frame("network", "tx") as f: if activation_msg.tx_enq_perf_t and self._profile: f.set("inwait", time.perf_counter() - activation_msg.tx_enq_t) - f.set("nonce", activation_msg.nonce) + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) q_wait_ms = ( time.perf_counter() - activation_msg.tx_enq_perf_t @@ -233,6 +233,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): logger.debug(f"Sending activation") if activation_msg.is_final: with self.tracer.frame("grpc", "send_activation.final") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) try: if self._mode == "offload" and self.window_size > 0: @@ -271,6 +272,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): f.event("reset_api") with self.tracer.frame("grpc", "token_request") as fr: + fr.set("req_id", activation_msg.nonce) fr.set("node", self._instance_name) try: req = shard_api_comm_pb2.TokenRequest( @@ -360,6 +362,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): if self.next_node_stub: with self.tracer.frame("network", "send_activation.next") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) request = activation_msg.to_proto(data) request.timestamp = utc_epoch_now() diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 50744b55..f7554de0 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -77,11 +77,13 @@ def _process_activation(self, activation_msg: ActivationMessage): try: # per-nonce kvcache for concurrent requests with self.tracer.frame("compute.thread", "kvcache.init") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) kv = self._get_or_make_kv(activation_msg.nonce) # Get input activation from pool with self.tracer.frame("compute.thread", "activations.load") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) input_buffer = self.input_pool.get_buffer(activation_msg.pool_id) if input_buffer is None: @@ -90,7 +92,7 @@ def _process_activation(self, activation_msg: ActivationMessage): # Prepare input activation with self.tracer.frame("compute.thread", "activations.process") as f: - f.set("nonce", activation_msg.nonce) + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) if activation_msg.dtype == "tokens": # embed locally on start shard logger.debug(f"Embedding tokens.") @@ -127,6 +129,7 @@ def _process_activation(self, activation_msg: ActivationMessage): did_early_swap = False with self.tracer.frame("compute.thread", "weights.prepare") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) # Determine contiguous local window starting at current_layer @@ -222,6 +225,7 @@ def _process_activation(self, activation_msg: ActivationMessage): # Execute the window with self.tracer.frame("compute.thread", "execute") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) self._beyond_cursor = window_layers[-1] if window_layers else (activation_msg.layer_id) @@ -279,6 +283,7 @@ def _process_activation(self, activation_msg: ActivationMessage): pass with self.tracer.frame("compute.thread", "execute.evict_and_unload") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) try: # Sliding-fit delta swap: maintain a single resident set by evicting @@ -375,6 +380,7 @@ def _process_activation(self, activation_msg: ActivationMessage): # Boundary reached — directly pass tensor to TX to avoid extra copy/sync with self.tracer.frame("compute.thread", "execute.enqueue_prefetch") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) x_cast = x if x.dtype == self._wire_mx_dtype else x.astype(self._wire_mx_dtype) try: @@ -390,6 +396,7 @@ def _process_activation(self, activation_msg: ActivationMessage): # Create and enqueue output message: either forward activations or finalize on end role with self.tracer.frame("compute.thread", "grpc.send") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) nxt = last_layer + 1 if nxt >= self.model_metadata.num_layers: # End of model @@ -506,10 +513,29 @@ def _process_activation(self, activation_msg: ActivationMessage): # Clean up input resources self.input_pool.release(activation_msg.pool_id) +<<<<<<< HEAD # Optional unload/evict after stage with self.tracer.frame("compute.thread", "cleanup"): if self._mode != "sliding_fit": if self._defer_unload: +======= + # Clean up input resources + with self.tracer.frame("compute.thread", "cleanup") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + self.input_pool.release(activation_msg.pool_id) + # After queuing TX, schedule prefetch and eviction in the background + # to avoid stalling the handoff to the next shard. + try: + self._prefetch_pause.set() + except Exception: + pass + next_window = self._next_local_layers(last_layer, self.window_size) + for nl in next_window: + self._prefetch_to_ram(nl) + self._enqueue_weight_prefetch(nl) + if getattr(self, "_defer_unload", False): +>>>>>>> 6c40e99 (reformat frames) try: while len(self._recent_windows) > max( 1, int(self._resident_windows) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 8c21aedb..740a613d 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -259,12 +259,12 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse and (self.model_path != req.model_path or self.assigned_layers != req.layers) ): logger.info("Node %s: Unloading current model to load new configuration", self.node_id) - with self.tracer.frame("memory", "model.unload") as f: + with self.tracer.frame("memory.model", "unload") as f: f.set("node", self._instance_name) await self.unload_model() # Load model metadata - with self.tracer.frame("memory", "model.load_metadata"): + with self.tracer.frame("memory.model", "load_metadata"): self.model_metadata = get_model_metadata(req.model_path) self.assigned_layers = req.layers @@ -368,7 +368,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # Initialize weight cache - with self.tracer.frame("memory", "weight_cache.init") as f: + with self.tracer.frame("memory.weights", "cache.init") as f: f.set("node", self._instance_name) self.weight_cache = WeightCache( self.assigned_layers, @@ -382,7 +382,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # Load the model - with self.tracer.frame("memory", "model.load") as f: + with self.tracer.frame("memory.model", "load") as f: f.set("node", self._instance_name) self.model = get_ring_model( self.model_metadata.model_type, @@ -407,7 +407,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse logger.warning("[QUANT] apply failed: %s", e) self.model.eval() - with self.tracer.frame("memory", "make_cache") as f: + with self.tracer.frame("memory.cache", "make_cache") as f: f.set("node", self._instance_name) self.cache = make_cache( self.model, @@ -448,7 +448,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse self.total_layers = req.total_layers self.api_callback_address = req.api_callback_address - with self.tracer.frame("network", "connect.next_node") as f: + with self.tracer.frame("network.connect", "next_node") as f: f.set("node", self._instance_name) if self.next_node: await self._connect_next_node() @@ -621,12 +621,12 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): return with self.tracer.frame("network.rx", "connect_next_node") as f: - f.set("nonce", request.nonce) + f.set("req_id", request.nonce) f.set("node", self._instance_name) await self._connect_next_node() with self.tracer.frame("network.rx", "process_activation") as f: - f.set("nonce", request.nonce) + f.set("req_id", request.nonce) try: activation = request.activation target_layer = activation.layer_id + 1 @@ -645,7 +645,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): t_alloc = time.perf_counter() if "|" in activation.dtype: with self.tracer.frame("grpc.receive", "decompress") as fr: - fr.set("nonce", request.nonce) + fr.set("req_id", request.nonce) fr.set("node", self._instance_name) try: deq = decompress_tensor_from_protobuf_data( @@ -702,7 +702,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): return with self.tracer.frame("network.rx", "alloc.buffer") as fr: - fr.set("nonce", request.nonce) + fr.set("req_id", request.nonce) fr.set("node", self._instance_name) pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, @@ -727,7 +727,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): else: # Special token stream support: dtype='tokens' carries int32 token IDs if activation.dtype == "tokens": with self.tracer.frame("network.rx", "token_stream") as fr: - fr.set("nonce", request.nonce) + fr.set("req_id", request.nonce) fr.set("node", self._instance_name) try: tokens = np.frombuffer(request.activation.data, dtype=np.int32) @@ -757,7 +757,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): else: with self.tracer.frame("network.ex", "default") as fr: fr.set("node", self._instance_name) - fr.set("nonce", request.nonce) + fr.set("req_id", request.nonce) # Safety: byte length must match shape*dtype try: expected = ( @@ -844,10 +844,9 @@ async def _ingress_worker(self): f.set("inwait", time.perf_counter() - req.rx_enq_t) f.set("inflight", req.rx_inflight_t) f.set("node", self._instance_name) - f.set("nonce", req.nonce) + f.set("req_id", req.nonce) try: - #with self.tracer.frame("network.ingress", "connect_next_node"): await self._connect_next_node() activation = req.activation @@ -862,7 +861,7 @@ async def _ingress_worker(self): logger.error(f"Unable to read length of data for {req.nonce}") payload_bytes = -1 - f.set("nonce", req.nonce) + fr.set("req_id", req.nonce) f.set("target", target_layer) f.set("payload_bytes", payload_bytes) f.event("received") @@ -904,8 +903,8 @@ async def _ingress_worker(self): # Enqueue for compute with self.tracer.frame("network.rx", "enque") as fr: + fr.set("req_id", req.nonce) fr.set("node", self._instance_name) - fr.set("nonce", req.nonce) while self.running: try: self.activation_recv_queue.put_nowait(activation_msg) @@ -995,8 +994,8 @@ def _prepare_activation_message_blocking( if "|" in activation.dtype: # Compressed path: decompress to MLX array and copy to pool with self.tracer.frame("network.rx.prepare_activation", "decompress") as f: + f.set("req_id", request.nonce) f.set("node", self._instance_name) - f.set("nonce", request.nonce) try: deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, @@ -1033,8 +1032,8 @@ def _prepare_activation_message_blocking( elif activation.dtype == "tokens": # Tokens path: parse int32 token IDs and stage them with self.tracer.frame("network.rx.prepare_activation", "tokens") as f: + f.set("req_id", request.nonce) f.set("node", self._instance_name) - f.set("nonce", request.nonce) try: tokens = np.frombuffer(activation.data, dtype=np.int32) shp = (int(len(tokens)),) @@ -1064,8 +1063,8 @@ def _prepare_activation_message_blocking( else: # Dense path: validate size and copy raw bytes view into pool buffer with self.tracer.frame("network.rx.prepare_activation", "default") as f: + f.set("req_id", request.nonce) f.set("node", self._instance_name) - f.set("nonce", request.nonce) try: expected = ( int(np.prod(activation.shape)) @@ -1136,7 +1135,7 @@ def _compute_worker(self) -> None: # Process the activation with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats - f.set("nonce", activation_msg.nonce) + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) if activation_msg.ex_enq_t == 0.0: # FIXME float comparison f.set("inwait", 0.0) @@ -1601,8 +1600,8 @@ async def warm(request: Request) -> JSONResponse: try: # FIXME: Append warmup config? Or something to distinguish with self.tracer.frame("memory", "model.warm") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("req_id", request.nonce) f.set("node", self._instance_name) - f.set("nonce", request.nonce) body = await request.json() start = int(body.get("start", -1)) window = int(body.get("window", self.window_size)) diff --git a/src/dnet/ring/weight_cache.py b/src/dnet/ring/weight_cache.py index fcb5f53a..2a9a22b0 100644 --- a/src/dnet/ring/weight_cache.py +++ b/src/dnet/ring/weight_cache.py @@ -76,7 +76,7 @@ def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[s ) return data - with self.tracer.frame("weights.cache", "search") as f: + with self.tracer.frame("memory.weights", "cache.search") as f: with self.lock: if layer_id in self.cache: data, _ = self.cache[layer_id] @@ -100,7 +100,7 @@ def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[s creator = False if creator: # Perform the blocking load without holding the cache lock - with self.tracer.frame("weights.cache", "load") as f: + with self.tracer.frame("memory.weights", "cache.load") as f: try: data = self.layer_manager.load_layer_to_gpu(layer_id) f.event("load") @@ -139,7 +139,7 @@ def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[s return None else: # Not the creator: wait for the in-flight load to complete - with self.tracer.frame("weights.cache", "wait") as f: + with self.tracer.frame("memory.weights", "cache.wait") as f: try: inflight.result() # block until the creator completes except Exception as e: From ea079ca84275a1279d8b236611289b54c439b19f Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 04:00:02 -0700 Subject: [PATCH 058/172] better sorting --- src/dnet/perf/utils/aggregators.py | 258 ++++++++++++----------------- 1 file changed, 108 insertions(+), 150 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index b5753589..be627a0d 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -216,7 +216,7 @@ class ReqStats: model: str = "" # Model name tokenizer: str = "" # Tokenizer name run_id: str = "" # ID of session (for later mapping) - nonce: str = "" # List of serviced requests + req_id: str = "" # List of serviced requests ttft: float = 0.0 # Time to first token itl: List[float] = None # Inter-token latency per round prompt_tokens: int = -1 # Number of prompt tokens per request (req_id: #) @@ -253,171 +253,129 @@ def __init__(self) -> None: self._max_inflight_req = 20 # per node FIXME: modify from repl # Frames are kept in here while in-flight, then remove the frame objects and append to _stats - self._frames: Dict[str, Dict[str, Dict[str, Any]]] = {} # Store frames per nonce, per node_id + self._frames: Dict[str, Dict[str, Dict[str, Any]]] = {} # Store frames per req_id, per node_id - self._nonces: List[str] = [] # Tracked nonces (either in-flight or done) - self._nonce_round_finish: Dict[str, bool] = {} # Track in-flight rounds - self._nonce_prefill: Dict[str, bool] = {} # Track if this round is prefill + self._req: List[str] = [] # Tracked requests (in-flight or done) + self._req_round_finish: Dict[str, bool] = {} # Track in-flight requests + self._req_prefill: Dict[str, bool] = {} # Track if this request round is prefill + self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per request) + + # Staging environment for events that arrive before + # the request.start of the request they belong to + self._staging: Dict[str, Any] = {} + + # Finished and running requests self._running_stats: Dict[str, ReqStats] = {} # Unfinished stat frames self._stats: Dict[str, ReqStats] = {} # Finished stat frames - self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per nonce) - self._model_per_run: Dict[str, str] = {} # Track model per run_id - - self.nodes = [] # Keep track of active nodes - - # Maps of frames to higher-level sub-systems - self._compute_set = [ - "compute.forward", - "compute.thread.kvcache.init", - "compute.thread.weights.prepare", - "compute.thread.activations.process", - "compute.thread.activations.load", - "compute.thread.execute", - "compute.thread.execute.enqueue_prefetch", - "compute.thread.execute.evict_and_unload", - "compute.thread.cleanup", - "compute.thread.mdns.send", - ] - self._network_set = [ - "network.tx", - "network.token_request", - "network.rx.prepare", - "network.rx.prepare_activation.tokens", - "network.rx.enque", - "network.send_activation.final", - "network.rx", - "network.connect.next_node", - "network.rx.prefetch", - ] - - self._memory_set = [ - "memory.model.load", - "memory.model.load_metadata", - "memory.warmup", - "memory.weight_cache.init", - "memory.prefetch", - "memory.memory_pools.init", - "memory.cache.reset", - "memory.make_cache", - ] + self.nodes = [] # Keep track of active nodes in the network # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: - run_id = data["run_id"] - node_id = data["node_id"] events = data["events"] or [] - - if not run_id or not node_id: - print("Dropped batch") - return # Drop the batch + if not events: return # Nothing to do with self._lock: - # Ensure we register workers and nodes - for i, ev in enumerate(events): - if "nonce" not in ev["args"]: ev["args"]["nonce"] = f"N_" - nonce = ev["args"]["nonce"] - - if node_id not in self._frames: - self._frames[node_id] = {} - - if nonce not in self._frames[node_id]: - self._frames[node_id][nonce] = {} - - if len(self._frames[node_id]) >= self._max_inflight_req: # remove oldest entry - del self._frames[self._nonces[0]] - del self._nonces[0] - - if nonce not in self._nonces: - self._nonces.append(nonce) # Update in-flight events or register new ones - for e in events: - nonce = e["args"]["nonce"] - assert nonce is not None, "" - - if not node_id or not nonce: return # Drop invalid frames - - if e["name"] == "chat.request.start": - self._open_frames[nonce] = {} - self._nonce_prefill[nonce] = True - self._running_stats[nonce] = ReqStats( - model=e["args"]["model"], - tokenizer=e["args"]["tokenizer"], - run_id=run_id, - nonce=nonce, - ttft= e["args"]["t0"], - itl=[ e["args"]["t0"], ], - prompt_tokens=e["args"]["prompt_tokens"], - total_tokens=e["args"]["prompt_tokens"], - latencies={}, - latency_per_layer={}, - latency_per_shard={}, - assignment=None, - compute_per_worker={}, - network_per_worker={}, - memory_per_worker={}, - ) - - # FIXME: We might receive other frames then "embed" from shards - # so we need to handle the creation of this better - if nonce not in self._running_stats: - continue - - stats = self._running_stats[nonce] - - if "node" not in e["args"]: - if e["name"] == "chat.request.end": - print(f"{e}") - st_obj = self._running_stats[nonce] - st_obj.generated_tokens = e["args"]["generated_tokens"] - st_obj.total_tokens += e["args"]["generated_tokens"] - print("Adding to stats") - self._stats[nonce] = st_obj - del self._running_stats[nonce] - #del self._frames[node_id][nonce] - # TODO: Handle latency of transfer back to API - - else: - continue # Drop frames without "node" - - node_id = e["args"]["node"] - if node_id not in self.nodes: - self.nodes.append(node_id) - stats.compute_per_worker[node_id] = 0.0 - stats.network_per_worker[node_id] = 0.0 - stats.memory_per_worker[node_id] = 0.0 - - if e["name"] in self._memory_set: - stats.memory_per_worker[node_id] += e["args"]["ms"] - - if e["name"] == "network.rx": # Time in transport, ingress queue and ingress_worker - _cost = lambda e: e["args"]["inflight"] + e["args"]["inwait"] + e["args"]["ms"] - #TODO: change shard in metadata - - if e["name"] == "compute.forward": - try: - _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # compute queue + execution - self._handle_round(e, nonce, stats, _cost) - except Exception as e: - print(f"{e}") - - if e["name"] in self._compute_set: # Aggregate for compute total - stats.compute_per_worker[node_id] += e["args"]["ms"] - - if e["name"] in self._network_set: - stats.network_per_worker[node_id] += e["args"]["ms"] + for i, e in enumerate(events): + symbol = e["name"].split(".") + + req_id = e["args"].get("req_id") + if not req_id: + print(f"Dropping {e["name"]}: {e["args"]}") + continue # Drop anonymous frames + + if symbol[0] == "request": + if symbol[1] == "start": # Start request, setup buffers and ingest staged frames + self._req.append(req_id) + self._open_frames[req_id] = {} + self._req_prefill[req_id] = True + stats = ReqStats( + model=e["args"]["model"], + tokenizer=e["args"]["tokenizer"], + req_id=req_id, + ttft= e["args"]["t0"], + itl=[ e["args"]["t0"], ], + prompt_tokens=e["args"]["prompt_tokens"], + total_tokens=e["args"]["prompt_tokens"], + latencies={}, + latency_per_layer={}, + latency_per_shard={}, + assignment=None, + compute_per_worker={}, + network_per_worker={}, + memory_per_worker={}, + ) + self._running_stats[req_id] = stats + + # Process all events in staging + if req_id in self._staging: + for pe in self._staging[req_id]: + node_id = e["args"].get("node_id") + if not node_id: continue + self._process_frame(pe, req_id, node_id, stats) + + del self._staging[req_id] + continue + + elif symbol[1] == "end": # Finish processing request + st_obj = self._running_stats[req_id] + st_obj.generated_tokens = e["args"]["generated_tokens"] + st_obj.total_tokens += e["args"]["generated_tokens"] + self._stats[req_id] = st_obj + del self._running_stats[req_id] + # TODO: Handle latency of transfer back to API + continue + + if req_id in self._stats: continue # Already finidhed processing request + if req_id not in self._req: # If unknown request, stage frames + if req_id not in self._staging: + self._staging[req_id] = [] + self._staging[req_id].append(e) + continue + + node_id = e["args"].get("node_id") + if not node_id: return # Drop unknown node + + stats = self._running_stats[req_id] + self._process_frame(e, req_id, node_id, stats) + + + def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): + if node_id not in self.nodes: + self.nodes.append(node_id) + stats.compute_per_worker[node_id] = 0.0 + stats.network_per_worker[node_id] = 0.0 + stats.memory_per_worker[node_id] = 0.0 + + if symbol[0] == "compute": + if symbol[1] == "forward": + try: + _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] + self._handle_round(e, req_id, stats, _cost) # compute queue + execution + print(f"TTFT: {stats.ttft}") + except Exception as e: + print(f"{e}") + stats.compute_per_worker[node_id] += e["args"]["ms"] + + elif symbol[0] == "network": + if symbol[1] == "rx": # Time in transport, ingress queue and ingress_worker + _cost = lambda e: e["args"]["inflight"] + e["args"]["inwait"] + e["args"]["ms"] + #TODO: change shard in metadata + stats.network_per_worker[node_id] += e["args"]["ms"] + + elif symbol[0] == "memory": + stats.memory_per_worker[node_id] += e["args"]["ms"] + return - if e["name"] in self._memory_set: - stats.memory_per_worker[node_id] += e["args"]["ms"] # Handle cost aggregation of frames - def _handle_round(self, e: Any, nonce, stats: ReqStats, _cost_fnc: Any): + def _handle_round(self, e: Any, req_id, stats: ReqStats, _cost_fnc: Any): try: - if self._nonce_prefill[nonce]: - logger.error(f"TTFT: {stats.ttft}") + if self._req_prefill[req_id]: stats.ttft = (e["args"]["t0"] - stats.ttft) * 1000.0 - self._nonce_prefill[nonce] = False + self._req_prefill[req_id] = False else: if e["args"]["t0"] > 0.0: stats.itl[-1] = (e["args"]["t0"] - stats.itl[-1]) @@ -457,7 +415,7 @@ def stats( return stats = self._stats[list(self._stats.keys())[-1]] #sys.stdout.write(f"\n Loaded model '{stats.model}'.\n") - sys.stdout.write(f"Performance stats for request '{stats.nonce}':\n\n") + sys.stdout.write(f"Performance stats for request '{stats.req_id}':\n\n") try: for tag, n, unit in fields: if tag == 0: # Native counter From 320dcfe8327f9e4115aef790b22d68e365a48333 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 11:56:39 -0700 Subject: [PATCH 059/172] use epoch for t0 --- src/dnet/perf/trace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 48f6e49b..4037d987 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -51,7 +51,7 @@ def __init__(self, tracer: "Tracer", name: str, attrs: Optional[Dict[str, Any]]) self.attrs = dict(attrs or {}) self._t0 = 0.0 def __enter__(self): - self._t0 = time.perf_counter() + self._t0 = time.time_ns() # cross-node timekeeping self.t._emit({"type": "B", "name": self.name, "args": dict(self.attrs)}) return self def __exit__(self, ex_type, ex, tb): @@ -96,7 +96,7 @@ def stop_aggregator(self, *, flush: bool = True, timeout: float = 5.0) -> None: if flush and self._events: try: self._agg_q.put_nowait({ - "run_id": (self._req_id or "run"), + "req_id": (self._req_id or "run"), "node_id": (self.config.node_id or "node"), "events": list(self._events), }) except queue.Full: From 1d645c9c7c8d90abc909966b26accc919b5062b2 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 11:57:18 -0700 Subject: [PATCH 060/172] add stats nodes and fix node registration --- src/dnet/perf/utils/aggregators.py | 33 ++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index be627a0d..e64deb88 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -216,12 +216,13 @@ class ReqStats: model: str = "" # Model name tokenizer: str = "" # Tokenizer name run_id: str = "" # ID of session (for later mapping) - req_id: str = "" # List of serviced requests + req_id: str = "" # List of serviced requests ttft: float = 0.0 # Time to first token itl: List[float] = None # Inter-token latency per round prompt_tokens: int = -1 # Number of prompt tokens per request (req_id: #) generated_tokens: int = -1 # Number of generated tokens per request (req_id: #) total_tokens: int = -1 # Total number of tokens processed + nodes: List[str] = None # Nodes that participated in computation latencies: List[List[str, str, str, int]] = None # List of inter-node latencies: [node0, node1, p50, 0.0] latency_per_layer: Dict[int, float] = None # Map of {layer: 0.0} @@ -258,7 +259,7 @@ def __init__(self) -> None: self._req: List[str] = [] # Tracked requests (in-flight or done) self._req_round_finish: Dict[str, bool] = {} # Track in-flight requests self._req_prefill: Dict[str, bool] = {} # Track if this request round is prefill - self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per request) + self._open_frames: Dict[str, Dict[str, Dict[str, Any]]] = {} # Staging environment for events that arrive before # the request.start of the request they belong to @@ -275,15 +276,24 @@ def add(self, data: Dict[str, Any]) -> None: events = data["events"] or [] if not events: return # Nothing to do + node_id = data.get("node_id") + if not node_id: return # Drop unknown node + with self._lock: # Update in-flight events or register new ones for i, e in enumerate(events): symbol = e["name"].split(".") + if e["type"] == 'B': + req_id = data.get("req_id") + if not req_id or not node_id: continue + self._open_frames[req_id][node_id][e["name"]] = e + continue + req_id = e["args"].get("req_id") if not req_id: - print(f"Dropping {e["name"]}: {e["args"]}") + #print(f"Dropping {e}") continue # Drop anonymous frames if symbol[0] == "request": @@ -306,6 +316,7 @@ def add(self, data: Dict[str, Any]) -> None: compute_per_worker={}, network_per_worker={}, memory_per_worker={}, + nodes=[], ) self._running_stats[req_id] = stats @@ -335,26 +346,27 @@ def add(self, data: Dict[str, Any]) -> None: self._staging[req_id].append(e) continue - node_id = e["args"].get("node_id") - if not node_id: return # Drop unknown node + #node_id = e["args"].get("node_id") + #if not node_id: return # Drop unknown node stats = self._running_stats[req_id] self._process_frame(e, req_id, node_id, stats) def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): - if node_id not in self.nodes: - self.nodes.append(node_id) + symbol = e["name"].split(".") + if node_id not in self.nodes: self.nodes.append(node_id) + if node_id not in stats.nodes: + stats.nodes.append(node_id) stats.compute_per_worker[node_id] = 0.0 stats.network_per_worker[node_id] = 0.0 stats.memory_per_worker[node_id] = 0.0 - + if symbol[0] == "compute": if symbol[1] == "forward": try: _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] self._handle_round(e, req_id, stats, _cost) # compute queue + execution - print(f"TTFT: {stats.ttft}") except Exception as e: print(f"{e}") stats.compute_per_worker[node_id] += e["args"]["ms"] @@ -375,11 +387,14 @@ def _handle_round(self, e: Any, req_id, stats: ReqStats, _cost_fnc: Any): try: if self._req_prefill[req_id]: stats.ttft = (e["args"]["t0"] - stats.ttft) * 1000.0 + print(f"TTFT: {stats.ttft}") self._req_prefill[req_id] = False else: if e["args"]["t0"] > 0.0: stats.itl[-1] = (e["args"]["t0"] - stats.itl[-1]) + print(f"ITL: {e["args"]["t0"]} - {stats.itl[-1]}") stats.itl.append(e["args"]["t0"]) + print(f"ITL: {stats.itl[-1]}") except Exception as ex: print(f"{ex}") From 2a11f02307091fcfb612f540ba9327b8a510a68f Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 12:01:06 -0700 Subject: [PATCH 061/172] always register t0 --- src/dnet/perf/trace.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 4037d987..84e90a60 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -52,11 +52,12 @@ def __init__(self, tracer: "Tracer", name: str, attrs: Optional[Dict[str, Any]]) self._t0 = 0.0 def __enter__(self): self._t0 = time.time_ns() # cross-node timekeeping + self.attrs.update({"t0": self._t0}) self.t._emit({"type": "B", "name": self.name, "args": dict(self.attrs)}) return self def __exit__(self, ex_type, ex, tb): - dt_ms = (time.perf_counter() - self._t0) * 1000.0 - self.attrs.update({"ms": round(dt_ms, 3), "exc": bool(ex)}) + dt_ms = (time.time_ns() - self._t0) + self.attrs.update({"ms": round(dt_ms, 3), "exc": bool(ex), "t0": time.time_ns()}) self.t._emit({"type": "E", "name": self.name, "args": self.attrs}) return False def event(self, name: str, **attrs): From 7a09d9134e83b6f72ada9b2501401a57726e74e5 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 13:20:21 -0700 Subject: [PATCH 062/172] mark round --- src/dnet/ring/api/node.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 34cf8ae7..3afc126b 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -1319,6 +1319,8 @@ async def _handle_completion( ), # type: ignore arange(req.max_tokens or 0), ): + self.tracer.mark("request.round", {"req_id": nonce,"t0": time.time_ns()}) + if profile_enabled and t_first_token is None: t_first_token = time.perf_counter() detokenizer.add_token(token) From 2c2366080ac6f02a749575fbb2cce7b1858ff857 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:09:16 -0700 Subject: [PATCH 063/172] don't track queue wait --- src/dnet/ring/shard/comms.py | 2 +- src/dnet/ring/shard/node.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index bca9bdc6..11a53fe5 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -174,7 +174,7 @@ async def _send_worker(self): activation_msg = await self.activation_computed_queue.get() with self.tracer.frame("network", "tx") as f: if activation_msg.tx_enq_perf_t and self._profile: - f.set("inwait", time.perf_counter() - activation_msg.tx_enq_t) + f.set("inwait", (time.perf_counter() - activation_msg.tx_enq_t)*1000) f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) q_wait_ms = ( diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 740a613d..f2f985e5 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -1130,8 +1130,7 @@ def _compute_worker(self) -> None: while self.running: try: # Get activation from queue (blocks until available) - with self.tracer.frame("compute", "deque.wait"): - activation_msg = self.activation_recv_queue.get(timeout=1.0) + activation_msg = self.activation_recv_queue.get(timeout=1.0) # Process the activation with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats From cc94dc95570662d1380f026671e67a264d234b42 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:10:00 -0700 Subject: [PATCH 064/172] break down tx.enque into correct compute-network frames --- src/dnet/ring/shard/compute.py | 77 +++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 30 deletions(-) diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index f7554de0..67f035aa 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -395,11 +395,13 @@ def _process_activation(self, activation_msg: ActivationMessage): pass # Create and enqueue output message: either forward activations or finalize on end role - with self.tracer.frame("compute.thread", "grpc.send") as f: + with self.tracer.frame("network.tx", "send") as f: f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) - nxt = last_layer + 1 - if nxt >= self.model_metadata.num_layers: # End of model + + nxt = last_layer + 1 + if nxt >= self.model_metadata.num_layers: # End of model + with self.tracer.frame("compute.thread", "sampling") as f: try: with self._mlx_lock: y = self.model.normalize(x_cast) @@ -492,33 +494,35 @@ def _process_activation(self, activation_msg: ActivationMessage): except Exception: output_msg.tx_enq_perf_t = 0.0 # Enqueue to appropriate asyncio TX queue from compute thread - try: - if self._loop is not None: - target_q = ( - self.activation_token_queue - if output_msg.is_final - else self.activation_computed_queue - ) - fut = asyncio.run_coroutine_threadsafe( - target_q.put(output_msg), self._loop + with self.tracer.frame("network.tx", "enque") as f: + output_msg.tx_enq_t = time.perf_counter() + try: + if self._loop is not None: + target_q = ( + self.activation_token_queue + if output_msg.is_final + else self.activation_computed_queue + ) + fut = asyncio.run_coroutine_threadsafe( + target_q.put(output_msg), self._loop + ) + fut.result() + else: + raise RuntimeError("Event loop not available for TX queue") + except Exception as e: + logger.error( + "Failed to queue computed activation for sending: %s", e ) - fut.result() - else: - raise RuntimeError("Event loop not available for TX queue") - except Exception as e: - logger.error( - "Failed to queue computed activation for sending: %s", e - ) - # Clean up input resources - self.input_pool.release(activation_msg.pool_id) + # Clean up input resources + self.input_pool.release(activation_msg.pool_id) -<<<<<<< HEAD # Optional unload/evict after stage with self.tracer.frame("compute.thread", "cleanup"): + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) if self._mode != "sliding_fit": if self._defer_unload: -======= # Clean up input resources with self.tracer.frame("compute.thread", "cleanup") as f: f.set("req_id", activation_msg.nonce) @@ -527,15 +531,28 @@ def _process_activation(self, activation_msg: ActivationMessage): # After queuing TX, schedule prefetch and eviction in the background # to avoid stalling the handoff to the next shard. try: - self._prefetch_pause.set() + while len(self._recent_windows) > max(1, int(getattr(self, "_resident_windows", 2))): + old = self._recent_windows.pop(0) + try: + evicted_cnt = self.weight_cache.evict_layers(old) + except Exception: + evicted_cnt = 0 + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) + except Exception: + pass + except Exception: + pass + if getattr(self, "_resident_windows", 2) <= 1: + try: + evicted = self.weight_cache.evict_layers(window_layers) + if hasattr(self.mode, "unload_layers"): + self.model.unload_layers(window_layers) except Exception: pass - next_window = self._next_local_layers(last_layer, self.window_size) - for nl in next_window: - self._prefetch_to_ram(nl) - self._enqueue_weight_prefetch(nl) - if getattr(self, "_defer_unload", False): ->>>>>>> 6c40e99 (reformat frames) try: while len(self._recent_windows) > max( 1, int(self._resident_windows) From d2a55c9c01a7e41e1d562da1ebb81c3ec1e44797 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:10:49 -0700 Subject: [PATCH 065/172] aggregate compound subsytem metrics correctly per node --- src/dnet/perf/utils/aggregators.py | 51 +++++++++++++++--------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index e64deb88..44ae1b60 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -223,6 +223,7 @@ class ReqStats: generated_tokens: int = -1 # Number of generated tokens per request (req_id: #) total_tokens: int = -1 # Total number of tokens processed nodes: List[str] = None # Nodes that participated in computation + _rounds_t0: List[int] = None # Keep round times for post-processing latencies: List[List[str, str, str, int]] = None # List of inter-node latencies: [node0, node1, p50, 0.0] latency_per_layer: Dict[int, float] = None # Map of {layer: 0.0} @@ -305,8 +306,8 @@ def add(self, data: Dict[str, Any]) -> None: model=e["args"]["model"], tokenizer=e["args"]["tokenizer"], req_id=req_id, - ttft= e["args"]["t0"], - itl=[ e["args"]["t0"], ], + ttft=0.0, + itl=[], prompt_tokens=e["args"]["prompt_tokens"], total_tokens=e["args"]["prompt_tokens"], latencies={}, @@ -317,6 +318,7 @@ def add(self, data: Dict[str, Any]) -> None: network_per_worker={}, memory_per_worker={}, nodes=[], + _rounds_t0=[], ) self._running_stats[req_id] = stats @@ -339,6 +341,10 @@ def add(self, data: Dict[str, Any]) -> None: # TODO: Handle latency of transfer back to API continue + elif symbol[1] == "round": + stats = self._running_stats[req_id] + stats._rounds_t0.append(e["args"]["t0"]) + if req_id in self._stats: continue # Already finidhed processing request if req_id not in self._req: # If unknown request, stage frames if req_id not in self._staging: @@ -364,12 +370,11 @@ def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): if symbol[0] == "compute": if symbol[1] == "forward": - try: - _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] - self._handle_round(e, req_id, stats, _cost) # compute queue + execution - except Exception as e: - print(f"{e}") + pass + #_cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] + # Defer stats compute until after we sort the times (async is kil) stats.compute_per_worker[node_id] += e["args"]["ms"] + print(f"COMPUTE_PER_WORKER: {e["name"]} : {stats.compute_per_worker}") elif symbol[0] == "network": if symbol[1] == "rx": # Time in transport, ingress queue and ingress_worker @@ -381,22 +386,17 @@ def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): stats.memory_per_worker[node_id] += e["args"]["ms"] return - - # Handle cost aggregation of frames - def _handle_round(self, e: Any, req_id, stats: ReqStats, _cost_fnc: Any): - try: - if self._req_prefill[req_id]: - stats.ttft = (e["args"]["t0"] - stats.ttft) * 1000.0 - print(f"TTFT: {stats.ttft}") - self._req_prefill[req_id] = False - else: - if e["args"]["t0"] > 0.0: - stats.itl[-1] = (e["args"]["t0"] - stats.itl[-1]) - print(f"ITL: {e["args"]["t0"]} - {stats.itl[-1]}") - stats.itl.append(e["args"]["t0"]) - print(f"ITL: {stats.itl[-1]}") - except Exception as ex: - print(f"{ex}") + def _compute_round_stats(self, stats): + rounds = stats._rounds_t0 + #rounds.sort() + assert len(rounds) > 2, "Not enough data." + stats.ttft = (rounds[1] - rounds[0]) * 1e-6 + stats.itl.append(rounds[1]) + for i in range(1, len(rounds)): + stats.itl[-1] = (rounds[i] - rounds[i-1]) * 1e-6 + stats.itl.append(rounds[i]) + stats.itl = stats.itl[:-1] + print(stats.itl) # Return data for total, per req, worker or model (maybe add per layer too?) def stats( @@ -429,6 +429,7 @@ def stats( print("No tracked stats in memory. Track a request first.\n") return stats = self._stats[list(self._stats.keys())[-1]] + self._compute_round_stats(stats) #sys.stdout.write(f"\n Loaded model '{stats.model}'.\n") sys.stdout.write(f"Performance stats for request '{stats.req_id}':\n\n") try: @@ -453,11 +454,11 @@ def stats( elif tag == 1: match n: case "tokens_per_second": - tps = [ 1 / rt for rt in stats.itl ] + tps = [ 1 / (rt/1000) for rt in stats.itl ] #median = statistics.median(tps) mean = sum(tps) / len(tps) sys.stdout.write(f"{mean:.2f}".rjust(20)+" tok/s".rjust(5)+" \ttokens_per_second") - sys.stdout.write(f"\t# {statistics.median(stats.itl):.3f} s/tok\n") + sys.stdout.write(f"\t# {statistics.median(stats.itl)/1000:.3f} s/tok\n") case "inter_token_latency": assert len(stats.itl) > 1, "Not enough trace frames" From f91293d3c6eaca0f691e8413642263bb768f2709 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:11:23 -0700 Subject: [PATCH 066/172] fix ms scaling --- src/dnet/perf/trace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 84e90a60..8afcca47 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -56,7 +56,7 @@ def __enter__(self): self.t._emit({"type": "B", "name": self.name, "args": dict(self.attrs)}) return self def __exit__(self, ex_type, ex, tb): - dt_ms = (time.time_ns() - self._t0) + dt_ms = (time.time_ns() - self._t0) * 1e-6 self.attrs.update({"ms": round(dt_ms, 3), "exc": bool(ex), "t0": time.time_ns()}) self.t._emit({"type": "E", "name": self.name, "args": self.attrs}) return False From 72b3b29488d56ad8785bbace9fc2395a8aacddde Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:37:34 -0700 Subject: [PATCH 067/172] rename old grpc frames to network --- src/dnet/ring/shard/comms.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index 11a53fe5..1725fd4f 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -232,7 +232,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): try: logger.debug(f"Sending activation") if activation_msg.is_final: - with self.tracer.frame("grpc", "send_activation.final") as f: + with self.tracer.frame("network", "send_activation.final") as f: f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) try: @@ -271,7 +271,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): self.api_stub = shard_api_comm_pb2_grpc.ShardApiServiceStub( self.api_channel) f.event("reset_api") - with self.tracer.frame("grpc", "token_request") as fr: + with self.tracer.frame("network", "token_request") as fr: fr.set("req_id", activation_msg.nonce) fr.set("node", self._instance_name) try: @@ -301,13 +301,16 @@ async def _send_activation(self, activation_msg: ActivationMessage): # FIXME: shaped var is a bit weird (is it np_array or mlx_array), @andthattoo shall check shaped = activation_msg.tensor - if shaped is None: - output_buffer = self.output_pool.get_buffer(activation_msg.pool_id) - if output_buffer is None: - logger.error( - "Failed to get output buffer %s", activation_msg.pool_id - ) - return + with self.tracer.frame("gprc.send_activations.default", "get_buffer") as fr: + fr.set("req_id", activation_msg.nonce) + fr.set("node", self._instance_name) + if shaped is None: + output_buffer = self.output_pool.get_buffer(activation_msg.pool_id) + if output_buffer is None: + logger.error( + "Failed to get output buffer %s", activation_msg.pool_id + ) + return if self._profile: logger.info( From 465d398c3460d7d411730009bf494290b783c223 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:38:09 -0700 Subject: [PATCH 068/172] correctly aggregate global memory use per node --- src/dnet/perf/utils/aggregators.py | 31 ++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index 44ae1b60..23d8c5ee 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -261,6 +261,7 @@ def __init__(self) -> None: self._req_round_finish: Dict[str, bool] = {} # Track in-flight requests self._req_prefill: Dict[str, bool] = {} # Track if this request round is prefill self._open_frames: Dict[str, Dict[str, Dict[str, Any]]] = {} + self._global_memory_per_worker: Dict[str, float] = {} # Staging environment for events that arrive before # the request.start of the request they belong to @@ -294,7 +295,10 @@ def add(self, data: Dict[str, Any]) -> None: req_id = e["args"].get("req_id") if not req_id: - #print(f"Dropping {e}") + if symbol[0] == "memory": # Global memory frames are not request-based + if node_id not in self._global_memory_per_worker: + self._global_memory_per_worker[node_id] = 0.0 + self._global_memory_per_worker[node_id] += e["args"]["ms"] continue # Drop anonymous frames if symbol[0] == "request": @@ -374,7 +378,6 @@ def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): #_cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # Defer stats compute until after we sort the times (async is kil) stats.compute_per_worker[node_id] += e["args"]["ms"] - print(f"COMPUTE_PER_WORKER: {e["name"]} : {stats.compute_per_worker}") elif symbol[0] == "network": if symbol[1] == "rx": # Time in transport, ingress queue and ingress_worker @@ -383,7 +386,10 @@ def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): stats.network_per_worker[node_id] += e["args"]["ms"] elif symbol[0] == "memory": + print(f"MEMORY_PER_WORKER: {e["name"]} : {stats.memory_per_worker}") stats.memory_per_worker[node_id] += e["args"]["ms"] + else: + print(f"UNTRACKED: {e["name"]}") return def _compute_round_stats(self, stats): @@ -477,14 +483,19 @@ def stats( pass for i, n in enumerate(self.nodes): - comp = stats.compute_per_worker[n] - net = stats.network_per_worker[n] - mem = stats.memory_per_worker[n] - total = comp + net + mem - sys.stdout.write(f"\n node{i} [{n}]\n") - sys.stdout.write(f"{comp:.2f}".rjust(20)+"ms".ljust(5)+f"\tcompute_time # [{(comp/total)*100:0.2f}%]\n") - sys.stdout.write(f"{net:.2f}".rjust(20)+"ms".ljust(5)+f"\tnetwork_time # [{(net/total)*100:0.2f}%]\n") - sys.stdout.write(f"{mem:.2f}".rjust(20)+"ms".ljust(5)+f"\tmemory_time # [{(mem/total)*100:0.2f}%]\n") + try: + comp = stats.compute_per_worker[n] + net = stats.network_per_worker[n] + req_mem = stats.memory_per_worker[n] + g_mem = self._global_memory_per_worker[n] + mem = req_mem + g_mem + total = comp + net + mem + sys.stdout.write(f"\n node{i} [{n}]\n") + sys.stdout.write(f"{comp:.2f}".rjust(20)+"ms".ljust(5)+f"\tcompute_time # [{(comp/total)*100:0.2f}%]\n") + sys.stdout.write(f"{net:.2f}".rjust(20)+"ms".ljust(5)+f"\tnetwork_time # [{(net/total)*100:0.2f}%]\n") + sys.stdout.write(f"{mem:.2f}".rjust(20)+"ms".ljust(5)+f"\tmemory_time # [{(mem/total)*100:0.2f}%]\n") + except Exception as e: + print(f"{e}") except Exception as e: logger.error(f"{e}") From eb80eabe43492098395490ee2eac48bf31c96e63 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:41:31 -0700 Subject: [PATCH 069/172] request.round continue --- src/dnet/perf/utils/aggregators.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index 23d8c5ee..a75518a1 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -348,6 +348,7 @@ def add(self, data: Dict[str, Any]) -> None: elif symbol[1] == "round": stats = self._running_stats[req_id] stats._rounds_t0.append(e["args"]["t0"]) + continue if req_id in self._stats: continue # Already finidhed processing request if req_id not in self._req: # If unknown request, stage frames From c8f4ed63519feeb374265c9ebda62d305bc9784f Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 21:04:09 -0700 Subject: [PATCH 070/172] update signature and unload --- src/dnet/ring/model/llama3.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/dnet/ring/model/llama3.py b/src/dnet/ring/model/llama3.py index 81626d51..978a7fe5 100644 --- a/src/dnet/ring/model/llama3.py +++ b/src/dnet/ring/model/llama3.py @@ -18,7 +18,8 @@ def __init__( self, model_config: Any, assigned_layers: Optional[List[int]] = [], - is_api_layer: bool = False + is_api_layer: bool = False, + shard_config: Optional[Any] = None, ): super().__init__() @@ -154,7 +155,6 @@ def apply_single_layer( layer = self.layers[local_idx] ret = self.layers[local_idx](x, mask, cache[local_idx] if local_idx < len(cache) else None) - logger.debug(f"Executed layer:{layer_idx} with output shape: {ret.shape}") return ret def load_weights(self, weights, strict=False): @@ -199,3 +199,15 @@ def load_weights(self, weights, strict=False): logger.error(f"Failed to load weights: {e}") logger.error(f"Weight keys: {list(shard_weights.keys())}") raise + + def unload_layers(self, layers: List[int]): + for l in layers: + local = self.abs_to_local[l] + for name, mod in self.layers[local].named_modules(): + if name in ['self_attn', 'mlp']: + for pname in mod.parameters(): + setattr(mod, pname, None) + logger.debug(f"Unloaded {pname}") + elif name in ['input_layernorm', 'post_attention_layernorm']: + mod.weight = None + logger.debug(f"Unloaded {name}") From 92e11e4001df690bf29ba5982cf536b6d43c94d9 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 21:41:40 -0700 Subject: [PATCH 071/172] force quantization field in model config (mlx_lm doesn't have it) --- src/dnet/ring/model/llama3.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/dnet/ring/model/llama3.py b/src/dnet/ring/model/llama3.py index 978a7fe5..0ee26280 100644 --- a/src/dnet/ring/model/llama3.py +++ b/src/dnet/ring/model/llama3.py @@ -27,6 +27,7 @@ def __init__( raise RuntimeError(f"API Service doesn't handle layers") self.config = ModelArgs.from_dict(model_config) + self.config.quantization = model_config["quantization"] # lmao self.is_api_layer = is_api_layer self._converted_to_quantized = False @@ -87,13 +88,15 @@ def lm_project(self, x: mx.array): def quantize_layers(self): self.quantization = None + logger.debug(f"{self.config}") if hasattr(self.config, "quantization"): self.quantization = getattr(self.config, "quantization") elif hasattr(self.config, "quantization_config"): self.quantization = getattr(self.config, "quantization_config") + logger.debug(f"QUANTIZING {self.quantization}") if self.quantization is not None: - bits = int(self.quantization.get("bits", 8)) + bits = int(self.quantization.get("bits", 4)) group = int(self.quantization.get("group_size", 64)) try: from mlx.nn.layers.quantized import QuantizedEmbedding From 6dc05bb048bac2ea1b1d551022ce6c73883741be Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 27 Oct 2025 00:50:29 -0700 Subject: [PATCH 072/172] add ShardConfig to __init__ --- src/dnet/ring/shard/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dnet/ring/shard/__init__.py b/src/dnet/ring/shard/__init__.py index ddc69862..00b76b81 100644 --- a/src/dnet/ring/shard/__init__.py +++ b/src/dnet/ring/shard/__init__.py @@ -2,5 +2,6 @@ from .node import RingShardNode from .servicer import ShardServicer +from .config import ShardConfig __all__ = ["RingShardNode", "ShardServicer"] From 255ea0ae00a1d69f20a6f7543b5916972ee166c0 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 27 Oct 2025 01:15:30 -0700 Subject: [PATCH 073/172] rm old bench framework --- src/dnet/perf/bench.py | 144 ----------------------------------------- 1 file changed, 144 deletions(-) delete mode 100644 src/dnet/perf/bench.py diff --git a/src/dnet/perf/bench.py b/src/dnet/perf/bench.py deleted file mode 100644 index 0cdbadd2..00000000 --- a/src/dnet/perf/bench.py +++ /dev/null @@ -1,144 +0,0 @@ - -from __future__ import annotations - -import json -import os -import statistics -import time -from dataclasses import dataclass, field -from typing import Any, Dict, Iterable, List, Optional - -from dnet.perf.trace import Tracer - - -def _percentile(xs: List[float], q: float) -> float: - if not xs: - return 0.0 - ys = sorted(xs) - k = int(round(q * (len(ys) - 1))) - k = max(0, min(k, len(ys) - 1)) - return ys[k] - -def collect_stats(times_ms: List[float], *, bytes_total: float = 0.0, tokens_total: float = 0.0) -> Dict[str, Any]: - if not times_ms: - return { - "mean": 0.0, - "std": 0.0, - "min": 0.0, - "p50": 0.0, - "p90": 0.0, - "p99": 0.0, - "max": 0.0, - "samples": 0, - "mb_s": 0.0, - "tok_s": 0.0, - } - total_ms = sum(times_ms) - mean = total_ms / len(times_ms) - std = statistics.pstdev(times_ms) if len(times_ms) > 1 else 0.0 - total_s = max(total_ms / 1000.0, 1e-12) - return { - "mean": mean, - "std": std, - "min": min(times_ms), - "p50": _percentile(times_ms, 0.5), - "p90": _percentile(times_ms, 0.9), - "p99": _percentile(times_ms, 0.99), - "max": max(times_ms), - "samples": len(times_ms), - "mb_per_s": (bytes_total / 1_000_000.0) / total_s if bytes_total else 0.0, - "tokens_per_s": (tokens_total / total_s) if tokens_total else 0.0, - } - - -def _ensure_dir(path: str) -> None: - d = os.path.dirname(path) or "." - os.makedirs(d, exist_ok=True) - - -@dataclass -class BenchCounters: - values: Dict[str, float] = field(default_factory=dict) - - def add_time(self, key: str, dt_ms: float) -> None: - self.values[key] = self.values.get(key, 0.0) + float(dt_ms) - - def add_bytes(self, *, direction: str, n: int) -> None: - k = "bytes_in" if direction == "in" else "bytes_out" - self.values[k] = self.values.get(k, 0.0) + float(n) - - def inc(self, key: str, delta: float = 1.0) -> None: - self.values[key] = self.values.get(key, 0.0) + float(delta) - - def snapshot(self, *, run_id: str, node: str, role: str = "shard") -> Dict[str, Any]: - snap = { - "run_id": run_id, - "node": node, - "role": role, - "counters": dict(self.values), - } - return snap - - -class TimedSpan: - __slots__ = ("_tracer", "_name", "_attrs", "_t0", "_frame", "_counters", "_counter_key") - - def __init__( - self, - tracer: Optional[Tracer], - name: str, - counters: Optional[BenchCounters] = None, - counter_key: Optional[str] = None, - attrs: Optional[Dict[str, Any]] = None, - ) -> None: - self._tracer = tracer - self._name = name - self._attrs = attrs or {} - self._t0 = 0.0 - self._frame = None - self._counters = counters - self._counter_key = counter_key - - def __enter__(self): - self._t0 = time.perf_counter() - if self._tracer is not None: - self._frame = self._tracer.frame("bench", self._name, self._attrs) - self._frame.__enter__() - return self - - def __exit__(self, ex_type, ex, tb) -> bool: - dt_ms = (time.perf_counter() - self._t0) * 1000.0 - if self._frame is not None: - try: - self._frame.__exit__(ex_type, ex, tb) - except Exception: - pass - if self._counters is not None and self._counter_key: - self._counters.add_time(self._counter_key, dt_ms) - return False - - -def aggregate_annotate( - snapshots: Iterable[Dict[str, Any]], - *, - mapping: Optional[Dict[str, str]] = None, - repeats: int = 0, -) -> List[Dict[str, Any]]: - - sums: Dict[str, float] = {} - for snap in snapshots: - ctr = snap.get("counters") if isinstance(snap, dict) else None - if not isinstance(ctr, dict): - continue - for k, v in ctr.items(): - name = mapping.get(k, k) if mapping else k - try: - sums[name] = sums.get(name, 0.0) + float(v) - except Exception: - continue - - rows = [ {"name": name, "self_ms": val, "total_ms": val, "count": repeats or 0, "max_ms": None} - for name, val in sums.items() if val > 0.0] - rows.sort(key=lambda r: r["self_ms"], reverse=True) - return rows - From 68fb93923124d82b2576d4374e43d3b52cec28c3 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 27 Oct 2025 01:22:50 -0700 Subject: [PATCH 074/172] remove old memory frame path --- src/dnet/perf/utils/aggregators.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index a75518a1..86e07f41 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -386,11 +386,6 @@ def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): #TODO: change shard in metadata stats.network_per_worker[node_id] += e["args"]["ms"] - elif symbol[0] == "memory": - print(f"MEMORY_PER_WORKER: {e["name"]} : {stats.memory_per_worker}") - stats.memory_per_worker[node_id] += e["args"]["ms"] - else: - print(f"UNTRACKED: {e["name"]}") return def _compute_round_stats(self, stats): @@ -403,7 +398,6 @@ def _compute_round_stats(self, stats): stats.itl[-1] = (rounds[i] - rounds[i-1]) * 1e-6 stats.itl.append(rounds[i]) stats.itl = stats.itl[:-1] - print(stats.itl) # Return data for total, per req, worker or model (maybe add per layer too?) def stats( From 313312a92803a9ad0a903b0423c41b60b40af9cb Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 27 Oct 2025 02:08:31 -0700 Subject: [PATCH 075/172] not-working chat interface --- src/repl.py | 58 +++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 11 deletions(-) diff --git a/src/repl.py b/src/repl.py index 231cecb1..f62d68e5 100644 --- a/src/repl.py +++ b/src/repl.py @@ -140,7 +140,8 @@ def loop(self): # Main tty loop self.print_mdns_nodes() continue elif cmd.startswith("load"): - self.load_model() + model = "mlx-community/llama-3.3-70b-instruct-4bit" + self.load_model(model) continue elif cmd.startswith(("trace", ".trace")): self.do_trace(cmd.split(" ")) @@ -212,7 +213,8 @@ def do_topo(self, cmd: List[str]) -> None: self.print_mdns_nodes() pass elif cmd[1] in ("auto", "build", "b"): - self.prepare_topo() + model = "mlx-community/llama-3.3-70b-instruct-4bit" + self.prepare_topo(model) pass elif cmd[1] == "setup": pass @@ -520,8 +522,11 @@ async def _await_then_set(): f.set_result(ret) except BaseException as e: f.set_exception(e) - self._api_loop.call_soon_threadsafe(runner) - return f.result(timeout) + try: + self._api_loop.call_soon_threadsafe(runner) + return f.result(timeout) + except Exception as e: + raise # ------- Trace aggregation helpers @@ -709,24 +714,55 @@ def print_topo(self, topo): sys.stdout.write(f"Devices: {topo.devices}\n\n") # TODO: Better print here - def prepare_topo(self): - req = PrepareTopologyRequest(model="Qwen/Qwen3-4B-MLX-4bit") + def prepare_topo(self, model): + req = PrepareTopologyRequest(model=model) try: - topo = self.api_call("_handle_prepare_topology", req, timeout=30) + topo = self.api_call("_handle_prepare_topology", req, timeout=120) except Exception as e: dprint(f"Unable to create topology: {e}\n\n") - return + return False self.state.topo = topo self.print_topo(topo) + return True - def load_model(self): - req = APILoadModelRequest(model="Qwen/Qwen3-4B-MLX-4bit") + def load_model(self, model): + req = APILoadModelRequest(model=model) try: res = self.api_call("_handle_load_model", req, timeout=30) + return True except Exception as e: dprint(f"Failed to load model: {e}\n\n") - return + return False + # ===== Handle chat + + def do_chat(self, cmd): + model = "mlx-community/llama-3.3-70b-instruct-4bit" + if len(cmd) < 2: + if not self.state.model or self.state.model == "": + self.prompt_model() + if not self.state.topology: + if not self._prepare_topo(self.state.model): + raise RuntimeError("Unable to create topology.") + if not self.load_model(self.state.model): + raise RuntimeError("Unable to load model.") + + while True: + prompt = input("\n> ") + prompt = self.format_prompt(prompt) + messages = prompt + req = ChatRequest( + messages=messages, + max_tokens=100, + temperature=0.7, + stream=True, + ) + + self.api_call("_handle_completion", req) + + # Start default chat with selected model + pass + pass # ===== Handle shutdown From ac203d1b4bfd6d67ea13260f590403c2347380b1 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 2 Nov 2025 18:25:17 -0800 Subject: [PATCH 076/172] fix indent after rebase --- src/dnet/ring/shard/compute.py | 242 +++++++++------------ src/dnet/ring/shard/node.py | 372 +++++++++++++++++---------------- 2 files changed, 285 insertions(+), 329 deletions(-) diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 67f035aa..e0cc1a64 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -276,7 +276,6 @@ def _process_activation(self, activation_msg: ActivationMessage): window_layers, (t_comp_done - t_comp) * 1000.0, ) - """ for lid in window_layers: #self.weight_cache.decrease_reference(lid) @@ -338,37 +337,37 @@ def _process_activation(self, activation_msg: ActivationMessage): self._bound_versions.pop(lid, None) except Exception: pass - else: - if not self._defer_unload: - while len(self._recent_windows) > max( - 1, int(self._resident_windows) - ): - old = self._recent_windows.pop(0) - try: - evicted_cnt = self.weight_cache.evict_layers( - old - ) - except Exception: - evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass - if self._profile: + else: + if not self._defer_unload: + while len(self._recent_windows) > max( + 1, int(self._resident_windows) + ): + old = self._recent_windows.pop(0) try: - logger.info( - "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s", - self.node_id, - activation_msg.nonce, - old, - evicted_cnt, - self._resident_windows, + evicted_cnt = self.weight_cache.evict_layers( + old ) + except Exception: + evicted_cnt = 0 + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) except Exception: pass + if self._profile: + try: + logger.info( + "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s", + self.node_id, + activation_msg.nonce, + old, + evicted_cnt, + self._resident_windows, + ) + except Exception: + pass except Exception: pass @@ -434,48 +433,6 @@ def _process_activation(self, activation_msg: ActivationMessage): is_final=True, token_id=token_id, ) - else: - output_msg = ActivationMessage( - nonce=activation_msg.nonce, - layer_id=last_layer, - pool_id=-1, - shape=cast(tuple[int, ...], x.shape), - batch_size=activation_msg.batch_size, - timestamp=utc_epoch_now(), - node_origin=f"node_{self.node_id}", - dtype=str(self._wire_mx_dtype), - callback_url=activation_msg.callback_url, - tensor=x_cast, - ) - try: - with self._mlx_lock: - y = self.model.normalize(x_cast) - y = self.model.lm_project(y) - # Greedy sample last position - if y.ndim == 3: - logits_2d = y[:, -1, :] - elif y.ndim == 2: - logits_2d = y[-1:, :] - else: - logits_2d = y.reshape(1, -1) - tok = mx.argmax(logits_2d, axis=-1) - token_id = int(tok.item()) - except Exception as e: - logger.error("End-shard sampling failed: %s", e) - return - output_msg = ActivationMessage( - nonce=activation_msg.nonce, - layer_id=last_layer, - pool_id=-1, - shape=cast(tuple[int, ...], x.shape), - batch_size=activation_msg.batch_size, - timestamp=utc_epoch_now(), - node_origin=f"node_{self.node_id}", - dtype=str(self._wire_mx_dtype), - callback_url=activation_msg.callback_url, - is_final=True, - token_id=token_id, - ) else: output_msg = ActivationMessage( nonce=activation_msg.nonce, @@ -490,69 +447,75 @@ def _process_activation(self, activation_msg: ActivationMessage): tensor=x_cast, ) try: - output_msg.tx_enq_perf_t = time.perf_counter() - except Exception: - output_msg.tx_enq_perf_t = 0.0 - # Enqueue to appropriate asyncio TX queue from compute thread - with self.tracer.frame("network.tx", "enque") as f: - output_msg.tx_enq_t = time.perf_counter() - try: - if self._loop is not None: - target_q = ( - self.activation_token_queue - if output_msg.is_final - else self.activation_computed_queue - ) - fut = asyncio.run_coroutine_threadsafe( - target_q.put(output_msg), self._loop - ) - fut.result() - else: - raise RuntimeError("Event loop not available for TX queue") - except Exception as e: - logger.error( - "Failed to queue computed activation for sending: %s", e + with self._mlx_lock: + y = self.model.normalize(x_cast) + y = self.model.lm_project(y) + # Greedy sample last position + if y.ndim == 3: + logits_2d = y[:, -1, :] + elif y.ndim == 2: + logits_2d = y[-1:, :] + else: + logits_2d = y.reshape(1, -1) + tok = mx.argmax(logits_2d, axis=-1) + token_id = int(tok.item()) + except Exception as e: + logger.error("End-shard sampling failed: %s", e) + return + output_msg = ActivationMessage( + nonce=activation_msg.nonce, + layer_id=last_layer, + pool_id=-1, + shape=cast(tuple[int, ...], x.shape), + batch_size=activation_msg.batch_size, + timestamp=utc_epoch_now(), + node_origin=f"node_{self.node_id}", + dtype=str(self._wire_mx_dtype), + callback_url=activation_msg.callback_url, + is_final=True, + token_id=token_id, + ) + else: + output_msg = ActivationMessage( + nonce=activation_msg.nonce, + layer_id=last_layer, + pool_id=-1, + shape=cast(tuple[int, ...], x.shape), + batch_size=activation_msg.batch_size, + timestamp=utc_epoch_now(), + node_origin=f"node_{self.node_id}", + dtype=str(self._wire_mx_dtype), + callback_url=activation_msg.callback_url, + tensor=x_cast, + ) + try: + output_msg.tx_enq_perf_t = time.perf_counter() + except Exception: + output_msg.tx_enq_perf_t = 0.0 + # Enqueue to appropriate asyncio TX queue from compute thread + with self.tracer.frame("network.tx", "enque") as f: + output_msg.tx_enq_t = time.perf_counter() + try: + if self._loop is not None: + target_q = ( + self.activation_token_queue + if output_msg.is_final + else self.activation_computed_queue ) - # Clean up input resources - self.input_pool.release(activation_msg.pool_id) + # Clean up input resources + self.input_pool.release(activation_msg.pool_id) - # Optional unload/evict after stage - with self.tracer.frame("compute.thread", "cleanup"): - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) - if self._mode != "sliding_fit": - if self._defer_unload: - # Clean up input resources - with self.tracer.frame("compute.thread", "cleanup") as f: - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) - self.input_pool.release(activation_msg.pool_id) - # After queuing TX, schedule prefetch and eviction in the background - # to avoid stalling the handoff to the next shard. - try: - while len(self._recent_windows) > max(1, int(getattr(self, "_resident_windows", 2))): - old = self._recent_windows.pop(0) - try: - evicted_cnt = self.weight_cache.evict_layers(old) - except Exception: - evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass - except Exception: - pass - if getattr(self, "_resident_windows", 2) <= 1: - try: - evicted = self.weight_cache.evict_layers(window_layers) - if hasattr(self.mode, "unload_layers"): - self.model.unload_layers(window_layers) - except Exception: - pass + # Optional unload/evict after stage + with self.tracer.frame("compute.thread", "cleanup"): + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + if self._mode != "sliding_fit": + if self._defer_unload: + # Clean up input resources + with self.tracer.frame("compute.thread", "cleanup") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) try: while len(self._recent_windows) > max( 1, int(self._resident_windows) @@ -569,23 +532,14 @@ def _process_activation(self, activation_msg: ActivationMessage): self._bound_versions.pop(lid, None) except Exception: pass - if self._profile: - logger.info( - "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s (post-stage)", - self.node_id, - activation_msg.nonce, - old, - evicted_cnt, - self._resident_windows, - ) except Exception: pass - if self._resident_windows <= 1: - try: - evicted = self.weight_cache.evict_layers(window_layers) - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(window_layers) # type: ignore[attr-defined] + if self._resident_windows <= 1: + try: + evicted = self.weight_cache.evict_layers(window_layers) + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(window_layers) # type: ignore[attr-defined] if self._profile: logger.info( "[PROFILE][EVICT] node=%s nonce=%s layers=%s evicted=%s", diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index f2f985e5..de82cc63 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -610,7 +610,9 @@ async def reset_cache(self) -> None: kv_bits=self.config.kv_cache.bits, kv_group=self.config.kv_cache.group_size, ) - logger.info("Node %s: Cache reset successfully", self.node_id) + logger.info("Node %s: Cache reset successfully", self.node_id) + except Exception as e: + logger.error("Node %s: Error resetting cache: %s", self.node_id, e) # FIXME This seems to still be dead code @@ -630,148 +632,84 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): try: activation = request.activation target_layer = activation.layer_id + 1 + + # Detect new sequence per node: initialize per-nonce KV + if request.nonce != self._active_nonce: + self._active_nonce = request.nonce + try: + payload_bytes = len(activation.data) + except Exception: + payload_bytes = -1 + f.event("process_payload") - # Detect new sequence per node: initialize per-nonce KV - if request.nonce != self._active_nonce: - self._active_nonce = request.nonce - try: - payload_bytes = len(activation.data) - except Exception: - payload_bytes = -1 - f.event("process_payload") - - if target_layer in self._assigned_set: - # Allocate input pool and copy payload (with optional decompression) - t_alloc = time.perf_counter() - if "|" in activation.dtype: - with self.tracer.frame("grpc.receive", "decompress") as fr: - fr.set("req_id", request.nonce) - fr.set("node", self._instance_name) - try: - deq = decompress_tensor_from_protobuf_data( - tensor_data=activation.data, - shape=list(activation.shape), - dtype_with_metadata=activation.dtype, - ) - except Exception as e: - logger.error( - "Decompression failed for nonce %s: %s", request.nonce, e - ) - return - - with self.tracer.frame("grpc.receive", "alloc.buffer") as fr: - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=deq.dtype, - shape=cast(tuple[int, ...], tuple(deq.shape)), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - flat = deq.reshape(-1) - buffer[: flat.size] = flat - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", - self.node_id, - request.nonce, - alloc_copy_ms, - ) - - # Update activation message with true dtype/shape - new_dtype_str = str(deq.dtype) - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = new_dtype_str - activation_msg.shape = tuple(deq.shape) - else: - # Special token stream support: dtype='tokens' carries int32 token IDs - if activation.dtype == "tokens": - with self.tracer.frame("grpc.receive", "token_stream") as fr: + if target_layer in self._assigned_set: + # Allocate input pool and copy payload (with optional decompression) + t_alloc = time.perf_counter() + if "|" in activation.dtype: + with self.tracer.frame("grpc.receive", "decompress") as fr: + fr.set("req_id", request.nonce) + fr.set("node", self._instance_name) try: deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, shape=list(activation.shape), - dtype_with_metadata=activation.dtype) + dtype_with_metadata=activation.dtype, + ) except Exception as e: - logger.error("Decompression failed for nonce %s: %s", request.nonce, e) + logger.error( + "Decompression failed for nonce %s: %s", request.nonce, e + ) return - with self.tracer.frame("network.rx", "alloc.buffer") as fr: - fr.set("req_id", request.nonce) - fr.set("node", self._instance_name) + with self.tracer.frame("grpc.receive", "alloc.buffer") as fr: pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, dtype=deq.dtype, - shape=cast(tuple[int, ...], tuple(deq.shape))) - + shape=cast(tuple[int, ...], tuple(deq.shape)), + ) if pool_id is None: - logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) return - buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: flat = deq.reshape(-1) buffer[: flat.size] = flat + alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 + logger.info( + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", + self.node_id, + request.nonce, + alloc_copy_ms, + ) # Update activation message with true dtype/shape new_dtype_str = str(deq.dtype) activation_msg = ActivationMessage.from_proto(request, pool_id) activation_msg.dtype = new_dtype_str activation_msg.shape = tuple(deq.shape) - - else: # Special token stream support: dtype='tokens' carries int32 token IDs + else: + # Special token stream support: dtype='tokens' carries int32 token IDs if activation.dtype == "tokens": - with self.tracer.frame("network.rx", "token_stream") as fr: - fr.set("req_id", request.nonce) - fr.set("node", self._instance_name) + with self.tracer.frame("grpc.receive", "token_stream") as fr: try: - tokens = np.frombuffer(request.activation.data, dtype=np.int32) - shp = (int(len(tokens)), ) + deq = decompress_tensor_from_protobuf_data( + tensor_data=activation.data, + shape=list(activation.shape), + dtype_with_metadata=activation.dtype) except Exception as e: - logger.error("Failed to parse tokens for nonce %s: %s", request.nonce, e,) - return - - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mx.int32, - shape=cast(tuple[int, ...], shp)) - - if pool_id is None: - logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) + logger.error("Decompression failed for nonce %s: %s", request.nonce, e) return - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - buffer[: len(tokens)] = tokens - activation_msg = ActivationMessage.from_proto(request, pool_id) - - # Ensure dtype reflects token payload for compute path - activation_msg.dtype = "tokens" - activation_msg.shape = shp - - else: - with self.tracer.frame("network.ex", "default") as fr: - fr.set("node", self._instance_name) + with self.tracer.frame("network.rx", "alloc.buffer") as fr: fr.set("req_id", request.nonce) - # Safety: byte length must match shape*dtype - try: - expected = ( - int(np.prod(activation.shape)) - * np.dtype(dtype_map[activation.dtype]).itemsize - ) - actual = len(request.activation.data) - except Exception: - pass - + fr.set("node", self._instance_name) pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, - dtype=mlx_dtype_map[activation.dtype], - shape=cast(tuple[int, ...], activation.shape)) + dtype=deq.dtype, + shape=cast(tuple[int, ...], tuple(deq.shape))) if pool_id is None: logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) @@ -779,26 +717,90 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: - data = request.activation.data - input_data = np.frombuffer(data, dtype=dtype_map[activation.dtype]) - buffer[: len(input_data)] = input_data + flat = deq.reshape(-1) + buffer[: flat.size] = flat + + # Update activation message with true dtype/shape + new_dtype_str = str(deq.dtype) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + + else: # Special token stream support: dtype='tokens' carries int32 token IDs + if activation.dtype == "tokens": + with self.tracer.frame("network.rx", "token_stream") as fr: + fr.set("req_id", request.nonce) + fr.set("node", self._instance_name) + try: + tokens = np.frombuffer(request.activation.data, dtype=np.int32) + shp = (int(len(tokens)), ) + except Exception as e: + logger.error("Failed to parse tokens for nonce %s: %s", request.nonce, e,) + return + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mx.int32, + shape=cast(tuple[int, ...], shp)) + + if pool_id is None: + logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) + return + + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + buffer[: len(tokens)] = tokens + activation_msg = ActivationMessage.from_proto(request, pool_id) + + # Ensure dtype reflects token payload for compute path + activation_msg.dtype = "tokens" + activation_msg.shape = shp - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = new_dtype_str - activation_msg.shape = tuple(deq.shape) - - # Queue for processing — non-blocking back-off loop (cancellable) - while self.running: - try: - self.activation_recv_queue.put_nowait(activation_msg) - activatino_msg.ex_enq_t = time.perf_counter() - logger.debug("Queued activation for processing: nonce %s", activation_msg.nonce) - break - except Full: - await asyncio.sleep(0) - else: - logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce) - self.input_pool.release(pool_id) + else: + with self.tracer.frame("network.ex", "default") as fr: + fr.set("node", self._instance_name) + fr.set("req_id", request.nonce) + # Safety: byte length must match shape*dtype + try: + expected = ( + int(np.prod(activation.shape)) + * np.dtype(dtype_map[activation.dtype]).itemsize + ) + actual = len(request.activation.data) + except Exception: + pass + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mlx_dtype_map[activation.dtype], + shape=cast(tuple[int, ...], activation.shape)) + + if pool_id is None: + logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) + return + + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + data = request.activation.data + input_data = np.frombuffer(data, dtype=dtype_map[activation.dtype]) + buffer[: len(input_data)] = input_data + + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + + # Queue for processing — non-blocking back-off loop (cancellable) + while self.running: + try: + self.activation_recv_queue.put_nowait(activation_msg) + activatino_msg.ex_enq_t = time.perf_counter() + logger.debug("Queued activation for processing: nonce %s", activation_msg.nonce) + break + except Full: + await asyncio.sleep(0) + else: + logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce) + self.input_pool.release(pool_id) else: # Forward to next node (not our layer) logger.debug("Forwarding activation (layer %s) to next node, nonce: %s", target_layer, request.nonce) @@ -852,43 +854,26 @@ async def _ingress_worker(self): activation = req.activation target_layer = activation.layer_id + 1 - # Detect new sequence per node: initialize per-nonce KV - if req.nonce != self._active_nonce: - self._active_nonce = req.nonce - try: - payload_bytes = len(activation.data) - except Exception: - logger.error(f"Unable to read length of data for {req.nonce}") - payload_bytes = -1 + # Detect new sequence per node: initialize per-nonce KV + if req.nonce != self._active_nonce: + self._active_nonce = req.nonce + try: + payload_bytes = len(activation.data) + except Exception: + logger.error(f"Unable to read length of data for {req.nonce}") + payload_bytes = -1 - fr.set("req_id", req.nonce) - f.set("target", target_layer) - f.set("payload_bytes", payload_bytes) - f.event("received") + fr.set("req_id", req.nonce) + f.set("target", target_layer) + f.set("payload_bytes", payload_bytes) + f.event("received") - if target_layer in self._assigned_set: - # Heavy prep in executor (alloc/copy/decompress) - with self.tracer.frame("grpc.ingress", "prepare") as fr: - fr.set("node", self._instance_name) - fr.set("nonce", req.nonce) - loop = asyncio.get_running_loop() - try: - activation_msg = await loop.run_in_executor( - self.executor, - self._prepare_activation_message_blocking, - req, - ) - except Exception as e: - logger.error("Activation prepare failed for nonce %s: %s", req.nonce, e) - continue - if activation_msg is None: - continue - if self._profile: - activation_msg.recv_perf_t = t_recv - - # Enqueue for compute (cancellable back-off) - with self.tracer.frame("grpc.ingress", "queue") as fr: - while self.running: + if target_layer in self._assigned_set: + # Heavy prep in executor (alloc/copy/decompress) + with self.tracer.frame("network.ingress", "prepare") as fr: + fr.set("node", self._instance_name) + fr.set("nonce", req.nonce) + loop = asyncio.get_running_loop() try: activation_msg = await loop.run_in_executor( self.executor, @@ -900,29 +885,46 @@ async def _ingress_worker(self): continue if activation_msg is None: continue + if self._profile: + activation_msg.recv_perf_t = t_recv - # Enqueue for compute - with self.tracer.frame("network.rx", "enque") as fr: - fr.set("req_id", req.nonce) - fr.set("node", self._instance_name) + # Enqueue for compute (cancellable back-off) + with self.tracer.frame("network.ingress", "queue") as fr: while self.running: try: - self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", - activation_msg.nonce, + activation_msg = await loop.run_in_executor( + self.executor, + self._prepare_activation_message_blocking, + req, ) - break - except Full: - await asyncio.sleep(0) - else: - logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce,) - try: - if self.input_pool: - # FIXME: !!! - self.input_pool.release(activation_msg.pool_id) - except Exception: - pass + except Exception as e: + logger.error("Activation prepare failed for nonce %s: %s", req.nonce, e) + continue + if activation_msg is None: + continue + + # Enqueue for compute + with self.tracer.frame("network.rx", "enque") as fr: + fr.set("req_id", req.nonce) + fr.set("node", self._instance_name) + while self.running: + try: + self.activation_recv_queue.put_nowait(activation_msg) + logger.debug( + "Queued activation for processing: nonce %s", + activation_msg.nonce, + ) + break + except Full: + await asyncio.sleep(0) + else: + logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce,) + try: + if self.input_pool: + # FIXME: !!! + self.input_pool.release(activation_msg.pool_id) + except Exception: + pass else: # Forward to next node (not our layer) logger.debug( From b95c68f90524084e4d7daf6f49670e0a5dbfa0d6 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 2 Nov 2025 19:15:05 -0800 Subject: [PATCH 077/172] cleanup weight_cache --- src/dnet/ring/shard/models.py | 2 +- src/dnet/ring/weight_cache.py | 125 +++------------------------------- 2 files changed, 12 insertions(+), 115 deletions(-) diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index f54d7954..428b70bd 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -53,7 +53,7 @@ class ShardUnloadModelResponse(BaseModel): class ShardProfileRequest(BaseModel): """Request to profile device and measure latencies.""" - api_address: Optional[str] = Field( ..., description="API Address" ) + #api_address: Optional[str] = Field( ..., description="API Address" ) devices: Dict[str, DnetDeviceProperties] = Field( ..., description="Device information mapping" ) diff --git a/src/dnet/ring/weight_cache.py b/src/dnet/ring/weight_cache.py index 2a9a22b0..251834aa 100644 --- a/src/dnet/ring/weight_cache.py +++ b/src/dnet/ring/weight_cache.py @@ -64,40 +64,19 @@ def __init__( def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[str, mx.array]]: """Get weight from cache""" - # First, fast path under lock for cache hit or in-flight load - with self.lock: - if layer_id in self.cache: - data, _ = self.cache[layer_id] - # refresh LRU timestamp - self.cache[layer_id] = (data, time.time()) - if inc_ref: - self.reference_counts[layer_id] = ( - self.reference_counts.get(layer_id, 0) + 1 - ) - return data with self.tracer.frame("memory.weights", "cache.search") as f: - with self.lock: - if layer_id in self.cache: - data, _ = self.cache[layer_id] - self.cache[layer_id] = (data, time.time()) # refresh LRU timestamp - if inc_ref: - self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) - logger.debug("Cache hit for layer %s, ref=%d inc=%d", - layer_id, self.reference_counts.get(layer_id, 0), int(inc_ref)) - return data - - inflight = self.loading_futures.get(layer_id) # If a load is in-flight, wait on it outside the lock - if inflight is None: - need_evict = len(self.cache) >= self.max_weights - if need_evict: # Prepare eviction decision now to avoid overfilling once loaded - self._evict_lru() # Evict under lock, then proceed to load - fut = Future() # Install a new future marker so others wait - self.loading_futures[layer_id] = fut - inflight = fut - creator = True - else: - creator = False + inflight = self.loading_futures.get(layer_id) # If a load is in-flight, wait on it outside the lock + if inflight is None: + need_evict = len(self.cache) >= self.max_weights + if need_evict: # Prepare eviction decision now to avoid overfilling once loaded + self._evict_lru() # Evict under lock, then proceed to load + fut = Future() # Install a new future marker so others wait + self.loading_futures[layer_id] = fut + inflight = fut + creator = True + else: + creator = False if creator: # Perform the blocking load without holding the cache lock with self.tracer.frame("memory.weights", "cache.load") as f: @@ -118,93 +97,11 @@ def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[s self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) else: self.reference_counts.setdefault(layer_id, 0) - - try: # Resolve future and clear from tracking - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_result(True) - except Exception: - self.loading_futures.pop(layer_id, None) return data except Exception as e: - with self.lock: # Signal failure to any waiters - try: - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_exception(e) - except Exception: - self.loading_futures.pop(layer_id, None) - logger.exception("Failed to load weight %s: %s", layer_id, e) - return None - - else: # Not the creator: wait for the in-flight load to complete - with self.tracer.frame("memory.weights", "cache.wait") as f: - try: - inflight.result() # block until the creator completes - except Exception as e: - logger.error("Wait for layer %s load failed: %s", layer_id, e) return None - with self.lock: # Return from cache - data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] - if data is None: - logger.error("Wait for layer %s load failed: data not in cache", layer_id) - return None - - self.cache[layer_id] = (data, time.time()) - if inc_ref: - self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) - else: - self.reference_counts.setdefault(layer_id, 0) - # Resolve future and clear from tracking - try: - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_result(True) - except Exception: - self.loading_futures.pop(layer_id, None) - logger.info( - "[PROFILE][MATERIALIZE] layer=%s ms=%.2f bytes=%.2fMB", - layer_id, - dt_ms, - (total_bytes / 1_048_576), - ) - return data - except Exception as e: - # Signal failure to any waiters - with self.lock: - try: - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_exception(e) - except Exception: - self.loading_futures.pop(layer_id, None) - logger.exception("Failed to load weight %s: %s", layer_id, e) - return None - else: - # Not the creator: wait for the in-flight load to complete - t0w = time.perf_counter() - try: - inflight.result() # block until the creator completes - except Exception as e: - logger.error("Wait for layer %s load failed: %s", layer_id, e) - return None - wait_ms = (time.perf_counter() - t0w) * 1000.0 - logger.info("[PROFILE][WAIT-WEIGHT] layer=%s ms=%.2f", layer_id, wait_ms) - # Return from cache (now populated) and update ref/LRU - with self.lock: - data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] - if data is None: - return None - self.cache[layer_id] = (data, time.time()) - if inc_ref: - self.reference_counts[layer_id] = ( - self.reference_counts.get(layer_id, 0) + 1 - ) - else: - self.reference_counts.setdefault(layer_id, 0) - return data def decrease_reference(self, layer_id: int): """Decrease reference count for layer""" From 464feee5c798250edd80e2931b046c212fe73f17 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 2 Nov 2025 19:16:04 -0800 Subject: [PATCH 078/172] Revert "cleanup weight_cache" This reverts commit a12eefb6f3807ff9f1812cd755743bc4664a8714. --- src/dnet/ring/shard/models.py | 2 +- src/dnet/ring/weight_cache.py | 125 +++++++++++++++++++++++++++++++--- 2 files changed, 115 insertions(+), 12 deletions(-) diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index 428b70bd..f54d7954 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -53,7 +53,7 @@ class ShardUnloadModelResponse(BaseModel): class ShardProfileRequest(BaseModel): """Request to profile device and measure latencies.""" - #api_address: Optional[str] = Field( ..., description="API Address" ) + api_address: Optional[str] = Field( ..., description="API Address" ) devices: Dict[str, DnetDeviceProperties] = Field( ..., description="Device information mapping" ) diff --git a/src/dnet/ring/weight_cache.py b/src/dnet/ring/weight_cache.py index 251834aa..2a9a22b0 100644 --- a/src/dnet/ring/weight_cache.py +++ b/src/dnet/ring/weight_cache.py @@ -64,19 +64,40 @@ def __init__( def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[str, mx.array]]: """Get weight from cache""" + # First, fast path under lock for cache hit or in-flight load + with self.lock: + if layer_id in self.cache: + data, _ = self.cache[layer_id] + # refresh LRU timestamp + self.cache[layer_id] = (data, time.time()) + if inc_ref: + self.reference_counts[layer_id] = ( + self.reference_counts.get(layer_id, 0) + 1 + ) + return data with self.tracer.frame("memory.weights", "cache.search") as f: - inflight = self.loading_futures.get(layer_id) # If a load is in-flight, wait on it outside the lock - if inflight is None: - need_evict = len(self.cache) >= self.max_weights - if need_evict: # Prepare eviction decision now to avoid overfilling once loaded - self._evict_lru() # Evict under lock, then proceed to load - fut = Future() # Install a new future marker so others wait - self.loading_futures[layer_id] = fut - inflight = fut - creator = True - else: - creator = False + with self.lock: + if layer_id in self.cache: + data, _ = self.cache[layer_id] + self.cache[layer_id] = (data, time.time()) # refresh LRU timestamp + if inc_ref: + self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) + logger.debug("Cache hit for layer %s, ref=%d inc=%d", + layer_id, self.reference_counts.get(layer_id, 0), int(inc_ref)) + return data + + inflight = self.loading_futures.get(layer_id) # If a load is in-flight, wait on it outside the lock + if inflight is None: + need_evict = len(self.cache) >= self.max_weights + if need_evict: # Prepare eviction decision now to avoid overfilling once loaded + self._evict_lru() # Evict under lock, then proceed to load + fut = Future() # Install a new future marker so others wait + self.loading_futures[layer_id] = fut + inflight = fut + creator = True + else: + creator = False if creator: # Perform the blocking load without holding the cache lock with self.tracer.frame("memory.weights", "cache.load") as f: @@ -97,11 +118,93 @@ def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[s self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) else: self.reference_counts.setdefault(layer_id, 0) + + try: # Resolve future and clear from tracking + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_result(True) + except Exception: + self.loading_futures.pop(layer_id, None) return data except Exception as e: + with self.lock: # Signal failure to any waiters + try: + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_exception(e) + except Exception: + self.loading_futures.pop(layer_id, None) + logger.exception("Failed to load weight %s: %s", layer_id, e) + return None + + else: # Not the creator: wait for the in-flight load to complete + with self.tracer.frame("memory.weights", "cache.wait") as f: + try: + inflight.result() # block until the creator completes + except Exception as e: + logger.error("Wait for layer %s load failed: %s", layer_id, e) return None + with self.lock: # Return from cache + data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] + if data is None: + logger.error("Wait for layer %s load failed: data not in cache", layer_id) + return None + + self.cache[layer_id] = (data, time.time()) + if inc_ref: + self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) + else: + self.reference_counts.setdefault(layer_id, 0) + # Resolve future and clear from tracking + try: + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_result(True) + except Exception: + self.loading_futures.pop(layer_id, None) + logger.info( + "[PROFILE][MATERIALIZE] layer=%s ms=%.2f bytes=%.2fMB", + layer_id, + dt_ms, + (total_bytes / 1_048_576), + ) + return data + except Exception as e: + # Signal failure to any waiters + with self.lock: + try: + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_exception(e) + except Exception: + self.loading_futures.pop(layer_id, None) + logger.exception("Failed to load weight %s: %s", layer_id, e) + return None + else: + # Not the creator: wait for the in-flight load to complete + t0w = time.perf_counter() + try: + inflight.result() # block until the creator completes + except Exception as e: + logger.error("Wait for layer %s load failed: %s", layer_id, e) + return None + wait_ms = (time.perf_counter() - t0w) * 1000.0 + logger.info("[PROFILE][WAIT-WEIGHT] layer=%s ms=%.2f", layer_id, wait_ms) + # Return from cache (now populated) and update ref/LRU + with self.lock: + data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] + if data is None: + return None + self.cache[layer_id] = (data, time.time()) + if inc_ref: + self.reference_counts[layer_id] = ( + self.reference_counts.get(layer_id, 0) + 1 + ) + else: + self.reference_counts.setdefault(layer_id, 0) + return data def decrease_reference(self, layer_id: int): """Decrease reference count for layer""" From 6d42b1c3cf716ad843f9da76642e4ec36e5f42b3 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 2 Nov 2025 19:22:31 -0800 Subject: [PATCH 079/172] remove double code in weight_cache from rebase --- src/dnet/ring/weight_cache.py | 90 +++++++---------------------------- 1 file changed, 18 insertions(+), 72 deletions(-) diff --git a/src/dnet/ring/weight_cache.py b/src/dnet/ring/weight_cache.py index 2a9a22b0..3749522d 100644 --- a/src/dnet/ring/weight_cache.py +++ b/src/dnet/ring/weight_cache.py @@ -65,26 +65,16 @@ def __init__( def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[str, mx.array]]: """Get weight from cache""" # First, fast path under lock for cache hit or in-flight load - with self.lock: - if layer_id in self.cache: - data, _ = self.cache[layer_id] - # refresh LRU timestamp - self.cache[layer_id] = (data, time.time()) - if inc_ref: - self.reference_counts[layer_id] = ( - self.reference_counts.get(layer_id, 0) + 1 - ) - return data - with self.tracer.frame("memory.weights", "cache.search") as f: - with self.lock: - if layer_id in self.cache: + with self.lock: + if layer_id in self.cache: data, _ = self.cache[layer_id] - self.cache[layer_id] = (data, time.time()) # refresh LRU timestamp + # refresh LRU timestamp + self.cache[layer_id] = (data, time.time()) if inc_ref: - self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) - logger.debug("Cache hit for layer %s, ref=%d inc=%d", - layer_id, self.reference_counts.get(layer_id, 0), int(inc_ref)) + self.reference_counts[layer_id] = ( + self.reference_counts.get(layer_id, 0) + 1 + ) return data inflight = self.loading_futures.get(layer_id) # If a load is in-flight, wait on it outside the lock @@ -137,74 +127,30 @@ def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[s self.loading_futures.pop(layer_id, None) logger.exception("Failed to load weight %s: %s", layer_id, e) return None - - else: # Not the creator: wait for the in-flight load to complete + else: + # Not the creator: wait for the in-flight load to complete with self.tracer.frame("memory.weights", "cache.wait") as f: + t0w = time.perf_counter() try: inflight.result() # block until the creator completes except Exception as e: logger.error("Wait for layer %s load failed: %s", layer_id, e) return None - - with self.lock: # Return from cache + wait_ms = (time.perf_counter() - t0w) * 1000.0 + logger.info("[PROFILE][WAIT-WEIGHT] layer=%s ms=%.2f", layer_id, wait_ms) + # Return from cache (now populated) and update ref/LRU + with self.lock: data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] if data is None: - logger.error("Wait for layer %s load failed: data not in cache", layer_id) return None - self.cache[layer_id] = (data, time.time()) if inc_ref: - self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) + self.reference_counts[layer_id] = ( + self.reference_counts.get(layer_id, 0) + 1 + ) else: self.reference_counts.setdefault(layer_id, 0) - # Resolve future and clear from tracking - try: - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_result(True) - except Exception: - self.loading_futures.pop(layer_id, None) - logger.info( - "[PROFILE][MATERIALIZE] layer=%s ms=%.2f bytes=%.2fMB", - layer_id, - dt_ms, - (total_bytes / 1_048_576), - ) - return data - except Exception as e: - # Signal failure to any waiters - with self.lock: - try: - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_exception(e) - except Exception: - self.loading_futures.pop(layer_id, None) - logger.exception("Failed to load weight %s: %s", layer_id, e) - return None - else: - # Not the creator: wait for the in-flight load to complete - t0w = time.perf_counter() - try: - inflight.result() # block until the creator completes - except Exception as e: - logger.error("Wait for layer %s load failed: %s", layer_id, e) - return None - wait_ms = (time.perf_counter() - t0w) * 1000.0 - logger.info("[PROFILE][WAIT-WEIGHT] layer=%s ms=%.2f", layer_id, wait_ms) - # Return from cache (now populated) and update ref/LRU - with self.lock: - data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] - if data is None: - return None - self.cache[layer_id] = (data, time.time()) - if inc_ref: - self.reference_counts[layer_id] = ( - self.reference_counts.get(layer_id, 0) + 1 - ) - else: - self.reference_counts.setdefault(layer_id, 0) - return data + return data def decrease_reference(self, layer_id: int): """Decrease reference count for layer""" From b406efac1d3fcaf5b6b0be0c6d7be230edcfd7f4 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 2 Nov 2025 19:24:53 -0800 Subject: [PATCH 080/172] comment message field for compatibility --- src/dnet/ring/shard/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index f54d7954..428b70bd 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -53,7 +53,7 @@ class ShardUnloadModelResponse(BaseModel): class ShardProfileRequest(BaseModel): """Request to profile device and measure latencies.""" - api_address: Optional[str] = Field( ..., description="API Address" ) + #api_address: Optional[str] = Field( ..., description="API Address" ) devices: Dict[str, DnetDeviceProperties] = Field( ..., description="Device information mapping" ) From 9af0a49d12c9499448eb140478b728ecae0e6c2c Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 3 Nov 2025 01:02:38 -0800 Subject: [PATCH 081/172] fix indent and duplicates from rebase --- src/dnet/ring/shard/compute.py | 481 ++++++++++++++------------------- 1 file changed, 209 insertions(+), 272 deletions(-) diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index e0cc1a64..6b64a48a 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -94,15 +94,14 @@ def _process_activation(self, activation_msg: ActivationMessage): with self.tracer.frame("compute.thread", "activations.process") as f: f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) + if activation_msg.dtype == "tokens": # embed locally on start shard - logger.debug(f"Embedding tokens.") numel = int(np.prod(activation_msg.shape)) tok_view = input_buffer[:numel].reshape(activation_msg.shape) toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) x = self.model.embed(toks[None]) - # NOTE: Used to track start of request in perf stats - self.tracer.mark("embedding", { + self.tracer.mark("embedding", { # NOTE: Used to track start of request in perf stats "nonce": activation_msg.nonce, "prompt_tokens": toks.size, }) @@ -125,12 +124,13 @@ def _process_activation(self, activation_msg: ActivationMessage): current_layer = activation_msg.layer_id + 1 last_layer = current_layer - 1 while True: + start_time = time.perf_counter() processed = 0 did_early_swap = False - with self.tracer.frame("compute.thread", "weights.prepare") as f: - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) + with self.tracer.frame("compute.thread", "weights.prepare") as fr: + fr.set("req_id", activation_msg.nonce) + fr.set("node", self._instance_name) # Determine contiguous local window starting at current_layer window_layers: List[int] = [] @@ -174,23 +174,19 @@ def _process_activation(self, activation_msg: ActivationMessage): # Ensure weights for the window are resident and bind only if arrays changed # if model fits and we've already bound these layers, skip the scan entirely. - fast_fit = ( - self._mode == "fit" - and len(self._assigned_sorted) <= self.window_size - ) - skip_scan = fast_fit and all( - (wl in self._bound_versions) for wl in window_layers - ) + fast_fit = (self._mode == "fit" and len(self._assigned_sorted) <= self.window_size) + skip_scan = fast_fit and all( (wl in self._bound_versions) for wl in window_layers) + to_bind: Dict[str, mx.array] = {} if not skip_scan: + t_w_ready = time.perf_counter() for wl in window_layers: weights = self.weight_cache.get_weight(wl) if weights is None: logger.error("Failed to load weights for layer %s", wl) self.input_pool.release(activation_msg.pool_id) return - try: - # Use identity of first array as a cheap version/fingerprint + try: # Use identity of first array as a cheap version/fingerprint first_arr = next(iter(weights.values())) version = id(first_arr) except StopIteration: @@ -211,118 +207,84 @@ def _process_activation(self, activation_msg: ActivationMessage): t_w_ms, ) - # Opportunistically schedule prefetch for the next window to overlap with compute - try: - next_win_pre = self._next_local_layers( - (window_layers[-1] if window_layers else (activation_msg.layer_id)), - self.window_size, - ) - for nl in next_win_pre: - self._prefetch_to_ram(nl) - self._enqueue_weight_prefetch(nl) - except Exception: - pass - # Execute the window with self.tracer.frame("compute.thread", "execute") as f: f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) - self._beyond_cursor = window_layers[-1] if window_layers else (activation_msg.layer_id) - try: # Prevent prefetch touching during encode/compute to minimize UMA pressure + if to_bind: # Block prefetch-touch during binding and serialize MLX ops self._compute_busy.set() - except Exception: - pass - t_bind = time.perf_counter() - with self._mlx_lock: - self.model.load_weights(list(to_bind.items()), strict=False) - bind_ms = (time.perf_counter() - t_bind) * 1000.0 + t_bind = time.perf_counter() + with self._mlx_lock: + self.model.load_weights(list(to_bind.items()), strict=False) + bind_ms = (time.perf_counter() - t_bind) * 1000.0 + if self._profile: + logger.info( + "[PROFILE][BIND] node=%s nonce=%s layers=%s tensors=%s bind_ms=%.3f", + self.node_id, + activation_msg.nonce, + window_layers, + len(to_bind), + bind_ms, + ) + t_comp = time.perf_counter() + self._compute_busy.set() + for i, lyr in enumerate(window_layers): + with self._mlx_lock: + x = self.model.apply_single_layer(lyr, x, cache=kv) + try: + if str(x.dtype) != str(self._wire_mx_dtype): + x = x.astype(self._wire_mx_dtype) + except Exception: + pass + + last_layer = ( window_layers[-1] if window_layers else activation_msg.layer_id) + mx.eval(x) + if self._profile: + t_comp_done = time.perf_counter() logger.info( - "[PROFILE][BIND] node=%s nonce=%s layers=%s tensors=%s bind_ms=%.3f", + "[PROFILE][WINDOW] node=%s nonce=%s layers=%s compute_ms=%.3f", self.node_id, activation_msg.nonce, window_layers, - len(to_bind), - bind_ms, + (t_comp_done - t_comp) * 1000.0, ) - t_comp = time.perf_counter() - try: - self._compute_busy.set() - except Exception: - pass - for i, lyr in enumerate(window_layers): - with self._mlx_lock: - x = self.model.apply_single_layer(lyr, x, cache=kv) - try: - if str(x.dtype) != str(self._wire_mx_dtype): - x = x.astype(self._wire_mx_dtype) - except Exception: - pass - - last_layer = ( - window_layers[-1] if window_layers else activation_msg.layer_id - ) - try: - mx.eval(x) - except Exception: - pass - if self._profile: - t_comp_done = time.perf_counter() - logger.info( - "[PROFILE][WINDOW] node=%s nonce=%s layers=%s compute_ms=%.3f", - self.node_id, - activation_msg.nonce, - window_layers, - (t_comp_done - t_comp) * 1000.0, - ) - for lid in window_layers: - #self.weight_cache.decrease_reference(lid) - pass with self.tracer.frame("compute.thread", "execute.evict_and_unload") as f: f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) - try: - # Sliding-fit delta swap: maintain a single resident set by evicting + + for lid in window_layers: + self.weight_cache.decrease_reference(lid) + + try: # Sliding-fit delta swap: maintain a single resident set by evicting # only what's needed to fit the next window into the budget. Prefer # keeping the tail of the previous window so we don't thrash weights # that are likely to be reused. if self._mode == "sliding_fit": if int(self._resident_windows) <= 1: - if did_early_swap: - # Early delta-swap already trimmed resident set for this window + if did_early_swap: # Early delta-swap already trimmed resident set for this window pass - elif not self._recent_windows: - # First window in token: seed resident set + elif not self._recent_windows: # First window in token: seed resident set self._recent_windows.append(list(window_layers)) else: prev = self._recent_windows.pop(0) - self._delta_swap_eviction(window_layers, prev, activation_msg, early=False) + self._delta_swap_eviction( + window_layers, prev, activation_msg, early=False + ) budget = max(1, int(self.window_size or 1)) curr = list(window_layers) prev_only = [x for x in prev if x not in curr] keep_quota = max(0, budget - len(curr)) - keep_tail = prev_only[-keep_quota:] if keep_quota > 0 else [] + keep_tail = ( + prev_only[-keep_quota:] if keep_quota > 0 else [] + ) combined = list(keep_tail) + curr self._recent_windows.append(combined) - else: - prev = self._recent_windows.pop(0) - self._delta_swap_eviction( - window_layers, prev, activation_msg, early=False - ) - budget = max(1, int(self.window_size or 1)) - curr = list(window_layers) - prev_only = [x for x in prev if x not in curr] - keep_quota = max(0, budget - len(curr)) - keep_tail = ( - prev_only[-keep_quota:] if keep_quota > 0 else [] - ) - combined = list(keep_tail) + curr - self._recent_windows.append(combined) - else: - # Original eviction policy for other modes + + else: # Original eviction policy for other modes self._recent_windows.append(list(window_layers)) if int(self._resident_windows) <= 1: old = self._recent_windows.pop(0) @@ -330,13 +292,11 @@ def _process_activation(self, activation_msg: ActivationMessage): evicted_cnt = self.weight_cache.evict_layers(old) except Exception: evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass + + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) else: if not self._defer_unload: while len(self._recent_windows) > max( @@ -349,76 +309,85 @@ def _process_activation(self, activation_msg: ActivationMessage): ) except Exception: evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) if self._profile: - try: - logger.info( - "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s", - self.node_id, - activation_msg.nonce, - old, - evicted_cnt, - self._resident_windows, - ) - except Exception: - pass + logger.info( + "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s", + self.node_id, + activation_msg.nonce, + old, + evicted_cnt, + self._resident_windows, + ) except Exception: pass - # If next layer is still local, continue without staging/tx - nxt = last_layer + 1 - if nxt in self._assigned_set: - current_layer = nxt - continue + computation_time = time.perf_counter() - start_time + self._prof.info( + "[PROFILE][COMPUTE] node=%s nonce=%s window_layers=%s total_ms=%.3f", + self.node_id, + activation_msg.nonce, + window_layers, + computation_time * 1000.0, + ) + self._prof.info( + "Completed layers up to %s in %.3fs, nonce: %s, result: %s %s", + last_layer, + computation_time, + activation_msg.nonce, + x.shape, + x.dtype, + ) + + # If next layer is still local, continue without staging/tx + nxt = last_layer + 1 + if nxt in self._assigned_set: + current_layer = nxt + continue # Boundary reached — directly pass tensor to TX to avoid extra copy/sync - with self.tracer.frame("compute.thread", "execute.enqueue_prefetch") as f: + with self.tracer.frame("compute.thread", "staging") as f: f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) - x_cast = x if x.dtype == self._wire_mx_dtype else x.astype(self._wire_mx_dtype) - try: - self._compute_busy.clear() - except Exception: - pass - try: - for lid in list(self._prefetch_pending): - self._prefetch_pending.discard(lid) - self._enqueue_weight_prefetch(lid) - except Exception: - pass - # Create and enqueue output message: either forward activations or finalize on end role - with self.tracer.frame("network.tx", "send") as f: - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) + t_stage = time.perf_counter() + x_cast = ( x if x.dtype == self._wire_mx_dtype else x.astype(self._wire_mx_dtype)) + self._compute_busy.clear() - nxt = last_layer + 1 - if nxt >= self.model_metadata.num_layers: # End of model - with self.tracer.frame("compute.thread", "sampling") as f: - try: - with self._mlx_lock: - y = self.model.normalize(x_cast) - y = self.model.lm_project(y) - #self.tracer.mark("lm_head", {"nonce": actication_msg.nonce}) # NOTE: canonical stats end - - # Greedy sample last position - if y.ndim == 3: - logits_2d = y[:, -1, :] - elif y.ndim == 2: - logits_2d = y[-1:, :] - else: - logits_2d = y.reshape(1, -1) - tok = mx.argmax(logits_2d, axis=-1) - token_id = int(tok.item()) - except Exception as e: - logger.error("End-shard sampling failed: %s", e) - return + if self._profile: + logger.info( + "[PROFILE][STAGE-DIRECT] node=%s nonce=%s layer_tail=%s stage_ms=%.3f shape=%s dtype=%s", + self.node_id, + activation_msg.nonce, + last_layer, + (time.perf_counter() - t_stage) * 1000.0, + tuple(x_cast.shape), + str(self._wire_mx_dtype), + ) + + if nxt >= self.model_metadata.num_layers: # End of model + with self.tracer.frame("compute.thread", "sampling") as fr: + try: + with self._mlx_lock: + y = self.model.normalize(x_cast) + y = self.model.lm_project(y) + self.tracer.mark("lm_head", {"nonce": activation_msg.nonce}) # NOTE: canonical stats end + + # Greedy sample last position + if y.ndim == 3: + logits_2d = y[:, -1, :] + elif y.ndim == 2: + logits_2d = y[-1:, :] + else: + logits_2d = y.reshape(1, -1) + tok = mx.argmax(logits_2d, axis=-1) + token_id = int(tok.item()) + except Exception as e: + logger.error("End-shard sampling failed: %s", e) + return output_msg = ActivationMessage( nonce=activation_msg.nonce, @@ -433,123 +402,91 @@ def _process_activation(self, activation_msg: ActivationMessage): is_final=True, token_id=token_id, ) - else: - output_msg = ActivationMessage( - nonce=activation_msg.nonce, - layer_id=last_layer, - pool_id=-1, - shape=cast(tuple[int, ...], x.shape), - batch_size=activation_msg.batch_size, - timestamp=utc_epoch_now(), - node_origin=f"node_{self.node_id}", - dtype=str(self._wire_mx_dtype), - callback_url=activation_msg.callback_url, - tensor=x_cast, - ) - try: - with self._mlx_lock: - y = self.model.normalize(x_cast) - y = self.model.lm_project(y) - # Greedy sample last position - if y.ndim == 3: - logits_2d = y[:, -1, :] - elif y.ndim == 2: - logits_2d = y[-1:, :] else: - logits_2d = y.reshape(1, -1) - tok = mx.argmax(logits_2d, axis=-1) - token_id = int(tok.item()) - except Exception as e: - logger.error("End-shard sampling failed: %s", e) - return - output_msg = ActivationMessage( - nonce=activation_msg.nonce, - layer_id=last_layer, - pool_id=-1, - shape=cast(tuple[int, ...], x.shape), - batch_size=activation_msg.batch_size, - timestamp=utc_epoch_now(), - node_origin=f"node_{self.node_id}", - dtype=str(self._wire_mx_dtype), - callback_url=activation_msg.callback_url, - is_final=True, - token_id=token_id, - ) - else: - output_msg = ActivationMessage( - nonce=activation_msg.nonce, - layer_id=last_layer, - pool_id=-1, - shape=cast(tuple[int, ...], x.shape), - batch_size=activation_msg.batch_size, - timestamp=utc_epoch_now(), - node_origin=f"node_{self.node_id}", - dtype=str(self._wire_mx_dtype), - callback_url=activation_msg.callback_url, - tensor=x_cast, - ) - try: - output_msg.tx_enq_perf_t = time.perf_counter() - except Exception: - output_msg.tx_enq_perf_t = 0.0 - # Enqueue to appropriate asyncio TX queue from compute thread - with self.tracer.frame("network.tx", "enque") as f: - output_msg.tx_enq_t = time.perf_counter() - try: - if self._loop is not None: - target_q = ( - self.activation_token_queue - if output_msg.is_final - else self.activation_computed_queue + output_msg = ActivationMessage( + nonce=activation_msg.nonce, + layer_id=last_layer, + pool_id=-1, + shape=cast(tuple[int, ...], x.shape), + batch_size=activation_msg.batch_size, + timestamp=utc_epoch_now(), + node_origin=f"node_{self.node_id}", + dtype=str(self._wire_mx_dtype), + callback_url=activation_msg.callback_url, + tensor=x_cast, ) - # Clean up input resources - self.input_pool.release(activation_msg.pool_id) + # Clean up input resources + self.input_pool.release(activation_msg.pool_id) - # Optional unload/evict after stage - with self.tracer.frame("compute.thread", "cleanup"): - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) - if self._mode != "sliding_fit": - if self._defer_unload: - # Clean up input resources - with self.tracer.frame("compute.thread", "cleanup") as f: - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) - try: - while len(self._recent_windows) > max( - 1, int(self._resident_windows) - ): - old = self._recent_windows.pop(0) - try: - evicted_cnt = self.weight_cache.evict_layers(old) - except Exception: - evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass - except Exception: - pass + try: + output_msg.tx_enq_perf_t = time.perf_counter() + except Exception: + output_msg.tx_enq_perf_t = 0.0 - if self._resident_windows <= 1: + # Enqueue to appropriate asyncio TX queue from compute thread + with self.tracer.frame("network.tx", "enque") as fr: + output_msg.tx_enq_t = time.perf_counter() try: - evicted = self.weight_cache.evict_layers(window_layers) - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(window_layers) # type: ignore[attr-defined] - if self._profile: - logger.info( - "[PROFILE][EVICT] node=%s nonce=%s layers=%s evicted=%s", - self.node_id, - activation_msg.nonce, - window_layers, - evicted, + if self._loop is not None: + target_q = ( + self.activation_token_queue + if output_msg.is_final + else self.activation_computed_queue ) - except Exception: - pass + fut = asyncio.run_coroutine_threadsafe( + target_q.put(output_msg), self._loop + ) + fut.result() + else: + raise RuntimeError("Event loop not available for TX queue") + except Exception as e: + logger.error( + "Failed to queue computed activation for sending: %s", e + ) + + # Clean up input resources + with self.tracer.frame("compute.thread", "cleanup") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + try: + while len(self._recent_windows) > max( + 1, int(self._resident_windows) + ): + old = self._recent_windows.pop(0) + try: + while len(self._recent_windows) > max(1, int(getattr(self, "_resident_windows", 2))): + old = self._recent_windows.pop(0) + try: + evicted_cnt = self.weight_cache.evict_layers(old) + except Exception: + evicted_cnt = 0 + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) + except Exception: + pass + except Exception: + pass + + if self._resident_windows <= 1: + try: + evicted = self.weight_cache.evict_layers(window_layers) + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(window_layers) # type: ignore[attr-defined] + if self._profile: + logger.info( + "[PROFILE][EVICT] node=%s nonce=%s layers=%s evicted=%s", + self.node_id, + activation_msg.nonce, + window_layers, + evicted, + ) + except Exception: + pass return except Exception as e: logger.exception("Error processing activation: %s", e) + From 95a42c7e8c5f973ca1f4d5a2828217b7c22eda75 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 3 Nov 2025 02:14:56 -0800 Subject: [PATCH 082/172] small rebase fixes --- src/dnet/ring/shard/comms.py | 232 +++++++++++++++++--------------- src/dnet/ring/shard/compute.py | 1 + src/dnet/ring/shard/node.py | 82 +++++------ src/dnet/ring/shard/servicer.py | 1 + 4 files changed, 159 insertions(+), 157 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index 1725fd4f..8d55fe7e 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -251,8 +251,10 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) except Exception: pass + cb = activation_msg.callback_url or "" parsed = urlparse(cb) if cb else None + t_rpc = time.perf_counter() if parsed and parsed.scheme == "grpc": addr = parsed.netloc if not addr: @@ -282,15 +284,25 @@ async def _send_activation(self, activation_msg: ActivationMessage): tx_enq_prev_t=time.perf_counter(), ) resp = await self.api_stub.SendToken(req) # type: ignore + rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 if not resp.success: logger.error( "API SendToken failed for %s: %s", activation_msg.nonce, resp.message, ) + if self._profile: + logger.info( + "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", + self.node_id, + activation_msg.nonce, + int(getattr(activation_msg, "token_id", -1)), + rpc_ms, + ) except Exception as e: logger.exception("Error sending token via gRPC: %s", e) else: + logger.error(activation_msg) logger.error( "No valid gRPC callback for token; cannot deliver nonce=%s", activation_msg.nonce, @@ -301,51 +313,49 @@ async def _send_activation(self, activation_msg: ActivationMessage): # FIXME: shaped var is a bit weird (is it np_array or mlx_array), @andthattoo shall check shaped = activation_msg.tensor - with self.tracer.frame("gprc.send_activations.default", "get_buffer") as fr: - fr.set("req_id", activation_msg.nonce) - fr.set("node", self._instance_name) + with self.tracer.frame("network.send_activations.default", "get_buffer") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) if shaped is None: output_buffer = self.output_pool.get_buffer(activation_msg.pool_id) if output_buffer is None: - logger.error( - "Failed to get output buffer %s", activation_msg.pool_id - ) + logger.error("Failed to get output buffer %s", activation_msg.pool_id) return - if self._profile: - logger.info( - "[PROFILE][SER-START] node=%s nonce=%s", - self.node_id, - activation_msg.nonce, - ) - t_ser = time.perf_counter() - t_cast = t_ser - _len_bytes = int(getattr(shaped, "nbytes", 0)) - _do_compress = bool( - self._compress and _len_bytes >= self._compress_min_bytes - ) - if _do_compress: - # Skip compression for decode. - _do_compress = False - try: - wire_np_dtype = dtype_map[self._wire_dtype_str] - except Exception: - wire_np_dtype = np.float16 # reasonable default fallback + if self._profile: + logger.info( + "[PROFILE][SER-START] node=%s nonce=%s", + self.node_id, + activation_msg.nonce, + ) - if isinstance(shaped, np.ndarray): - logger.warning("Activation tensor is a numpy array!!!") - if shaped.dtype != wire_np_dtype: - # FIXME: numpy vs mx array here - shaped = shaped.astype(wire_np_dtype, copy=False) + with self.tracer.frame("network.tx", "cast") as f: + t_ser = time.perf_counter() + t_cast = t_ser + _len_bytes = int(getattr(shaped, "nbytes", 0)) + _do_compress = bool( + self._compress and _len_bytes >= self._compress_min_bytes + ) + if _do_compress: + # Skip compression for decode. + _do_compress = False + try: + wire_np_dtype = dtype_map[self._wire_dtype_str] + except Exception: + wire_np_dtype = np.float16 # reasonable default fallback - else: # MLX array -> cast to desired wire dtype - if str(shaped.dtype) != self._wire_dtype_str: - shaped = shaped.astype(self._wire_mx_dtype) + if isinstance(shaped, np.ndarray): + logger.warning("Activation tensor is a numpy array!!!") + if shaped.dtype != wire_np_dtype: + # FIXME: numpy vs mx array here + shaped = shaped.astype(wire_np_dtype, copy=False) - activation_msg.dtype = self._wire_dtype_str - t_cast = time.perf_counter() + else: # MLX array -> cast to desired wire dtype + if str(shaped.dtype) != self._wire_dtype_str: + shaped = shaped.astype(self._wire_mx_dtype) - with self.tracer.frame("grpc", "send_activations.cast_to_dtype") as f: + activation_msg.dtype = self._wire_dtype_str + t_cast = time.perf_counter() if isinstance(shaped, np.ndarray): # Cast to target dtype if shaped.dtype != wire_np_dtype: @@ -358,15 +368,15 @@ async def _send_activation(self, activation_msg: ActivationMessage): f.event("mxarray.cast") data = tensor_to_bytes(shaped) - activation_msg.dtype = self._wire_dtype_str + activation_msg.dtype = self._wire_dtype_str - nxt = activation_msg.layer_id + 1 - if (nxt < self.model_metadata.num_layers) and (nxt not in self._assigned_set): - if self.next_node_stub: + with self.tracer.frame("memory", "prepare.window") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) - with self.tracer.frame("network", "send_activation.next") as f: - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) + nxt = activation_msg.layer_id + 1 + if (nxt < self.model_metadata.num_layers) and (nxt not in self._assigned_set): + if self.next_node_stub: request = activation_msg.to_proto(data) request.timestamp = utc_epoch_now() if self._mode == "offload" and self.window_size > 0: @@ -476,82 +486,84 @@ async def _send_activation(self, activation_msg: ActivationMessage): ring_timeout, ring_retries, ) - t0 = time.perf_counter() - last_exc: Optional[Exception] = None - for attempt in range(1, ring_retries + 1): - try: - # FIXME: use response here? - _ = await self.next_node_stub.SendActivation( - request, timeout=ring_timeout - ) # type: ignore - break - except grpc.aio.AioRpcError as e: # type: ignore - last_exc = e - code = e.code() - if code in { - grpc.StatusCode.UNAVAILABLE, - grpc.StatusCode.CANCELLED, - grpc.StatusCode.DEADLINE_EXCEEDED, - }: - logger.warning( - "SendActivation attempt %s/%s failed (%s); reconnecting...", - attempt, - ring_retries, - code.name, - ) - await self._reconnect_next_node() - await asyncio.sleep(min(0.25 * attempt, 1.0)) - continue - raise - else: - raise last_exc # type: ignore - rpc_ms = (time.perf_counter() - t0) * 1000.0 - logger.info( - "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", - self.node_id, - activation_msg.nonce, - activation_msg.layer_id + 1, - (len(data) / 1024), - ser_ms, - rpc_ms, - cast_ms, - ) - else: - logger.error("Cannot forward activation - no next node configured; end shard should sample inline.") + t0 = time.perf_counter() + last_exc: Optional[Exception] = None + for attempt in range(1, ring_retries + 1): + try: + # FIXME: use response here? + _ = await self.next_node_stub.SendActivation( + request, timeout=ring_timeout + ) # type: ignore + break + except grpc.aio.AioRpcError as e: # type: ignore + last_exc = e + code = e.code() + if code in { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.CANCELLED, + grpc.StatusCode.DEADLINE_EXCEEDED, + }: + logger.warning( + "SendActivation attempt %s/%s failed (%s); reconnecting...", + attempt, + ring_retries, + code.name, + ) + await self._reconnect_next_node() + await asyncio.sleep(min(0.25 * attempt, 1.0)) + continue + raise + else: + raise last_exc # type: ignore + rpc_ms = (time.perf_counter() - t0) * 1000.0 + logger.info( + "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", + self.node_id, + activation_msg.nonce, + activation_msg.layer_id + 1, + (len(data) / 1024), + ser_ms, + rpc_ms, + cast_ms, + ) + else: + logger.error("Cannot forward activation - no next node configured; end shard should sample inline.") - # Final layer not annotated with 'is_final' - else: - logger.error( - "Final activation reached send path unexpectedly; sampling should occur on end shard." - ) - # Clear scheduling at request end - # Sequential offload: prefetch state is unused + # Final layer not annotated with 'is_final' + else: + logger.error( + "Final activation reached send path unexpectedly; sampling should occur on end shard." + ) + # Clear scheduling at request end + # Sequential offload: prefetch state is unused - # Optional: explicitly end the per-nonce stream on request completion - # Enable by setting RING_EXPLICIT_EOR=1 when you emit a true end-of-request signal. - try: - if self._explicit_eor: - if ( - hasattr(self, "_streams") - and activation_msg.nonce in self._streams - ): - await self._end_stream(activation_msg.nonce, eor=True) - except Exception: - pass + # Optional: explicitly end the per-nonce stream on request completion + # Enable by setting RING_EXPLICIT_EOR=1 when you emit a true end-of-request signal. + try: + if self._explicit_eor: + if ( + hasattr(self, "_streams") + and activation_msg.nonce in self._streams + ): + await self._end_stream(activation_msg.nonce, eor=True) + except Exception: + pass - # Release resources at end of send - try: - activation_msg.tensor = None - except Exception: - pass - if used_pool: + # Release resources at end of send try: - self.output_pool.release(activation_msg.pool_id) + activation_msg.tensor = None except Exception: pass + if used_pool: + try: + self.output_pool.release(activation_msg.pool_id) + except Exception: + pass except Exception as e: logger.exception("Error sending activation: %s", e) + + async def _connect_next_node(self) -> bool: """Connect to next node in ring. diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 6b64a48a..489e6b24 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -74,6 +74,7 @@ def _process_activation(self, activation_msg: ActivationMessage): logger.error("Node %s: Cannot process activation - model not loaded", self.node_id) return + logger.error(f"PROCESS_ACTIVATION: {activation_msg.callback_url}") try: # per-nonce kvcache for concurrent requests with self.tracer.frame("compute.thread", "kvcache.init") as f: diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index de82cc63..009223a9 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -618,6 +618,7 @@ async def reset_cache(self) -> None: # FIXME This seems to still be dead code async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): """Receive activation from previous node and queue for local compute or forward.""" + logger.debug("RECEIVE ACTIVATION") if self.input_pool is None: logger.error("Node %s: Cannot receive activation - input pool not initialized", self.node_id) return @@ -792,6 +793,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): # Queue for processing — non-blocking back-off loop (cancellable) while self.running: try: + logger.error(f"NETWORK RX: {activation_msg.callback_url}") self.activation_recv_queue.put_nowait(activation_msg) activatino_msg.ex_enq_t = time.perf_counter() logger.debug("Queued activation for processing: nonce %s", activation_msg.nonce) @@ -818,6 +820,7 @@ async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: request.rx_enq_t = rx_t request.rx_inflight_t = 0.0 if request.tx_enq_prev_t == 0.0 else rx_t - request.tx_enq_prev_t + logger.error(f"ADMIT_FRAME: {request.callback_url}") self.ingress_q.put_nowait(request) logger.debug(f"[ENQUE] Enqueued activation request") return @@ -839,6 +842,7 @@ async def _ingress_worker(self): req = await self.ingress_q.get() logger.debug(f"[DEQUE]Dequeued activation for processing") except asyncio.CancelledError: + logger.error("Error while waiting ingress worker.") break # Trace processing of request, in-flight and in-wait times @@ -849,30 +853,27 @@ async def _ingress_worker(self): f.set("req_id", req.nonce) try: - await self._connect_next_node() - activation = req.activation target_layer = activation.layer_id + 1 + try: + payload_bytes = len(activation.data) + except Exception: + payload_bytes = -1 + # Detect new sequence per node: initialize per-nonce KV if req.nonce != self._active_nonce: self._active_nonce = req.nonce try: - payload_bytes = len(activation.data) + self._get_or_make_kv(req.nonce) except Exception: - logger.error(f"Unable to read length of data for {req.nonce}") - payload_bytes = -1 - - fr.set("req_id", req.nonce) - f.set("target", target_layer) - f.set("payload_bytes", payload_bytes) - f.event("received") + pass if target_layer in self._assigned_set: # Heavy prep in executor (alloc/copy/decompress) with self.tracer.frame("network.ingress", "prepare") as fr: - fr.set("node", self._instance_name) - fr.set("nonce", req.nonce) + #fr.set("node", self._instance_name) + #fr.set("nonce", req.nonce) loop = asyncio.get_running_loop() try: activation_msg = await loop.run_in_executor( @@ -885,46 +886,32 @@ async def _ingress_worker(self): continue if activation_msg is None: continue - if self._profile: - activation_msg.recv_perf_t = t_recv + #if self._profile: + # activation_msg.recv_perf_t = t_recv # Enqueue for compute (cancellable back-off) - with self.tracer.frame("network.ingress", "queue") as fr: + with self.tracer.frame("network.rx", "enque") as fr: + fr.set("req_id", req.nonce) + fr.set("node", self._instance_name) while self.running: try: - activation_msg = await loop.run_in_executor( - self.executor, - self._prepare_activation_message_blocking, - req, + logger.error(f"NETWORK RX: {activation_msg.callback_url}") + self.activation_recv_queue.put_nowait(activation_msg) + logger.debug( + "Queued activation for processing: nonce %s", + activation_msg.nonce, ) - except Exception as e: - logger.error("Activation prepare failed for nonce %s: %s", req.nonce, e) - continue - if activation_msg is None: - continue - - # Enqueue for compute - with self.tracer.frame("network.rx", "enque") as fr: - fr.set("req_id", req.nonce) - fr.set("node", self._instance_name) - while self.running: - try: - self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", - activation_msg.nonce, - ) - break - except Full: - await asyncio.sleep(0) - else: - logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce,) - try: - if self.input_pool: - # FIXME: !!! - self.input_pool.release(activation_msg.pool_id) - except Exception: - pass + break + except Full: + await asyncio.sleep(0) + else: + logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce,) + try: + if self.input_pool: + # FIXME: !!! + self.input_pool.release(activation_msg.pool_id) + except Exception: + logger.error("Unable to release from input pool") else: # Forward to next node (not our layer) logger.debug( @@ -1501,6 +1488,7 @@ async def profile(req: ShardProfileRequest) -> ShardProfileResponse: device_profile = await self._profile_device( req.repo_id, req.max_batch_exp ) + logger.debug(device_profile) # Overwrite `t_comm` with median latency (subprocess returns a dict) median_latency = calculate_median_latency_seconds(latency_results) diff --git a/src/dnet/ring/shard/servicer.py b/src/dnet/ring/shard/servicer.py index 93e8627e..006f38a8 100644 --- a/src/dnet/ring/shard/servicer.py +++ b/src/dnet/ring/shard/servicer.py @@ -35,6 +35,7 @@ async def SendActivation( request.activation.layer_id, ) + logger.error(f"SERVICER: {request.callback_url}") await self.node.admit_frame(request) return ActivationResponse( From 3da09c916c7107be1fb69dcc3b8d533b9d565275 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 3 Nov 2025 02:15:22 -0800 Subject: [PATCH 083/172] change order of elements so callback_url is position 4 again --- src/dnet/protos/dnet_ring.proto | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dnet/protos/dnet_ring.proto b/src/dnet/protos/dnet_ring.proto index 8009601f..d1b3b33a 100644 --- a/src/dnet/protos/dnet_ring.proto +++ b/src/dnet/protos/dnet_ring.proto @@ -32,11 +32,11 @@ message ActivationRequest { string nonce = 1; Activation activation = 2; int64 timestamp = 3; - float rx_enq_t = 4; - float tx_enq_prev_t = 5; - float rx_inflight_t = 6; - string node_origin = 7; - string callback_url = 8; + string node_origin = 4; + string callback_url = 5; + float rx_enq_t = 6; + float tx_enq_prev_t = 7; + float rx_inflight_t = 8; } // Response message for activation sending From 532f127d386d244e300751102fc280bd755524f8 Mon Sep 17 00:00:00 2001 From: erhant Date: Mon, 10 Nov 2025 10:18:47 +0300 Subject: [PATCH 084/172] fix type issues, fix `is_head` setting --- src/dnet/ring/api/node.py | 128 +++++++++++++++----------------- src/dnet/ring/api/utils.py | 18 +---- src/dnet/ring/api/utils_test.py | 23 ++---- 3 files changed, 70 insertions(+), 99 deletions(-) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index b0f09d14..48e38549 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -6,7 +6,6 @@ import json from dataclasses import asdict from io import StringIO -from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Tuple, Union import httpx @@ -42,10 +41,9 @@ from ...utils.logger import logger from ...utils.banner import print_startup_banner -from ...utils.latency import calculate_median_latency_seconds +from ...utils.latency import LatencyResults, calculate_median_latency_seconds from ...utils.model import ( ModelMetadata, - get_model_metadata, get_model_config_json, resolve_tokenizer_dir, ) @@ -415,12 +413,14 @@ async def _handle_prepare_topology( logger.info("Discovered %d shards: %s", len(shards), list(shards.keys())) - shard_profiles, thunderbolt_conns = await self._collect_shard_profiles( + thunderbolt_conns = discover_all_thunderbolt_connections(shards) + shard_profiles = await self._collect_shard_profiles( shards, req.model, embedding_size, req.max_batch_exp, batch_sizes, + thunderbolt_conns ) optimized_device_name_order = optimize_device_ordering( shard_profiles, thunderbolt_conns @@ -971,7 +971,8 @@ async def _collect_shard_profiles( embedding_size: int, max_batch_exp: int, batch_sizes: List[int], - ) -> Tuple[Dict[str, DeviceProfile], Dict[str, Dict[str, ThunderboltConnection]]]: + thunderbolt_conns: dict[str, dict[str, ThunderboltConnection]] = {}, + ) -> Dict[str, DeviceProfile]: """Collect profile data from all shards. Args: @@ -979,6 +980,8 @@ async def _collect_shard_profiles( repo_id: Model repository ID embedding_size: Model embedding size max_batch_exp: Maximum batch size exponent + batch_sizes: List of batch sizes to profile + all_thunderbolts: Pre-discovered thunderbolt connections per shard Returns: Tuple of (collected shard profiles, thunderbolt connections) @@ -987,8 +990,6 @@ async def _collect_shard_profiles( base_size = embedding_size * 4 # 4*e due to paper payload_sizes = [base_size * batch_size for batch_size in batch_sizes] - this_device = self.discovery.get_own_properties() - logger.info( "Model %s: embedding_size=%d, payload_sizes=%s", repo_id, @@ -996,74 +997,68 @@ async def _collect_shard_profiles( payload_sizes, ) - # Find Thunderbolt connections - all_thunderbolts = discover_all_thunderbolt_connections(shards) - async with httpx.AsyncClient() as client: - # Step 1: Health check all shards in parallel + # health-check all shards in parallel logger.info("Starting health checks for all shards...") health_tasks: list[asyncio._CoroutineLike[httpx.Response]] = [] shard_list: list[tuple[str, DnetDeviceProperties]] = [] for shard_name, shard_props in shards.items(): if shard_props.is_manager: - logger.warning( - "Skipping manager node %s in profile collection", shard_name - ) + logger.warning("Skipping manager node %s in profile collection", shard_name) continue shard_list.append((shard_name, shard_props)) - server_port, server_ip = shard_props.server_port, shard_props.local_ip - health_url = f"http://{server_ip}:{server_port}/health" - health_tasks.append(client.get(health_url, timeout=5.0)) + health_tasks.append(client.get(f"http://{shard_props.local_ip}:{shard_props.server_port}/health", timeout=5.0)) health_results = await asyncio.gather(*health_tasks, return_exceptions=True) - # Filter healthy shards - healthy_shards = [] + # filter healthy shards + healthy_shards: list[tuple[str, DnetDeviceProperties]] = [] for (shard_name, shard_props), health_result in zip(shard_list, health_results): if isinstance(health_result, Exception): logger.warning("Health check failed for %s: %s", shard_name, health_result) continue - if health_result.status_code == 200: + elif isinstance(health_result, httpx.Response): + if health_result.status_code == 200: healthy_shards.append((shard_name, shard_props)) logger.info("Health check passed for %s", shard_name) + else: + logger.warning("Health check failed for %s: status %s", shard_name, health_result.status_code) else: - logger.warning("Health check failed for %s: status %s", shard_name, health_result.status_code) + pass + logger.info("Healthy shards: %d/%d", len(healthy_shards), len(shard_list)) - if not healthy_shards: logger.error("No healthy shards found!") - return {}, all_thunderbolts + return {} - # Step 2: Measure latencies on all healthy shards in parallel + # measure latencies on all healthy shards in parallel) logger.info("Measuring latencies for all healthy shards...") - latency_tasks = [] + latency_tasks: list[asyncio._CoroutineLike[httpx.Response]] = [] for shard_name, shard_props in healthy_shards: server_port, server_ip = shard_props.server_port, shard_props.local_ip latency_url = f"http://{server_ip}:{server_port}/measure_latency" latency_request = MeasureLatencyRequest( devices=shards, - thunderbolts=all_thunderbolts.get(shard_name, {}), + thunderbolts=thunderbolt_conns.get(shard_name, {}), payload_sizes=payload_sizes, ) latency_tasks.append( client.post(latency_url, json=latency_request.model_dump(), timeout=1000.0) ) - latency_results = await asyncio.gather(*latency_tasks, return_exceptions=True) - # Store latency data for each shard - shard_latencies = {} + # store latency data for each shard + shard_latencies: dict[str, LatencyResults] = {} final_healthy_shards = [] - for (shard_name, shard_props), latency_result in zip(healthy_shards, latency_results): if isinstance(latency_result, Exception): logger.warning( "Latency measurement failed for %s: %s", shard_name, latency_result ) continue - else: + elif isinstance(latency_result, httpx.Response): if latency_result.status_code == 200: latency_data = MeasureLatencyResponse.model_validate(latency_result.json()) shard_latencies[shard_name] = latency_data.latency @@ -1075,34 +1070,33 @@ async def _collect_shard_profiles( shard_name, latency_result.status_code, ) + else: + pass # unexpected case logger.info("Latencies collected from %d shards", len(shard_latencies)) if not final_healthy_shards: logger.error("No shards with successful latency measurements!") - return {}, all_thunderbolts + return {} - # Step 3: Group healthy shards by local_ip (same device) + # group healthy shards by local_ip (same device), so that we can profile per-device shards_by_device: Dict[str, List[Tuple[str, DnetDeviceProperties]]] = {} for shard_name, shard_props in final_healthy_shards: local_ip = shard_props.local_ip if local_ip not in shards_by_device: shards_by_device[local_ip] = [] shards_by_device[local_ip].append((shard_name, shard_props)) + logger.info("Grouped %d shards into %d devices", len(final_healthy_shards), len(shards_by_device)) - logger.info("Grouped shards into %d devices", len(shards_by_device)) - - # Step 4: Profile devices (parallel per device, sequential per shard within device) + # profile devices (parallel per device, sequential per shard within device) async def profile_device_shards( - device_ip: str, device_shards: List[Tuple[str, DnetDeviceProperties]] + device_shards: List[Tuple[str, DnetDeviceProperties]] ) -> List[Tuple[str, DeviceProfile]]: - """Profile all shards on a single device sequentially.""" profiles = [] for shard_name, shard_props in device_shards: try: - server_port, server_ip = shard_props.server_port, shard_props.local_ip - profile_url = f"http://{server_ip}:{server_port}/profile" + profile_url = f"http://{shard_props.local_ip}:{shard_props.server_port}/profile" logger.info( "Calling /profile endpoint for shard %s at %s", @@ -1114,7 +1108,7 @@ async def profile_device_shards( profile_url, json=ShardProfileRequest( repo_id=repo_id, - thunderbolts=all_thunderbolts.get(shard_name, {}), + thunderbolts=thunderbolt_conns.get(shard_name, {}), payload_sizes=payload_sizes, max_batch_exp=max_batch_exp, devices=shards, @@ -1125,11 +1119,6 @@ async def profile_device_shards( if response.status_code == 200: profile_data = ShardProfileResponse.model_validate(response.json()) profile = load_device_profile_from_dict(profile_data.profile) - - # Mark head device (same local IP as API) - if shard_props.local_ip == this_device.local_ip: - profile.is_head = True - profiles.append((shard_name, profile)) logger.info("Successfully collected profile from %s", shard_name) else: @@ -1144,39 +1133,38 @@ async def profile_device_shards( return profiles - # Run profiling for all devices in parallel + # run profiling for all devices in parallel device_tasks = [ - profile_device_shards(device_ip, device_shards) - for device_ip, device_shards in shards_by_device.items() + profile_device_shards(device_shards) + for device_shards in shards_by_device.values() ] device_results = await asyncio.gather(*device_tasks, return_exceptions=True) - # Step 5: Merge latency data into device profiles + # merge latency data into device profiles shard_profiles: Dict[str, DeviceProfile] = {} - for device_result in device_results: if isinstance(device_result, Exception): logger.error("Device profiling failed: %s", device_result) continue - - for shard_name, profile in device_result: - # Set t_comm using median latency - if shard_name in shard_latencies: - median_latency = calculate_median_latency_seconds(shard_latencies[shard_name]) - if median_latency is not None: - profile.t_comm = float(median_latency) - logger.info( - f"Set t_comm for {shard_name} to median latency: {profile.t_comm:.6f}s" - ) - else: - logger.warning( - f"No valid latency measurements for {shard_name}, keeping default t_comm" - ) - - shard_profiles[shard_name] = profile + elif isinstance(device_result, list): + for shard_name, profile in device_result: + # set t_comm using median latency + if shard_name in shard_latencies: + median_latency = calculate_median_latency_seconds(shard_latencies[shard_name]) + if median_latency is not None: + profile.t_comm = float(median_latency) + logger.info( + f"Set t_comm for {shard_name} to median latency: {profile.t_comm:.6f}s" + ) + else: + logger.warning( + f"No valid latency measurements for {shard_name}, keeping default t_comm" + ) + + shard_profiles[shard_name] = profile logger.info("Collected profiles from %d shards", len(shard_profiles)) - return shard_profiles, all_thunderbolts + return shard_profiles # FIXME: move this to elsewhere async def _run_solver( @@ -1203,6 +1191,10 @@ async def _run_solver( if not sorted_shard_profiles: raise ValueError("No valid shard profiles found") + # mark the first device as head, others as non-head + for i, profile in enumerate(sorted_shard_profiles): + profile.is_head = (i == 0) + logger.info("Running solver with %d shard profiles", len(sorted_shard_profiles)) solution = halda_solve( diff --git a/src/dnet/ring/api/utils.py b/src/dnet/ring/api/utils.py index ad0fee36..d59387c1 100644 --- a/src/dnet/ring/api/utils.py +++ b/src/dnet/ring/api/utils.py @@ -263,22 +263,10 @@ def optimize_device_ordering( thunderbolt_conns: Thunderbolt connections mapping (device -> {neighbor -> connection_info}) Returns: - Optimized list of device names with head devices first and Thunderbolt neighbors adjacent + Optimized list of device names """ device_names = list(shard_profiles.keys()) - # Find all head devices (multiple shards can run on same machine as API) - head_devices = [] - for device_name, profile_data in shard_profiles.items(): - if profile_data.is_head: - head_devices.append(device_name) - - if not head_devices: - logger.warning("No head device found in profiles, using first device") - head_devices = [device_names[0]] if device_names else [] - else: - logger.info("Found %d head device(s): %s", len(head_devices), head_devices) - # FIXME: shards on the same machine should be adjacent too! # Build adjacency graph of Thunderbolt connections @@ -292,8 +280,8 @@ def optimize_device_ordering( tb_graph[neighbor_name].add(device_name) # Greedy ordering: Start with all head devices, then pick neighbors with most TB connections - ordered = head_devices.copy() - remaining = set(device_names) - set(head_devices) + ordered: list[str] = [] + remaining: set[str] = set(device_names) while remaining: best_candidate = None diff --git a/src/dnet/ring/api/utils_test.py b/src/dnet/ring/api/utils_test.py index b250e29a..1c54665e 100644 --- a/src/dnet/ring/api/utils_test.py +++ b/src/dnet/ring/api/utils_test.py @@ -54,20 +54,15 @@ def test_single_round_postprocess_complex(): def test_optimize_device_ordering(): - from pydantic import BaseModel - - # fake type for the sake of testing - class _DeviceProfileIsHead(BaseModel): - is_head: bool device_profiles: dict[str, DeviceProfile] = { - "dev1": _DeviceProfileIsHead(is_head=False), # type: ignore - "dev2": _DeviceProfileIsHead(is_head=False), # type: ignore - "dev3": _DeviceProfileIsHead(is_head=True), # type: ignore - "dev4": _DeviceProfileIsHead(is_head=False), # type: ignore - "dev5": _DeviceProfileIsHead(is_head=False), # type: ignore - "dev6": _DeviceProfileIsHead(is_head=False), # type: ignore - "dev7": _DeviceProfileIsHead(is_head=False), # type: ignore + "dev1": {}, # type: ignore + "dev2": {}, # type: ignore + "dev3": {}, # type: ignore + "dev4": {}, # type: ignore + "dev5": {}, # type: ignore + "dev6": {}, # type: ignore + "dev7": {}, # type: ignore } thunderbolts: dict[str, dict[str, ThunderboltConnection]] = { "dev3": {"dev1": 1}, # type: ignore @@ -78,10 +73,6 @@ class _DeviceProfileIsHead(BaseModel): optimized_order = optimize_device_ordering(device_profiles, thunderbolts) - # the ordering is not deterministic, but the connectino should be as follows: - # head must be the first - assert optimized_order[0] == "dev3" - # dev1 and dev3 must be next to each other (due to thunderbolt) dev1_index = optimized_order.index("dev1") dev3_index = optimized_order.index("dev3") From 302f4354dff2d100a0e679be08b564230a061cdd Mon Sep 17 00:00:00 2001 From: andthattoo Date: Mon, 10 Nov 2025 13:39:01 +0300 Subject: [PATCH 085/172] dnet-p2p commit update --- lib/dnet-p2p | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/dnet-p2p b/lib/dnet-p2p index 37b63fe1..9f81988c 160000 --- a/lib/dnet-p2p +++ b/lib/dnet-p2p @@ -1 +1 @@ -Subproject commit 37b63fe175f7518cfd4e06b56e7fe2437feb47ab +Subproject commit 9f81988c822ee4123595115e87ca8579ed3f5f7d From fdbedc94970cb7df5c3bc7de9d925811615c8844 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Mon, 10 Nov 2025 16:14:17 +0300 Subject: [PATCH 086/172] patchwork fix --- src/dnet/ring/api/node.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 56031bf3..e819235f 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -386,6 +386,10 @@ async def _handle_prepare_topology( # Load only config.json to avoid weight downloads on API node cfg = get_model_config_json(req.model) + + if str(cfg.get("model_type", "")).strip().lower() == "gpt_oss" and req.kv_bits != "fp16": + raise ValueError("GPT-OSS models only support kv_bits='fp16'") + num_layers_raw = cfg.get("num_hidden_layers") if not isinstance(num_layers_raw, int): raise ValueError( From 7ad24686102c044866996c08bc9369de2a29ebb6 Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 15 Oct 2025 18:54:49 -0700 Subject: [PATCH 087/172] repl sketch --- src/repl.py | 295 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 295 insertions(+) create mode 100644 src/repl.py diff --git a/src/repl.py b/src/repl.py new file mode 100644 index 00000000..87b1fe77 --- /dev/null +++ b/src/repl.py @@ -0,0 +1,295 @@ + +import os +import sys +import cmd +import argparse +import subprocess +from dataclasses import dataclass + +from src.ring.api import run as run_api_node +from src.ring.shard import run as run_shard_node +from src.util import ( + ModelMetadata, + NodeAddress, + logger, + get_model_metadata, + load_api_layer_weights, + get_safetensor_details, + create_generate_step_for_ring_with_grpc, +) + +# Handle restricted repos +from importlib import import_module +import huggingface_hub as hb +from huggingface_hub import snapshot_download, try_to_load_from_cache +try: + hf_errors = import_module("huggingface_hub.errors") +except ModuleNotFoundError: + hf_errors = import_module("huggingface_hub.utils") +GatedRepoError = getattr(hf_errors, "GatedRepoError", Exception) +HfHubHTTPError = getattr(hf_errors, "HfHubHTTPError", Exception) + +from src.ring.api_node import RingApiNode + +def dprint(msg): + sys.stdout.write(msg) + sys.stdout.flush() + + +@dataclass +class REPLState: + model: str = "NULL" + model_info: ModelMetadata = None, + num_local_nodes: int = 1 + running_port = 50501 + running_httpport = 8091 + api_addr_host: str = "10.0.0.2" # TODO: Don't hardcode + api_addr_port: int = 0 + grpc_listen_port:int = 0 + window_size = 2 # Number of layers per node per visit (also number resident in cache) + +class REPL(cmd.Cmd): + + PS1 = "dnet > " + WELCOME = "\nDNET Distributed Inference Engine, v0.1\nExperimental software. Enter '.help' for usage hints.\n\n" + def __init__(self, model="NULL", nodes=1): + super().__init__() + self.state = REPLState() + self.state.model = model + + self.state.api_addr_port = self.state.running_port + self.state.grpc_listening_port = self.state.running_port + 1 + self.state.running_port += 2 + self.discovery = None + + # TODO: Maybe have a 'start search' 'stop search' cmds to manage discovery + + self.api = None + #self.config_api_node() + #self.start_api_discovery() + + assert nodes >= 1 and nodes < 10, "Invalid number of local nodes. Must be 0 < num < 10." + self.state.num_local_nodes = nodes + + def loop(self): + self.greeting() + while True: + + #if self.state.model == "NULL": + # self.prompt_model() + # continue + + dprint(self.PS1) + cmd = sys.stdin.readline().strip() + + if cmd == "": + self.print_state() + if cmd in [".exit", "exit", "quit", "q"]: + self.handle_terminate_signal() + if cmd in [".help", "help", "h"]: + self.print_help() + if cmd.startswith((".model", "model", "m")): + cmd.split(" ") + path = self._handle_model_pull(cmd[1]) + if path: + self.state.model = path + + def greeting(self): + sys.stdout.write(self.WELCOME) + + def print_help(self): + def _print_hf(cmd, desc, examples=[""]): + pcmd = " " + cmd.ljust(30, '.') + dprint(f"{pcmd} {desc}\n") + for e in examples: + pex = e.rjust(len(e)+35)+"\n" if e != "" else "" + dprint(f"{pex}") + + dprint("Command Options:\n") + _print_hf("nodes [VALUE]", "Set the number of local worker nodes") + _print_hf("model [REPO]", "Set the target model. [REPO] must be a valid repository", + ["Examples > model meta-llama/Meta-Llama-3-8B"]) + _print_hf("limit [RESOURCE] [VALUE]", "Set a higher limit for a system resource.", + ["Examples > limit memory 12000 (MB)", + " > limit CPU_CORE_COUNT 4", + " > limit GPU_SM 128"]) + _print_hf("log [LEVEL]", "Set the logging level.") + dprint("\n Building a topology:\n") + _print_hf("search [ON/OFF]", "Toggle mDNS worker node search across the local network.") + _print_hf("topo [AUTO/SETUP]", "Toggle between automatic and manual topology creation.") + _print_hf("topo add [NODE]", "Add [NODE] to the topology.") + _print_hf("topo remove [NODE]", "Add [NODE] to the topology.") + dprint("\n Building a schedule:\n") + _print_hf("sched create", "Automatic search for best schedule given the active topology and the loaded model.") + _print_hf("sched assign [LAYER] [NODE]", "Assign the layer with index [LAYER] to [NODE].", + ["Example > sched assign 10 benny_234"]) + _print_hf("schedule assign [START-END] [NODE]", "Assign the layer range between [START] and [END] to [NODE].", + ["Example > sched assign 0-12 benny_234"]) + dprint("\n Benchmarking and profiling:\n") + _print_hf("profile [REPO]", "Estimate the total FLOPS of the model from [REPO]") + _print_hf("bench [REPO]", "Benchmark the system using the model from [REPO]") + _print_hf("bench [KERNEL]", "Behcnmark the system using base kernel [KERNEL]") + _print_hf("bench [NODE]", "Behcnmark the network latency between the current system and [NODE]") + _print_hf("bench [NODE0] [NODE1]", "Behcnmark the network latency between [NODE0] and [NODE11") + _print_hf("bench ", "Behcnmark the system using base library kernels") + dprint("\n") + + def print_state(self): + dprint("Network state:\n") + dprint(f"{("Model".ljust(20)): >10}: {self.state.model}\n") + dprint(f"{("Local workers".ljust(20)): >10}: {self.state.num_local_nodes}\n") + + + # ===== Handle Model input and pull from server + + def prompt_model(self): + while True: + dprint("Target model > ") + model = sys.stdin.readline().strip() + try: + path = self._handle_model_pull(model) + return path + #self.model_info = ModelMetadata() + except Exception as e: + dprint(f"Unable to load model {model}. Target needs to be a valid HF repository. Try again:{e}\n") + + # Read HF access token + def _resolve_hf_token(self): + dprint("Ener the HuggingFace access token > ") + tok = sys.stdin.readline().strip() + return tok + + # Require a HF access token for restricted repositories + # Ask user for HF access token until they have a valid one + def _handle_model_pull(self, repo_path): + try: + path = try_to_load_from_cache(repo_path) + if path is None: + dprint(f"Model {repo_path} not found in local cache\n") + path = get_model_path(repo_path) + self.state.model = repo_path + return path + except hb.errors.HTTPError: + dprint(f"Repository {repo_path} not found in Hugging Face registry.") + return Null + except GatedRepoError as e: + dprint("Restricted model.\n") + tok = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") + while True: + tok = self._resolve_hf_token() + print(tok) + try: + ret = snapshot_download(repo_id=repo_path, token=tok) + return ret + except GatedRepoError as e: + print(e) + continue + except Exception as e: + raise RuntimeError(f"Unknown error during HF snapshot_download") + except Exception as e: + raise RuntimeError(f"Unable to pull model {repo_path} locally") + + def _parse_model_metadata(self, model_path): + if isinstance(model_path, tuple): + model_path = model_path[0] if model_path else None + if model_path is None: + raise ValueError("Could not resolve model path {model_path}") + path = Path(model_path) if not isinstance(model_path, Path) else model_path + + with open(path / "config.json", "r") as f: + config = json.load(f) + + weight_files = glob.glob(oath / "*.safetensors") + weight_info = Dict[int, Dict[str, Any]] = defaultdict(dict) + embed_tokens, lm_head, norm = {}, {}, {} + for weight in weight_files: + details = get_safetensor_details(weight) + for key, val in details.items(): + if m := EMBED_TOKENS_RE.match(key): + embed_tokens[m.group(1)] = val + elif m := LM_HEAD_RE.match(key): + lm_head[m.group(1)] = val + elif m := NORM_RE.match(key): + norm[m.group(1)] = val + elif m := LAYERS_RE.match(key): + layer_idx, suffix = m.groups() + weight_info[int(layer_idx)][suffix] = val + else: + raise RuntimeError(f"Unexpected key {key}") + num_layers = max(weight_info.keys()) + 1 + if not (set(weight_info.keys()) == set(range(num_layers))): + raise RuntimeError("Inconsistent weights") + return ModelMetadata(path, weight_info, embed_tokens, lm_head, norm, config) + + # ===== Handle termination signals + + def handle_terminate_signal(self): + # Handle worker/api shutdown + dprint("No workers to shut down. Terminating.\n") + sys.exit() + + # ===== Handle Shard worker servers + + # TODO: Redirect output logs to different files + def handle_start_worker(self): + bin = os.path.join(".venv", "bin", "piped_mlx_ring_shard") + cmd = ["uv", "run", bin] + cmd.append(f" -m {self.state.model}") + cmd.append(f" -p {self.state.running_port}") + cmd.append(f" -httpport {self.state.running_httpport}") + cmd.append(f" -l [{0}]") + cmd.append(f" --prefetch-window {2}") + proc = subprocess.Popen(cmd) + + self.state.running_port += 1 # increment the running port + self.state.running_httpport += 1 + + # ===== Handle API server + + def handle_device_discovery(self): + from socket import gethostname + from secrets import token_hex + + hostname = gethostname() + instance = f"api-{token_hex(4)}-{hostname}" + lib = DnetP2P("lib/dnet-p2p/lib") + + """ + self.discovery = lib.create_instance( + instance, hostname, + self.state.p2p_addr.host, self.state.p2p_addr_port, + self.state.grpc_listen_port, is_manager=True + ) + self.discovery.start() + """ + + def config_api_node(self): + api_address = NodeAddress(self.state.api_addr_host, self.state.api_addr_port) + self.api = RingApiNode(api_address, shard_address.format(), model_metadata) + + def start_api_discovery(self): + if self.api: + self.api._start_discovery() + + # Calls dsolver and optimizes topology + async def build_topology(self): + if self.api: + topo = await self.api.topology() + return topo + + # ===== Handle shutdown + + def handle_shutdown(self): + os.system("pkill -9 -f piped_mlx") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", "-m", type=str, help="HF Repository of target model") + parser.add_argument("--local-nodes", "-n", type=int, help="Number of local worker nodes") + args = parser.parse_args() + + #workers = args.workers + #model = args.model + + repl = REPL() + repl.loop() From 70bb0c10fff61da86d045a2df1836e0ac48b4fac Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 16 Oct 2025 20:43:57 -0700 Subject: [PATCH 088/172] manage api server, discover and print nodes table --- src/repl.py | 384 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 299 insertions(+), 85 deletions(-) diff --git a/src/repl.py b/src/repl.py index 87b1fe77..0b95a19f 100644 --- a/src/repl.py +++ b/src/repl.py @@ -5,17 +5,22 @@ import argparse import subprocess from dataclasses import dataclass - -from src.ring.api import run as run_api_node -from src.ring.shard import run as run_shard_node -from src.util import ( - ModelMetadata, - NodeAddress, - logger, - get_model_metadata, +from typing import Optional, List, Any + +import asyncio +import inspect +import threading +import concurrent.futures + +from dnet.ring.api import RingApiNode +from dnet.ring.shard import RingShardNode +from dnet.utils.network import NodeAddress +from dnet.utils.logger import logger +from dnet.utils.model import ( + ModelMetadata, + get_model_metadata, load_api_layer_weights, get_safetensor_details, - create_generate_step_for_ring_with_grpc, ) # Handle restricted repos @@ -29,13 +34,11 @@ GatedRepoError = getattr(hf_errors, "GatedRepoError", Exception) HfHubHTTPError = getattr(hf_errors, "HfHubHTTPError", Exception) -from src.ring.api_node import RingApiNode def dprint(msg): sys.stdout.write(msg) sys.stdout.flush() - @dataclass class REPLState: model: str = "NULL" @@ -43,96 +46,175 @@ class REPLState: num_local_nodes: int = 1 running_port = 50501 running_httpport = 8091 - api_addr_host: str = "10.0.0.2" # TODO: Don't hardcode - api_addr_port: int = 0 - grpc_listen_port:int = 0 + api_http_port: int = 8080 + api_grpc_port: int = 50500 window_size = 2 # Number of layers per node per visit (also number resident in cache) -class REPL(cmd.Cmd): +class REPL(cmd.Cmd): PS1 = "dnet > " - WELCOME = "\nDNET Distributed Inference Engine, v0.1\nExperimental software. Enter '.help' for usage hints.\n\n" + WELCOME = "\nDNET Distributed Inference Engine, v0.1\nExperimental software. Type 'help' for usage hints.\n\n" + def __init__(self, model="NULL", nodes=1): + assert nodes >= 1 and nodes < 10, "Invalid number of local nodes. Must be 0 < num < 10." + super().__init__() self.state = REPLState() self.state.model = model - - self.state.api_addr_port = self.state.running_port - self.state.grpc_listening_port = self.state.running_port + 1 self.state.running_port += 2 - self.discovery = None - - # TODO: Maybe have a 'start search' 'stop search' cmds to manage discovery - - self.api = None - #self.config_api_node() - #self.start_api_discovery() - - assert nodes >= 1 and nodes < 10, "Invalid number of local nodes. Must be 0 < num < 10." self.state.num_local_nodes = nodes - def loop(self): - self.greeting() - while True: + self._node: Optional[RingApiNode] = None + self._api_thread: Optional[threading.Thread] = None + self._api_ready = threading.Event() + self._api_running = threading.Event() + self._api_loop: Optional[asyncio.AbstractEventLoop] = None + self._api_shutdown_e: Optional[asyncio.Event] = None + self._api_exc: Optional[BaseException] = None - #if self.state.model == "NULL": - # self.prompt_model() - # continue + self._api_searching = threading.Event() # Track mDNS searching + def loop(self): # Main tty loop + sys.stdout.write(self.WELCOME) + while True: dprint(self.PS1) cmd = sys.stdin.readline().strip() if cmd == "": self.print_state() - if cmd in [".exit", "exit", "quit", "q"]: + elif cmd in [".exit", "exit", "quit", "q"]: self.handle_terminate_signal() - if cmd in [".help", "help", "h"]: + elif cmd in [".help", "help", "h"]: self.print_help() - if cmd.startswith((".model", "model", "m")): + + elif cmd.startswith(("api", ".api")): + self.do_api(cmd.split(" ")) + continue + elif cmd.startswith("search"): + self.do_search(cmd.split(" ")) + continue + elif cmd.startswith("nodes"): + self.print_mdns_nodes() + continue + elif cmd.startswith(("topo", ".topo")): + self.do_topo(cmd.split(" ")) + continue + elif cmd.startswith((".model", "model", "m")): cmd.split(" ") path = self._handle_model_pull(cmd[1]) if path: self.state.model = path - - def greeting(self): - sys.stdout.write(self.WELCOME) - + + def do_api(self, cmd: List[str]) -> None: + if len(cmd) < 2: + dprint("Invalid API command. Type 'help' for a list of valid commands.\n") + return + if cmd[1] in ["start", "run"]: + http_port, grpc_port = None, None + try: + http_port = cmd[2]; + grpc_port = cmd[3] + except: + pass + self.start_api( + http_port or self.state.api_http_port, + grpc_port or self.state.api_grpc_port + ) + elif cmd[1] == "stop": + self.stop_api() + elif cmd[1] == "status": + dprint("Running\n" if self._api_running else "Stopped.\n") + elif cmd[1] == "log": + dprint("Log print is not yet supported.\n") + else: + dprint("Invalid API command. Type 'help' for a list of valid commands.\n") + return + + def do_search(self, cmd: List[str]) -> None: + if len(cmd) != 2: + dprint("mDNS search is " + ("ON\n\n" if self._api_searching else "OFF\n\n")) + return + if cmd[1] == "on": + if self._api_searching: + return + if not self._api_ready and self._api_running: + dprint("Starting API Server thread.\n") + self.start_api() + self.api_call("_start_discovery", timeout=10) + self._api_searching.set() + dprint("Starting mDNS search for worker nodes.\n") + elif cmd[1] == "off": + dprint("Stop discovery not yet implemented in the API node.\n") + pass + else: + dprint("Invalid topology command. Start searchign with 'search on'.\n") + return + + def do_topo(self, cmd: List[str]) -> None: + if len(cmd) < 2: + dprint("Invalid topology command. Type 'help' for a list of valid commands.\n") + return + if cmd[1] == "search": + pass + elif cmd[1] == "auto": + pass + elif cmd[1] == "setup": + pass + elif cmd[1] == "add": + pass + elif cmd[1] in ["remove", "rm"]: + pass + return + + # TODO: standardize ANSI escape codes for easy use def print_help(self): def _print_hf(cmd, desc, examples=[""]): pcmd = " " + cmd.ljust(30, '.') - dprint(f"{pcmd} {desc}\n") + sys.stdout.write(f"{pcmd} {desc}\n") for e in examples: pex = e.rjust(len(e)+35)+"\n" if e != "" else "" - dprint(f"{pex}") + sys.stdout.write(f"{pex}") - dprint("Command Options:\n") - _print_hf("nodes [VALUE]", "Set the number of local worker nodes") + sys.stdout.write("\033[1m\nAvailable commands:\n\033[0m") + dprint("\033[1m\n Common:\n\033[0m") _print_hf("model [REPO]", "Set the target model. [REPO] must be a valid repository", ["Examples > model meta-llama/Meta-Llama-3-8B"]) - _print_hf("limit [RESOURCE] [VALUE]", "Set a higher limit for a system resource.", - ["Examples > limit memory 12000 (MB)", - " > limit CPU_CORE_COUNT 4", - " > limit GPU_SM 128"]) + _print_hf("nodes list ", "List mDNS discovered nodes.") _print_hf("log [LEVEL]", "Set the logging level.") - dprint("\n Building a topology:\n") - _print_hf("search [ON/OFF]", "Toggle mDNS worker node search across the local network.") + dprint("\033[1m\n API Server Control:\n\033[0m") + _print_hf("api start [http_port=8080] [grpc_port=50500]", "Start the API server in a separate thread. Use provided ports if given.") + _print_hf("api stop ", "Signal clean shutdown of the API server.") + _print_hf("api status ", "Prints the status of the API server.") + _print_hf("api log ", "Print latest logs to the current terminal.") + dprint("\033[1m\n Building a topology:\n\033[0m") + _print_hf("search ", "Returns the current state of mDNS search.") + _print_hf("search [on/off] ", "Toggle mDNS search across the local network.") + _print_hf("nodes list ", "List all nodes in the current topology (including local ones).") + _print_hf("nodes all ", "List all nodes (including local ones).") + _print_hf("nodes ", "List mDNS discovered nodes.") _print_hf("topo [AUTO/SETUP]", "Toggle between automatic and manual topology creation.") _print_hf("topo add [NODE]", "Add [NODE] to the topology.") _print_hf("topo remove [NODE]", "Add [NODE] to the topology.") - dprint("\n Building a schedule:\n") + sys.stdout.write("\033[1m\n Building a schedule:\n\033[0m") _print_hf("sched create", "Automatic search for best schedule given the active topology and the loaded model.") _print_hf("sched assign [LAYER] [NODE]", "Assign the layer with index [LAYER] to [NODE].", ["Example > sched assign 10 benny_234"]) _print_hf("schedule assign [START-END] [NODE]", "Assign the layer range between [START] and [END] to [NODE].", ["Example > sched assign 0-12 benny_234"]) - dprint("\n Benchmarking and profiling:\n") + sys.stdout.write("\033[1m\n Benchmarking and profiling:\n\033[0m") _print_hf("profile [REPO]", "Estimate the total FLOPS of the model from [REPO]") _print_hf("bench [REPO]", "Benchmark the system using the model from [REPO]") _print_hf("bench [KERNEL]", "Behcnmark the system using base kernel [KERNEL]") _print_hf("bench [NODE]", "Behcnmark the network latency between the current system and [NODE]") _print_hf("bench [NODE0] [NODE1]", "Behcnmark the network latency between [NODE0] and [NODE11") _print_hf("bench ", "Behcnmark the system using base library kernels") - dprint("\n") + sys.stdout.write("\033[1m\n System control:\n\033[0m") + _print_hf("limit [RESOURCE] [VALUE]", "Set a higher limit for a system resource.", + ["Examples > limit memory 12000 (MB)", + " > limit CPU_CORE_COUNT 4", + " > limit GPU_SM 128"]) + sys.stdout.write("\n") + sys.stdout.flush() def print_state(self): dprint("Network state:\n") @@ -225,7 +307,10 @@ def _parse_model_metadata(self, model_path): def handle_terminate_signal(self): # Handle worker/api shutdown - dprint("No workers to shut down. Terminating.\n") + if self._api_running: + self.stop_api() + else: + dprint("No workers to shut down. Terminating.\n") sys.exit() # ===== Handle Shard worker servers @@ -246,36 +331,165 @@ def handle_start_worker(self): # ===== Handle API server - def handle_device_discovery(self): - from socket import gethostname - from secrets import token_hex - - hostname = gethostname() - instance = f"api-{token_hex(4)}-{hostname}" - lib = DnetP2P("lib/dnet-p2p/lib") - - """ - self.discovery = lib.create_instance( - instance, hostname, - self.state.p2p_addr.host, self.state.p2p_addr_port, - self.state.grpc_listen_port, is_manager=True - ) - self.discovery.start() - """ - - def config_api_node(self): - api_address = NodeAddress(self.state.api_addr_host, self.state.api_addr_port) - self.api = RingApiNode(api_address, shard_address.format(), model_metadata) - - def start_api_discovery(self): - if self.api: - self.api._start_discovery() - - # Calls dsolver and optimizes topology - async def build_topology(self): - if self.api: - topo = await self.api.topology() - return topo + async def _api_main(self) -> None: # main thread loop + self._api_loop = asyncio.get_running_loop() + self._api_shutdown_e = asyncio.Event() + self._node = RingApiNode( + http_port=self.state.api_http_port, + grpc_port=self.state.api_grpc_port + ) + + try: + await self._node.start(shutdown_trigger=self._api_shutdown_e.wait) + self._api_searching.set() + self._api_running.set() + self._api_ready.set() + await self._api_shutdown_e.wait() + except Exception as e: + self._api_exc = e + self._api_running.set() + self._api_ready.set() + finally: + try: + await self._node.shutdown() + except Exception: + pass + self._api_running.clear() + + def _api_running_loop(self): + try: + asyncio.run(self._api_main()) + except BaseException as e: + self._api_exc = e + self._api_ready.set() + self._api_running.clear() + + def start_api(self, http_port: int=8080, grpc_port: int=50500, timeout=10): + if self._api_thread and self._api_thread.is_alive(): return + self._api_exc = None + self._api_ready.clear() + self._api_running.clear() + self._api_thread = threading.Thread(target=self._api_running_loop, name="api_server", daemon=True) + self._api_thread.start() + if not self._api_ready.wait(timeout): + raise RuntimeError("API Server Timeout.") + if self._api_exc is not None: + raise RuntimeError(f"API Server failed to start: {e}") + + def stop_api(self, timeout: float = 5.0) -> None: + if not self._api_thread: return + if self._api_loop and self._api_shutdown_e: + self._api_loop.call_soon_threadsafe(self._api_shutdown_e.set) + if self._api_loop and self._node: + f = asyncio.run_coroutine_threadsafe(self._node.shutdown(), self._api_loop) + try: + f.result(timeout=timeout) + except Exception: + pass + self._api_thread.join(timeout=timeout) + self._api_thread = None + self._api_running.clear() + self._api_ready.clear() + + def api_call( # Call an API function from the REPL thread + self, + method: str, + *args: Any, + timeout: float=30.0, + **kwargs: Any + ) -> Any: + if not self._api_loop or not self._node: + raise RuntimeError("API Thread not set up correctly.") + + target = getattr(self._node, method, None) + if target is None: + raise AttributeError(f"RingApiNode has no method {method}") + + # method is async + if inspect.iscoroutinefunction(target): + coroutine = target(*args, **kwargs) + f = asyncio.run_coroutine_threadsafe(coroutine, self._api_loop) + return f.result(timeout) + + # method is sync + f = concurrent.futures.Future() + + # TODO: this is a mess lol + def runner(): + try: + ret = target(*args, **kwargs) + if inspect.isawaitable(ret): + + async def _await_then_set(): + try: + val = await res + f.set_result(val) + except BaseException as e: + f.set_exception(e) + + asyncio.create_task(_await_then_set()) + else: + f.set_result(ret) + except BaseException as e: + f.set_exception(e) + self._api_loop.call_soon_threadsafe(runner) + return f.result(timeout) + + def _print_nodes_table(self, rows: List[Any]) -> None: + headers = ["name", "role", "addr", "http", "grpc", "status", "head"] + limits = {"name": 36, "addr": 15} + w = {h: max(len(h), min(limits.get(h, 8), max(len(str(r[h])) for r in rows))) for h in headers} + line = " ".join(h.ljust(w[h]) for h in headers) + sys.stdout.write("\n") + sys.stdout.write(line + "\n") + sys.stdout.write(" ".join("-"*w[h] for h in headers)) + sys.stdout.write("\n") + for r in rows: + name = str(r["name"]) + addr = str(r["addr"]) + if len(name) > w["name"]: name = name[:w["name"]-1] + "..." + if len(addr) > w["addr"]: addr = addr[:w["addr"]-1] + "..." + vals = { + "name": name, + "role": r["role"], + "addr": addr, + "http": r["http"], + "grpc": r["grpc"], + "status": "yes" if r["status"] else "no", + "head": "head" if r["head"] else "no", + } + sys.stdout.write(" ".join(str(vals[h]).ljust(w[h]) for h in headers)) + sys.stdout.write("\n\n") + sys.stdout.flush() + + + # Print a table of discovered nodes + def print_mdns_nodes(self) -> None: + try: + shards = self.api_call("_get_shards_from_discovery", timeout=10) + if not shards: + dprint("No worker nodes discovered. Is the API searching?\n") + return + + rows = [] + for name, props in shards.items(): + addr = getattr(props, "local_ip", getattr(props, "host", "")) + http = getattr(props, "server_port", 0) + grpc = getattr(props, "shard_port", 0) + busy = bool(getattr(props, "is_busy", False)) + head = bool("127.0.0.1" and "127.0.0.1" == addr) # TODO: FIX + rows.append({ + "name": name, + "role": "worker", + "addr": addr, + "http": http, + "grpc": grpc, + "status": busy, + "head": head, + }) + self._print_nodes_table(rows) + except Exception as e: + dprint(f"Could not list nodes: {e}\n") # ===== Handle shutdown From ba7d37e11b0ca05042a33caf1e017f2542870282 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sat, 18 Oct 2025 23:29:29 -0700 Subject: [PATCH 089/172] trace to file (+ trace frames) --- src/dnet/perf/__init__.py | 0 src/dnet/perf/trace.py | 250 +++++++++++++++++++++++++ src/dnet/ring/shard/compute.py | 330 ++++++++++++++++++--------------- src/dnet/ring/shard/node.py | 160 +++++++++------- 4 files changed, 518 insertions(+), 222 deletions(-) create mode 100644 src/dnet/perf/__init__.py create mode 100644 src/dnet/perf/trace.py diff --git a/src/dnet/perf/__init__.py b/src/dnet/perf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py new file mode 100644 index 00000000..ad7738d8 --- /dev/null +++ b/src/dnet/perf/trace.py @@ -0,0 +1,250 @@ +""" +Object-oriented tracing utilities for dnet. + +This module provides a Tracer class configured explicitly from the REPL (or code), +without relying on environment variables or module-level globals. It supports: + +- Boundary frames via tracer.frame(scope, name, attrs) +- Deep sys.setprofile callgraph via tracer.callgraph(...) +- Aggregated call stats via tracer.profile_block(...) + +All events are written as JSON Lines to a file (TraceConfig.file), suitable +for simple REPL visualization and easy sharing. +""" + +from __future__ import annotations + +import os +import sys +import time +import json +import threading +import contextvars +import cProfile +import pstats +import io +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, List +from contextlib import contextmanager + +from dnet.utils.logger import logger + +@dataclass +class TraceConfig: + file: str = "logs/dnet-trace.jsonl" + streaming: bool = True + include_prefixes: Tuple[str, ...] = ("src/dnet/",) + include_c_calls: bool = False + budget: int = 0 # 0 means unlimited + enabled: bool = True + record_pid_tid: bool = True + + +class Tracer: + def __init__(self, cfg: TraceConfig): + self.cfg = cfg + self._lock = threading.Lock() + self._fh: Optional[io.TextIOBase] = None + self._events: List[Dict[str, Any]] = [] + self._req_id: str = None + self._active = False + + def start(self, *, reset: bool = True) -> None: + self._active = bool(self.cfg.enabled) + if not self._active: + logger.info("Initialized tracer.") + return + if self.cfg.file: + d = os.path.dirname(self.cfg.file) or "." + os.makedirs(d, exist_ok=True) + if reset and os.path.exists(self.cfg.file): + try: + os.remove(self.cfg.file) + except Exception: + logger.warning(f"Unable to remove existing trace file {self.cfg.file}") + if self.cfg.streaming: + with self._lock: + self._fh = open(self.cfg.file, "a", encoding="utf-8") + logger.info(f"Streaming trace to {self.cfg.file}.") + + def stop(self, *, flush_events: bool = True) -> None: + if flush_events: + self.flush() + self._active = False + with self._lock: + if self._fh: + try: + self._fh.flush() + self._fh.close() + except Exception: + logger.warning(f"Unable to flush to file {self.cfg.file}") + self._fh = None + + def set_request_id(self, rid: Optional[str]) -> None: + self._req_id = rid + + def get_request_id(self) -> Optional[str]: + return self._req_id + + # Flush file to disk + def flush(self, *, clear: bool = False) -> None: + if not self._active: + return + with self._lock: + if not self.cfg.streaming and self._events: + with open(self.cfg.file, "a", encoding="utf-8") as f: + for ev in self._events: + f.write(json.dumps(ev, ensure_ascii=False) + "\n") + if clear: + self._events.clear() + + def snapshot(self, path: str) -> None: + with self._lock: + with open(path, "w", encoding="utf-8") as f: + for ev in self._events: + f.write(json.dumps(ev, ensure_ascii=False) + "\n") + + # Emit a new frame event + def _emit(self, ev: Dict[str, Any]) -> None: + if not self._active: + return + ev.setdefault("ts_us", time.time_ns() // 1000) + if self._req_id is not None: + ev.setdefault("req_id", self._req_id) + if self.cfg.record_pid_tid: + try: + ev.setdefault("pid", os.getpid()) + ev.setdefault("tid", threading.get_ident()) + except Exception: + logger.warning("Unable to get PID and TID for tracer frame.") + with self._lock: + if self.cfg.streaming and self._fh: + self._fh.write(json.dumps(ev, ensure_ascii=False) + "\n") + self._fh.flush() + else: + self._events.append(ev) + + # Frames + class _NoopFrame: + def __enter__(self): + return self + def __exit__(self, *a): + return False + def event(self, *a, **k): + pass + def set(self, *a, **k): + pass + + class _Frame: + __slots__ = ("t", "name", "attrs", "_t0") + def __init__(self, tracer: "Tracer", name: str, attrs: Optional[Dict[str, Any]]): + self.t = tracer + self.name = name + self.attrs = dict(attrs or {}) + self._t0 = 0.0 + def __enter__(self): + self._t0 = time.perf_counter() + self.t._emit({"type": "B", "name": self.name, "args": dict(self.attrs)}) + return self + def __exit__(self, ex_type, ex, tb): + dt_ms = (time.perf_counter() - self._t0) * 1000.0 + self.t._emit({"type": "E", "name": self.name, "args": {"ms": round(dt_ms, 3), "exc": bool(ex)}}) + return False + def event(self, name: str, **attrs): + self.t._emit({"type": "I", "name": f"{self.name}.{name}", "args": attrs}) + def set(self, key: str, val: Any): + self.attrs[key] = val + + def frame(self, scope: str, name: str, attrs: Optional[Dict[str, Any]] = None): + if not self._active: + return Tracer._NoopFrame() + return Tracer._Frame(self, f"{scope}.{name}", attrs) + + def mark(self, name: str, **attrs: Any) -> None: + self._emit({"type": "I", "name": name, "args": attrs}) + + # Helpers + @contextmanager + def profile_block(self, outfile: Optional[str] = None, sort: str = "cumtime", limit: int = 40): + pr = cProfile.Profile() + pr.enable() + try: + yield pr + finally: + pr.disable() + s = io.StringIO() + pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sort).print_stats(limit) + out = s.getvalue() + if outfile: + d = os.path.dirname(outfile) or "." + os.makedirs(d, exist_ok=True) + with open(outfile, "w", encoding="utf-8") as f: + f.write(out) + else: + self._emit({"type": "PROFILE", "name": "cprofile", "args": {"sort": sort, "limit": limit, "report": out}}) + + @contextmanager + def callgraph( + self, + include_prefixes: Optional[Tuple[str, ...]] = None, + budget_events: Optional[int] = None, + include_c_calls: Optional[bool] = None, + apply_to_new_threads: bool = False, + ): + """ + Interpreter-level tracing (sys.setprofile) for all Python calls/returns + within the scope. Heavy overhead; best for deep debugging runs. + """ + prefixes = include_prefixes if include_prefixes is not None else self.cfg.include_prefixes + budget = (budget_events if budget_events is not None else self.cfg.budget) or 0 + inc_c = include_c_calls if include_c_calls is not None else self.cfg.include_c_calls + + emitted = 0 + stack: list[Tuple[str, float]] = [] + + def prof(frame, event, arg): + nonlocal emitted + if budget and emitted >= budget: + return + if event in ("call", "return"): + code = frame.f_code + filename = code.co_filename or "" + if prefixes and not any(filename.startswith(p) for p in prefixes): + return + name = code.co_name + key = f"{filename}:{code.co_firstlineno}:{name}" + if event == "call": + stack.append((key, time.perf_counter())) + self._emit({"type": "B", "name": f"py.{name}", "args": {"file": filename, "line": code.co_firstlineno}}) + emitted += 1 + else: + if stack and stack[-1][0] == key: + _, t0 = stack.pop() + dt_ms = (time.perf_counter() - t0) * 1000.0 + self._emit({"type": "E", "name": f"py.{name}", "args": {"ms": round(dt_ms, 3)}}) + emitted += 1 + elif inc_c and event in ("c_call", "c_return"): + func = getattr(arg, "__name__", None) + mod = getattr(arg, "__module__", None) + if not func: + return + if event == "c_call": + self._emit({"type": "B", "name": f"c.{mod}.{func}", "args": {}}) + emitted += 1 + else: + self._emit({"type": "E", "name": f"c.{mod}.{func}", "args": {}}) + emitted += 1 + + prev = sys.getprofile() + sys.setprofile(prof) + prev_thread = None + if apply_to_new_threads: + prev_thread = threading.getprofile() + threading.setprofile(prof) + try: + yield + finally: + sys.setprofile(prev) + if apply_to_new_threads: + threading.setprofile(prev_thread) + diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 2c5626cd..b9226e3b 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -79,52 +79,53 @@ def _process_activation(self, activation_msg: ActivationMessage): try: # per-nonce kvcache for concurrent requests - kv = self._get_or_make_kv(activation_msg.nonce) + with self.tracer.frame("compute.thread", "kvcache.init"): + kv = self._get_or_make_kv(activation_msg.nonce) # Get input activation from pool - input_buffer = self.input_pool.get_buffer(activation_msg.pool_id) - if input_buffer is None: - logger.error("Failed to get input buffer %s", activation_msg.pool_id) - return + with self.tracer.frame("compute.thread", "activations.load"): + input_buffer = self.input_pool.get_buffer(activation_msg.pool_id) + if input_buffer is None: + logger.error("Failed to get input buffer %s", activation_msg.pool_id) + return # Prepare input activation - if activation_msg.dtype == "tokens": - # tokens were staged as int32 in the pool; embed locally on start shard - input_size = int(np.prod(activation_msg.shape)) - tok_view = input_buffer[:input_size].reshape(activation_msg.shape) - # Convert robustly to MLX int32 and embed (batch=1) - toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) - x = self.model.embed(toks[None]) - if x.dtype != self._wire_mx_dtype: - x = x.astype(self._wire_mx_dtype) - else: - # Prepare input activation using MLX view operations only - input_size = int(np.prod(activation_msg.shape)) - x = input_buffer[:input_size].reshape(activation_msg.shape) - # Ensure expected dtype without re-materializing when not needed - try: - if str(x.dtype) != activation_msg.dtype: - x = x.astype(mlx_dtype_map[activation_msg.dtype]) - except Exception: - pass + with self.tracer.frame("compute.thread", "activations.process"): + if activation_msg.dtype == "tokens": + # tokens were staged as int32 in the pool; embed locally on start shard + input_size = int(np.prod(activation_msg.shape)) + tok_view = input_buffer[:input_size].reshape(activation_msg.shape) + # Convert robustly to MLX int32 and embed (batch=1) + toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) + x = self.model.embed(toks[None]) + if x.dtype != self._wire_mx_dtype: + x = x.astype(self._wire_mx_dtype) + else: + # Prepare input activation using MLX view operations only + input_size = int(np.prod(activation_msg.shape)) + x = input_buffer[:input_size].reshape(activation_msg.shape) + # Ensure expected dtype without re-materializing when not needed + try: + if str(x.dtype) != activation_msg.dtype: + x = x.astype(mlx_dtype_map[activation_msg.dtype]) + except Exception: + pass # Compute windows until boundary (stay local as long as possible) current_layer = activation_msg.layer_id + 1 last_layer = current_layer - 1 while True: - start_time = time.perf_counter() processed = 0 did_early_swap = False - # Determine contiguous local window starting at current_layer - window_layers: List[int] = [] - _tmp_layer = current_layer - while processed < self.window_size and ( - _tmp_layer in self._assigned_set - ): - window_layers.append(_tmp_layer) - _tmp_layer += 1 - processed += 1 + with self.tracer.frame("compute.thread", "weights.prepare"): + # Determine contiguous local window starting at current_layer + window_layers: List[int] = [] + _tmp_layer = current_layer + while processed < self.window_size and (_tmp_layer in self._assigned_set): + window_layers.append(_tmp_layer) + _tmp_layer += 1 + processed += 1 if self._mode == "offload" and window_layers: prep = self._prepared_by_nonce.get(activation_msg.nonce) @@ -198,10 +199,23 @@ def _process_activation(self, activation_msg: ActivationMessage): t_w_ms, ) - bind_ms = 0.0 - if to_bind: - # Block prefetch-touch during binding and serialize MLX ops + # Opportunistically schedule prefetch for the next window to overlap with compute try: + next_win_pre = self._next_local_layers( + (window_layers[-1] if window_layers else (activation_msg.layer_id)), + self.window_size, + ) + for nl in next_win_pre: + self._prefetch_to_ram(nl) + self._enqueue_weight_prefetch(nl) + except Exception: + pass + + # Execute the window + with self.tracer.frame("compute.thread", "execute"): + self._beyond_cursor = window_layers[-1] if window_layers else (activation_msg.layer_id) + + try: # Prevent prefetch touching during encode/compute to minimize UMA pressure self._compute_busy.set() except Exception: pass @@ -248,23 +262,35 @@ def _process_activation(self, activation_msg: ActivationMessage): window_layers, (t_comp_done - t_comp) * 1000.0, ) + """ for lid in window_layers: self.weight_cache.decrease_reference(lid) - try: - # Sliding-fit delta swap: maintain a single resident set by evicting - # only what's needed to fit the next window into the budget. Prefer - # keeping the tail of the previous window so we don't thrash weights - # that are likely to be reused. - if self._mode == "sliding_fit": - if int(self._resident_windows) <= 1: - if did_early_swap: - # Early delta-swap already trimmed resident set for this window - pass - elif not self._recent_windows: - # First window in token: seed resident set - self._recent_windows.append(list(window_layers)) + with self.tracer.frame("compute.thread", "execute.evict_and_unload"): + try: + # Sliding-fit delta swap: maintain a single resident set by evicting + # only what's needed to fit the next window into the budget. Prefer + # keeping the tail of the previous window so we don't thrash weights + # that are likely to be reused. + if self._mode == "sliding_fit": + if int(self._resident_windows) <= 1: + if did_early_swap: + # Early delta-swap already trimmed resident set for this window + pass + elif not self._recent_windows: + # First window in token: seed resident set + self._recent_windows.append(list(window_layers)) + else: + prev = self._recent_windows.pop(0) + self._delta_swap_eviction(window_layers, prev, activation_msg, early=False) + budget = max(1, int(self.window_size or 1)) + curr = list(window_layers) + prev_only = [x for x in prev if x not in curr] + keep_quota = max(0, budget - len(curr)) + keep_tail = prev_only[-keep_quota:] if keep_quota > 0 else [] + combined = list(keep_tail) + curr + self._recent_windows.append(combined) else: prev = self._recent_windows.pop(0) self._delta_swap_eviction( @@ -280,34 +306,19 @@ def _process_activation(self, activation_msg: ActivationMessage): combined = list(keep_tail) + curr self._recent_windows.append(combined) else: - # resident_windows>1 not expected in sliding_fit; fall back to seeding + # Original eviction policy for other modes self._recent_windows.append(list(window_layers)) - else: - # Original eviction policy for other modes - self._recent_windows.append(list(window_layers)) - if int(self._resident_windows) <= 1: - old = self._recent_windows.pop(0) - try: - evicted_cnt = self.weight_cache.evict_layers(old) - except Exception: - evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass - if self._profile: + if int(self._resident_windows) <= 1: + old = self._recent_windows.pop(0) try: - logger.info( - "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s", - self.node_id, - activation_msg.nonce, - old, - evicted_cnt, - self._resident_windows, - ) + evicted_cnt = self.weight_cache.evict_layers(old) + except Exception: + evicted_cnt = 0 + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) except Exception: pass else: @@ -341,25 +352,8 @@ def _process_activation(self, activation_msg: ActivationMessage): ) except Exception: pass - except Exception: - pass - - computation_time = time.perf_counter() - start_time - self._prof.info( - "[PROFILE][COMPUTE] node=%s nonce=%s window_layers=%s total_ms=%.3f", - self.node_id, - activation_msg.nonce, - window_layers, - computation_time * 1000.0, - ) - self._prof.info( - "Completed layers up to %s in %.3fs, nonce: %s, result: %s %s", - last_layer, - computation_time, - activation_msg.nonce, - x.shape, - x.dtype, - ) + except Exception: + pass # If next layer is still local, continue without staging/tx nxt = last_layer + 1 @@ -368,33 +362,64 @@ def _process_activation(self, activation_msg: ActivationMessage): continue # Boundary reached — directly pass tensor to TX to avoid extra copy/sync - t_stage = time.perf_counter() - x_cast = ( - x - if x.dtype == self._wire_mx_dtype - else x.astype(self._wire_mx_dtype) - ) - try: - self._compute_busy.clear() - except Exception: - pass - - if self._profile: + with self.tracer.frame("compute.thread", "execute.enqueue_prefetch"): + x_cast = x if x.dtype == self._wire_mx_dtype else x.astype(self._wire_mx_dtype) try: - logger.info( - "[PROFILE][STAGE-DIRECT] node=%s nonce=%s layer_tail=%s stage_ms=%.3f shape=%s dtype=%s", - self.node_id, - activation_msg.nonce, - last_layer, - (time.perf_counter() - t_stage) * 1000.0, - tuple(x_cast.shape), - str(self._wire_mx_dtype), - ) + self._compute_busy.clear() + except Exception: + pass + try: + for lid in list(self._prefetch_pending): + self._prefetch_pending.discard(lid) + self._enqueue_weight_prefetch(lid) except Exception: pass - nxt = last_layer + 1 - if nxt >= self.model_metadata.num_layers: # End of model + with self.tracer.frame("compute.thread", "mdns.send"): + nxt = last_layer + 1 + if nxt >= self.model_metadata.num_layers: # End of model + try: + with self._mlx_lock: + y = self.model.normalize(x_cast) + y = self.model.lm_project(y) + # Greedy sample last position + if y.ndim == 3: + logits_2d = y[:, -1, :] + elif y.ndim == 2: + logits_2d = y[-1:, :] + else: + logits_2d = y.reshape(1, -1) + tok = mx.argmax(logits_2d, axis=-1) + token_id = int(tok.item()) + except Exception as e: + logger.error("End-shard sampling failed: %s", e) + return + output_msg = ActivationMessage( + nonce=activation_msg.nonce, + layer_id=last_layer, + pool_id=-1, + shape=cast(tuple[int, ...], x.shape), + batch_size=activation_msg.batch_size, + timestamp=utc_epoch_now(), + node_origin=f"node_{self.node_id}", + dtype=str(self._wire_mx_dtype), + callback_url=activation_msg.callback_url, + is_final=True, + token_id=token_id, + ) + else: + output_msg = ActivationMessage( + nonce=activation_msg.nonce, + layer_id=last_layer, + pool_id=-1, + shape=cast(tuple[int, ...], x.shape), + batch_size=activation_msg.batch_size, + timestamp=utc_epoch_now(), + node_origin=f"node_{self.node_id}", + dtype=str(self._wire_mx_dtype), + callback_url=activation_msg.callback_url, + tensor=x_cast, + ) try: with self._mlx_lock: y = self.model.normalize(x_cast) @@ -464,41 +489,42 @@ def _process_activation(self, activation_msg: ActivationMessage): self.input_pool.release(activation_msg.pool_id) # Optional unload/evict after stage - if self._mode != "sliding_fit": - if self._defer_unload: - try: - while len(self._recent_windows) > max( - 1, int(self._resident_windows) - ): - old = self._recent_windows.pop(0) - try: - evicted_cnt = self.weight_cache.evict_layers(old) - except Exception: - evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass - if self._profile: - logger.info( - "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s (post-stage)", - self.node_id, - activation_msg.nonce, - old, - evicted_cnt, - self._resident_windows, - ) - except Exception: - pass + with self.tracer.frame("compute.thread", "cleanup"): + if self._mode != "sliding_fit": + if self._defer_unload: + try: + while len(self._recent_windows) > max( + 1, int(self._resident_windows) + ): + old = self._recent_windows.pop(0) + try: + evicted_cnt = self.weight_cache.evict_layers(old) + except Exception: + evicted_cnt = 0 + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) + except Exception: + pass + if self._profile: + logger.info( + "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s (post-stage)", + self.node_id, + activation_msg.nonce, + old, + evicted_cnt, + self._resident_windows, + ) + except Exception: + pass - if self._resident_windows <= 1: - try: - evicted = self.weight_cache.evict_layers(window_layers) - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(window_layers) # type: ignore[attr-defined] + if self._resident_windows <= 1: + try: + evicted = self.weight_cache.evict_layers(window_layers) + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(window_layers) # type: ignore[attr-defined] if self._profile: logger.info( "[PROFILE][EVICT] node=%s nonce=%s layers=%s evicted=%s", diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 919d1efb..ffb34ac3 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -63,6 +63,8 @@ from .comms import CommsMixin from ..weight_cache import WeightCache +from dnet.perf.trace import TraceConfig, Tracer + class RingShardNode(ComputeMixin, PrefetchMixin, CommsMixin): """Single shard node in the distributed inference ring with dynamic model loading.""" @@ -201,6 +203,19 @@ def __init__( if self._profile: logger.info("[PROFILE] enabled on shard node %s", self.node_id) + # Debug tracing + cfg = TraceConfig( + file="./trace.json", + streaming=True, + include_prefixes = ("src/dnet/"), + include_c_calls = False, + budget = 10000, + enabled = True, + record_pid_tid = True, + ) + self.tracer = Tracer(cfg) + self.tracer.start() + # Per-nonce KV caches (concurrent requests) self._kv_by_nonce: Dict[str, list] = {} self._kv_last_seen: Dict[str, float] = {} @@ -219,11 +234,8 @@ def __init__( ) async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse: - """Load model with specified layers.""" - try: - start_time = time.perf_counter() - - # Check if already loaded with same configuration + """Load model with specified layers. """ + try: # Check if already loaded with same configuration if ( self.model is not None and self.model_path == req.model_path @@ -241,22 +253,22 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # If model loaded with different config, unload first - if self.model is not None and ( - self.model_path != req.model_path or self.assigned_layers != req.layers + if (self.model is not None + and (self.model_path != req.model_path or self.assigned_layers != req.layers) ): - logger.info( - "Node %s: Unloading current model to load new configuration", - self.node_id, - ) - await self.unload_model() + logger.info("Node %s: Unloading current model to load new configuration", self.node_id) + with self.tracer.frame("memory", "model.unload"): + await self.unload_model() # Load model metadata - self.model_metadata = get_model_metadata(req.model_path) + with self.tracer.frame("memory", "model.load_metadata"): + self.model_metadata = get_model_metadata(req.model_path) self.assigned_layers = req.layers self._assigned_sorted = sorted(self.assigned_layers) self._assigned_set = set(self._assigned_sorted) self.model_path = req.model_path + # Decide mode dynamically from assignment + requested window requested_w = int(max(1, int(req.window_size))) local_count = max(1, len(self.assigned_layers)) @@ -352,45 +364,49 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # Initialize weight cache - self.weight_cache = WeightCache( - self.assigned_layers, - self.model_metadata, - window_size=self.window_size, - prefetch_threads=self._prefetch_threads, - resident_windows=self._resident_windows, - use_mxload_fastpath=self.config.mxload_fastpath, - prefetch_mode=self.config.prefetch_mode, - ) + with self.tracer.frame("memory", "weight_cache.init"): + self.weight_cache = WeightCache( + self.assigned_layers, + self.model_metadata, + window_size=self.window_size, + prefetch_threads=self._prefetch_threads, + resident_windows=self._resident_windows, + use_mxload_fastpath=self.config.mxload_fastpath, + prefetch_mode=self.config.prefetch_mode, + ) # Load the model - self.model = get_ring_model( - self.model_metadata.model_type, - self.model_metadata.model_config, - assigned_layers=self.assigned_layers, - is_api_layer=False, - ) - try: - applied = bool( - self.model.apply_quantization_from_config( # type: ignore[attr-defined] - self.model_metadata.model_config, - model_metadata=self.model_metadata, - ) - ) - logger.info( - "[QUANT] applied=%s for model=%s", - applied, + with self.tracer.frame("memory", "model.load"): + self.model = get_ring_model( self.model_metadata.model_type, + self.model_metadata.model_config, + assigned_layers=self.assigned_layers, + is_api_layer=False, ) + try: + applied = bool( + self.model.apply_quantization_from_config( # type: ignore[attr-defined] + self.model_metadata.model_config, + model_metadata=self.model_metadata, + ) + ) + logger.info( + "[QUANT] applied=%s for model=%s", + applied, + self.model_metadata.model_type, + ) - except RuntimeError as e: - logger.warning("[QUANT] apply failed: %s", e) - self.model.eval() - self.cache = make_cache( - self.model, - kv_mode=self.config.kv_cache.mode, - kv_bits=self.config.kv_cache.bits, - kv_group=self.config.kv_cache.group_size, - ) + except RuntimeError as e: + logger.warning("[QUANT] apply failed: %s", e) + self.model.eval() + + with self.tracer.frame("memory", "make_cache"): + self.cache = make_cache( + self.model, + kv_mode=self.config.kv_cache.mode, + kv_bits=self.config.kv_cache.bits, + kv_group=self.config.kv_cache.group_size, + ) try: has_start = 0 in self.assigned_layers @@ -425,7 +441,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse self.api_callback_address = req.api_callback_address if self.next_node: - await self._connect_next_node() + with self.tracer.frame("network", "connect.next_node"): + await self._connect_next_node() else: logger.warning("Node %s: No next node configured", self.node_id) @@ -492,7 +509,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse success=True, message="Model loaded successfully", layers_loaded=req.layers, - load_time_ms=load_time_ms, + load_time_ms=0.0, ) except Exception as e: @@ -538,23 +555,24 @@ async def unload_model(self) -> ShardUnloadModelResponse: self._assigned_set = set() # Clear memory pools - if self.weight_cache: - # Stop any in-flight prefetch and close layer manager resources - try: - self.weight_cache.cancel_all_prefetch() - except Exception: - pass - # Clear all cached weights - for layer_id in list(self._bound_versions.keys()): + with self.tracer.frame("memory", "cache.evict"): + if self.weight_cache: + # Stop any in-flight prefetch and close layer manager resources try: - self.weight_cache.evict_layer(layer_id) + self.weight_cache.cancel_all_prefetch() except Exception: pass - try: - self.weight_cache.layer_manager.close() - except Exception: - pass - self.weight_cache = None + # Clear all cached weights + for layer_id in list(self._bound_versions.keys()): + try: + self.weight_cache.evict_layer(layer_id) + except Exception: + pass + try: + self.weight_cache.layer_manager.close() + except Exception: + pass + self.weight_cache = None self.input_pool = None self.output_pool = None @@ -588,12 +606,13 @@ async def reset_cache(self) -> None: return try: - self.cache = make_cache( - self.model, # type: ignore[arg-type] - kv_mode=self.config.kv_cache.mode, - kv_bits=self.config.kv_cache.bits, - kv_group=self.config.kv_cache.group_size, - ) + with self.tracer.frame("memory", "cache.reset"): + self.cache = make_cache( + self.model, # type: ignore[arg-type] + kv_mode=self.config.kv_cache.mode, + kv_bits=self.config.kv_cache.bits, + kv_group=self.config.kv_cache.group_size, + ) logger.info("Node %s: Cache reset successfully", self.node_id) except Exception as e: logger.error("Node %s: Error resetting cache: %s", self.node_id, e) @@ -1095,7 +1114,8 @@ def _compute_worker(self) -> None: activation_msg = self.activation_recv_queue.get(timeout=1.0) # Process the activation - self._process_activation(activation_msg) + with self.tracer.frame("compute", "forward"): + self._process_activation(activation_msg) except Empty: continue From c2b5e6f978ae408f22e0bc1a75b1502efa831cc9 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 00:30:55 -0700 Subject: [PATCH 090/172] aggregate trace buffers back to api --- src/dnet/perf/__init__.py | 2 + src/dnet/perf/trace.py | 222 +++++++----- src/dnet/perf/utils/__init__.py | 2 + src/dnet/perf/utils/aggregator.py | 152 ++++++++ src/dnet/ring/api/models.py | 28 ++ src/dnet/ring/api/node.py | 34 ++ src/dnet/ring/shard/compute.py | 183 +++++----- src/dnet/ring/shard/models.py | 30 ++ src/dnet/ring/shard/startup.py | 563 ++++++++++++++++++++++++++++++ 9 files changed, 1042 insertions(+), 174 deletions(-) create mode 100644 src/dnet/perf/utils/__init__.py create mode 100644 src/dnet/perf/utils/aggregator.py create mode 100644 src/dnet/ring/shard/startup.py diff --git a/src/dnet/perf/__init__.py b/src/dnet/perf/__init__.py index e69de29b..330893e8 100644 --- a/src/dnet/perf/__init__.py +++ b/src/dnet/perf/__init__.py @@ -0,0 +1,2 @@ + +from .trace import TraceConfig, Tracer diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index ad7738d8..92119d79 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -1,32 +1,22 @@ -""" -Object-oriented tracing utilities for dnet. - -This module provides a Tracer class configured explicitly from the REPL (or code), -without relying on environment variables or module-level globals. It supports: - -- Boundary frames via tracer.frame(scope, name, attrs) -- Deep sys.setprofile callgraph via tracer.callgraph(...) -- Aggregated call stats via tracer.profile_block(...) - -All events are written as JSON Lines to a file (TraceConfig.file), suitable -for simple REPL visualization and easy sharing. -""" from __future__ import annotations import os +import io import sys import time import json -import threading -import contextvars -import cProfile import pstats -import io +import cProfile +import threading +import queue + from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, List from contextlib import contextmanager +import httpx + from dnet.utils.logger import logger @dataclass @@ -37,35 +27,126 @@ class TraceConfig: include_c_calls: bool = False budget: int = 0 # 0 means unlimited enabled: bool = True + node_id: Optional[str] = None record_pid_tid: bool = True + aggregate: bool = False + aggregate_url: Optional[str] = None + agg_max_events: int = 1000 + +class _NoopFrame: + def __enter__(self): + return self + def __exit__(self, *a): + return False + def event(self, *a, **k): + pass + def set(self, *a, **k): + pass +class _Frame: + __slots__ = ("t", "name", "attrs", "_t0") + def __init__(self, tracer: "Tracer", name: str, attrs: Optional[Dict[str, Any]]): + self.t = tracer + self.name = name + self.attrs = dict(attrs or {}) + self._t0 = 0.0 + def __enter__(self): + self._t0 = time.perf_counter() + self.t._emit({"type": "B", "name": self.name, "args": dict(self.attrs)}) + return self + def __exit__(self, ex_type, ex, tb): + dt_ms = (time.perf_counter() - self._t0) * 1000.0 + self.attrs.update({"ms": round(dt_ms, 3), "exc": bool(ex)}) + self.t._emit({"type": "E", "name": self.name, "args": self.attrs}) + return False + def event(self, name: str, **attrs): + out = dict(attrs or {}) + out.setdefault("t_rel_ms", (time.perf_counter() - self._t0) * 1000.0) + self.t._emit({"type": "I", "name": f"{self.name}.{name}", "args": out}) + def set(self, key: str, val: Any): + self.attrs[key] = val class Tracer: - def __init__(self, cfg: TraceConfig): - self.cfg = cfg + def __init__(self, config: TraceConfig): + self.config = config self._lock = threading.Lock() self._fh: Optional[io.TextIOBase] = None self._events: List[Dict[str, Any]] = [] self._req_id: str = None self._active = False + self._agg_enabled: bool = False + self._agg_max_events: int = int(self.config.agg_max_events or 1000) + self._agg_q: queue.Queue[dict] = queue.Queue(maxsize=256) + self._agg_thread: Optional[threading.Thread] = None + + if self.config.aggregate: + self.start_aggregator() + + # Aggregator worker thread + def start_aggregator(self) -> None: + self._agg_enabled = True + self._agg_max_events = max(10, int(self.config.agg_max_events or 1000)) # 10 min, 1000 default + if not self._agg_thread or not self._agg_thread.is_alive(): + self._agg_thread = threading.Thread(target=self._agg_exec, name="trace-agg", daemon=True) + self._agg_thread.start() + + def stop_aggregator(self, *, flush: bool = True, timeout: float = 5.0) -> None: + self._agg_enabled = False + if flush and self._events: + try: + self._agg_q.put_nowait({ + "run_id": (self._req_id or "run"), + "node_id": (self.config.node_id or "node"), + "events": list(self._events), }) + except queue.Full: + logger.warning(f"Trace aggragator queue is full.") + self._events.clear() + if self._agg_thread and self._agg_thread.is_alive(): + self._agg_thread.join(timeout) + self._agg_thread = None + + def _agg_exec(self) -> None: + url = self.config.aggregate_url or "" + assert url != "" + client = httpx.Client(timeout=5.0) + try: + while self._agg_enabled and not self._agg_q.empty(): + try: + batch = self._agg_q.get(timeout=0.2) + except queue.Empty: + continue + try: + client.post(url, json=batch) + except Exception: + logger.warning(f"Unable to POST trace aggregation data to {url}") + finally: + self._agg_q.task_done() + finally: + try: + client.close() + except Exception: + logger.warining("Unable to close httpx client.") + def start(self, *, reset: bool = True) -> None: - self._active = bool(self.cfg.enabled) + self._active = bool(self.config.enabled) if not self._active: logger.info("Initialized tracer.") return - if self.cfg.file: - d = os.path.dirname(self.cfg.file) or "." - os.makedirs(d, exist_ok=True) - if reset and os.path.exists(self.cfg.file): + if self.config.file: + d = os.path.dirname(self.config.file) or "." + os.makedirs(d, exist_ok=True) + if reset and os.path.exists(self.config.file): try: - os.remove(self.cfg.file) + os.remove(self.config.file) except Exception: - logger.warning(f"Unable to remove existing trace file {self.cfg.file}") - if self.cfg.streaming: + logger.warning(f"Unable to remove existing trace file {self.config.file}") + if self.config.streaming: with self._lock: - self._fh = open(self.cfg.file, "a", encoding="utf-8") - logger.info(f"Streaming trace to {self.cfg.file}.") + self._fh = open(self.config.file, "a", encoding="utf-8") + logger.info(f"Streaming trace to {self.config.file}.") + if self.config.aggregate and self.config.aggregate_url and self.config.node_id: + self.start_aggregator() def stop(self, *, flush_events: bool = True) -> None: if flush_events: @@ -77,89 +158,65 @@ def stop(self, *, flush_events: bool = True) -> None: self._fh.flush() self._fh.close() except Exception: - logger.warning(f"Unable to flush to file {self.cfg.file}") + logger.warning(f"Unable to flush to file {self.config.file}") self._fh = None - def set_request_id(self, rid: Optional[str]) -> None: - self._req_id = rid - - def get_request_id(self) -> Optional[str]: - return self._req_id - # Flush file to disk def flush(self, *, clear: bool = False) -> None: - if not self._active: - return + if not self._active: return with self._lock: - if not self.cfg.streaming and self._events: - with open(self.cfg.file, "a", encoding="utf-8") as f: + if not self.config.streaming and self._events: + with open(self.config.file, "a", encoding="utf-8") as f: for ev in self._events: f.write(json.dumps(ev, ensure_ascii=False) + "\n") if clear: self._events.clear() - + + # Quick dump to memory def snapshot(self, path: str) -> None: with self._lock: with open(path, "w", encoding="utf-8") as f: for ev in self._events: f.write(json.dumps(ev, ensure_ascii=False) + "\n") - # Emit a new frame event + # emit a new frame def _emit(self, ev: Dict[str, Any]) -> None: - if not self._active: - return - ev.setdefault("ts_us", time.time_ns() // 1000) + if not self._active: return + ev.setdefault("ts", time.perf_counter()) if self._req_id is not None: ev.setdefault("req_id", self._req_id) - if self.cfg.record_pid_tid: + if self.config.record_pid_tid: try: ev.setdefault("pid", os.getpid()) ev.setdefault("tid", threading.get_ident()) except Exception: logger.warning("Unable to get PID and TID for tracer frame.") + with self._lock: - if self.cfg.streaming and self._fh: + if self.config.streaming and self._fh: self._fh.write(json.dumps(ev, ensure_ascii=False) + "\n") self._fh.flush() else: self._events.append(ev) - # Frames - class _NoopFrame: - def __enter__(self): - return self - def __exit__(self, *a): - return False - def event(self, *a, **k): - pass - def set(self, *a, **k): - pass - - class _Frame: - __slots__ = ("t", "name", "attrs", "_t0") - def __init__(self, tracer: "Tracer", name: str, attrs: Optional[Dict[str, Any]]): - self.t = tracer - self.name = name - self.attrs = dict(attrs or {}) - self._t0 = 0.0 - def __enter__(self): - self._t0 = time.perf_counter() - self.t._emit({"type": "B", "name": self.name, "args": dict(self.attrs)}) - return self - def __exit__(self, ex_type, ex, tb): - dt_ms = (time.perf_counter() - self._t0) * 1000.0 - self.t._emit({"type": "E", "name": self.name, "args": {"ms": round(dt_ms, 3), "exc": bool(ex)}}) - return False - def event(self, name: str, **attrs): - self.t._emit({"type": "I", "name": f"{self.name}.{name}", "args": attrs}) - def set(self, key: str, val: Any): - self.attrs[key] = val + if self._agg_enabled: + if len(self._events) < self._agg_max_events: return + batch = { "run_id": self._agg_run_id, + "node_id": self._agg_node_id or self.config.node_id, + "events": list(self._events)} + try: + self._agg_q.put_nowait(batch) + except queue.Full: + logger.warning(f"Aggregator queue is full. Dropping {len(batch["events"])} frames.") + self._events.clear() + # Frames def frame(self, scope: str, name: str, attrs: Optional[Dict[str, Any]] = None): if not self._active: - return Tracer._NoopFrame() - return Tracer._Frame(self, f"{scope}.{name}", attrs) + return _NoopFrame() + return _Frame(self, f"{scope}.{name}", attrs) + # Mark an event outside of a frame def mark(self, name: str, **attrs: Any) -> None: self._emit({"type": "I", "name": name, "args": attrs}) @@ -195,9 +252,9 @@ def callgraph( Interpreter-level tracing (sys.setprofile) for all Python calls/returns within the scope. Heavy overhead; best for deep debugging runs. """ - prefixes = include_prefixes if include_prefixes is not None else self.cfg.include_prefixes - budget = (budget_events if budget_events is not None else self.cfg.budget) or 0 - inc_c = include_c_calls if include_c_calls is not None else self.cfg.include_c_calls + prefixes = include_prefixes if include_prefixes is not None else self.config.include_prefixes + budget = (budget_events if budget_events is not None else self.config.budget) or 0 + inc_c = include_c_calls if include_c_calls is not None else self.config.include_c_calls emitted = 0 stack: list[Tuple[str, float]] = [] @@ -247,4 +304,3 @@ def prof(frame, event, arg): sys.setprofile(prev) if apply_to_new_threads: threading.setprofile(prev_thread) - diff --git a/src/dnet/perf/utils/__init__.py b/src/dnet/perf/utils/__init__.py new file mode 100644 index 00000000..9c310fd5 --- /dev/null +++ b/src/dnet/perf/utils/__init__.py @@ -0,0 +1,2 @@ + +from aggregator import TraceAggregator diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py new file mode 100644 index 00000000..f9bbdef6 --- /dev/null +++ b/src/dnet/perf/utils/aggregator.py @@ -0,0 +1,152 @@ + +from __future__ import annotations + +import threading +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple, Optional, DefaultDict +from collections import defaultdict, deque + +from dnet.utils.logger import logger + +Key = Tuple[str, Optional[int], Optional[int], str] # (node_id, pid, tid, req_id) + +@dataclass +class _OpenFrame: + name: str + t0: int + child: int = 0 + children: List[Dict[str, Any]] = field(default_factory=list) + +# Sort known frames and compute averages by key +@dataclass +class RunAggregator: + sums_by_name: Dict[str, float] = field(default_factory=dict) + counts_by_name: Dict[str, int] = field(default_factory=dict) + last_batch_seq: Dict[str, int] = field(default_factory=dict) + + stacks: Dict[Key, List[_OpenFrame]] = field(default_factory=dict) + drops: int = 0 + roots_by_req: DefaultDict[str, List[Dict[str, Any]]] = field(default_factory=lambda: defaultdict(list)) + + def _key(self, node_id: str, pid: Optional[int], tid: Optional[int], req_id: Optional[str]) -> Key: + return (node_id, pid, tid, req_id or "") + + def _push(self, key: Key, f: _OpenFrame) -> None: + self.stacks.setdefault(key, []).append(f) + + def _pop(self, key: Key) -> Optional[_OpenFrame]: + st = self.stacks.get(key) + if not st: return None + return st.pop() + + def _peek(self, key: Key) -> Optional[_OpenFrame]: + st = self.stacks.get(key) + return st[-1] if st else None + + def _acc_annotate(self, name: str, self_ms: float) -> None: + self.sums_by_name[name] = self.sums_by_name.get(name, 0.0) + self_ms + self.counts_by_name[name] = self.counts_by_name.get(name, 0) + 1 + + def ingest_event(self, node_id: str, ev: Dict[str, Any]) -> None: + if not ev.get("name"): + logger.error(f"Received trace frame without name {ev.get('ts')}") + return + # Normalize ts to microseconds (accept float seconds or int microseconds) + ts_raw = ev.get("ts") + ts_us = 0 + try: + if isinstance(ts_raw, float): + ts_us = int(ts_raw * 1_000_000) + elif isinstance(ts_raw, int): + ts_us = ts_raw + else: + ts_us = int(ts_raw or 0) + except Exception: + ts_us = 0 + req_id = ev.get("req_id") or "" + key = self._key(node_id, ev.get("pid"), ev.get("tid"), req_id) + if ev.get("type") == "B": + self._push(key, _OpenFrame(name=ev.get("name"), t0=ts_us)) + elif ev.get("type") == "E": + fr = self._pop(key) + if not fr: return + dur_us = max(0, ts_us - fr.t0) + self_us = max(0, dur_us - fr.child) + self_ms = self_us / 1000.0 + self._acc_annotate(fr.name, self_ms) + parent = self._peek(key) + completed = { + "name": fr.name, + "ts": fr.t0, + "dur_ms": dur_us / 1000.0, + "self_ms": self_ms, + "children": fr.children, + "pid": ev.get("pid"), + "tid": ev.get("tid"), + "req_id": req_id, + "node_id": node_id, + } + if parent: + parent.child += dur_us + parent.children.append(completed) + else: + self.roots_by_req[req_id or ""].append(completed) + else: + # TODO :Process other events + pass + + +class TraceAggregator: + def __init__(self) -> None: + self._req: Dict[str, RunAggregator] = {} + self._lock = threading.Lock() + + def enqueue(self, batch: Dict[str, Any]) -> None: + run_id = batch.get("run_id") + node_id = batch.get("node_id") + if not run_id or not node_id: + return + events = batch.get("events") or [] + batch_seq = int(batch.get("batch_seq") or 0) + with self._lock: + agg = self._req.setdefault(run_id, RunAggregator()) + last = agg.last_batch_seq.get(node_id) + if (last is not None) and (batch_seq != last + 1): + agg.drops += abs(batch_seq - (last + 1)) + agg.last_batch_seq[node_id] = batch_seq + for ev in events: + try: + agg.ingest_event(node_id, ev) + except Exception: + continue + + def annotate(self, run_id: str, *, mapping: Optional[Dict[str, str]] = None, repeats: int = 0) -> List[Dict[str, Any]]: + with self._lock: + agg = self._req.get(run_id) + if not agg: + return [] + if not mapping: + rows = [ + {"name": k, "self_ms": v, "total_ms": v, "count": repeats or agg.counts_by_name.get(k, 0), "max_ms": None} + for k, v in agg.sums_by_name.items() + ] + else: + sums: Dict[str, float] = {} + counts: Dict[str, int] = {} + for raw, val in agg.sums_by_name.items(): + disp = mapping.get(raw, raw) + sums[disp] = sums.get(disp, 0.0) + val + counts[disp] = counts.get(disp, 0) + agg.counts_by_name.get(raw, 0) + rows = [ + {"name": k, "self_ms": v, "total_ms": v, "count": repeats or counts.get(k, 0), "max_ms": None} + for k, v in sums.items() + ] + rows.sort(key=lambda r: r["self_ms"], reverse=True) + return rows + + def roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: + with self._lock: + agg = self._req.get(run_id) + if not agg: + return [] + return list(agg.roots_by_req.get(req_id or "", [])) diff --git a/src/dnet/ring/api/models.py b/src/dnet/ring/api/models.py index 7199142a..fc84162f 100644 --- a/src/dnet/ring/api/models.py +++ b/src/dnet/ring/api/models.py @@ -8,6 +8,7 @@ from ..common import LayerAssignment +from dnet.perf.trace import _Frame class RoleMapping(BaseModel): """Role mapping for chat formatting.""" @@ -403,3 +404,30 @@ class UnloadModelResponse(BaseModel): message: Optional[str] = Field( default=None, description="Overall status or error message" ) + +# Tracer ingest + +class TraceEvent(BaseModel): + type: Literal["B", "E", "I"] = Field(..., description="Event type/phase") + name: str = Field(..., description="Span/mark name") + ts: int = Field(..., description="Timestamp in microseconds") + args: Dict[str, Any] = Field(default_factory=dict) + req_id: Optional[str] = None + pid: Optional[int] = None + tid: Optional[int] = None + +class TraceIngestBatch(BaseModel): + run_id: str = Field(..., description="Bench run identifier") + node_id: str = Field(..., description="Shard/service identity") + batch_seq: int = Field(..., description="Monotonic batch sequence per node") + events: List[TraceEvent] = Field(default_factory=list) + dropped: Optional[int] = Field(default=0, description="Events dropped on sender") + max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") + last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") + schema_version: int = Field(default=1) + +class TraceIngestResponse(BaseModel): + ok: bool = True + accepted: int = 0 + batch_seq: Optional[int] = None + message: Optional[str] = None diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index e819235f..d1e07f8e 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -79,6 +79,8 @@ ShardLoadModelRequest, ShardLoadModelResponse, ShardProfileResponse, + TraceIngestBatch, + TraceIngestResponse, ) from ..data_types import StopCondition from .servicer import ShardApiServicer @@ -371,6 +373,34 @@ async def completions(req: CompletionRequestModel): # type: ignore ) return await self._handle_text_completion(req) + # Ingest trace buffers and forward to REPL + @self.app.post("/trace/ingest") + async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: ignore + logger.debug(f"Received trace buffer.") + try: + if self._trace_ingest_cb is not None: + self._trace_ingest_cb(batch.model_dump()) + return TraceIngestResponse(ok=True, accepted=len(batch.events), batch_seq=batch.batch_seq) + + try: + run_dir = Path("logs/trace/ingest") / batch.run_id + run_dir.mkdir(parents=True, exist_ok=True) + fpath = run_dir / f"{batch.node_id}.jsonl" + with fpath.open("a", encoding="utf-8") as f: + f.write(batch.model_dump_json() + "\n") + except Exception: + logger.warning(f"Unable to write trace ingest buffer to temp file {fpath}") + return TraceIngestResponse( + ok=True, + accepted=len(batch.events), + batch_seq=batch.batch_seq, + message="no aggregator; appended" + ) + except Exception as e: + logger.warning(f"Unable to ingest trace buffer: {e}") + return TraceIngestResponse(ok=False, accepted=0, message=str(e)) + + async def _handle_prepare_topology( self, req: PrepareTopologyRequest ) -> TopologyInfo: @@ -1700,3 +1730,7 @@ async def shutdown(self) -> None: logger.warning("Discovery service was not running") logger.info("API server shutdown complete") + + # REPL helper to install a trace ingestion callback + def set_trace_ingest_callback(self, cb: Optional[Callable[[Dict[str, Any]], None]]) -> None: + self._trace_ingest_cb = cb diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index b9226e3b..d25705fe 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -90,26 +90,26 @@ def _process_activation(self, activation_msg: ActivationMessage): return # Prepare input activation - with self.tracer.frame("compute.thread", "activations.process"): - if activation_msg.dtype == "tokens": - # tokens were staged as int32 in the pool; embed locally on start shard - input_size = int(np.prod(activation_msg.shape)) - tok_view = input_buffer[:input_size].reshape(activation_msg.shape) - # Convert robustly to MLX int32 and embed (batch=1) + with self.tracer.frame("compute.thread", "activations.process") as f: + if activation_msg.dtype == "tokens": # embed locally on start shard + f.event("embed_tokens") + numel = int(np.prod(activation_msg.shape)) + tok_view = input_buffer[:numel].reshape(activation_msg.shape) toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) x = self.model.embed(toks[None]) if x.dtype != self._wire_mx_dtype: x = x.astype(self._wire_mx_dtype) - else: - # Prepare input activation using MLX view operations only - input_size = int(np.prod(activation_msg.shape)) - x = input_buffer[:input_size].reshape(activation_msg.shape) - # Ensure expected dtype without re-materializing when not needed - try: + + else: # Prepare input activation using MLX view operations only + f.set("activation_dtype", activation_msg.dtype) + numel = int(np.prod(activation_msg.shape)) + x = input_buffer[:numel].reshape(activation_msg.shape) + + try: # Ensure expected dtype if str(x.dtype) != activation_msg.dtype: x = x.astype(mlx_dtype_map[activation_msg.dtype]) except Exception: - pass + logger.warning(f"Unable to update activation dtype") # Compute windows until boundary (stay local as long as possible) current_layer = activation_msg.layer_id + 1 @@ -118,7 +118,8 @@ def _process_activation(self, activation_msg: ActivationMessage): processed = 0 did_early_swap = False - with self.tracer.frame("compute.thread", "weights.prepare"): + with self.tracer.frame("compute.thread", "weights.prepare") as f: + # Determine contiguous local window starting at current_layer window_layers: List[int] = [] _tmp_layer = current_layer @@ -127,89 +128,89 @@ def _process_activation(self, activation_msg: ActivationMessage): _tmp_layer += 1 processed += 1 - if self._mode == "offload" and window_layers: - prep = self._prepared_by_nonce.get(activation_msg.nonce) - if prep is not None: - layers, fut = prep - if layers == window_layers and fut is not None: - try: - fut.result(timeout=30) - except Exception: - pass + if self._mode == "offload" and window_layers: + prep = self._prepared_by_nonce.get(activation_msg.nonce) + if prep is not None: + layers, fut = prep + if layers == window_layers and fut is not None: + try: + fut.result(timeout=30) + except Exception: + pass - # In sliding_fit with a single resident window, proactively evict only the - # non-needed head from the current resident set before loading new weights. - # This prevents LRU from evicting the useful tail during materialization. - if ( - self._mode == "sliding_fit" - and int(self._resident_windows) <= 1 - and window_layers - ): - try: - resident = [] + # In sliding_fit with a single resident window, proactively evict only the + # non-needed head from the current resident set before loading new weights. + # This prevents LRU from evicting the useful tail during materialization. + if ( + self._mode == "sliding_fit" + and int(self._resident_windows) <= 1 + and window_layers + ): try: - resident = self.weight_cache.get_resident_layers() # type: ignore[union-attr] - except Exception: resident = [] - evicted_cnt = self._delta_swap_eviction( - window_layers, resident, activation_msg, early=True - ) - if evicted_cnt > 0: - did_early_swap = True - except Exception: - pass + try: + resident = self.weight_cache.get_resident_layers() # type: ignore[union-attr] + except Exception: + resident = [] + evicted_cnt = self._delta_swap_eviction( + window_layers, resident, activation_msg, early=True + ) + if evicted_cnt > 0: + did_early_swap = True + except Exception: + pass - # Ensure weights for the window are resident and bind only if arrays changed - # if model fits and we've already bound these layers, skip the scan entirely. - fast_fit = ( - self._mode == "fit" - and len(self._assigned_sorted) <= self.window_size - ) - skip_scan = fast_fit and all( - (wl in self._bound_versions) for wl in window_layers - ) - to_bind: Dict[str, mx.array] = {} - if not skip_scan: - t_w_ready = time.perf_counter() - for wl in window_layers: - weights = self.weight_cache.get_weight(wl) - if weights is None: - logger.error("Failed to load weights for layer %s", wl) - self.input_pool.release(activation_msg.pool_id) - return + # Ensure weights for the window are resident and bind only if arrays changed + # if model fits and we've already bound these layers, skip the scan entirely. + fast_fit = ( + self._mode == "fit" + and len(self._assigned_sorted) <= self.window_size + ) + skip_scan = fast_fit and all( + (wl in self._bound_versions) for wl in window_layers + ) + to_bind: Dict[str, mx.array] = {} + if not skip_scan: + t_w_ready = time.perf_counter() + for wl in window_layers: + weights = self.weight_cache.get_weight(wl) + if weights is None: + logger.error("Failed to load weights for layer %s", wl) + self.input_pool.release(activation_msg.pool_id) + return + try: + # Use identity of first array as a cheap version/fingerprint + first_arr = next(iter(weights.values())) + version = id(first_arr) + except StopIteration: + version = -1 + if self._bound_versions.get(wl) != version: + for k, v in weights.items(): + to_bind[k] = v + self._bound_versions[wl] = version + if self._profile: + t_w_ms = (time.perf_counter() - t_w_ready) * 1000.0 + # Only log when non-trivial or binding happened to reduce overhead/noise + if to_bind or t_w_ms > 0.5: + logger.info( + "[PROFILE][WAIT-WEIGHTS] node=%s nonce=%s layers=%s ms=%.3f", + self.node_id, + activation_msg.nonce, + window_layers, + t_w_ms, + ) + + # Opportunistically schedule prefetch for the next window to overlap with compute try: - # Use identity of first array as a cheap version/fingerprint - first_arr = next(iter(weights.values())) - version = id(first_arr) - except StopIteration: - version = -1 - if self._bound_versions.get(wl) != version: - for k, v in weights.items(): - to_bind[k] = v - self._bound_versions[wl] = version - if self._profile: - t_w_ms = (time.perf_counter() - t_w_ready) * 1000.0 - # Only log when non-trivial or binding happened to reduce overhead/noise - if to_bind or t_w_ms > 0.5: - logger.info( - "[PROFILE][WAIT-WEIGHTS] node=%s nonce=%s layers=%s ms=%.3f", - self.node_id, - activation_msg.nonce, - window_layers, - t_w_ms, + next_win_pre = self._next_local_layers( + (window_layers[-1] if window_layers else (activation_msg.layer_id)), + self.window_size, ) - - # Opportunistically schedule prefetch for the next window to overlap with compute - try: - next_win_pre = self._next_local_layers( - (window_layers[-1] if window_layers else (activation_msg.layer_id)), - self.window_size, - ) - for nl in next_win_pre: - self._prefetch_to_ram(nl) - self._enqueue_weight_prefetch(nl) - except Exception: - pass + for nl in next_win_pre: + self._prefetch_to_ram(nl) + self._enqueue_weight_prefetch(nl) + except Exception: + pass # Execute the window with self.tracer.frame("compute.thread", "execute"): diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index bb26c4d5..c2f10c60 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -109,3 +109,33 @@ class HealthResponse(BaseModel): grpc_port: int = Field(..., description="gRPC server port") http_port: int = Field(..., description="HTTP server port") instance: Optional[str] = Field(default=None, description="Shard name") + + +# Tracer ingest + +class TraceEvent(BaseModel): + type: Literal["B", "E", "I"] = Field(..., description="Event type/phase") + name: str = Field(..., description="Span/mark name") + ts: int = Field(..., description="Timestamp in microseconds") + args: Dict[str, Any] = Field(default_factory=dict) + req_id: Optional[str] = None + pid: Optional[int] = None + tid: Optional[int] = None + + +class TraceIngestBatch(BaseModel): + run_id: str = Field(..., description="Bench run identifier") + node_id: str = Field(..., description="Shard/service identity") + batch_seq: int = Field(..., description="Monotonic batch sequence per node") + events: List[TraceEvent] = Field(default_factory=list) + dropped: Optional[int] = Field(default=0, description="Events dropped on sender") + max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") + last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") + schema_version: int = Field(default=1) + + +class TraceIngestResponse(BaseModel): + ok: bool = True + accepted: int = 0 + batch_seq: Optional[int] = None + message: Optional[str] = None diff --git a/src/dnet/ring/shard/startup.py b/src/dnet/ring/shard/startup.py new file mode 100644 index 00000000..ccfb9f44 --- /dev/null +++ b/src/dnet/ring/shard/startup.py @@ -0,0 +1,563 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, Dict, List, Mapping +import threading +from socket import gethostname +from secrets import token_hex + +import mlx.core as mx +from fastapi import Request +from fastapi.responses import JSONResponse +from grpc import aio as aio_grpc + +from hypercorn import Config +import hypercorn.asyncio as aio_hypercorn +from dnet_p2p.thunderbolt import ThunderboltConnection +from dnet_p2p import ( + DnetDeviceProperties, + discover_thunderbolt_connection, +) + +from ...protos.dnet_ring_pb2_grpc import add_DnetRingServiceServicer_to_server +from .servicer import ShardServicer +from ...utils.logger import logger +from ...utils.serialization import tensor_to_bytes +from ...utils.latency import ( + DeviceLatencyResult, + LatencyMeasurement, + LatencyResults, + calculate_median_latency_seconds, +) +from .models import ( + HealthResponse, + ShardLoadModelRequest, + ShardLoadModelResponse, + ShardProfileRequest, + ShardProfileResponse, + ShardUnloadModelResponse, +) +from ...protos import dnet_ring_pb2 + + +class StartupMixin: + async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()): + self.running = True + try: # Capture the main event loop for cross-thread scheduling + self._loop = asyncio.get_running_loop() + except Exception: + self._loop = None + await self._start_grpc_server() + await self._start_http_server(shutdown_trigger) + await asyncio.sleep(0.2) + + self.background_tasks = [ + asyncio.create_task(self._ingress_worker()), + asyncio.create_task(self._prefetch_worker()), + asyncio.create_task(self._send_worker()), + ] + # Start idle sweeper to close silent streams + try: + if getattr(self, "_streaming_enabled", False) and hasattr( + self, "_stream_sweeper" + ): + self.background_tasks.append( + asyncio.create_task(self._stream_sweeper()) + ) + except Exception: + pass + + self.compute_thread = threading.Thread(target=self._compute_worker, daemon=True) + self.compute_thread.start() + + self._start_discovery() + logger.info( + "Shard node %s started on gRPC port %s HTTP port %s", + self.node_id, + self.grpc_port, + self.http_port, + ) + + def _start_discovery(self) -> None: + """Start mDNS discovery service.""" + hostname = gethostname() + # TODO: optionally take shard name from CLI + instance = f"shard-{token_hex(4)}-{hostname}" + self.discovery.create_instance( + instance, + hostname, + "0.0.0.0", # Binds to all addresses + self.http_port, # HTTP port + self.grpc_port, # gRPC port + is_manager=False, # Shard is never a manager + ) + self.discovery.start() + logger.info( + "Discovery service started for shard node %s with name %s", + self.node_id, + self.discovery.fullname(), + ) + + async def _start_grpc_server(self) -> None: + """Start gRPC server.""" + self.server = aio_grpc.server() + + # Add the ring servicer; shard acts as client for ShardApiService (to API) + servicer = ShardServicer(self) # type: ignore # FIXME: !!! + add_DnetRingServiceServicer_to_server(servicer, self.server) + + listen_addr = f"[::]:{self.grpc_port}" + self.server.add_insecure_port(listen_addr) + await self.server.start() + logger.info( + "Shard node %s gRPC server started on %s", self.node_id, listen_addr + ) + try: + await asyncio.get_running_loop().run_in_executor( + self.executor, self._warmup_serialization + ) + logger.info("Warmup serialization completed") + except Exception as e: + logger.warning("Warmup serialization failed: %s", e) + + def _warmup_serialization(self): + try: + dummy = mx.random.normal((1024, 1024), dtype=mx.float32) + dummy16 = dummy.astype(self._wire_mx_dtype) + _ = tensor_to_bytes(dummy16) + except Exception: + pass + + def _warmup_shard(self): + logger.info( + "[WARMUP] Starting shard warmup with window size %s", self.window_size + ) + batch_size, seq_len = 1, 1 + hidden_size = self.model_metadata.model_config.get("hidden_size", 2560) + x = mx.zeros((batch_size, seq_len, hidden_size), dtype=mx.bfloat16) + start_time = time.perf_counter() + try: + default_n = max(1, int(getattr(self, "_resident_windows", 1))) + except Exception: + default_n = 1 + try: + max_windows = max( + 1, + int( + getattr(self, "config", None).warmup_windows + if getattr(self, "config", None) + else default_n + ), + ) + except Exception: + max_windows = default_n + windows: list[list[int]] = [] + for window_start in range(0, len(self._assigned_sorted), self.window_size): + window_end = min( + window_start + self.window_size, len(self._assigned_sorted) + ) + windows.append(self._assigned_sorted[window_start:window_end]) + for wi, window_layers in enumerate(windows[:max_windows]): + weights_to_bind = {} + for layer_id in window_layers: + weights = self.weight_cache.get_weight(layer_id) + if weights: + for k, v in weights.items(): + weights_to_bind[k] = v + if weights_to_bind: + self.model.load_weights(list(weights_to_bind.items()), strict=False) + try: + for layer_id in window_layers: + x = self.model.apply_single_layer(layer_id, x, cache=None) + _s = mx.sum(x) + mx.eval(_s) + except Exception: + pass + try: + for lid in window_layers: + self.weight_cache.decrease_reference(lid) + except Exception: + pass + if not self._warmup_keep_flag: + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(window_layers) # type: ignore[attr-defined] + except Exception: + pass + try: + self.weight_cache.evict_layers(window_layers) + except Exception: + pass + total_time = (time.perf_counter() - start_time) * 1000 + self._warmup_completed = True + logger.info( + "[WARMUP] Shard warmup completed in %.2fms; windows=%s kept=%s", + total_time, + min(len(windows), max_windows), + int(self._warmup_keep_flag), + ) + + async def _start_http_server(self, shutdown_trigger: Any) -> None: + """Start HTTP server. + + Args: + shutdown_trigger: Shutdown trigger function + """ + await self._setup_routes() + + # Start HTTP server in background + config = Config.from_mapping( + bind=f"0.0.0.0:{self.http_port}", + log_level="info", + log_config=None, + use_reloader=False, + h2c=False, + ) + + # Start the server as a background task + self.http_server = asyncio.create_task( + aio_hypercorn.serve(self.app, config, shutdown_trigger=shutdown_trigger) # type: ignore + ) + logger.info( + "Shard node %s HTTP server started on port %s", self.node_id, self.http_port + ) + + async def _setup_routes(self) -> None: + """Setup HTTP routes.""" + + @self.app.get("/health") + async def health() -> HealthResponse: + try: + instance = self.discovery.instance_name() + except Exception: + instance = None + return HealthResponse( + status="ok", + node_id=self.node_id, + running=self.running, + model_loaded=self._check_model_loaded(), + model_path=self.model_path, + assigned_layers=self.assigned_layers, + queue_size=self.activation_recv_queue.qsize(), + grpc_port=self.grpc_port, + http_port=self.http_port, + instance=instance, + ) + + @self.app.post("/profile") + async def profile(req: ShardProfileRequest) -> ShardProfileResponse: + logger.info("Received /profile request") + try: + # Measure latencies + latency_results = await self._measure_latency_to_devices( + req.devices, req.thunderbolts, req.payload_sizes + ) + + # Profile device using dperf + device_profile = await self._profile_device( + req.repo_id, req.max_batch_exp + ) + + # Overwrite `t_comm` with median latency (subprocess returns a dict) + median_latency = calculate_median_latency_seconds(latency_results) + if median_latency is not None: + device_profile["t_comm"] = float(median_latency) + logger.info( + f"Set t_comm to median latency: {device_profile['t_comm']:.6f}s" + ) + else: + logger.warning( + "No valid latency measurements, keeping default t_comm" + ) + + # Return the dict payload directly + return ShardProfileResponse( + profile=device_profile, + latency=latency_results, + ) + except Exception as e: + logger.error(f"Error in /profile endpoint: {e}") + raise + + @self.app.post("/load_model") + async def load_model_endpoint( + req: ShardLoadModelRequest, + ) -> ShardLoadModelResponse: + """Load model with specified layers.""" + try: + logger.info( + f"HTTP /load_model: model={req.model_path}, layers={req.layers}, " + f"next_node={req.next_node or 'none'}, window_size={req.window_size}, " + f"total_layers={req.total_layers}, api_callback={req.api_callback_address or 'none'}" + ) + result = await self.load_model(req) + return result + + except Exception as e: + logger.error(f"Error in /load_model endpoint: {e}") + return ShardLoadModelResponse( + success=False, + message=f"Error: {str(e)}", + layers_loaded=[], + load_time_ms=0.0, + ) + + @self.app.post("/unload_model") + async def unload_model_endpoint() -> ShardUnloadModelResponse: + """Unload current model.""" + try: + logger.info("HTTP /unload_model") + result = await self.unload_model() + return result + + except Exception as e: + logger.error(f"Error in /unload_model endpoint: {e}") + return ShardUnloadModelResponse( + success=False, + message=f"Error: {str(e)}", + ) + + @self.app.post("/warm") + async def warm(request: Request) -> JSONResponse: + try: + body = await request.json() + start = int(body.get("start", -1)) + window = int(body.get("window", self.window_size)) + if start < 0: + return JSONResponse( + status_code=400, content={"error": "missing/invalid start"} + ) + start_idx = 0 + for i, lyr in enumerate(self._assigned_sorted): + if lyr >= start: + start_idx = i + break + else: + return JSONResponse(content={"prefetched": []}) + window_layers = self._assigned_sorted[ + start_idx : start_idx + max(1, window) + ] + for wl in window_layers: + self._prefetch_to_ram(wl) + self._enqueue_weight_prefetch(wl) + return JSONResponse(content={"prefetched": window_layers}) + except Exception as e: + logger.error("/warm failed: %s", e) + return JSONResponse(status_code=500, content={"error": str(e)}) + + async def _profile_device(self, repo_id: str, max_batch_exp: int) -> dict: + """Profile device using dperf in a subprocess and return a dict. + + Args: + repo_id: Hugging Face repository ID + max_batch_exp: Maximum batch size exponent (2^max_batch_exp) + + Returns: + Device profile information as a plain dict + """ + from ...utils.profile_subproc import profile_device_via_subprocess + + profile_dict = profile_device_via_subprocess( + repo_id, max_batch_exp=max_batch_exp, debug=0 + ) + logger.info("Device profiling completed for node %s", self.node_id) + return profile_dict + + async def _connect_next_node(self) -> bool: + """Connect to next node in ring. + + Returns: + True if connected or no next node, False on failure + """ + if not self.next_node: + logger.info(f"Shard node {self.node_id} is the final shard (no next node)") + return True + + if self.next_node_channel: + logger.debug(f"Shard node {self.node_id} already connected to next node.") + return True + + try: + # use thunderbolt here if available + this_properties = self.discovery.get_own_properties() + thunderbolt_conn = discover_thunderbolt_connection( + this_properties, + self.next_node, + ) + next_ip = ( + thunderbolt_conn.ip_addr + if thunderbolt_conn + else self.next_node.local_ip + ) + address = f"{next_ip}:{self.next_node.shard_port}" + logger.info( + f"Shard node {this_properties.instance} connecting to next node {self.next_node.instance} at {address}" + ) + + self.next_node_channel = aio_grpc.insecure_channel(address) + from ...protos.dnet_ring_pb2_grpc import DnetRingServiceStub + + self.next_node_stub = DnetRingServiceStub(self.next_node_channel) + return True + except Exception as e: + logger.warning( + f"Shard node {self.node_id} failed to connect to next node {address}: {e}" + ) + self.next_node_channel = None + self.next_node_stub = None + return False + + async def _reconnect_next_node(self) -> bool: + try: + if self.next_node_channel: + await self.next_node_channel.close() + except Exception: + pass + self.next_node_channel = None + self.next_node_stub = None + return await self._connect_next_node() + + async def _health_check(self): + try: + health_request = dnet_ring_pb2.HealthRequest(requester_id=str(self.node_id)) + response = await self.next_node_stub.HealthCheck(health_request) # type: ignore + logger.info( + "Shard node %s successfully pinged: %s, healthy: %s", + self.node_id, + response.node_id, + response.healthy, + ) + return True + except Exception as e: + logger.warning( + "Shard node %s failed to ping next node %s: %s", + self.node_id, + self.next_node_address, + e, + ) + return False + + async def _measure_latency_to_devices( + self, + devices: Mapping[str, DnetDeviceProperties], + thunderbolts: Mapping[str, ThunderboltConnection], + payload_sizes: List[int], + ) -> LatencyResults: + """Measure latency to all devices except self. + + Args: + devices: Device information mapping + thunderbolts: Thunderbolt connection information + payload_sizes: List of payload sizes to test + + Returns: + Latency measurement results + """ + latency_results_dict: Dict[str, DeviceLatencyResult] = {} + + for service_name, device_info in devices.items(): + # Skip measuring latency to ourselves + if service_name.startswith(self.discovery.instance_name()): + logger.debug("Skipping latency measurement to self: %s", service_name) + continue + + # Skip measuring latency to API (manager) devices + if device_info.is_manager: + logger.debug( + "Skipping latency measurement to manager/API: %s", service_name + ) + continue + + try: + shard_port = device_info.shard_port + + # Check for Thunderbolt connection + if service_name in thunderbolts: + tb_data = thunderbolts[service_name] + service_ip = tb_data.ip_addr + logger.info( + "Using Thunderbolt for %s at %s, connected to instance %s", + service_name, + service_ip, + tb_data.instance, + ) + else: + # No Thunderbolt, use WiFi + service_ip = device_info.local_ip + + if not shard_port or not service_ip: + logger.warning( + "No shard_port or local_ip for device %s", service_name + ) + continue + + # Connect to target shard's gRPC server + target_address = f"{service_ip}:{shard_port}" + channel = aio_grpc.insecure_channel(target_address) + from ...protos.dnet_ring_pb2_grpc import DnetRingServiceStub + + stub = DnetRingServiceStub(channel) + + # Measure latency for each payload size + latency_measurements: List[LatencyMeasurement] = [] + for payload_size in payload_sizes: + # Create dummy payload + dummy_data = b"x" * payload_size + + start_time = time.perf_counter() + timestamp_ms = int(time.time() * 1000) + + request = dnet_ring_pb2.LatencyMeasureRequest( + requester_id=str(self.node_id), + payload_size=payload_size, + dummy_data=dummy_data, + timestamp=timestamp_ms, + ) + + response = await stub.MeasureLatency(request) # type: ignore + end_time = time.perf_counter() + + if response.success: + latency_ms = (end_time - start_time) * 1000 + latency_measurements.append( + LatencyMeasurement( + payload_size=payload_size, + latency_ms=round(latency_ms, 2), + success=True, + error=None, + ) + ) + else: + latency_measurements.append( + LatencyMeasurement( + payload_size=payload_size, + success=False, + error=response.message, + latency_ms=0, + ) + ) + + # Store results + result = DeviceLatencyResult( + target_node_id=response.node_id if response.success else None, + measurements=latency_measurements, + success=True, + error=None, + ) + latency_results_dict[service_name] = result + + # Close channel + await channel.close() + + except Exception as e: + logger.error("Error measuring latency to %s: %s", service_name, e) + result = DeviceLatencyResult( + target_node_id=None, + success=False, + error=str(e), + measurements=[], + ) + latency_results_dict[service_name] = result + + return LatencyResults(results=latency_results_dict) From 81eeb44ea95684aee1f36b9589896c41b9d1d94d Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 03:16:43 -0700 Subject: [PATCH 091/172] Receive correct data format, dump to temp log file without REPL registered callback --- src/dnet/perf/trace.py | 27 ++++++++++++++++++++------- src/dnet/ring/api/models.py | 13 ++++++------- src/dnet/ring/api/node.py | 6 +++--- src/dnet/ring/shard/models.py | 17 +++++++---------- src/dnet/ring/shard/node.py | 8 +++++--- src/dnet/ring/shard/startup.py | 19 ++++++++----------- 6 files changed, 49 insertions(+), 41 deletions(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 92119d79..78a92a65 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -31,7 +31,7 @@ class TraceConfig: record_pid_tid: bool = True aggregate: bool = False aggregate_url: Optional[str] = None - agg_max_events: int = 1000 + agg_max_events: int = 300 class _NoopFrame: def __enter__(self): @@ -107,17 +107,21 @@ def stop_aggregator(self, *, flush: bool = True, timeout: float = 5.0) -> None: self._agg_thread = None def _agg_exec(self) -> None: - url = self.config.aggregate_url or "" - assert url != "" + assert self.config.aggregate_url != "" + url = "http://" + self.config.aggregate_url + "/trace/ingest" client = httpx.Client(timeout=5.0) try: - while self._agg_enabled and not self._agg_q.empty(): + logger.debug(f"Aggregation worker thread {self._agg_enabled}, {self._agg_q.empty()}") + while self._agg_enabled or not self._agg_q.empty(): try: batch = self._agg_q.get(timeout=0.2) except queue.Empty: continue + logger.info(f"Sending trace buffer to API : {url}") try: - client.post(url, json=batch) + res = client.post(url, json=batch) + if res.status_code != 200: + logger.error(f"Aggregator POST failed {res.status_code}: {res.text}") except Exception: logger.warning(f"Unable to POST trace aggregation data to {url}") finally: @@ -128,6 +132,14 @@ def _agg_exec(self) -> None: except Exception: logger.warining("Unable to close httpx client.") + # We don't have the API addr at init time + def update_api_addr(self, addr): + self.config.aggregate_url = addr + logger.debug(f"Updated API Address: {self.config.aggregate_url}") + + def update_confi(self, config): + pass + def start(self, *, reset: bool = True) -> None: self._active = bool(self.config.enabled) if not self._active: @@ -201,8 +213,9 @@ def _emit(self, ev: Dict[str, Any]) -> None: if self._agg_enabled: if len(self._events) < self._agg_max_events: return - batch = { "run_id": self._agg_run_id, - "node_id": self._agg_node_id or self.config.node_id, + logger.debug(f"Queuing tracer frame batch of {len(self._events)}") + batch = { "run_id": (self._req_id or "NONE"), + "node_id": (self.config.node_id or "UNKNOWN_NODE"), "events": list(self._events)} try: self._agg_q.put_nowait(batch) diff --git a/src/dnet/ring/api/models.py b/src/dnet/ring/api/models.py index fc84162f..c9f4b45b 100644 --- a/src/dnet/ring/api/models.py +++ b/src/dnet/ring/api/models.py @@ -408,9 +408,9 @@ class UnloadModelResponse(BaseModel): # Tracer ingest class TraceEvent(BaseModel): - type: Literal["B", "E", "I"] = Field(..., description="Event type/phase") + type: str = Field(..., description="Event type/phase") name: str = Field(..., description="Span/mark name") - ts: int = Field(..., description="Timestamp in microseconds") + ts: float = Field(..., description="Timestamp in microseconds") args: Dict[str, Any] = Field(default_factory=dict) req_id: Optional[str] = None pid: Optional[int] = None @@ -419,12 +419,11 @@ class TraceEvent(BaseModel): class TraceIngestBatch(BaseModel): run_id: str = Field(..., description="Bench run identifier") node_id: str = Field(..., description="Shard/service identity") - batch_seq: int = Field(..., description="Monotonic batch sequence per node") events: List[TraceEvent] = Field(default_factory=list) - dropped: Optional[int] = Field(default=0, description="Events dropped on sender") - max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") - last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") - schema_version: int = Field(default=1) + #dropped: Optional[int] = Field(default=0, description="Events dropped on sender") + #max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") + #last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") + #schema_version: int = Field(default=1) class TraceIngestResponse(BaseModel): ok: bool = True diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index d1e07f8e..dfb47858 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -376,14 +376,15 @@ async def completions(req: CompletionRequestModel): # type: ignore # Ingest trace buffers and forward to REPL @self.app.post("/trace/ingest") async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: ignore - logger.debug(f"Received trace buffer.") try: if self._trace_ingest_cb is not None: + logger.debug("Forwarding trace batch to REPL.") self._trace_ingest_cb(batch.model_dump()) - return TraceIngestResponse(ok=True, accepted=len(batch.events), batch_seq=batch.batch_seq) + return TraceIngestResponse(ok=True, accepted=len(batch.events)) try: run_dir = Path("logs/trace/ingest") / batch.run_id + logger.debug(f"callback not available. Dumping trace buffer to file {run_dir}") run_dir.mkdir(parents=True, exist_ok=True) fpath = run_dir / f"{batch.node_id}.jsonl" with fpath.open("a", encoding="utf-8") as f: @@ -393,7 +394,6 @@ async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: return TraceIngestResponse( ok=True, accepted=len(batch.events), - batch_seq=batch.batch_seq, message="no aggregator; appended" ) except Exception as e: diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index c2f10c60..a5d495ff 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -53,6 +53,7 @@ class ShardUnloadModelResponse(BaseModel): class ShardProfileRequest(BaseModel): """Request to profile device and measure latencies.""" + api_address: Optional[str] = Field( ..., description="API Address" ) devices: Dict[str, DnetDeviceProperties] = Field( ..., description="Device information mapping" ) @@ -114,28 +115,24 @@ class HealthResponse(BaseModel): # Tracer ingest class TraceEvent(BaseModel): - type: Literal["B", "E", "I"] = Field(..., description="Event type/phase") + type: str = Field(..., description="Event type/phase") name: str = Field(..., description="Span/mark name") - ts: int = Field(..., description="Timestamp in microseconds") + ts: float = Field(..., description="Timestamp in microseconds") args: Dict[str, Any] = Field(default_factory=dict) req_id: Optional[str] = None pid: Optional[int] = None tid: Optional[int] = None - class TraceIngestBatch(BaseModel): run_id: str = Field(..., description="Bench run identifier") node_id: str = Field(..., description="Shard/service identity") - batch_seq: int = Field(..., description="Monotonic batch sequence per node") events: List[TraceEvent] = Field(default_factory=list) - dropped: Optional[int] = Field(default=0, description="Events dropped on sender") - max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") - last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") - schema_version: int = Field(default=1) - + #dropped: Optional[int] = Field(default=0, description="Events dropped on sender") + #max_ts: Optional[int] = Field(default=None, description="Max ts_us in this batch") + #last: Optional[bool] = Field(default=False, description="Sender indicates end-of-run") + #schema_version: int = Field(default=1) class TraceIngestResponse(BaseModel): ok: bool = True accepted: int = 0 - batch_seq: Optional[int] = None message: Optional[str] = None diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index ffb34ac3..3583fea1 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -63,7 +63,7 @@ from .comms import CommsMixin from ..weight_cache import WeightCache -from dnet.perf.trace import TraceConfig, Tracer +from dnet.perf import TraceConfig, Tracer class RingShardNode(ComputeMixin, PrefetchMixin, CommsMixin): @@ -206,12 +206,14 @@ def __init__( # Debug tracing cfg = TraceConfig( file="./trace.json", - streaming=True, + streaming=False, include_prefixes = ("src/dnet/"), include_c_calls = False, budget = 10000, enabled = True, record_pid_tid = True, + aggregate=False, + aggregate_url=None, # FIXME: This is set when we get a /profile req ) self.tracer = Tracer(cfg) self.tracer.start() @@ -234,7 +236,7 @@ def __init__( ) async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse: - """Load model with specified layers. """ + """Load model with specified layers""" try: # Check if already loaded with same configuration if ( self.model is not None diff --git a/src/dnet/ring/shard/startup.py b/src/dnet/ring/shard/startup.py index ccfb9f44..7aaf8e73 100644 --- a/src/dnet/ring/shard/startup.py +++ b/src/dnet/ring/shard/startup.py @@ -249,15 +249,14 @@ async def health() -> HealthResponse: async def profile(req: ShardProfileRequest) -> ShardProfileResponse: logger.info("Received /profile request") try: - # Measure latencies - latency_results = await self._measure_latency_to_devices( - req.devices, req.thunderbolts, req.payload_sizes - ) + # Since this is the first request we get from API grab the address and store it + # TODO: Have a handshake request before this one where we share addresses and state + self.api_address = req.api_address + self.tracer.update_api_addr(self.api_address) + self.tracer.start_aggregator() - # Profile device using dperf - device_profile = await self._profile_device( - req.repo_id, req.max_batch_exp - ) + latency_results = await self._measure_latency_to_devices( req.devices, req.thunderbolts, req.payload_sizes) + device_profile = await self._profile_device( req.repo_id, req.max_batch_exp) # Overwrite `t_comm` with median latency (subprocess returns a dict) median_latency = calculate_median_latency_seconds(latency_results) @@ -267,9 +266,7 @@ async def profile(req: ShardProfileRequest) -> ShardProfileResponse: f"Set t_comm to median latency: {device_profile['t_comm']:.6f}s" ) else: - logger.warning( - "No valid latency measurements, keeping default t_comm" - ) + logger.warning( "No valid latency measurements, keeping default t_comm") # Return the dict payload directly return ShardProfileResponse( From 47e7860b6ed9900ef4baa265322843cdfbf09dd4 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 04:08:11 -0700 Subject: [PATCH 092/172] register repl callback --- src/dnet/ring/api/node.py | 1 + src/repl.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index dfb47858..a132a5b8 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -1733,4 +1733,5 @@ async def shutdown(self) -> None: # REPL helper to install a trace ingestion callback def set_trace_ingest_callback(self, cb: Optional[Callable[[Dict[str, Any]], None]]) -> None: + logger.debug(f"Registered tracer ingest callback.") self._trace_ingest_cb = cb diff --git a/src/repl.py b/src/repl.py index 0b95a19f..f6483c05 100644 --- a/src/repl.py +++ b/src/repl.py @@ -120,6 +120,7 @@ def do_api(self, cmd: List[str]) -> None: http_port or self.state.api_http_port, grpc_port or self.state.api_grpc_port ) + self.api_call("set_trace_ingest_callback", self.__trace_cb, timeout=2.0) elif cmd[1] == "stop": self.stop_api() elif cmd[1] == "status": @@ -331,6 +332,11 @@ def handle_start_worker(self): # ===== Handle API server + # Tracer frames ingest callback + def __trace_cb(self, data): + dprint(str(data)) + pass + async def _api_main(self) -> None: # main thread loop self._api_loop = asyncio.get_running_loop() self._api_shutdown_e = asyncio.Event() @@ -374,7 +380,7 @@ def start_api(self, http_port: int=8080, grpc_port: int=50500, timeout=10): if not self._api_ready.wait(timeout): raise RuntimeError("API Server Timeout.") if self._api_exc is not None: - raise RuntimeError(f"API Server failed to start: {e}") + raise RuntimeError(f"API Server failed to start") def stop_api(self, timeout: float = 5.0) -> None: if not self._api_thread: return @@ -459,6 +465,7 @@ def _print_nodes_table(self, rows: List[Any]) -> None: "head": "head" if r["head"] else "no", } sys.stdout.write(" ".join(str(vals[h]).ljust(w[h]) for h in headers)) + sys.stdout.write("\n") sys.stdout.write("\n\n") sys.stdout.flush() From b74bb278897207197d1260d598cb5acde5f6f8eb Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 14:09:09 -0700 Subject: [PATCH 093/172] add llama3 model script --- src/dnet/ring/model/__init__.py | 5 + src/dnet/ring/model/llama3.py | 181 ++++++++++++++++++++++++++++++++ 2 files changed, 186 insertions(+) create mode 100644 src/dnet/ring/model/llama3.py diff --git a/src/dnet/ring/model/__init__.py b/src/dnet/ring/model/__init__.py index 4f432a7e..57ccdc32 100644 --- a/src/dnet/ring/model/__init__.py +++ b/src/dnet/ring/model/__init__.py @@ -8,6 +8,11 @@ from .llama import LlamaRingModel from .gpt_oss import GptOssRingModel from .qwen3 import Qwen3RingModel +from .llama3 import Llama3RingModel +from .llama4 import Llama4RingModel +from .gpt_oss import GptOssRingModel +from .glm import GLMRingModel +from .glm4 import GLM4RingModel def get_ring_model( diff --git a/src/dnet/ring/model/llama3.py b/src/dnet/ring/model/llama3.py new file mode 100644 index 00000000..0583ac9c --- /dev/null +++ b/src/dnet/ring/model/llama3.py @@ -0,0 +1,181 @@ +from typing import Any, Dict, List, Optional, Tuple + +import mlx.nn as nn +import mlx.core as mx +from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.llama import ModelArgs, TransformerBlock + +from .base import BaseRingModel + +import logging +logger = logging.getLogger(__name__) + + +class Llama3RingModel(BaseRingModel): + model_type = "llama" + + def __init__( + self, + model_config: Any, + assigned_layers: Optional[List[int]] = [], + is_api_layer: bool = False + ): + super().__init__() + + if is_api_layer and assigned_layers: + raise RuntimeError(f"API Service doesn't handle layers") + + self.config = ModelArgs.from_dict(model_config) + self.is_api_layer = is_api_layer + + self._converted_to_quantized = False + self.runtime_cache: Optional[List[Any]] = None + + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + self.norm = nn.RMSNorm(self.config.hidden_size, self.config.rms_norm_eps) + self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) + + self.layers: List[nn.Module] = [] + self.abs_to_local: Dict[int, int] = {} + + for i, l in enumerate(sorted(assigned_layers or [])): + self.layers.append(TransformerBlock(self.config)) + self.abs_to_local[l] = i + logger.debug(f"Created {len(self.layers)} Transformer layers") + logger.debug(f"abs_to_local mapping: {self.abs_to_local}") + + @property + def decoding_layers(self): + return self.layers + + @property + def head_dim(self) -> Tuple[int, int]: + return self.config.head_dim + + @property + def n_kv_heads(self) -> int: + return self.config.num_key_value_heads + + @property + def num_layers(self) -> int: + return len(self.layers) + + def set_runtime_cache(self, cache: Optional[List[Any]]) -> None: + self._runtime_cache = cache + + def class_predicate(p, m): + return hasattr(m, "to_quantized") + + def embed(self, x: mx.array): + return self.embed_tokens(x) if self.is_api_layer else x + + def normalize(self, x: mx.array): + return self.norm(x) if self.is_api_layer else x + + def lm_project(self, x: mx.array): + return self.lm_head(x) if self.is_api_layer else x + + def quantize_layers(self): + self.quantization = None + if hasattr(self.config, "quantization"): + self.quantization = getattr(self.config, "quantization") + elif hasattr(self.config, "quantization_config"): + self.quantization = getattr(self.config, "quantization_config") + + if self.quantization is not None: + bits = int(self.quantization.get("bits", 8)) + group = int(self.quantization.get("group_size", 64)) + if self.is_api_layer: + try: + from mlx.nn.layers.quantized import QuantizedEmbedding + self.embed_tokens = QuantizedEmbedding(self.config.vocab_size, + self.config.hidden_size, + group_size=group, bits=bits) + + logger.debug(f"API Service initialized to QuantizedEmbedding:" + f"{self.config.vocab_size}, hidden={self.config.hidden_size}" + f"group_size={group}, bits={bits}") + except Exception as e: + logger.warning(f"Unable to initialize QuantizedEmbedding: {e}") + + else: + try: + nn.quantize(self, bits=bits, group_size=group, class_predicate=Llama3RingModel.class_predicate) + logger.debug(f"Quantized the model: bits={bits}, group_size={group}") + self._converted_to_quantized = True + except: + self._converted_to_quantized = False + + def forward( + self, + x: mx.array, + cache: Optional[List[Any]] = None + ): + mask = create_attention_mask(x, cache) + if cache is None: + cache = [None] * len(self.layers) + + for i, l in enumerate(self.layers): + x = l(x, mask, cache[i] if i < len(cache) else None) + + return x + + # TODO: Original implementation is slidin window. Bench to see if it's faster or just do sparse + def apply_single_layer( + self, + layer_idx: int, + x: mx.array, + cache: Optional[List[Any]] = None + ): + if layer_idx not in self.abs_to_local: + raise RuntimeError(f"Attempted execution of foreign layer {layer_idx}") + mask = create_attention_mask(x, cache) + local_idx = self.abs_to_local[layer_idx] + logger.debug(f"apply_single_layer: layer:{layer_idx}, local_idx:{local_idx}, input_shape:{x.shape}") + + layer = self.layers[local_idx] + ret = self.layers[local_idx](x, mask, cache[local_idx] if local_idx < len(cache) else None) + logger.debug(f"Executed layer:{layer_idx} with output shape: {ret.shape}") + return ret + + def load_weights(self, weights, strict=False): + weight_keys = [k for k, _ in weights] + has_scales = any(".scales" in k for k in weight_keys) + has_biases = any(".biases" in k for k in weight_keys) + + if has_scales and has_biases: + if not self._converted_to_quantized: + self.quantize_layers() + + shard_weights = {} + for k, v in weights: + if k.startswith("model.layers.") or k.startswith("layers."): + p = k.split(".") + idx_pos = 2 if p[0] == "model" else 1 + try: + idx = int(p[idx_pos]) + except Exception as e: + logger.warning(f"Unable to read weight positions: {e}") + continue + if idx not in self.abs_to_local: + continue + local_idx = self.abs_to_local[idx] + p[idx_pos] = str(local_idx) + if p[0] == "model": + p = p[1:] + new_key = ".".join(p) + logger.debug(f"Mapping weight {k} -> {new_key}") + shard_weights[new_key] = v + + elif self.is_api_layer: + if (k.startswith("embed_tokens") or k.startswith("lm_head") or k.startswith("norm")): + shard_weights[k] = v + + if shard_weights: + try: + super().load_weights(list(shard_weights.items()), strict=strict) + logger.debug(f"Loaded {len(shard_weights.keys())} weights into model") + except Exception as e: + logger.error(f"Failed to load weights: {e}") + logger.error(f"Weight keys: {list(shard_weights.keys())}") + raise From e6242d844c66b15eb79255749bf0be9d21cdcd0a Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 14:10:32 -0700 Subject: [PATCH 094/172] comment out unavailable models --- src/dnet/ring/model/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dnet/ring/model/__init__.py b/src/dnet/ring/model/__init__.py index 57ccdc32..93d31dd8 100644 --- a/src/dnet/ring/model/__init__.py +++ b/src/dnet/ring/model/__init__.py @@ -9,10 +9,10 @@ from .gpt_oss import GptOssRingModel from .qwen3 import Qwen3RingModel from .llama3 import Llama3RingModel -from .llama4 import Llama4RingModel -from .gpt_oss import GptOssRingModel -from .glm import GLMRingModel -from .glm4 import GLM4RingModel +#from .llama4 import Llama4RingModel +#from .gpt_oss import GptOssRingModel +#from .glm import GLMRingModel +#from .glm4 import GLM4RingModel def get_ring_model( From cbf479d6bf78fd65fce0582663a49b33caf10fea Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 15:53:07 -0700 Subject: [PATCH 095/172] embed correctly --- src/dnet/ring/model/llama3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dnet/ring/model/llama3.py b/src/dnet/ring/model/llama3.py index 0583ac9c..cdec5669 100644 --- a/src/dnet/ring/model/llama3.py +++ b/src/dnet/ring/model/llama3.py @@ -67,7 +67,7 @@ def class_predicate(p, m): return hasattr(m, "to_quantized") def embed(self, x: mx.array): - return self.embed_tokens(x) if self.is_api_layer else x + return self.embed_tokens(x) def normalize(self, x: mx.array): return self.norm(x) if self.is_api_layer else x From 104365519011aec74f31d797d3da0d75d6bbbf3c Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 20 Oct 2025 16:09:09 -0700 Subject: [PATCH 096/172] don't filter weights based on is_api_layer --- src/dnet/ring/model/llama3.py | 45 ++++++++++++++++------------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/src/dnet/ring/model/llama3.py b/src/dnet/ring/model/llama3.py index cdec5669..231f02b5 100644 --- a/src/dnet/ring/model/llama3.py +++ b/src/dnet/ring/model/llama3.py @@ -70,10 +70,10 @@ def embed(self, x: mx.array): return self.embed_tokens(x) def normalize(self, x: mx.array): - return self.norm(x) if self.is_api_layer else x + return self.norm(x) def lm_project(self, x: mx.array): - return self.lm_head(x) if self.is_api_layer else x + return self.lm_head(x) def quantize_layers(self): self.quantization = None @@ -85,26 +85,24 @@ def quantize_layers(self): if self.quantization is not None: bits = int(self.quantization.get("bits", 8)) group = int(self.quantization.get("group_size", 64)) - if self.is_api_layer: - try: - from mlx.nn.layers.quantized import QuantizedEmbedding - self.embed_tokens = QuantizedEmbedding(self.config.vocab_size, - self.config.hidden_size, - group_size=group, bits=bits) - - logger.debug(f"API Service initialized to QuantizedEmbedding:" - f"{self.config.vocab_size}, hidden={self.config.hidden_size}" - f"group_size={group}, bits={bits}") - except Exception as e: - logger.warning(f"Unable to initialize QuantizedEmbedding: {e}") + try: + from mlx.nn.layers.quantized import QuantizedEmbedding + self.embed_tokens = QuantizedEmbedding(self.config.vocab_size, + self.config.hidden_size, + group_size=group, bits=bits) + + logger.debug(f"API Service initialized to QuantizedEmbedding:" + f"{self.config.vocab_size}, hidden={self.config.hidden_size}" + f"group_size={group}, bits={bits}") + except Exception as e: + logger.warning(f"Unable to initialize QuantizedEmbedding: {e}") - else: - try: - nn.quantize(self, bits=bits, group_size=group, class_predicate=Llama3RingModel.class_predicate) - logger.debug(f"Quantized the model: bits={bits}, group_size={group}") - self._converted_to_quantized = True - except: - self._converted_to_quantized = False + try: + nn.quantize(self, bits=bits, group_size=group, class_predicate=Llama3RingModel.class_predicate) + logger.debug(f"Quantized the model: bits={bits}, group_size={group}") + self._converted_to_quantized = True + except: + self._converted_to_quantized = False def forward( self, @@ -167,9 +165,8 @@ def load_weights(self, weights, strict=False): logger.debug(f"Mapping weight {k} -> {new_key}") shard_weights[new_key] = v - elif self.is_api_layer: - if (k.startswith("embed_tokens") or k.startswith("lm_head") or k.startswith("norm")): - shard_weights[k] = v + elif (k.startswith("embed_tokens") or k.startswith("lm_head") or k.startswith("norm")): + shard_weights[k] = v if shard_weights: try: From 670deede2bf6d3382d8ab3c55589f88ee3298fa5 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 03:23:41 -0700 Subject: [PATCH 097/172] tracer config request, aggregate and separate api logger for less repl noise --- src/dnet/perf/trace.py | 9 +- src/dnet/perf/utils/__init__.py | 3 +- src/dnet/perf/utils/aggregator.py | 1 + src/dnet/ring/api/api_logging.py | 51 +++++++++ src/dnet/ring/api/node.py | 58 ++++++++++- src/dnet/ring/api/servicer.py | 3 +- src/dnet/ring/api/utils.py | 3 +- src/dnet/ring/shard/models.py | 20 +++- src/dnet/ring/shard/startup.py | 37 +++++-- src/repl.py | 167 ++++++++++++++++++++++++++---- 10 files changed, 310 insertions(+), 42 deletions(-) create mode 100644 src/dnet/ring/api/api_logging.py diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 78a92a65..e4a92cfb 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -108,7 +108,6 @@ def stop_aggregator(self, *, flush: bool = True, timeout: float = 5.0) -> None: def _agg_exec(self) -> None: assert self.config.aggregate_url != "" - url = "http://" + self.config.aggregate_url + "/trace/ingest" client = httpx.Client(timeout=5.0) try: logger.debug(f"Aggregation worker thread {self._agg_enabled}, {self._agg_q.empty()}") @@ -117,13 +116,13 @@ def _agg_exec(self) -> None: batch = self._agg_q.get(timeout=0.2) except queue.Empty: continue - logger.info(f"Sending trace buffer to API : {url}") + logger.info(f"Sending trace buffer to API : {self.config.aggregate_url}") try: - res = client.post(url, json=batch) + res = client.post(self.config.aggregate_url, json=batch) if res.status_code != 200: logger.error(f"Aggregator POST failed {res.status_code}: {res.text}") - except Exception: - logger.warning(f"Unable to POST trace aggregation data to {url}") + except Exception as e: + logger.warning(f"Unable to POST trace aggregation data to {self.config.aggregate_url}: {e}") finally: self._agg_q.task_done() finally: diff --git a/src/dnet/perf/utils/__init__.py b/src/dnet/perf/utils/__init__.py index 9c310fd5..7228627d 100644 --- a/src/dnet/perf/utils/__init__.py +++ b/src/dnet/perf/utils/__init__.py @@ -1,2 +1 @@ - -from aggregator import TraceAggregator +from .aggregator import TraceAggregator diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index f9bbdef6..2736a25b 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -104,6 +104,7 @@ def __init__(self) -> None: def enqueue(self, batch: Dict[str, Any]) -> None: run_id = batch.get("run_id") node_id = batch.get("node_id") + logger.debug(f"Enquing trace buffer from {run_id}, {node_id}") if not run_id or not node_id: return events = batch.get("events") or [] diff --git a/src/dnet/ring/api/api_logging.py b/src/dnet/ring/api/api_logging.py new file mode 100644 index 00000000..d90c526c --- /dev/null +++ b/src/dnet/ring/api/api_logging.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import os +import logging +from logging.handlers import RotatingFileHandler +from pathlib import Path + + +_CONFIGURED_FLAG = "_dnet_api_logger_configured" + + +def get_api_logger() -> logging.Logger: + """Return a process‑local logger for the API server. + + - Does not propagate to the root logger (so it won't spam the REPL TTY). + - Writes to logs/api.log with rotation. + - Level is controlled by DNET_API_LOG (default: INFO). + """ + log = logging.getLogger("dnet.api") + if getattr(log, _CONFIGURED_FLAG, False): + return log + + # Level from env, fallback INFO + level_name = (os.getenv("DNET_API_LOG", "INFO") or "INFO").strip().upper() + level = getattr(logging, level_name, logging.INFO) + log.setLevel(level) + + # Do not bubble to root (console) + log.propagate = False + + # Ensure logs directory + try: + Path("logs").mkdir(parents=True, exist_ok=True) + except Exception: + pass + + # Attach a rotating file handler + try: + fh = RotatingFileHandler("logs/api.log", maxBytes=10_000_000, backupCount=5) + fmt = logging.Formatter( + "%(asctime)s %(levelname)s [%(threadName)s] %(name)s: %(message)s" + ) + fh.setFormatter(fmt) + log.addHandler(fh) + except Exception: + # As a last resort, attach a NullHandler to avoid 'No handler' warnings + log.addHandler(logging.NullHandler()) + + setattr(log, _CONFIGURED_FLAG, True) + return log + diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index a132a5b8..18cf4059 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -39,6 +39,7 @@ add_ShardApiServiceServicer_to_server, ) +from .api_logging import get_api_logger from ...utils.logger import logger from ...utils.banner import print_startup_banner from ...utils.latency import LatencyResults, calculate_median_latency_seconds @@ -81,6 +82,8 @@ ShardProfileResponse, TraceIngestBatch, TraceIngestResponse, + TraceConfigRequest, + TraceConfigResponse, ) from ..data_types import StopCondition from .servicer import ShardApiServicer @@ -104,6 +107,9 @@ async def azip(*async_iterables): break +logger = get_api_logger() + + class RingApiNode: """API node for distributed inference ring with dynamic topology.""" @@ -158,6 +164,24 @@ async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()) -> None: shutdown_trigger: Shutdown trigger function """ self.running = True + # Reduce third‑party library noise in this process (keeps REPL TTY clean) + try: + import logging as _logging + for name in ( + "grpc", + "grpc._cython", + "asyncio", + "hpack", + "h2", + "hypercorn", + "hypercorn.error", + "hypercorn.access", + ): + lg = _logging.getLogger(name) + lg.setLevel(_logging.CRITICAL) + lg.propagate = False + except Exception: + pass await self._start_grpc_server() await self._start_discovery() @@ -206,7 +230,7 @@ async def _start_http_server(self, shutdown_trigger: Any) -> None: config = Config.from_mapping( bind=f"0.0.0.0:{self.http_port}", - log_level="info", + log_level="error", # keep HTTP server quiet on console log_config=None, use_reloader=False, h2c=True, @@ -400,6 +424,38 @@ async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: logger.warning(f"Unable to ingest trace buffer: {e}") return TraceIngestResponse(ok=False, accepted=0, message=str(e)) + async def _forward_trace_config(self, cfg: Any) -> bool: + shards = self._get_shards_from_discovery() + this = self.discovery.get_own_properties() + api_endpoint = f"http://{this.local_ip}:{this.server_port}/trace/ingest" + payload = TraceConfigRequest( + file=cfg.file, + streaming=cfg.streaming, + include_prefixes=cfg.include_prefixes, + include_c_calls=cfg.include_c_calls, + budget=cfg.budget, + enabled=cfg.enabled, + node_id=None, + record_pid_tid=cfg.record_pid_tid, + aggregate=cfg.aggregate, + aggregate_url=api_endpoint, + agg_max_events=cfg.agg_max_events + ) + + async with httpx.AsyncClient(timeout=5.0) as client: + for name, props in shards.items(): + url = f"http://{props.local_ip}:{props.server_port}/trace" + logger.debug(f"Forwarding trace config to {url}") + payload.node_id = name + try: + res = await client.post(url, json=dict(payload)) + if res.status_code != 200: + logger.warning(f"Failed to POST tracer config to node {name}.") + except Exception as e: + logger.warning(f"Failed to POST tracer config: {e}") + return False + return True + async def _handle_prepare_topology( self, req: PrepareTopologyRequest diff --git a/src/dnet/ring/api/servicer.py b/src/dnet/ring/api/servicer.py index 31b53d35..a1024822 100644 --- a/src/dnet/ring/api/servicer.py +++ b/src/dnet/ring/api/servicer.py @@ -6,7 +6,8 @@ from ...protos import shard_api_comm_pb2 as pb2 from ...protos import shard_api_comm_pb2_grpc as pb2_grpc -from ...utils.logger import logger +from .api_logging import get_api_logger +logger = get_api_logger() if TYPE_CHECKING: pass diff --git a/src/dnet/ring/api/utils.py b/src/dnet/ring/api/utils.py index d59387c1..002a3ea9 100644 --- a/src/dnet/ring/api/utils.py +++ b/src/dnet/ring/api/utils.py @@ -16,7 +16,8 @@ from .models import ChatParams - +from .api_logging import get_api_logger +logger = get_api_logger() def create_generate_step_for_ring_with_grpc( stub: DnetRingServiceStub, diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index a5d495ff..6a5efcc5 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -1,6 +1,6 @@ """Shard models for dnet ring topology endpoints.""" -from typing import Any, Dict, List, Optional, Literal +from typing import Any, Dict, List, Optional, Literal, Tuple from pydantic import BaseModel, Field from dnet_p2p import DnetDeviceProperties, ThunderboltConnection @@ -112,8 +112,24 @@ class HealthResponse(BaseModel): instance: Optional[str] = Field(default=None, description="Shard name") -# Tracer ingest +# Tracer +class TraceConfigRequest(BaseModel): + file: str = Field(..., description="File for trace streaming") + streaming: bool = Field(..., description="Toggle for trace streaming to file") + include_prefixes: List[str] = Field(default=("src/dnet/"), description="") + include_c_calls: bool = Field(default=False, description="") + budget: int = Field(default=0, description="Max amount of traces in memory") + enabled: bool = Field(default=False, description="Start or stop tracing") + node_id: Optional[str] = Field(default="NONE", description="") + record_pid_tid: bool = Field(default=True, descriptino="") + aggregate: bool = Field(default=True, description="") + aggregate_url: Optional[str] = Field(default=None, description="") + agg_max_events: int = Field(default=300, description="Threshold for sending frames to API") + +class TraceConfigResponse(BaseModel): + ok: bool = True + class TraceEvent(BaseModel): type: str = Field(..., description="Event type/phase") name: str = Field(..., description="Span/mark name") diff --git a/src/dnet/ring/shard/startup.py b/src/dnet/ring/shard/startup.py index 7aaf8e73..6790be64 100644 --- a/src/dnet/ring/shard/startup.py +++ b/src/dnet/ring/shard/startup.py @@ -20,6 +20,8 @@ discover_thunderbolt_connection, ) +from dnet.perf.trace import TraceConfig + from ...protos.dnet_ring_pb2_grpc import add_DnetRingServiceServicer_to_server from .servicer import ShardServicer from ...utils.logger import logger @@ -37,6 +39,8 @@ ShardProfileRequest, ShardProfileResponse, ShardUnloadModelResponse, + TraceConfigRequest, + TraceConfigResponse, ) from ...protos import dnet_ring_pb2 @@ -247,14 +251,7 @@ async def health() -> HealthResponse: @self.app.post("/profile") async def profile(req: ShardProfileRequest) -> ShardProfileResponse: - logger.info("Received /profile request") try: - # Since this is the first request we get from API grab the address and store it - # TODO: Have a handshake request before this one where we share addresses and state - self.api_address = req.api_address - self.tracer.update_api_addr(self.api_address) - self.tracer.start_aggregator() - latency_results = await self._measure_latency_to_devices( req.devices, req.thunderbolts, req.payload_sizes) device_profile = await self._profile_device( req.repo_id, req.max_batch_exp) @@ -277,6 +274,32 @@ async def profile(req: ShardProfileRequest) -> ShardProfileResponse: logger.error(f"Error in /profile endpoint: {e}") raise + @self.app.post("/trace") + async def setup_trace(req: TraceConfigRequest) -> TraceConfigResponse: + try: + cfg = TraceConfig( + file=req.file, + streaming=req.streaming, + include_prefixes=req.include_prefixes, + include_c_calls=req.include_c_calls, + budget=req.budget, + enabled=req.enabled, + node_id=req.node_id, + record_pid_tid=req.record_pid_tid, + aggregate=req.aggregate, + aggregate_url=req.aggregate_url, + agg_max_events=req.agg_max_events, + ) + self.tracer.config = cfg + logger.info("Updated tracer config.") + self.api_address = cfg.aggregate_url + self.tracer.start_aggregator() + logger.debug(cfg) + return TraceConfigResponse(ok=True) + except Exception as e: + logger.error(f"Unable to setup tracing on shard: {e}") + return TraceConfigResponse(ok=False) + @self.app.post("/load_model") async def load_model_endpoint( req: ShardLoadModelRequest, diff --git a/src/repl.py b/src/repl.py index f6483c05..592424a6 100644 --- a/src/repl.py +++ b/src/repl.py @@ -1,11 +1,13 @@ import os import sys +import logging import cmd +import time import argparse import subprocess from dataclasses import dataclass -from typing import Optional, List, Any +from typing import Optional, List, Any, Dict import asyncio import inspect @@ -15,7 +17,8 @@ from dnet.ring.api import RingApiNode from dnet.ring.shard import RingShardNode from dnet.utils.network import NodeAddress -from dnet.utils.logger import logger +#from dnet.utils.logger import logger +from dnet.ring.api.api_logging import get_api_logger from dnet.utils.model import ( ModelMetadata, get_model_metadata, @@ -23,6 +26,12 @@ get_safetensor_details, ) +logger = get_api_logger() + +from dnet.perf.trace import TraceConfig, Tracer +from dnet.perf.utils import TraceAggregator +#from dnet.perf.bench import + # Handle restricted repos from importlib import import_module import huggingface_hub as hb @@ -57,22 +66,39 @@ class REPL(cmd.Cmd): def __init__(self, model="NULL", nodes=1): assert nodes >= 1 and nodes < 10, "Invalid number of local nodes. Must be 0 < num < 10." - super().__init__() + + # State self.state = REPLState() self.state.model = model self.state.running_port += 2 self.state.num_local_nodes = nodes + # API Thread self._node: Optional[RingApiNode] = None self._api_thread: Optional[threading.Thread] = None self._api_ready = threading.Event() self._api_running = threading.Event() + self._api_searching = threading.Event() # Track mDNS searching self._api_loop: Optional[asyncio.AbstractEventLoop] = None self._api_shutdown_e: Optional[asyncio.Event] = None self._api_exc: Optional[BaseException] = None - self._api_searching = threading.Event() # Track mDNS searching + # Tracing + self._trace_cfg = TraceConfig( + enabled=False, + streaming=False, + budget=3000, + aggregate=True, + agg_max_events=50, + ) + + self._tracing = threading.Event() + self._tracer = None + self._trace_file = f"trace-{time.strftime("%Y%m%d-%H%M%S")}" + self._trace_cursor = 0 # keep track of registered events in buffer + self._trace_agg = TraceAggregator() + def loop(self): # Main tty loop sys.stdout.write(self.WELCOME) @@ -82,7 +108,7 @@ def loop(self): # Main tty loop if cmd == "": self.print_state() - elif cmd in [".exit", "exit", "quit", "q"]: + elif cmd in [".exit", "exit", "quit"]: self.handle_terminate_signal() elif cmd in [".help", "help", "h"]: self.print_help() @@ -96,10 +122,13 @@ def loop(self): # Main tty loop elif cmd.startswith("nodes"): self.print_mdns_nodes() continue + elif cmd.startswith(("trace", ".trace")): + self.do_trace(cmd.split(" ")) + continue elif cmd.startswith(("topo", ".topo")): self.do_topo(cmd.split(" ")) continue - elif cmd.startswith((".model", "model", "m")): + elif cmd.startswith((".model", "model", "m ")): cmd.split(" ") path = self._handle_model_pull(cmd[1]) if path: @@ -121,6 +150,7 @@ def do_api(self, cmd: List[str]) -> None: grpc_port or self.state.api_grpc_port ) self.api_call("set_trace_ingest_callback", self.__trace_cb, timeout=2.0) + elif cmd[1] == "stop": self.stop_api() elif cmd[1] == "status": @@ -182,12 +212,12 @@ def _print_hf(cmd, desc, examples=[""]): ["Examples > model meta-llama/Meta-Llama-3-8B"]) _print_hf("nodes list ", "List mDNS discovered nodes.") _print_hf("log [LEVEL]", "Set the logging level.") - dprint("\033[1m\n API Server Control:\n\033[0m") + dprint("\033[1m\n Controlling the API Server:\n\033[0m") _print_hf("api start [http_port=8080] [grpc_port=50500]", "Start the API server in a separate thread. Use provided ports if given.") _print_hf("api stop ", "Signal clean shutdown of the API server.") _print_hf("api status ", "Prints the status of the API server.") _print_hf("api log ", "Print latest logs to the current terminal.") - dprint("\033[1m\n Building a topology:\n\033[0m") + dprint("\033[1m\n Topology construction:\n\033[0m") _print_hf("search ", "Returns the current state of mDNS search.") _print_hf("search [on/off] ", "Toggle mDNS search across the local network.") _print_hf("nodes list ", "List all nodes in the current topology (including local ones).") @@ -195,15 +225,19 @@ def _print_hf(cmd, desc, examples=[""]): _print_hf("nodes ", "List mDNS discovered nodes.") _print_hf("topo [AUTO/SETUP]", "Toggle between automatic and manual topology creation.") _print_hf("topo add [NODE]", "Add [NODE] to the topology.") - _print_hf("topo remove [NODE]", "Add [NODE] to the topology.") - sys.stdout.write("\033[1m\n Building a schedule:\n\033[0m") - _print_hf("sched create", "Automatic search for best schedule given the active topology and the loaded model.") - _print_hf("sched assign [LAYER] [NODE]", "Assign the layer with index [LAYER] to [NODE].", - ["Example > sched assign 10 benny_234"]) - _print_hf("schedule assign [START-END] [NODE]", "Assign the layer range between [START] and [END] to [NODE].", - ["Example > sched assign 0-12 benny_234"]) - sys.stdout.write("\033[1m\n Benchmarking and profiling:\n\033[0m") - _print_hf("profile [REPO]", "Estimate the total FLOPS of the model from [REPO]") + _print_hf("topo remove [NODE]", "Remove [NODE] from the topology.") + sys.stdout.write("\033[1m\n Scheduling:\n\033[0m") + _print_hf("sched auto ", "Automatic search for best schedule given the active topology and the loaded model.") + _print_hf("sched assign [INDEX] [NODE]", "Assign the layer range between [START] and [END] to [NODE].", + ["Example: > sched assign 10 benny_234", + " > sched assign 0-12 benny_234"]) + sys.stdout.write("\033[1m\n Benchmarking, Tracing and Profiling:\n\033[0m") + _print_hf("trace [ON|OFF][PATH][SYSTEM] ", "Trace [SYSTEM] and output to file at [PATH].") + _print_hf("trace status ", "See status of the trace, eg. number of frames captured") + _print_hf("trace focus [SUBSYSTEM] ", "Focus the trace on [SUBSYSTEM]. Do 'trace focus' for a list of available subsystems.") + _print_hf("trace stream [ON|OFF] ", "Stream the trace spans to current terminal.") + _print_hf("trace set [BUDGET] ", "Set the maximum amount of recoded events.") + _print_hf("profile [REPO] ", "Estimate the total FLOPS of the model from [REPO]") _print_hf("bench [REPO]", "Benchmark the system using the model from [REPO]") _print_hf("bench [KERNEL]", "Behcnmark the system using base kernel [KERNEL]") _print_hf("bench [NODE]", "Behcnmark the network latency between the current system and [NODE]") @@ -332,11 +366,6 @@ def handle_start_worker(self): # ===== Handle API server - # Tracer frames ingest callback - def __trace_cb(self, data): - dprint(str(data)) - pass - async def _api_main(self) -> None: # main thread loop self._api_loop = asyncio.get_running_loop() self._api_shutdown_e = asyncio.Event() @@ -380,7 +409,33 @@ def start_api(self, http_port: int=8080, grpc_port: int=50500, timeout=10): if not self._api_ready.wait(timeout): raise RuntimeError("API Server Timeout.") if self._api_exc is not None: - raise RuntimeError(f"API Server failed to start") + raise RuntimeError(f"API Server failed to start: {self._api_exc}") + # Register REPL aggregator callback on the API node + try: + self.api_call("set_trace_ingest_callback", self._trace_agg.enqueue, timeout=5) + except Exception: + pass + + # Silence API server logs on the REPL console: drop records emitted from the API thread + try: + class _DropApiOnConsole(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + # Only drop records coming from the API thread so other threads keep logging + tname = getattr(record, "threadName", "") or "" + return tname != "api_server" + + root = logging.getLogger() + for h in list(root.handlers): + if isinstance(h, logging.StreamHandler) and getattr(h, "stream", None) in (sys.stdout, sys.stderr): + if not any(isinstance(f, _DropApiOnConsole) for f in getattr(h, "filters", [])): + h.addFilter(_DropApiOnConsole()) + + # Also quiet Hypercorn logs explicitly (HTTP server used by API) + logging.getLogger("hypercorn").setLevel(logging.CRITICAL) + logging.getLogger("hypercorn.error").setLevel(logging.CRITICAL) + logging.getLogger("hypercorn.access").setLevel(logging.CRITICAL) + except Exception: + pass def stop_api(self, timeout: float = 5.0) -> None: if not self._api_thread: return @@ -441,6 +496,72 @@ async def _await_then_set(): self._api_loop.call_soon_threadsafe(runner) return f.result(timeout) + # ------- Trace aggregation helpers + + def do_trace(self, cmd): + if len(cmd) < 2: + dprint(f"Tracing is currently {"ON" if self._trace_cfg.enabled else "OFF"}\n") + elif cmd[1] in ("on", "ON"): + self._trace_cfg.enabled = True + if self._api_running: + self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards + dprint("Tracing is now ON\n") + elif cmd[1] in ("off", "OFF"): + self._trace_cfg.enabled = False + if self._api_running: + self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards + dprint("Tracing is not OFF\n") + elif cmd[1] == "focus": + #self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards + dprint("Subsystems not yet implemented.\n") + elif cmd[1] == "stream": + if len(cmd) == 2: + dprint(f"Trace is {"streaming to file: "+str(self._trace_cfg.file) if self._trace_cfg.streaming else "not streaming."}\n") + elif cmd[2] == "on": + self._trace_cfg.streaming = True + dprint(f"Streaming trace frames to {self._trace_cfg.file}\n") + elif cmd[2] == "off": + self._trace_cfg.streaming = False + dprint("Trace streaming is OFF.\n") + elif cmd[1] == "set": + if len(cmd) == 2: + dprint("Use: trace set [BUDGET], eg. 2000\n") + else: + dprint("Not implemented yet\n") + # FIXME: Implement + elif cmd[1] == "status": + dprint(f"Frames: {len(self._trace_agg._req)}\n") + + elif cmd[1] == "annotate": + self.print_trace_annotate("NONE") + + # Trace callback registered with API Thread + def __trace_cb(self, data): + self._trace_agg.enqueue(data) + + def __print_tr(self, symbol, ms, counts): + sym = " " + symbol.ljust(40, ' ') + pms = f"{ms:.10}".ljust(10, ' ') + cns = f"{counts}".ljust(4, ' ') + sys.stdout.write(f"{sym} {pms} {cns}\n") + + def print_trace_annotate( + self, + run_id: str = "run", + mapping: Optional[Dict[str, str]] = None, + repeats: int = 0, + ) -> List[Dict[str, Any]]: + names = " "*17 + "symbol" + " "*21 + "ms" + " "*4 + "counts" + dots = " " + "."*41 + " " + "."*10 + " " + "."*4 + dprint(f"{names}\n{dots}\n\n") + sums = self._trace_agg._req[run_id].sums_by_name + cnts = self._trace_agg._req[run_id].counts_by_name + for n, d in sums.items(): + self.__print_tr(n, d, cnts[n]) + + def get_trace_roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: + return self._trace_agg.roots(run_id, req_id) + def _print_nodes_table(self, rows: List[Any]) -> None: headers = ["name", "role", "addr", "http", "grpc", "status", "head"] limits = {"name": 36, "addr": 15} From 5f604a621b7c3e04b98153638a4514ea3e9ee0e1 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 03:24:48 -0700 Subject: [PATCH 098/172] mlx bug in lm_head, transposes weights for no reason, manually compute matmul --- src/dnet/ring/model/llama3.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/src/dnet/ring/model/llama3.py b/src/dnet/ring/model/llama3.py index 231f02b5..81626d51 100644 --- a/src/dnet/ring/model/llama3.py +++ b/src/dnet/ring/model/llama3.py @@ -33,7 +33,9 @@ def __init__( self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) self.norm = nn.RMSNorm(self.config.hidden_size, self.config.rms_norm_eps) - self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) + + if not self.config.tie_word_embeddings: + self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) self.layers: List[nn.Module] = [] self.abs_to_local: Dict[int, int] = {} @@ -41,8 +43,9 @@ def __init__( for i, l in enumerate(sorted(assigned_layers or [])): self.layers.append(TransformerBlock(self.config)) self.abs_to_local[l] = i + logger.debug(f"Created {len(self.layers)} Transformer layers") - logger.debug(f"abs_to_local mapping: {self.abs_to_local}") + #logger.debug(f"abs_to_local mapping: {self.abs_to_local}") @property def decoding_layers(self): @@ -72,8 +75,14 @@ def embed(self, x: mx.array): def normalize(self, x: mx.array): return self.norm(x) + # FIXME: Weird MLX bug, lm_head weights are transposed internally for no reason def lm_project(self, x: mx.array): - return self.lm_head(x) + if self.config.tie_word_embeddings: + return self.embed_tokens.as_linear(x) + try: + return self.lm_head(x) + except Exception as e: + return mx.matmul(x, self.lm_head.weight) def quantize_layers(self): self.quantization = None @@ -127,7 +136,19 @@ def apply_single_layer( ): if layer_idx not in self.abs_to_local: raise RuntimeError(f"Attempted execution of foreign layer {layer_idx}") - mask = create_attention_mask(x, cache) + + mask = None + sqlen = int(x.shape[1]) + if sqlen > 1: + cached = getattr(self, "_cached_mask_len", None) + cached_mask = getattr(self, "_cached_mask", None) + if cached is None or cached != sqlen or not cached_mask: + mask = create_attention_mask(x, cache) + self._cached_mask = mask + self._cached_mask_len = sqlen + else: + mask = cached_mask + local_idx = self.abs_to_local[layer_idx] logger.debug(f"apply_single_layer: layer:{layer_idx}, local_idx:{local_idx}, input_shape:{x.shape}") @@ -165,7 +186,9 @@ def load_weights(self, weights, strict=False): logger.debug(f"Mapping weight {k} -> {new_key}") shard_weights[new_key] = v - elif (k.startswith("embed_tokens") or k.startswith("lm_head") or k.startswith("norm")): + elif k.startswith("lm_head"): + shard_weights[k] = v + elif (k.startswith("embed_tokens") or k.startswith("norm")): shard_weights[k] = v if shard_weights: From aa3c1cc60286aeaaea4109b141e82954df50a5c9 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 03:32:59 -0700 Subject: [PATCH 099/172] wrap in trace frames --- src/dnet/ring/shard/comms.py | 395 ++++++++++++++++++----------------- 1 file changed, 199 insertions(+), 196 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index 9ca00406..8390457d 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -227,76 +227,79 @@ async def _send_activation(self, activation_msg: ActivationMessage): return try: if activation_msg.is_final: - try: - if self._mode == "offload" and self.window_size > 0: - first_window = self._assigned_sorted[: self.window_size] - if first_window: - loop = asyncio.get_running_loop() - fut = loop.run_in_executor( - self.executor, - self._prepare_window_blocking, - list(first_window), + with self.tracer.frame("grpc", "send_activation.final") as f: + try: + if self._mode == "offload" and self.window_size > 0: + first_window = self._assigned_sorted[: self.window_size] + if first_window: + loop = asyncio.get_running_loop() + fut = loop.run_in_executor( + self.executor, + self._prepare_window_blocking, + list(first_window), + ) + self._prepared_by_nonce[activation_msg.nonce] = ( + list(first_window), + fut, + ) + except Exception: + pass + cb = activation_msg.callback_url or "" + parsed = urlparse(cb) if cb else None + t_rpc = time.perf_counter() + if parsed and parsed.scheme == "grpc": + addr = parsed.netloc + if not addr: + logger.error("Invalid gRPC callback URL for token: %s", cb) + return + # Ensure API channel/stub + if (self.api_channel is None) or (addr != self.api_address): + # close existing channel if any + try: + if self.api_channel is not None: + await self.api_channel.close() + except Exception: + pass + + self.api_address = addr + self.api_channel = aio_grpc.insecure_channel( + addr, options=GRPC_AIO_OPTIONS ) - self._prepared_by_nonce[activation_msg.nonce] = ( - list(first_window), - fut, + self.api_stub = shard_api_comm_pb2_grpc.ShardApiServiceStub( + self.api_channel ) - except Exception: - pass - cb = activation_msg.callback_url or "" - parsed = urlparse(cb) if cb else None - t_rpc = time.perf_counter() - if parsed and parsed.scheme == "grpc": - addr = parsed.netloc - if not addr: - logger.error("Invalid gRPC callback URL for token: %s", cb) - return - # Ensure API channel/stub - if (self.api_channel is None) or (addr != self.api_address): - # close existing channel if any + f.event("reset_api") + try: - if self.api_channel is not None: - await self.api_channel.close() - except Exception: - pass - - self.api_address = addr - self.api_channel = aio_grpc.insecure_channel( - addr, options=GRPC_AIO_OPTIONS - ) - self.api_stub = shard_api_comm_pb2_grpc.ShardApiServiceStub( - self.api_channel - ) - try: - req = shard_api_comm_pb2.TokenRequest( - nonce=activation_msg.nonce, - token_id=int(getattr(activation_msg, "token_id", -1)), - timestamp=utc_epoch_now(), - ) - resp = await self.api_stub.SendToken(req) # type: ignore - rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 - if not resp.success: - logger.error( - "API SendToken failed for %s: %s", - activation_msg.nonce, - resp.message, - ) - if self._profile: - logger.info( - "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", - self.node_id, - activation_msg.nonce, - int(getattr(activation_msg, "token_id", -1)), - rpc_ms, + req = shard_api_comm_pb2.TokenRequest( + nonce=activation_msg.nonce, + token_id=int(getattr(activation_msg, "token_id", -1)), + timestamp=utc_epoch_now(), ) - except Exception as e: - logger.exception("Error sending token via gRPC: %s", e) - else: - logger.error( - "No valid gRPC callback for token; cannot deliver nonce=%s", - activation_msg.nonce, - ) - return + resp = await self.api_stub.SendToken(req) # type: ignore + rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 + if not resp.success: + logger.error( + "API SendToken failed for %s: %s", + activation_msg.nonce, + resp.message, + ) + if self._profile: + logger.info( + "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", + self.node_id, + activation_msg.nonce, + int(getattr(activation_msg, "token_id", -1)), + rpc_ms, + ) + except Exception as e: + logger.exception("Error sending token via gRPC: %s", e) + else: + logger.error( + "No valid gRPC callback for token; cannot deliver nonce=%s", + activation_msg.nonce, + ) + return used_pool = False @@ -309,9 +312,6 @@ async def _send_activation(self, activation_msg: ActivationMessage): "Failed to get output buffer %s", activation_msg.pool_id ) return - data_size = int(np.prod(activation_msg.shape)) - shaped = output_buffer[:data_size].reshape(activation_msg.shape) - used_pool = True if self._profile: logger.info( @@ -346,156 +346,159 @@ async def _send_activation(self, activation_msg: ActivationMessage): activation_msg.dtype = self._wire_dtype_str t_cast = time.perf_counter() - if isinstance(shaped, np.ndarray): - data = shaped.tobytes(order="C") - else: - data = tensor_to_bytes(shaped) + with self.tracer.frame("grpc", "send_activations.cast_to_dtype") as f: + + if isinstance(shaped, np.ndarray): # Cast to target dtype + if shaped.dtype != wire_np_dtype: + shaped = shaped.astype(wire_np_dtype, copy=False) + f.event("ndarray.cast") + data = shaped.tobytes(order="C") + else: + if str(shaped.dtype) != self._wire_dtype_str: # MLX array + shaped = shaped.astype(self._wire_mx_dtype) + f.event("mxarray.cast") + data = tensor_to_bytes(shaped) - ser_ms = (time.perf_counter() - t_ser) * 1000.0 - cast_ms = (t_cast - t_ser) * 1000.0 + activation_msg.dtype = self._wire_dtype_str nxt = activation_msg.layer_id + 1 - if (nxt < self.model_metadata.num_layers) and ( - nxt not in self._assigned_set - ): + if (nxt < self.model_metadata.num_layers) and (nxt not in self._assigned_set): if self.next_node_stub: - request = activation_msg.to_proto(data) - request.timestamp = utc_epoch_now() - if self._mode == "offload" and self.window_size > 0: - next_window = self._next_local_layers( - activation_msg.layer_id, self.window_size - ) - loop = asyncio.get_running_loop() - if next_window: - fut = loop.run_in_executor( - self.executor, - self._prepare_window_blocking, - list(next_window), - ) - self._prepared_by_nonce[activation_msg.nonce] = ( - list(next_window), - fut, + with self.tracer.frame("grpc", "send_activation.next") as f: + request = activation_msg.to_proto(data) + request.timestamp = utc_epoch_now() + if self._mode == "offload" and self.window_size > 0: + next_window = self._next_local_layers( + activation_msg.layer_id, self.window_size ) - else: - first_window = self._assigned_sorted[: self.window_size] - if first_window: + loop = asyncio.get_running_loop() + if next_window: fut = loop.run_in_executor( self.executor, self._prepare_window_blocking, - list(first_window), + list(next_window), ) self._prepared_by_nonce[activation_msg.nonce] = ( - list(first_window), + list(next_window), fut, ) - stream_used = False - ctx = await self._ensure_stream(activation_msg.nonce) - if ( - ctx - and ctx.open - and not ctx.disabled - and hasattr(dnet_ring_pb2, "ActivationFrame") - ): - try: - ctx.last_seq += 1 - frame = dnet_ring_pb2.ActivationFrame( - request=request, - seq=ctx.last_seq, - end_of_request=False, - ) - await ctx.queue.put(frame) - ctx.last_activity_t = asyncio.get_running_loop().time() - stream_used = True - if self._profile: - logger.info( - "[PROFILE][STREAM-ENQ] nonce=%s seq=%s q=%s", - activation_msg.nonce, - ctx.last_seq, - ctx.queue.qsize(), + else: + first_window = self._assigned_sorted[: self.window_size] + if first_window: + fut = loop.run_in_executor( + self.executor, + self._prepare_window_blocking, + list(first_window), + ) + self._prepared_by_nonce[activation_msg.nonce] = ( + list(first_window), + fut, + ) + stream_used = False + ctx = await self._ensure_stream(activation_msg.nonce) + if ( + ctx + and ctx.open + and not ctx.disabled + and hasattr(dnet_ring_pb2, "ActivationFrame") + ): + try: + ctx.last_seq += 1 + frame = dnet_ring_pb2.ActivationFrame( + request=request, + seq=ctx.last_seq, + end_of_request=False, ) - except Exception as e: - logger.warning( - "[STREAM] enqueue failed; fallback to unary: %s", e + await ctx.queue.put(frame) + ctx.last_activity_t = asyncio.get_running_loop().time() + stream_used = True + if self._profile: + logger.info( + "[PROFILE][STREAM-ENQ] nonce=%s seq=%s q=%s", + activation_msg.nonce, + ctx.last_seq, + ctx.queue.qsize(), + ) + except Exception as e: + logger.warning( + "[STREAM] enqueue failed; fallback to unary: %s", e + ) + ctx.disabled = True + + if not stream_used: + # In fit mode, avoid long unary stalls: use short deadline and min retries + # Streaming should be the norm; unary is a quick safety valve only. + ring_timeout = 3.0 if self._mode == "fit" else 30.0 + ring_retries = ( + 1 + if self._mode == "fit" + else max(1, int(self._send_retries)) ) - ctx.disabled = True - - if not stream_used: - # In fit mode, avoid long unary stalls: use short deadline and min retries - # Streaming should be the norm; unary is a quick safety valve only. - ring_timeout = 3.0 if self._mode == "fit" else 30.0 - ring_retries = ( - 1 - if self._mode == "fit" - else max(1, int(self._send_retries)) - ) - # Emit a clear fallback log with reason/context - if self._profile: - if ctx is None: - reason = "no_stream_ctx" - elif not ctx.open: - reason = "stream_closed" - elif ctx.disabled: - reason = "stream_disabled" + # Emit a clear fallback log with reason/context + if self._profile: + if ctx is None: + reason = "no_stream_ctx" + elif not ctx.open: + reason = "stream_closed" + elif ctx.disabled: + reason = "stream_disabled" + else: + reason = "enqueue_failed" + logger.warning( + "[STREAM->UNARY] node=%s nonce=%s reason=%s mode=%s timeout_s=%.1f retries=%d", + self.node_id, + activation_msg.nonce, + reason, + self._mode, + ring_timeout, + ring_retries, + ) + t0 = time.perf_counter() + last_exc: Optional[Exception] = None + for attempt in range(1, ring_retries + 1): + try: + # FIXME: use response here? + _ = await self.next_node_stub.SendActivation( + request, timeout=ring_timeout + ) # type: ignore + break + except grpc.aio.AioRpcError as e: # type: ignore + last_exc = e + code = e.code() + if code in { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.CANCELLED, + grpc.StatusCode.DEADLINE_EXCEEDED, + }: + logger.warning( + "SendActivation attempt %s/%s failed (%s); reconnecting...", + attempt, + ring_retries, + code.name, + ) + await self._reconnect_next_node() + await asyncio.sleep(min(0.25 * attempt, 1.0)) + continue + raise else: - reason = "enqueue_failed" - logger.warning( - "[STREAM->UNARY] node=%s nonce=%s reason=%s mode=%s timeout_s=%.1f retries=%d", + raise last_exc # type: ignore + rpc_ms = (time.perf_counter() - t0) * 1000.0 + logger.info( + "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", self.node_id, activation_msg.nonce, - reason, - self._mode, - ring_timeout, - ring_retries, + activation_msg.layer_id + 1, + (len(data) / 1024), + ser_ms, + rpc_ms, + cast_ms, ) - t0 = time.perf_counter() - last_exc: Optional[Exception] = None - for attempt in range(1, ring_retries + 1): - try: - # FIXME: use response here? - _ = await self.next_node_stub.SendActivation( - request, timeout=ring_timeout - ) # type: ignore - break - except grpc.aio.AioRpcError as e: # type: ignore - last_exc = e - code = e.code() - if code in { - grpc.StatusCode.UNAVAILABLE, - grpc.StatusCode.CANCELLED, - grpc.StatusCode.DEADLINE_EXCEEDED, - }: - logger.warning( - "SendActivation attempt %s/%s failed (%s); reconnecting...", - attempt, - ring_retries, - code.name, - ) - await self._reconnect_next_node() - await asyncio.sleep(min(0.25 * attempt, 1.0)) - continue - raise - else: - raise last_exc # type: ignore - rpc_ms = (time.perf_counter() - t0) * 1000.0 - logger.info( - "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", - self.node_id, - activation_msg.nonce, - activation_msg.layer_id + 1, - (len(data) / 1024), - ser_ms, - rpc_ms, - cast_ms, - ) else: - logger.error( - "Cannot forward activation - no next node configured; end shard should sample inline." - ) + logger.error("Cannot forward activation - no next node configured; end shard should sample inline.") else: logger.error( "Final activation reached send path unexpectedly; sampling should occur on end shard." ) - # Clear scheduling at request end # Sequential offload: prefetch state is unused From cbd4995832e407e26e88d95093c3f7eae53964cf Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 11:51:07 -0700 Subject: [PATCH 100/172] trace ingress worker --- src/dnet/ring/shard/node.py | 328 +++++++++++++++++------------------- 1 file changed, 156 insertions(+), 172 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 3583fea1..6bae4147 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -619,6 +619,7 @@ async def reset_cache(self) -> None: except Exception as e: logger.error("Node %s: Error resetting cache: %s", self.node_id, e) + async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): """Receive activation from previous node and queue for local compute or forward.""" if self.input_pool is None: @@ -631,142 +632,45 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): t_recv = time.perf_counter() await self._connect_next_node() - try: - activation = request.activation - target_layer = activation.layer_id + 1 + with self.tracer.frame("grpc.receive", "connect_next_node"): + await self._connect_next_node() + with self.tracer.frame("grpc.receive", "process_activation") as f: try: - payload_bytes = len(activation.data) - except Exception: - payload_bytes = -1 - transport_ms = float(utc_epoch_now() - request.timestamp) - logger.info( - "[PROFILE][RX] node=%s nonce=%s target_layer=%s transport_ms=%.1f payload_kb=%.1f", - self.node_id, - request.nonce, - target_layer, - transport_ms, - (payload_bytes / 1024.0), - ) + activation = request.activation + target_layer = activation.layer_id + 1 # Detect new sequence per node: initialize per-nonce KV if request.nonce != self._active_nonce: self._active_nonce = request.nonce try: - self._get_or_make_kv(request.nonce) + payload_bytes = len(activation.data) except Exception: - pass + payload_bytes = -1 + f.event("process_payload") if target_layer in self._assigned_set: # Allocate input pool and copy payload (with optional decompression) t_alloc = time.perf_counter() if "|" in activation.dtype: - try: - deq = decompress_tensor_from_protobuf_data( - tensor_data=activation.data, - shape=list(activation.shape), - dtype_with_metadata=activation.dtype, - ) - except Exception as e: - logger.error( - "Decompression failed for nonce %s: %s", request.nonce, e - ) - return - - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=deq.dtype, - shape=cast(tuple[int, ...], tuple(deq.shape)), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - flat = deq.reshape(-1) - buffer[: flat.size] = flat - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", - self.node_id, - request.nonce, - alloc_copy_ms, - ) - # Update activation message with true dtype/shape - new_dtype_str = str(deq.dtype) - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = new_dtype_str - activation_msg.shape = tuple(deq.shape) - else: - # Special token stream support: dtype='tokens' carries int32 token IDs - if activation.dtype == "tokens": + with self.tracer.frame("grpc.receive", "decompress") as fr: try: - tokens = np.frombuffer( - request.activation.data, dtype=np.int32 + deq = decompress_tensor_from_protobuf_data( + tensor_data=activation.data, + shape=list(activation.shape), + dtype_with_metadata=activation.dtype, ) - shp = (int(len(tokens)),) except Exception as e: logger.error( - "Failed to parse tokens for nonce %s: %s", - request.nonce, - e, - ) - return - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mx.int32, - shape=cast(tuple[int, ...], shp), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - buffer[: len(tokens)] = tokens - if self._profile: - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (tokens)", - self.node_id, - request.nonce, - alloc_copy_ms, - ) - activation_msg = ActivationMessage.from_proto(request, pool_id) - # Ensure dtype reflects token payload for compute path - activation_msg.dtype = "tokens" - activation_msg.shape = shp - else: - # Safety: byte length must match shape*dtype - try: - expected = ( - int(np.prod(activation.shape)) - * np.dtype(dtype_map[activation.dtype]).itemsize - ) - actual = len(request.activation.data) - except Exception: - expected = -1 - actual = -1 - if expected != actual: - logger.error( - "Payload size mismatch for nonce=%s: expected=%d actual=%d dtype=%s shape=%s", - request.nonce, - expected, - actual, - activation.dtype, - activation.shape, + "Decompression failed for nonce %s: %s", request.nonce, e ) return + with self.tracer.frame("grpc.receive", "alloc.buffer") as fr: pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, - dtype=mlx_dtype_map[activation.dtype], - shape=cast(tuple[int, ...], activation.shape), + dtype=deq.dtype, + shape=cast(tuple[int, ...], tuple(deq.shape)), ) if pool_id is None: logger.warning( @@ -776,22 +680,103 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): return buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: - data = request.activation.data - input_data = np.frombuffer( - data, dtype=dtype_map[activation.dtype] - ) - buffer[: len(input_data)] = input_data + flat = deq.reshape(-1) + buffer[: flat.size] = flat alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f", + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", self.node_id, request.nonce, alloc_copy_ms, ) - activation_msg = ActivationMessage.from_proto(request, pool_id) - if self._profile: - activation_msg.recv_perf_t = t_recv + # Update activation message with true dtype/shape + new_dtype_str = str(deq.dtype) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + else: + # Special token stream support: dtype='tokens' carries int32 token IDs + if activation.dtype == "tokens": + with self.tracer.frame("grpc.receive", "token_stream") as fr: + try: + tokens = np.frombuffer( + request.activation.data, dtype=np.int32 + ) + shp = (int(len(tokens)),) + except Exception as e: + logger.error( + "Failed to parse tokens for nonce %s: %s", + request.nonce, + e, + ) + return + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mx.int32, + shape=cast(tuple[int, ...], shp), + ) + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + buffer[: len(tokens)] = tokens + if self._profile: + alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 + logger.info( + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (tokens)", + self.node_id, + request.nonce, + alloc_copy_ms, + ) + activation_msg = ActivationMessage.from_proto(request, pool_id) + # Ensure dtype reflects token payload for compute path + activation_msg.dtype = "tokens" + activation_msg.shape = shp + else: + with self.tracer.frame("grpc.receive", "default") as fr: + # Safety: byte length must match shape*dtype + try: + expected = ( + int(np.prod(activation.shape)) + * np.dtype(dtype_map[activation.dtype]).itemsize + ) + actual = len(request.activation.data) + except Exception: + pass + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mlx_dtype_map[activation.dtype], + shape=cast(tuple[int, ...], activation.shape), + ) + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + data = request.activation.data + input_data = np.frombuffer( + data, dtype=dtype_map[activation.dtype] + ) + buffer[: len(input_data)] = input_data + alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 + logger.info( + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f", + self.node_id, + request.nonce, + alloc_copy_ms, + ) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) # Queue for processing — non-blocking back-off loop (cancellable) if self._profile: @@ -821,8 +806,6 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): ) await self._forward_activation(request) - except Exception as e: - logger.exception("Error receiving activation: %s", e) async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: """ @@ -846,64 +829,67 @@ async def _ingress_worker(self): finally enqueues for compute or forwards to the next shard. """ while self.running: - try: - req = await self.ingress_q.get() - except asyncio.CancelledError: - break - try: - t_recv = time.perf_counter() - await self._connect_next_node() - - activation = req.activation - target_layer = activation.layer_id + 1 + with self.tracer.frame("grpc", "ingress") as f: + with self.tracer.frame("grpc.ingress", "get"): + try: + req = await self.ingress_q.get() + except asyncio.CancelledError: + break try: - payload_bytes = len(activation.data) - except Exception: - payload_bytes = -1 - transport_ms = float(utc_epoch_now() - req.timestamp) - logger.info( - "[PROFILE][RX] node=%s nonce=%s target_layer=%s transport_ms=%.1f payload_kb=%.1f", - self.node_id, - req.nonce, - target_layer, - transport_ms, - (payload_bytes / 1024.0), - ) + with self.tracer.frame("grpc.ingress", "connect_next_node"): + await self._connect_next_node() + + activation = req.activation + target_layer = activation.layer_id + 1 # Detect new sequence per node: initialize per-nonce KV if req.nonce != self._active_nonce: self._active_nonce = req.nonce try: - self._get_or_make_kv(req.nonce) + payload_bytes = len(activation.data) except Exception: - pass + logger.error(f"Unable to read length of data for {req.nonce}") + payload_bytes = -1 + + f.set("nonce", req.nonce) + f.set("target", target_layer) + f.set("payload_bytes", payload_bytes) + f.event("received") if target_layer in self._assigned_set: # Heavy prep in executor (alloc/copy/decompress) - loop = asyncio.get_running_loop() - try: - activation_msg = await loop.run_in_executor( - self.executor, - self._prepare_activation_message_blocking, - req, - ) - except Exception as e: - logger.error( - "Activation prepare failed for nonce %s: %s", req.nonce, e - ) - continue - if activation_msg is None: - continue - if self._profile: - activation_msg.recv_perf_t = t_recv + with self.tracer.frame("grpc.ingress", "prepare"): + loop = asyncio.get_running_loop() + try: + activation_msg = await loop.run_in_executor( + self.executor, + self._prepare_activation_message_blocking, + req, + ) + except Exception as e: + logger.error("Activation prepare failed for nonce %s: %s", req.nonce, e) + continue + if activation_msg is None: + continue + if self._profile: + activation_msg.recv_perf_t = t_recv # Enqueue for compute (cancellable back-off) - while self.running: - try: - self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", + with self.tracer.frame("grpc.ingress", "queue") as fr: + while self.running: + try: + self.activation_recv_queue.put_nowait(activation_msg) + logger.debug( + "Queued activation for processing: nonce %s", + activation_msg.nonce, + ) + break + except Full: + await asyncio.sleep(0) + else: + logger.error( + "Failed to queue activation %s (node stopping)", activation_msg.nonce, ) break @@ -929,8 +915,6 @@ async def _ingress_worker(self): ) await self._forward_activation(req) - except Exception as e: - logger.error("Ingress worker error: %s", e) def _get_or_make_kv(self, nonce: str) -> list: """Return a per-nonce KV cache list for this shard's local layers.""" From 2124522966021f3b4adc1a734850647098de1806 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 11:51:44 -0700 Subject: [PATCH 101/172] trace token request stall --- src/dnet/ring/shard/comms.py | 45 +++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index 8390457d..dc197637 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -270,28 +270,29 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) f.event("reset_api") - try: - req = shard_api_comm_pb2.TokenRequest( - nonce=activation_msg.nonce, - token_id=int(getattr(activation_msg, "token_id", -1)), - timestamp=utc_epoch_now(), - ) - resp = await self.api_stub.SendToken(req) # type: ignore - rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 - if not resp.success: - logger.error( - "API SendToken failed for %s: %s", - activation_msg.nonce, - resp.message, - ) - if self._profile: - logger.info( - "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", - self.node_id, - activation_msg.nonce, - int(getattr(activation_msg, "token_id", -1)), - rpc_ms, + with self.tracer.frame("grpc", "token_request") as fr: + try: + req = shard_api_comm_pb2.TokenRequest( + nonce=activation_msg.nonce, + token_id=int(getattr(activation_msg, "token_id", -1)), + timestamp=utc_epoch_now(), ) + resp = await self.api_stub.SendToken(req) # type: ignore + rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 + if not resp.success: + logger.error( + "API SendToken failed for %s: %s", + activation_msg.nonce, + resp.message, + ) + if self._profile: + logger.info( + "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", + self.node_id, + activation_msg.nonce, + int(getattr(activation_msg, "token_id", -1)), + rpc_ms, + ) except Exception as e: logger.exception("Error sending token via gRPC: %s", e) else: @@ -495,6 +496,8 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) else: logger.error("Cannot forward activation - no next node configured; end shard should sample inline.") + + # Final layer not annotated with 'is_final' else: logger.error( "Final activation reached send path unexpectedly; sampling should occur on end shard." From d6f63c86bb6b6ae810c659b04088d9b3c59485c0 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 11:52:10 -0700 Subject: [PATCH 102/172] stop printing state on empty prompt --- src/repl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/repl.py b/src/repl.py index 592424a6..862d7136 100644 --- a/src/repl.py +++ b/src/repl.py @@ -107,7 +107,8 @@ def loop(self): # Main tty loop cmd = sys.stdin.readline().strip() if cmd == "": - self.print_state() + #self.print_state() + continue elif cmd in [".exit", "exit", "quit"]: self.handle_terminate_signal() elif cmd in [".help", "help", "h"]: From 079d6ce0b76475c42522affd44cbdd8c3bfe4ef6 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 12:53:09 -0700 Subject: [PATCH 103/172] trace startup --- src/dnet/ring/shard/startup.py | 44 ++++++++++++++-------------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/src/dnet/ring/shard/startup.py b/src/dnet/ring/shard/startup.py index 6790be64..f3b38fbb 100644 --- a/src/dnet/ring/shard/startup.py +++ b/src/dnet/ring/shard/startup.py @@ -48,34 +48,26 @@ class StartupMixin: async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()): self.running = True - try: # Capture the main event loop for cross-thread scheduling - self._loop = asyncio.get_running_loop() - except Exception: - self._loop = None - await self._start_grpc_server() - await self._start_http_server(shutdown_trigger) - await asyncio.sleep(0.2) - - self.background_tasks = [ - asyncio.create_task(self._ingress_worker()), - asyncio.create_task(self._prefetch_worker()), - asyncio.create_task(self._send_worker()), - ] - # Start idle sweeper to close silent streams - try: - if getattr(self, "_streaming_enabled", False) and hasattr( - self, "_stream_sweeper" - ): - self.background_tasks.append( - asyncio.create_task(self._stream_sweeper()) - ) - except Exception: - pass - self.compute_thread = threading.Thread(target=self._compute_worker, daemon=True) - self.compute_thread.start() + with self.tracer.frame("startup", "workers"): + self.background_tasks = [ + asyncio.create_task(self._ingress_worker()), + asyncio.create_task(self._prefetch_worker()), + asyncio.create_task(self._send_worker()) ] + + try: # Start idle sweeper to close silent streams + if getattr(self, "_streaming_enabled", False) and hasattr(self, "_stream_sweeper"): + self.background_tasks.append( asyncio.create_task(self._stream_sweeper())) + except Exception: + pass + + with self.tracer.frame("startup", "compute"): + self.compute_thread = threading.Thread(target=self._compute_worker, daemon=True) + self.compute_thread.start() + + with self.tracer.frame("startup", "discovery"): + self._start_discovery() - self._start_discovery() logger.info( "Shard node %s started on gRPC port %s HTTP port %s", self.node_id, From a2c9fc95902729efc3ec142a1e26a3497c0ed2ba Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 12:53:43 -0700 Subject: [PATCH 104/172] trace prepare_activation --- src/dnet/ring/shard/node.py | 257 ++++++++++++++++++------------------ 1 file changed, 131 insertions(+), 126 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 6bae4147..a4786ec1 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -200,8 +200,6 @@ def __init__( self._sync_per_layer = obs.sync_per_layer self._sync_every_n = obs.sync_every_n self._prof = make_profiler(self._profile) - if self._profile: - logger.info("[PROFILE] enabled on shard node %s", self.node_id) # Debug tracing cfg = TraceConfig( @@ -265,6 +263,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse # Load model metadata with self.tracer.frame("memory", "model.load_metadata"): self.model_metadata = get_model_metadata(req.model_path) + self.assigned_layers = req.layers self._assigned_sorted = sorted(self.assigned_layers) self._assigned_set = set(self._assigned_sorted) @@ -442,29 +441,30 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse self.total_layers = req.total_layers self.api_callback_address = req.api_callback_address - if self.next_node: - with self.tracer.frame("network", "connect.next_node"): + with self.tracer.frame("network", "connect.next_node"): + if self.next_node: await self._connect_next_node() else: logger.warning("Node %s: No next node configured", self.node_id) # Warmup: compile hot path and stabilize allocators before first request - if req.warmup and self._mode == "fit": - loop = asyncio.get_running_loop() - try: - await loop.run_in_executor(self.executor, self._warmup_shard) - except Exception: - # Fall back to direct call if executor is unavailable - self._warmup_shard() - elif req.warmup and self._mode != "fit": - # Offload/sliding-fit: perform a small, offload-safe warmup for the first window - loop = asyncio.get_running_loop() - try: - await loop.run_in_executor( - self.executor, self._warmup_shard_offload - ) - except Exception: - self._warmup_shard_offload() + with self.tracer.frame("memory", "warmup"): + if req.warmup and self._mode == "fit": + loop = asyncio.get_running_loop() + try: + await loop.run_in_executor(self.executor, self._warmup_shard) + except Exception: + # Fall back to direct call if executor is unavailable + self._warmup_shard() + elif req.warmup and self._mode != "fit": + # Offload/sliding-fit: perform a small, offload-safe warmup for the first window + loop = asyncio.get_running_loop() + try: + await loop.run_in_executor( + self.executor, self._warmup_shard_offload + ) + except Exception: + self._warmup_shard_offload() if self._mode == "offload" and not ( self._warmup_completed and self._warmup_keep_flag @@ -485,10 +485,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse if m % self.window_size != 0: logger.warning( "Window size %s does not divide local layer count %s. Rounds per token will vary; consider setting k*w = %s.", - self.window_size, - m, - m, - ) + self.window_size, m, m) else: k = m // self.window_size logger.info( @@ -953,13 +950,14 @@ def _clear_kv(self, nonce: str) -> None: except Exception: pass + def _prepare_activation_message_blocking( self, request: dnet_ring_pb2.ActivationRequest ) -> Optional[ActivationMessage]: + """Blocking heavy prep: allocate pool buffer, copy/decompress payload, build ActivationMessage. + Returns None on failure. """ - Returns None on failure. - """ if self.input_pool is None: logger.error( "Node %s: Cannot prepare activation - input pool not initialized", @@ -969,113 +967,117 @@ def _prepare_activation_message_blocking( try: activation = request.activation - if "|" in activation.dtype: - # Compressed path: decompress to MLX array and copy to pool - try: - deq = decompress_tensor_from_protobuf_data( - tensor_data=activation.data, - shape=list(activation.shape), - dtype_with_metadata=activation.dtype, - ) - except Exception as e: - logger.error( - "Decompression failed for nonce %s: %s", request.nonce, e - ) - return None + if "|" in activation.dtype: # Compressed path: decompress to MLX array and copy to pool - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=deq.dtype, - shape=cast(tuple[int, ...], tuple(deq.shape)), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return None - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - flat = deq.reshape(-1) - buffer[: flat.size] = flat - # Update activation message with true dtype/shape - new_dtype_str = str(deq.dtype) - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = new_dtype_str - activation_msg.shape = tuple(deq.shape) - return activation_msg - elif activation.dtype == "tokens": - # Tokens path: parse int32 token IDs and stage them - try: - tokens = np.frombuffer(activation.data, dtype=np.int32) - shp = (int(len(tokens)),) - except Exception as e: - logger.error( - "Failed to parse tokens for nonce %s: %s", request.nonce, e - ) - return None - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mx.int32, - shape=cast(tuple[int, ...], shp), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return None - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - buffer[: len(tokens)] = tokens - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = "tokens" - activation_msg.shape = shp - return activation_msg - else: - # Dense path: validate size and copy raw bytes view into pool buffer - try: - expected = ( - int(np.prod(activation.shape)) - * np.dtype(dtype_map[activation.dtype]).itemsize + with self.tracer.frame("grpc.ingress.prepare_activation", "decompress") as f: + try: + deq = decompress_tensor_from_protobuf_data( + tensor_data=activation.data, + shape=list(activation.shape), + dtype_with_metadata=activation.dtype, + ) + except Exception as e: + logger.error( + "Decompression failed for nonce %s: %s", request.nonce, e + ) + return None + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=deq.dtype, + shape=cast(tuple[int, ...], tuple(deq.shape)), ) - actual = len(activation.data) - except Exception: - expected = -1 - actual = -1 - if expected != actual: - logger.error( - "Payload size mismatch for nonce=%s: expected=%d actual=%d dtype=%s shape=%s", - request.nonce, - expected, - actual, - activation.dtype, - activation.shape, + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return None + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + flat = deq.reshape(-1) + buffer[: flat.size] = flat + # Update activation message with true dtype/shape + new_dtype_str = str(deq.dtype) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + return activation_msg + + elif activation.dtype == "tokens": # Tokens path: parse int32 token IDs and stage them + with self.tracer.frame("grpc.ingress.prepare_activation", "tokens") as f: + try: + tokens = np.frombuffer(activation.data, dtype=np.int32) + shp = (int(len(tokens)),) + except Exception as e: + logger.error( + "Failed to parse tokens for nonce %s: %s", request.nonce, e + ) + return None + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mx.int32, + shape=cast(tuple[int, ...], shp), ) - return None + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return None + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + buffer[: len(tokens)] = tokens + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = "tokens" + activation_msg.shape = shp + return activation_msg - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mlx_dtype_map[activation.dtype], - shape=cast(tuple[int, ...], activation.shape), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, + else: # Dense path: validate size and copy raw bytes view into pool buffer + with self.tracer.frame("grpc.ingress.prepare_activation", "default") as f: + try: + expected = ( + int(np.prod(activation.shape)) + * np.dtype(dtype_map[activation.dtype]).itemsize + ) + actual = len(activation.data) + except Exception: + expected = -1 + actual = -1 + if expected != actual: + logger.error( + "Payload size mismatch for nonce=%s: expected=%d actual=%d dtype=%s shape=%s", + request.nonce, + expected, + actual, + activation.dtype, + activation.shape, + ) + return None + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mlx_dtype_map[activation.dtype], + shape=cast(tuple[int, ...], activation.shape), ) - return None - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - data = request.activation.data - input_data = np.frombuffer(data, dtype=dtype_map[activation.dtype]) - buffer[: len(input_data)] = input_data - activation_msg = ActivationMessage.from_proto(request, pool_id) - return activation_msg + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return None + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + data = request.activation.data + input_data = np.frombuffer(data, dtype=dtype_map[activation.dtype]) + buffer[: len(input_data)] = input_data + activation_msg = ActivationMessage.from_proto(request, pool_id) + return activation_msg except Exception as e: logger.error("Activation prep error: %s", e) return None + def _next_local_layers(self, after_layer: int, count: int) -> List[int]: """Get next local layers after specified layer. @@ -1092,22 +1094,25 @@ def _next_local_layers(self, after_layer: int, count: int) -> List[int]: i = _bisect_left(s, after_layer + 1) return s[i : i + count] + def _compute_worker(self) -> None: """Compute thread worker.""" while self.running: try: # Get activation from queue (blocks until available) - activation_msg = self.activation_recv_queue.get(timeout=1.0) + with self.tracer.frame("compute", "dequeue"): + activation_msg = self.activation_recv_queue.get(timeout=1.0) # Process the activation with self.tracer.frame("compute", "forward"): - self._process_activation(activation_msg) + self._process_activation(activation_msg) except Empty: continue except Exception as e: logger.error("Compute worker error: %s", e) + async def shutdown(self) -> None: """Shutdown the node.""" self.running = False From 001dac5f7c51a808a708e61d43800860fe4ae376 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 12:54:21 -0700 Subject: [PATCH 105/172] don't set _prefetch_pause --- src/dnet/ring/shard/compute.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index d25705fe..41111154 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -92,6 +92,7 @@ def _process_activation(self, activation_msg: ActivationMessage): # Prepare input activation with self.tracer.frame("compute.thread", "activations.process") as f: if activation_msg.dtype == "tokens": # embed locally on start shard + logger.debug(f"Embedding tokens.") f.event("embed_tokens") numel = int(np.prod(activation_msg.shape)) tok_view = input_buffer[:numel].reshape(activation_msg.shape) @@ -395,6 +396,7 @@ def _process_activation(self, activation_msg: ActivationMessage): except Exception as e: logger.error("End-shard sampling failed: %s", e) return + output_msg = ActivationMessage( nonce=activation_msg.nonce, layer_id=last_layer, From 6c6c08895806d8a3283383767ecf7183e0432f38 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 14:36:57 -0700 Subject: [PATCH 106/172] compute mean, p50, p99, etc. per trace symbol and print --- src/dnet/perf/trace.py | 7 +- src/dnet/perf/utils/aggregator.py | 154 ++++++++++++++++++++---------- src/repl.py | 40 +++++--- 3 files changed, 135 insertions(+), 66 deletions(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index e4a92cfb..d2ed43b6 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -131,14 +131,10 @@ def _agg_exec(self) -> None: except Exception: logger.warining("Unable to close httpx client.") - # We don't have the API addr at init time def update_api_addr(self, addr): self.config.aggregate_url = addr logger.debug(f"Updated API Address: {self.config.aggregate_url}") - def update_confi(self, config): - pass - def start(self, *, reset: bool = True) -> None: self._active = bool(self.config.enabled) if not self._active: @@ -214,8 +210,9 @@ def _emit(self, ev: Dict[str, Any]) -> None: if len(self._events) < self._agg_max_events: return logger.debug(f"Queuing tracer frame batch of {len(self._events)}") batch = { "run_id": (self._req_id or "NONE"), - "node_id": (self.config.node_id or "UNKNOWN_NODE"), + "node_id": (self.config.node_id or "NODE"), "events": list(self._events)} + logger.debug(batch) try: self._agg_q.put_nowait(batch) except queue.Full: diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index 2736a25b..010e6922 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -11,11 +11,24 @@ Key = Tuple[str, Optional[int], Optional[int], str] # (node_id, pid, tid, req_id) @dataclass -class _OpenFrame: +class _ActiveSpan: + """Per-instance active span used for self-time accounting on a call stack.""" name: str t0: int child: int = 0 - children: List[Dict[str, Any]] = field(default_factory=list) + + +@dataclass +class _SymbolAgg: + """Aggregated statistics for a single trace symbol (name).""" + total_ms: float = 0.0 + count: int = 0 + durations: deque = field(default_factory=lambda: deque(maxlen=10000)) + + def add(self, self_ms: float) -> None: + self.total_ms += float(self_ms) + self.count += 1 + self.durations.append(float(self_ms)) # Sort known frames and compute averages by key @dataclass @@ -24,73 +37,74 @@ class RunAggregator: counts_by_name: Dict[str, int] = field(default_factory=dict) last_batch_seq: Dict[str, int] = field(default_factory=dict) - stacks: Dict[Key, List[_OpenFrame]] = field(default_factory=dict) + stacks: Dict[Key, List[_ActiveSpan]] = field(default_factory=dict) drops: int = 0 - roots_by_req: DefaultDict[str, List[Dict[str, Any]]] = field(default_factory=lambda: defaultdict(list)) + # Aggregated stats per symbol (primary source of truth) + symbols: Dict[str, _SymbolAgg] = field(default_factory=dict) + # Back-compat mirrors for existing readers (e.g., legacy REPL code) + durations_by_name: Dict[str, deque] = field(default_factory=dict) def _key(self, node_id: str, pid: Optional[int], tid: Optional[int], req_id: Optional[str]) -> Key: return (node_id, pid, tid, req_id or "") - def _push(self, key: Key, f: _OpenFrame) -> None: + def _push(self, key: Key, f: _ActiveSpan) -> None: self.stacks.setdefault(key, []).append(f) - def _pop(self, key: Key) -> Optional[_OpenFrame]: + def _pop(self, key: Key) -> Optional[_ActiveSpan]: st = self.stacks.get(key) if not st: return None return st.pop() - def _peek(self, key: Key) -> Optional[_OpenFrame]: + def _peek(self, key: Key) -> Optional[_ActiveSpan]: st = self.stacks.get(key) return st[-1] if st else None - def _acc_annotate(self, name: str, self_ms: float) -> None: - self.sums_by_name[name] = self.sums_by_name.get(name, 0.0) + self_ms - self.counts_by_name[name] = self.counts_by_name.get(name, 0) + 1 + def _accumulate(self, name: str, self_ms: float) -> None: + sym = self.symbols.get(name) + if sym is None: + sym = _SymbolAgg() + self.symbols[name] = sym + sym.add(self_ms) + + # FIXME: Remove + self.sums_by_name[name] = sym.total_ms + self.counts_by_name[name] = sym.count + dq = self.durations_by_name.get(name) + if dq is None: + dq = deque(maxlen=10000) + self.durations_by_name[name] = dq + dq.append(float(self_ms)) def ingest_event(self, node_id: str, ev: Dict[str, Any]) -> None: if not ev.get("name"): logger.error(f"Received trace frame without name {ev.get('ts')}") return - # Normalize ts to microseconds (accept float seconds or int microseconds) + # Normalize timestamp to microseconds ts_raw = ev.get("ts") - ts_us = 0 + ts = 0 try: if isinstance(ts_raw, float): - ts_us = int(ts_raw * 1_000_000) + ts = int(ts_raw * 1_000_000) elif isinstance(ts_raw, int): - ts_us = ts_raw + ts = ts_raw else: - ts_us = int(ts_raw or 0) + ts = int(ts_raw or 0) except Exception: - ts_us = 0 + ts = 0 req_id = ev.get("req_id") or "" key = self._key(node_id, ev.get("pid"), ev.get("tid"), req_id) if ev.get("type") == "B": - self._push(key, _OpenFrame(name=ev.get("name"), t0=ts_us)) + self._push(key, _ActiveSpan(name=ev.get("name"), t0=ts)) elif ev.get("type") == "E": fr = self._pop(key) if not fr: return - dur_us = max(0, ts_us - fr.t0) + dur_us = max(0, ts - fr.t0) self_us = max(0, dur_us - fr.child) self_ms = self_us / 1000.0 - self._acc_annotate(fr.name, self_ms) + self._accumulate(fr.name, self_ms) parent = self._peek(key) - completed = { - "name": fr.name, - "ts": fr.t0, - "dur_ms": dur_us / 1000.0, - "self_ms": self_ms, - "children": fr.children, - "pid": ev.get("pid"), - "tid": ev.get("tid"), - "req_id": req_id, - "node_id": node_id, - } if parent: parent.child += dur_us - parent.children.append(completed) - else: - self.roots_by_req[req_id or ""].append(completed) else: # TODO :Process other events pass @@ -126,28 +140,68 @@ def annotate(self, run_id: str, *, mapping: Optional[Dict[str, str]] = None, rep agg = self._req.get(run_id) if not agg: return [] + def _stats(xs: List[float]) -> Dict[str, float]: + if not xs: + return {"mean": 0.0, "p50": 0.0, "p90": 0.0, "p99": 0.0, "min": 0.0, "max": 0.0} + n = len(xs) + srt = sorted(xs) + def q(p: float) -> float: + if n == 1: + return srt[0] + k = int(round(p * (n - 1))) + k = max(0, min(k, n - 1)) + return srt[k] + return { + "mean": (sum(xs) / n), + "p50": q(0.5), + "p90": q(0.9), + "p99": q(0.99), + "min": srt[0], + "max": srt[-1], + } + + rows: List[Dict[str, Any]] = [] if not mapping: - rows = [ - {"name": k, "self_ms": v, "total_ms": v, "count": repeats or agg.counts_by_name.get(k, 0), "max_ms": None} - for k, v in agg.sums_by_name.items() - ] + for name, sym in agg.symbols.items(): + total = sym.total_ms + samples = list(sym.durations) + st = _stats(samples) + rows.append({ + "name": name, + "total": total, + "max": st["max"], + "mean": st["mean"], + "p50": st["p50"], + "p90": st["p90"], + "p99": st["p99"], + "samples": len(samples), + }) else: sums: Dict[str, float] = {} counts: Dict[str, int] = {} - for raw, val in agg.sums_by_name.items(): + dists: Dict[str, List[float]] = {} + for raw, sym in agg.symbols.items(): disp = mapping.get(raw, raw) - sums[disp] = sums.get(disp, 0.0) + val - counts[disp] = counts.get(disp, 0) + agg.counts_by_name.get(raw, 0) - rows = [ - {"name": k, "self_ms": v, "total_ms": v, "count": repeats or counts.get(k, 0), "max_ms": None} - for k, v in sums.items() - ] - rows.sort(key=lambda r: r["self_ms"], reverse=True) + sums[disp] = sums.get(disp, 0.0) + sym.total_ms + counts[disp] = counts.get(disp, 0) + sym.count + if sym.durations: + dists.setdefault(disp, []).extend(sym.durations) + for name, total in sums.items(): + samples = dists.get(name, []) + st = _stats(samples) + rows.append({ + "name": name, + "total": total, + "max": st["max"], + "mean": st["mean"], + "p50": st["p50"], + "p90": st["p90"], + "p99": st["p99"], + "samples": len(samples), + }) + rows.sort(key=lambda r: r["total"], reverse=True) return rows def roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: - with self._lock: - agg = self._req.get(run_id) - if not agg: - return [] - return list(agg.roots_by_req.get(req_id or "", [])) + # Call-tree storage is disabled to reduce memory; keep API for compatibility. + return [] diff --git a/src/repl.py b/src/repl.py index 862d7136..c2f081fb 100644 --- a/src/repl.py +++ b/src/repl.py @@ -540,7 +540,7 @@ def do_trace(self, cmd): def __trace_cb(self, data): self._trace_agg.enqueue(data) - def __print_tr(self, symbol, ms, counts): + def __print_tr(self, row): sym = " " + symbol.ljust(40, ' ') pms = f"{ms:.10}".ljust(10, ' ') cns = f"{counts}".ljust(4, ' ') @@ -552,16 +552,34 @@ def print_trace_annotate( mapping: Optional[Dict[str, str]] = None, repeats: int = 0, ) -> List[Dict[str, Any]]: - names = " "*17 + "symbol" + " "*21 + "ms" + " "*4 + "counts" - dots = " " + "."*41 + " " + "."*10 + " " + "."*4 - dprint(f"{names}\n{dots}\n\n") - sums = self._trace_agg._req[run_id].sums_by_name - cnts = self._trace_agg._req[run_id].counts_by_name - for n, d in sums.items(): - self.__print_tr(n, d, cnts[n]) - - def get_trace_roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: - return self._trace_agg.roots(run_id, req_id) + + rows = self._trace_agg.annotate(run_id) + headers = ["name", "total","max","mean","p50","p90","p99","samples"] + limits = {"name": 50,} + w = {h: max(len(h), min(limits.get(h, 8), max(len(str(r[h])) for r in rows))) for h in headers} + w["name"] = max(w["name"], 35) + + line = " ".join(h.ljust(w[h]) for h in headers); sys.stdout.write("\n") + sys.stdout.write(line + "\n") + sys.stdout.write(" ".join("."*w[h] for h in headers)); sys.stdout.write("\n") + for r in rows: + name = str(r["name"]) + if len(name) > w["name"]: name = name[:w["name"]-1] + "..." + vals = { + "name": r["name"], + "total": r["total"], + "max": r["max"], + "mean": r["mean"], + "p50": r["p50"], + "p90": r["p90"], + "p99": r["p99"], + "samples": r["samples"], + } + sys.stdout.write(" " + str(vals[headers[0]]).ljust(w[headers[0]])) + sys.stdout.write(" ".join(f"{vals[h]:8.2f}".rjust(w[h]) for h in headers[1:])) + sys.stdout.write("\n") + sys.stdout.write("\n\n") + sys.stdout.flush() def _print_nodes_table(self, rows: List[Any]) -> None: headers = ["name", "role", "addr", "http", "grpc", "status", "head"] From a62fc38f4db930c415fbed281adbb99772ea425f Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 22:35:29 -0700 Subject: [PATCH 107/172] append async symbols with 'wait' --- src/dnet/ring/shard/compute.py | 3 ++- src/dnet/ring/shard/node.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 41111154..794fc682 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -377,7 +377,8 @@ def _process_activation(self, activation_msg: ActivationMessage): except Exception: pass - with self.tracer.frame("compute.thread", "mdns.send"): + # Create and enqueue output message: either forward activations or finalize on end role + with self.tracer.frame("compute.thread", "grpc.send"): nxt = last_layer + 1 if nxt >= self.model_metadata.num_layers: # End of model try: diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index a4786ec1..8f53da88 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -812,12 +812,13 @@ async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: while self.running: try: self.ingress_q.put_nowait(request) + logger.debug(f"[ENQUE] Enqueued activation request") return except asyncio.QueueFull: await asyncio.sleep(0) - # If we reached here, node is stopping; drop admission silently return + async def _ingress_worker(self): """Drains ingress queue and processes frames with heavy work offloaded. @@ -827,9 +828,10 @@ async def _ingress_worker(self): """ while self.running: with self.tracer.frame("grpc", "ingress") as f: - with self.tracer.frame("grpc.ingress", "get"): + with self.tracer.frame("grpc.ingress", "get.wait"): try: req = await self.ingress_q.get() + logger.debug(f"[DEQUE]Dequeued activation for processing {req}") except asyncio.CancelledError: break @@ -1100,7 +1102,7 @@ def _compute_worker(self) -> None: while self.running: try: # Get activation from queue (blocks until available) - with self.tracer.frame("compute", "dequeue"): + with self.tracer.frame("compute", "deque.wait"): activation_msg = self.activation_recv_queue.get(timeout=1.0) # Process the activation From 60ada709fb68a1c6b0b2e2a4e43621485b2dbb7b Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 22:36:12 -0700 Subject: [PATCH 108/172] min bench wrapper --- src/dnet/perf/bench.py | 145 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 src/dnet/perf/bench.py diff --git a/src/dnet/perf/bench.py b/src/dnet/perf/bench.py new file mode 100644 index 00000000..ccbd040d --- /dev/null +++ b/src/dnet/perf/bench.py @@ -0,0 +1,145 @@ + +from __future__ import annotations + +import json +import os +import statistics +import time +from dataclasses import dataclass, field +from typing import Any, Dict, Iterable, List, Optional + +from dnet.perf.trace import Tracer + + +def _percentile(xs: List[float], q: float) -> float: + if not xs: + return 0.0 + ys = sorted(xs) + k = int(round(q * (len(ys) - 1))) + k = max(0, min(k, len(ys) - 1)) + return ys[k] + + +def collect_stats(times_ms: List[float], *, bytes_total: float = 0.0, tokens_total: float = 0.0) -> Dict[str, Any]: + if not times_ms: + return { + "mean": 0.0, + "std": 0.0, + "min": 0.0, + "p50": 0.0, + "p90": 0.0, + "p99": 0.0, + "max": 0.0, + "samples": 0, + "mb_s": 0.0, + "tok_s": 0.0, + } + total_ms = sum(times_ms) + mean = total_ms / len(times_ms) + std = statistics.pstdev(times_ms) if len(times_ms) > 1 else 0.0 + total_s = max(total_ms / 1000.0, 1e-12) + return { + "mean": mean, + "std": std, + "min": min(times_ms), + "p50": _percentile(times_ms, 0.5), + "p90": _percentile(times_ms, 0.9), + "p99": _percentile(times_ms, 0.99), + "max": max(times_ms), + "samples": len(times_ms), + "mb_per_s": (bytes_total / 1_000_000.0) / total_s if bytes_total else 0.0, + "tokens_per_s": (tokens_total / total_s) if tokens_total else 0.0, + } + + +def _ensure_dir(path: str) -> None: + d = os.path.dirname(path) or "." + os.makedirs(d, exist_ok=True) + + +@dataclass +class BenchCounters: + values: Dict[str, float] = field(default_factory=dict) + + def add_time(self, key: str, dt_ms: float) -> None: + self.values[key] = self.values.get(key, 0.0) + float(dt_ms) + + def add_bytes(self, *, direction: str, n: int) -> None: + k = "bytes_in" if direction == "in" else "bytes_out" + self.values[k] = self.values.get(k, 0.0) + float(n) + + def inc(self, key: str, delta: float = 1.0) -> None: + self.values[key] = self.values.get(key, 0.0) + float(delta) + + def snapshot(self, *, run_id: str, node: str, role: str = "shard") -> Dict[str, Any]: + snap = { + "run_id": run_id, + "node": node, + "role": role, + "counters": dict(self.values), + } + return snap + + +class TimedSpan: + __slots__ = ("_tracer", "_name", "_attrs", "_t0", "_frame", "_counters", "_counter_key") + + def __init__( + self, + tracer: Optional[Tracer], + name: str, + counters: Optional[BenchCounters] = None, + counter_key: Optional[str] = None, + attrs: Optional[Dict[str, Any]] = None, + ) -> None: + self._tracer = tracer + self._name = name + self._attrs = attrs or {} + self._t0 = 0.0 + self._frame = None + self._counters = counters + self._counter_key = counter_key + + def __enter__(self): + self._t0 = time.perf_counter() + if self._tracer is not None: + self._frame = self._tracer.frame("bench", self._name, self._attrs) + self._frame.__enter__() + return self + + def __exit__(self, ex_type, ex, tb) -> bool: + dt_ms = (time.perf_counter() - self._t0) * 1000.0 + if self._frame is not None: + try: + self._frame.__exit__(ex_type, ex, tb) + except Exception: + pass + if self._counters is not None and self._counter_key: + self._counters.add_time(self._counter_key, dt_ms) + return False + + +def aggregate_annotate( + snapshots: Iterable[Dict[str, Any]], + *, + mapping: Optional[Dict[str, str]] = None, + repeats: int = 0, +) -> List[Dict[str, Any]]: + + sums: Dict[str, float] = {} + for snap in snapshots: + ctr = snap.get("counters") if isinstance(snap, dict) else None + if not isinstance(ctr, dict): + continue + for k, v in ctr.items(): + name = mapping.get(k, k) if mapping else k + try: + sums[name] = sums.get(name, 0.0) + float(v) + except Exception: + continue + + rows = [ {"name": name, "self_ms": val, "total_ms": val, "count": repeats or 0, "max_ms": None} + for name, val in sums.items() if val > 0.0] + rows.sort(key=lambda r: r["self_ms"], reverse=True) + return rows + From 0b6e25f21b5845ed794ee54bce485ff75002894d Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 21 Oct 2025 22:47:57 -0700 Subject: [PATCH 109/172] various --- src/dnet/perf/bench.py | 1 - src/dnet/perf/trace.py | 1 - src/dnet/ring/shard/startup.py | 1 - src/dnet/utils/logger.py | 2 +- 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/dnet/perf/bench.py b/src/dnet/perf/bench.py index ccbd040d..0cdbadd2 100644 --- a/src/dnet/perf/bench.py +++ b/src/dnet/perf/bench.py @@ -19,7 +19,6 @@ def _percentile(xs: List[float], q: float) -> float: k = max(0, min(k, len(ys) - 1)) return ys[k] - def collect_stats(times_ms: List[float], *, bytes_total: float = 0.0, tokens_total: float = 0.0) -> Dict[str, Any]: if not times_ms: return { diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index d2ed43b6..7b5d79c1 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -212,7 +212,6 @@ def _emit(self, ev: Dict[str, Any]) -> None: batch = { "run_id": (self._req_id or "NONE"), "node_id": (self.config.node_id or "NODE"), "events": list(self._events)} - logger.debug(batch) try: self._agg_q.put_nowait(batch) except queue.Full: diff --git a/src/dnet/ring/shard/startup.py b/src/dnet/ring/shard/startup.py index f3b38fbb..20ba93ec 100644 --- a/src/dnet/ring/shard/startup.py +++ b/src/dnet/ring/shard/startup.py @@ -286,7 +286,6 @@ async def setup_trace(req: TraceConfigRequest) -> TraceConfigResponse: logger.info("Updated tracer config.") self.api_address = cfg.aggregate_url self.tracer.start_aggregator() - logger.debug(cfg) return TraceConfigResponse(ok=True) except Exception as e: logger.error(f"Unable to setup tracing on shard: {e}") diff --git a/src/dnet/utils/logger.py b/src/dnet/utils/logger.py index 19f09c46..bac462f9 100644 --- a/src/dnet/utils/logger.py +++ b/src/dnet/utils/logger.py @@ -14,7 +14,7 @@ def get_logger() -> logging.Logger: Returns: Configured logger instance """ - logLevelEnv = os.getenv("DNET_LOG", "INFO").strip().upper() + logLevelEnv = os.getenv("DNET_LOG", "DEBUG").strip().upper() logLevel = logging.INFO # default if logLevelEnv in logging._nameToLevel: logLevel = logging._nameToLevel[logLevelEnv] From aff55614ff07275180447ca72e089dc39a665a90 Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 22 Oct 2025 01:13:59 -0700 Subject: [PATCH 110/172] fix indent and other rebase errors --- src/dnet/ring/shard/comms.py | 87 +++++++---- src/dnet/ring/shard/node.py | 294 +++++++++++++++++++++-------------- 2 files changed, 237 insertions(+), 144 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index dc197637..db5817e5 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -293,8 +293,8 @@ async def _send_activation(self, activation_msg: ActivationMessage): int(getattr(activation_msg, "token_id", -1)), rpc_ms, ) - except Exception as e: - logger.exception("Error sending token via gRPC: %s", e) + except Exception as e: + logger.exception("Error sending token via gRPC: %s", e) else: logger.error( "No valid gRPC callback for token; cannot deliver nonce=%s", @@ -322,7 +322,6 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) t_ser = time.perf_counter() t_cast = t_ser - _len_bytes = int(getattr(shaped, "nbytes", 0)) _do_compress = bool( self._compress and _len_bytes >= self._compress_min_bytes @@ -360,7 +359,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): f.event("mxarray.cast") data = tensor_to_bytes(shaped) - activation_msg.dtype = self._wire_dtype_str + activation_msg.dtype = self._wire_dtype_str nxt = activation_msg.layer_id + 1 if (nxt < self.model_metadata.num_layers) and (nxt not in self._assigned_set): @@ -426,6 +425,38 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) ctx.disabled = True + # Prefer streaming if enabled/available; fallback to unary + stream_used = False + ctx = await self._ensure_stream(activation_msg.nonce) + if ( + ctx + and ctx.open + and not ctx.disabled + and hasattr(dnet_ring_pb2, "ActivationFrame") + ): + try: + ctx.last_seq += 1 + frame = dnet_ring_pb2.ActivationFrame( + request=request, + seq=ctx.last_seq, + end_of_request=False, + ) + await ctx.queue.put(frame) + ctx.last_activity_t = asyncio.get_running_loop().time() + stream_used = True + if self._profile: + logger.info( + "[PROFILE][STREAM-ENQ] nonce=%s seq=%s q=%s", + activation_msg.nonce, + ctx.last_seq, + ctx.queue.qsize(), + ) + except Exception as e: + logger.warning( + "[STREAM] enqueue failed; fallback to unary: %s", e + ) + ctx.disabled = True + if not stream_used: # In fit mode, avoid long unary stalls: use short deadline and min retries # Streaming should be the norm; unary is a quick safety valve only. @@ -443,17 +474,17 @@ async def _send_activation(self, activation_msg: ActivationMessage): reason = "stream_closed" elif ctx.disabled: reason = "stream_disabled" - else: - reason = "enqueue_failed" - logger.warning( - "[STREAM->UNARY] node=%s nonce=%s reason=%s mode=%s timeout_s=%.1f retries=%d", - self.node_id, - activation_msg.nonce, - reason, - self._mode, - ring_timeout, - ring_retries, - ) + else: + reason = "enqueue_failed" + logger.warning( + "[STREAM->UNARY] node=%s nonce=%s reason=%s mode=%s timeout_s=%.1f retries=%d", + self.node_id, + activation_msg.nonce, + reason, + self._mode, + ring_timeout, + ring_retries, + ) t0 = time.perf_counter() last_exc: Optional[Exception] = None for attempt in range(1, ring_retries + 1): @@ -481,19 +512,19 @@ async def _send_activation(self, activation_msg: ActivationMessage): await asyncio.sleep(min(0.25 * attempt, 1.0)) continue raise - else: - raise last_exc # type: ignore - rpc_ms = (time.perf_counter() - t0) * 1000.0 - logger.info( - "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", - self.node_id, - activation_msg.nonce, - activation_msg.layer_id + 1, - (len(data) / 1024), - ser_ms, - rpc_ms, - cast_ms, - ) + else: + raise last_exc # type: ignore + rpc_ms = (time.perf_counter() - t0) * 1000.0 + logger.info( + "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", + self.node_id, + activation_msg.nonce, + activation_msg.layer_id + 1, + (len(data) / 1024), + ser_ms, + rpc_ms, + cast_ms, + ) else: logger.error("Cannot forward activation - no next node configured; end shard should sample inline.") diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 8f53da88..264d277b 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -444,8 +444,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse with self.tracer.frame("network", "connect.next_node"): if self.next_node: await self._connect_next_node() - else: - logger.warning("Node %s: No next node configured", self.node_id) + else: + logger.warning("Node %s: No next node configured", self.node_id) # Warmup: compile hot path and stabilize allocators before first request with self.tracer.frame("memory", "warmup"): @@ -697,59 +697,22 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): if activation.dtype == "tokens": with self.tracer.frame("grpc.receive", "token_stream") as fr: try: - tokens = np.frombuffer( - request.activation.data, dtype=np.int32 + deq = decompress_tensor_from_protobuf_data( + tensor_data=activation.data, + shape=list(activation.shape), + dtype_with_metadata=activation.dtype, ) - shp = (int(len(tokens)),) except Exception as e: logger.error( - "Failed to parse tokens for nonce %s: %s", - request.nonce, - e, - ) - return - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mx.int32, - shape=cast(tuple[int, ...], shp), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, + "Decompression failed for nonce %s: %s", request.nonce, e ) return - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - buffer[: len(tokens)] = tokens - if self._profile: - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (tokens)", - self.node_id, - request.nonce, - alloc_copy_ms, - ) - activation_msg = ActivationMessage.from_proto(request, pool_id) - # Ensure dtype reflects token payload for compute path - activation_msg.dtype = "tokens" - activation_msg.shape = shp - else: - with self.tracer.frame("grpc.receive", "default") as fr: - # Safety: byte length must match shape*dtype - try: - expected = ( - int(np.prod(activation.shape)) - * np.dtype(dtype_map[activation.dtype]).itemsize - ) - actual = len(request.activation.data) - except Exception: - pass + with self.tracer.frame("grpc.receive", "alloc.buffer") as fr: pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, - dtype=mlx_dtype_map[activation.dtype], - shape=cast(tuple[int, ...], activation.shape), + dtype=deq.dtype, + shape=cast(tuple[int, ...], tuple(deq.shape)), ) if pool_id is None: logger.warning( @@ -759,49 +722,135 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): return buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: - data = request.activation.data - input_data = np.frombuffer( - data, dtype=dtype_map[activation.dtype] - ) - buffer[: len(input_data)] = input_data + flat = deq.reshape(-1) + buffer[: flat.size] = flat alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f", + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", self.node_id, request.nonce, alloc_copy_ms, ) - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = new_dtype_str - activation_msg.shape = tuple(deq.shape) - - # Queue for processing — non-blocking back-off loop (cancellable) - if self._profile: - activation_msg.enq_perf_t = time.perf_counter() - while self.running: - try: - self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", + + # Update activation message with true dtype/shape + new_dtype_str = str(deq.dtype) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + else: + # Special token stream support: dtype='tokens' carries int32 token IDs + if activation.dtype == "tokens": + with self.tracer.frame("grpc.receive", "token_stream") as fr: + try: + tokens = np.frombuffer( + request.activation.data, dtype=np.int32 + ) + shp = (int(len(tokens)),) + except Exception as e: + logger.error( + "Failed to parse tokens for nonce %s: %s", + request.nonce, + e, + ) + return + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mx.int32, + shape=cast(tuple[int, ...], shp), + ) + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + buffer[: len(tokens)] = tokens + if self._profile: + alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 + logger.info( + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (tokens)", + self.node_id, + request.nonce, + alloc_copy_ms, + ) + activation_msg = ActivationMessage.from_proto(request, pool_id) + # Ensure dtype reflects token payload for compute path + activation_msg.dtype = "tokens" + activation_msg.shape = shp + else: + with self.tracer.frame("grpc.receive", "default") as fr: + # Safety: byte length must match shape*dtype + try: + expected = ( + int(np.prod(activation.shape)) + * np.dtype(dtype_map[activation.dtype]).itemsize + ) + actual = len(request.activation.data) + except Exception: + pass + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mlx_dtype_map[activation.dtype], + shape=cast(tuple[int, ...], activation.shape), + ) + if pool_id is None: + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) + return + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + data = request.activation.data + input_data = np.frombuffer( + data, dtype=dtype_map[activation.dtype] + ) + buffer[: len(input_data)] = input_data + alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 + logger.info( + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f", + self.node_id, + request.nonce, + alloc_copy_ms, + ) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + + # Queue for processing — non-blocking back-off loop (cancellable) + if self._profile: + activation_msg.enq_perf_t = time.perf_counter() + while self.running: + try: + self.activation_recv_queue.put_nowait(activation_msg) + logger.debug( + "Queued activation for processing: nonce %s", + activation_msg.nonce, + ) + break + except Full: + await asyncio.sleep(0) + else: + logger.error( + "Failed to queue activation %s (node stopping)", activation_msg.nonce, ) - break - except Full: - await asyncio.sleep(0) + self.input_pool.release(pool_id) else: - logger.error( - "Failed to queue activation %s (node stopping)", - activation_msg.nonce, + # Forward to next node (not our layer) + logger.debug( + "Forwarding activation (layer %s) to next node, nonce: %s", + target_layer, + request.nonce, ) - self.input_pool.release(pool_id) - else: - # Forward to next node (not our layer) - logger.debug( - "Forwarding activation (layer %s) to next node, nonce: %s", - target_layer, - request.nonce, - ) - await self._forward_activation(request) + await self._forward_activation(request) + + except Exception as e: + logger.exception("Error receiving activation: %s", e) + async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: @@ -851,10 +900,10 @@ async def _ingress_worker(self): logger.error(f"Unable to read length of data for {req.nonce}") payload_bytes = -1 - f.set("nonce", req.nonce) - f.set("target", target_layer) - f.set("payload_bytes", payload_bytes) - f.event("received") + f.set("nonce", req.nonce) + f.set("target", target_layer) + f.set("payload_bytes", payload_bytes) + f.event("received") if target_layer in self._assigned_set: # Heavy prep in executor (alloc/copy/decompress) @@ -878,41 +927,54 @@ async def _ingress_worker(self): with self.tracer.frame("grpc.ingress", "queue") as fr: while self.running: try: - self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", + activation_msg = await loop.run_in_executor( + self.executor, + self._prepare_activation_message_blocking, + req, + ) + except Exception as e: + logger.error("Activation prepare failed for nonce %s: %s", req.nonce, e) + continue + if activation_msg is None: + continue + if self._profile: + activation_msg.recv_perf_t = t_recv + + # Enqueue for compute (cancellable back-off) + with self.tracer.frame("grpc.ingress", "queue") as fr: + while self.running: + try: + self.activation_recv_queue.put_nowait(activation_msg) + logger.debug( + "Queued activation for processing: nonce %s", + activation_msg.nonce, + ) + break + except Full: + await asyncio.sleep(0) + else: + logger.error( + "Failed to queue activation %s (node stopping)", activation_msg.nonce, ) - break - except Full: - await asyncio.sleep(0) - else: - logger.error( - "Failed to queue activation %s (node stopping)", - activation_msg.nonce, - ) - break - except Full: - await asyncio.sleep(0) + try: + if self.input_pool: + # FIXME: !!! + self.input_pool.release(activation_msg.pool_id) + except Exception: + pass else: - logger.error( - "Failed to queue activation %s (node stopping)", - activation_msg.nonce, - ) - try: - if self.input_pool: - # FIXME: !!! - self.input_pool.release(activation_msg.pool_id) - except Exception: - pass - else: - # Forward to next node (not our layer) - logger.debug( - "Forwarding activation (layer %s) to next node, nonce: %s", - target_layer, - req.nonce, - ) - await self._forward_activation(req) + # Forward to next node (not our layer) + logger.debug( + "Forwarding activation (layer %s) to next node, nonce: %s", + target_layer, + req.nonce, + ) + await self._forward_activation(req) + + except Exceptio as e: + logger.error("Ingress worker error: %s", e) + def _get_or_make_kv(self, nonce: str) -> list: From 1a7ea70848d94fcb76682188c0d4ce508c39f6ab Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 22 Oct 2025 01:36:29 -0700 Subject: [PATCH 111/172] add accidentally removed code --- src/dnet/ring/shard/node.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 264d277b..223a1943 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -495,14 +495,6 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse k, ) - load_time_ms = (time.perf_counter() - start_time) * 1000.0 - logger.info( - "Node %s: Successfully loaded model %s with layers %s in %.2fms", - self.node_id, - req.model_path, - req.layers, - load_time_ms, - ) return ShardLoadModelResponse( success=True, From 02747d47518b276847cdf577a9cac9189efe61ad Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 22 Oct 2025 02:48:51 -0700 Subject: [PATCH 112/172] fix indent for _send_activation last token --- src/dnet/ring/shard/comms.py | 26 ++++++++------------------ src/dnet/ring/shard/node.py | 7 ++----- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index db5817e5..bcc826c3 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -226,6 +226,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) return try: + logger.debug(f"Sending activation") if activation_msg.is_final: with self.tracer.frame("grpc", "send_activation.final") as f: try: @@ -339,10 +340,11 @@ async def _send_activation(self, activation_msg: ActivationMessage): if shaped.dtype != wire_np_dtype: # FIXME: numpy vs mx array here shaped = shaped.astype(wire_np_dtype, copy=False) - else: - # MLX array -> cast to desired wire dtype + + else: # MLX array -> cast to desired wire dtype if str(shaped.dtype) != self._wire_dtype_str: shaped = shaped.astype(self._wire_mx_dtype) + activation_msg.dtype = self._wire_dtype_str t_cast = time.perf_counter() @@ -364,6 +366,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): nxt = activation_msg.layer_id + 1 if (nxt < self.model_metadata.num_layers) and (nxt not in self._assigned_set): if self.next_node_stub: + with self.tracer.frame("grpc", "send_activation.next") as f: request = activation_msg.to_proto(data) request.timestamp = utc_epoch_now() @@ -428,12 +431,8 @@ async def _send_activation(self, activation_msg: ActivationMessage): # Prefer streaming if enabled/available; fallback to unary stream_used = False ctx = await self._ensure_stream(activation_msg.nonce) - if ( - ctx - and ctx.open - and not ctx.disabled - and hasattr(dnet_ring_pb2, "ActivationFrame") - ): + if (ctx and ctx.open and not ctx.disabled and hasattr(dnet_ring_pb2, "ActivationFrame")): + logger.debug(f"Sending activation with stream") try: ctx.last_seq += 1 frame = dnet_ring_pb2.ActivationFrame( @@ -444,17 +443,8 @@ async def _send_activation(self, activation_msg: ActivationMessage): await ctx.queue.put(frame) ctx.last_activity_t = asyncio.get_running_loop().time() stream_used = True - if self._profile: - logger.info( - "[PROFILE][STREAM-ENQ] nonce=%s seq=%s q=%s", - activation_msg.nonce, - ctx.last_seq, - ctx.queue.qsize(), - ) except Exception as e: - logger.warning( - "[STREAM] enqueue failed; fallback to unary: %s", e - ) + logger.warning("[STREAM] enqueue failed; fallback to unary: %s", e) ctx.disabled = True if not stream_used: diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 223a1943..12a3bf46 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -846,10 +846,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: - """ - Lightweight admission for streaming: - enqueue protobuf frame to ingress queue, then return. - """ + """enqueue protobuf frame to ingress queue""" while self.running: try: self.ingress_q.put_nowait(request) @@ -872,7 +869,7 @@ async def _ingress_worker(self): with self.tracer.frame("grpc.ingress", "get.wait"): try: req = await self.ingress_q.get() - logger.debug(f"[DEQUE]Dequeued activation for processing {req}") + logger.debug(f"[DEQUE]Dequeued activation for processing") except asyncio.CancelledError: break From 928cbf9b189b9cb55900453cf716253c77aac7f5 Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 22 Oct 2025 03:34:37 -0700 Subject: [PATCH 113/172] remove old startup file and add tracer endpoints in node.py:_setup_routes --- src/dnet/ring/shard/node.py | 31 +- src/dnet/ring/shard/startup.py | 574 --------------------------------- 2 files changed, 30 insertions(+), 575 deletions(-) delete mode 100644 src/dnet/ring/shard/startup.py diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 12a3bf46..e10c4d65 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -33,6 +33,8 @@ ShardProfileRequest, ShardProfileResponse, ShardUnloadModelResponse, + TraceConfigRequest, + TraceConfigResponse, ) from ..model.base import BaseRingModel @@ -211,7 +213,7 @@ def __init__( enabled = True, record_pid_tid = True, aggregate=False, - aggregate_url=None, # FIXME: This is set when we get a /profile req + aggregate_url=None, ) self.tracer = Tracer(cfg) self.tracer.start() @@ -1529,6 +1531,33 @@ async def measure_latency( logger.error(f"Error in /measure_latency endpoint: {e}") raise + @self.app.post("/trace") + async def setup_trace(req: TraceConfigRequest) -> TraceConfigResponse: + logger.debug("Updating trace config") + try: + cfg = TraceConfig( + file=req.file, + streaming=req.streaming, + include_prefixes=req.include_prefixes, + include_c_calls=req.include_c_calls, + budget=req.budget, + enabled=req.enabled, + node_id=req.node_id, + record_pid_tid=req.record_pid_tid, + aggregate=req.aggregate, + aggregate_url=req.aggregate_url, + agg_max_events=req.agg_max_events, + ) + self.tracer.config = cfg + logger.info("Updated tracer config.") + self.api_address = cfg.aggregate_url + self.tracer.start_aggregator() + return TraceConfigResponse(ok=True) + except Exception as e: + logger.error(f"Unable to setup tracing on shard: {e}") + return TraceConfigResponse(ok=False) + + @self.app.post("/load_model") async def load_model_endpoint( req: ShardLoadModelRequest, diff --git a/src/dnet/ring/shard/startup.py b/src/dnet/ring/shard/startup.py deleted file mode 100644 index 20ba93ec..00000000 --- a/src/dnet/ring/shard/startup.py +++ /dev/null @@ -1,574 +0,0 @@ -from __future__ import annotations - -import asyncio -import time -from typing import Any, Dict, List, Mapping -import threading -from socket import gethostname -from secrets import token_hex - -import mlx.core as mx -from fastapi import Request -from fastapi.responses import JSONResponse -from grpc import aio as aio_grpc - -from hypercorn import Config -import hypercorn.asyncio as aio_hypercorn -from dnet_p2p.thunderbolt import ThunderboltConnection -from dnet_p2p import ( - DnetDeviceProperties, - discover_thunderbolt_connection, -) - -from dnet.perf.trace import TraceConfig - -from ...protos.dnet_ring_pb2_grpc import add_DnetRingServiceServicer_to_server -from .servicer import ShardServicer -from ...utils.logger import logger -from ...utils.serialization import tensor_to_bytes -from ...utils.latency import ( - DeviceLatencyResult, - LatencyMeasurement, - LatencyResults, - calculate_median_latency_seconds, -) -from .models import ( - HealthResponse, - ShardLoadModelRequest, - ShardLoadModelResponse, - ShardProfileRequest, - ShardProfileResponse, - ShardUnloadModelResponse, - TraceConfigRequest, - TraceConfigResponse, -) -from ...protos import dnet_ring_pb2 - - -class StartupMixin: - async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()): - self.running = True - - with self.tracer.frame("startup", "workers"): - self.background_tasks = [ - asyncio.create_task(self._ingress_worker()), - asyncio.create_task(self._prefetch_worker()), - asyncio.create_task(self._send_worker()) ] - - try: # Start idle sweeper to close silent streams - if getattr(self, "_streaming_enabled", False) and hasattr(self, "_stream_sweeper"): - self.background_tasks.append( asyncio.create_task(self._stream_sweeper())) - except Exception: - pass - - with self.tracer.frame("startup", "compute"): - self.compute_thread = threading.Thread(target=self._compute_worker, daemon=True) - self.compute_thread.start() - - with self.tracer.frame("startup", "discovery"): - self._start_discovery() - - logger.info( - "Shard node %s started on gRPC port %s HTTP port %s", - self.node_id, - self.grpc_port, - self.http_port, - ) - - def _start_discovery(self) -> None: - """Start mDNS discovery service.""" - hostname = gethostname() - # TODO: optionally take shard name from CLI - instance = f"shard-{token_hex(4)}-{hostname}" - self.discovery.create_instance( - instance, - hostname, - "0.0.0.0", # Binds to all addresses - self.http_port, # HTTP port - self.grpc_port, # gRPC port - is_manager=False, # Shard is never a manager - ) - self.discovery.start() - logger.info( - "Discovery service started for shard node %s with name %s", - self.node_id, - self.discovery.fullname(), - ) - - async def _start_grpc_server(self) -> None: - """Start gRPC server.""" - self.server = aio_grpc.server() - - # Add the ring servicer; shard acts as client for ShardApiService (to API) - servicer = ShardServicer(self) # type: ignore # FIXME: !!! - add_DnetRingServiceServicer_to_server(servicer, self.server) - - listen_addr = f"[::]:{self.grpc_port}" - self.server.add_insecure_port(listen_addr) - await self.server.start() - logger.info( - "Shard node %s gRPC server started on %s", self.node_id, listen_addr - ) - try: - await asyncio.get_running_loop().run_in_executor( - self.executor, self._warmup_serialization - ) - logger.info("Warmup serialization completed") - except Exception as e: - logger.warning("Warmup serialization failed: %s", e) - - def _warmup_serialization(self): - try: - dummy = mx.random.normal((1024, 1024), dtype=mx.float32) - dummy16 = dummy.astype(self._wire_mx_dtype) - _ = tensor_to_bytes(dummy16) - except Exception: - pass - - def _warmup_shard(self): - logger.info( - "[WARMUP] Starting shard warmup with window size %s", self.window_size - ) - batch_size, seq_len = 1, 1 - hidden_size = self.model_metadata.model_config.get("hidden_size", 2560) - x = mx.zeros((batch_size, seq_len, hidden_size), dtype=mx.bfloat16) - start_time = time.perf_counter() - try: - default_n = max(1, int(getattr(self, "_resident_windows", 1))) - except Exception: - default_n = 1 - try: - max_windows = max( - 1, - int( - getattr(self, "config", None).warmup_windows - if getattr(self, "config", None) - else default_n - ), - ) - except Exception: - max_windows = default_n - windows: list[list[int]] = [] - for window_start in range(0, len(self._assigned_sorted), self.window_size): - window_end = min( - window_start + self.window_size, len(self._assigned_sorted) - ) - windows.append(self._assigned_sorted[window_start:window_end]) - for wi, window_layers in enumerate(windows[:max_windows]): - weights_to_bind = {} - for layer_id in window_layers: - weights = self.weight_cache.get_weight(layer_id) - if weights: - for k, v in weights.items(): - weights_to_bind[k] = v - if weights_to_bind: - self.model.load_weights(list(weights_to_bind.items()), strict=False) - try: - for layer_id in window_layers: - x = self.model.apply_single_layer(layer_id, x, cache=None) - _s = mx.sum(x) - mx.eval(_s) - except Exception: - pass - try: - for lid in window_layers: - self.weight_cache.decrease_reference(lid) - except Exception: - pass - if not self._warmup_keep_flag: - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(window_layers) # type: ignore[attr-defined] - except Exception: - pass - try: - self.weight_cache.evict_layers(window_layers) - except Exception: - pass - total_time = (time.perf_counter() - start_time) * 1000 - self._warmup_completed = True - logger.info( - "[WARMUP] Shard warmup completed in %.2fms; windows=%s kept=%s", - total_time, - min(len(windows), max_windows), - int(self._warmup_keep_flag), - ) - - async def _start_http_server(self, shutdown_trigger: Any) -> None: - """Start HTTP server. - - Args: - shutdown_trigger: Shutdown trigger function - """ - await self._setup_routes() - - # Start HTTP server in background - config = Config.from_mapping( - bind=f"0.0.0.0:{self.http_port}", - log_level="info", - log_config=None, - use_reloader=False, - h2c=False, - ) - - # Start the server as a background task - self.http_server = asyncio.create_task( - aio_hypercorn.serve(self.app, config, shutdown_trigger=shutdown_trigger) # type: ignore - ) - logger.info( - "Shard node %s HTTP server started on port %s", self.node_id, self.http_port - ) - - async def _setup_routes(self) -> None: - """Setup HTTP routes.""" - - @self.app.get("/health") - async def health() -> HealthResponse: - try: - instance = self.discovery.instance_name() - except Exception: - instance = None - return HealthResponse( - status="ok", - node_id=self.node_id, - running=self.running, - model_loaded=self._check_model_loaded(), - model_path=self.model_path, - assigned_layers=self.assigned_layers, - queue_size=self.activation_recv_queue.qsize(), - grpc_port=self.grpc_port, - http_port=self.http_port, - instance=instance, - ) - - @self.app.post("/profile") - async def profile(req: ShardProfileRequest) -> ShardProfileResponse: - try: - latency_results = await self._measure_latency_to_devices( req.devices, req.thunderbolts, req.payload_sizes) - device_profile = await self._profile_device( req.repo_id, req.max_batch_exp) - - # Overwrite `t_comm` with median latency (subprocess returns a dict) - median_latency = calculate_median_latency_seconds(latency_results) - if median_latency is not None: - device_profile["t_comm"] = float(median_latency) - logger.info( - f"Set t_comm to median latency: {device_profile['t_comm']:.6f}s" - ) - else: - logger.warning( "No valid latency measurements, keeping default t_comm") - - # Return the dict payload directly - return ShardProfileResponse( - profile=device_profile, - latency=latency_results, - ) - except Exception as e: - logger.error(f"Error in /profile endpoint: {e}") - raise - - @self.app.post("/trace") - async def setup_trace(req: TraceConfigRequest) -> TraceConfigResponse: - try: - cfg = TraceConfig( - file=req.file, - streaming=req.streaming, - include_prefixes=req.include_prefixes, - include_c_calls=req.include_c_calls, - budget=req.budget, - enabled=req.enabled, - node_id=req.node_id, - record_pid_tid=req.record_pid_tid, - aggregate=req.aggregate, - aggregate_url=req.aggregate_url, - agg_max_events=req.agg_max_events, - ) - self.tracer.config = cfg - logger.info("Updated tracer config.") - self.api_address = cfg.aggregate_url - self.tracer.start_aggregator() - return TraceConfigResponse(ok=True) - except Exception as e: - logger.error(f"Unable to setup tracing on shard: {e}") - return TraceConfigResponse(ok=False) - - @self.app.post("/load_model") - async def load_model_endpoint( - req: ShardLoadModelRequest, - ) -> ShardLoadModelResponse: - """Load model with specified layers.""" - try: - logger.info( - f"HTTP /load_model: model={req.model_path}, layers={req.layers}, " - f"next_node={req.next_node or 'none'}, window_size={req.window_size}, " - f"total_layers={req.total_layers}, api_callback={req.api_callback_address or 'none'}" - ) - result = await self.load_model(req) - return result - - except Exception as e: - logger.error(f"Error in /load_model endpoint: {e}") - return ShardLoadModelResponse( - success=False, - message=f"Error: {str(e)}", - layers_loaded=[], - load_time_ms=0.0, - ) - - @self.app.post("/unload_model") - async def unload_model_endpoint() -> ShardUnloadModelResponse: - """Unload current model.""" - try: - logger.info("HTTP /unload_model") - result = await self.unload_model() - return result - - except Exception as e: - logger.error(f"Error in /unload_model endpoint: {e}") - return ShardUnloadModelResponse( - success=False, - message=f"Error: {str(e)}", - ) - - @self.app.post("/warm") - async def warm(request: Request) -> JSONResponse: - try: - body = await request.json() - start = int(body.get("start", -1)) - window = int(body.get("window", self.window_size)) - if start < 0: - return JSONResponse( - status_code=400, content={"error": "missing/invalid start"} - ) - start_idx = 0 - for i, lyr in enumerate(self._assigned_sorted): - if lyr >= start: - start_idx = i - break - else: - return JSONResponse(content={"prefetched": []}) - window_layers = self._assigned_sorted[ - start_idx : start_idx + max(1, window) - ] - for wl in window_layers: - self._prefetch_to_ram(wl) - self._enqueue_weight_prefetch(wl) - return JSONResponse(content={"prefetched": window_layers}) - except Exception as e: - logger.error("/warm failed: %s", e) - return JSONResponse(status_code=500, content={"error": str(e)}) - - async def _profile_device(self, repo_id: str, max_batch_exp: int) -> dict: - """Profile device using dperf in a subprocess and return a dict. - - Args: - repo_id: Hugging Face repository ID - max_batch_exp: Maximum batch size exponent (2^max_batch_exp) - - Returns: - Device profile information as a plain dict - """ - from ...utils.profile_subproc import profile_device_via_subprocess - - profile_dict = profile_device_via_subprocess( - repo_id, max_batch_exp=max_batch_exp, debug=0 - ) - logger.info("Device profiling completed for node %s", self.node_id) - return profile_dict - - async def _connect_next_node(self) -> bool: - """Connect to next node in ring. - - Returns: - True if connected or no next node, False on failure - """ - if not self.next_node: - logger.info(f"Shard node {self.node_id} is the final shard (no next node)") - return True - - if self.next_node_channel: - logger.debug(f"Shard node {self.node_id} already connected to next node.") - return True - - try: - # use thunderbolt here if available - this_properties = self.discovery.get_own_properties() - thunderbolt_conn = discover_thunderbolt_connection( - this_properties, - self.next_node, - ) - next_ip = ( - thunderbolt_conn.ip_addr - if thunderbolt_conn - else self.next_node.local_ip - ) - address = f"{next_ip}:{self.next_node.shard_port}" - logger.info( - f"Shard node {this_properties.instance} connecting to next node {self.next_node.instance} at {address}" - ) - - self.next_node_channel = aio_grpc.insecure_channel(address) - from ...protos.dnet_ring_pb2_grpc import DnetRingServiceStub - - self.next_node_stub = DnetRingServiceStub(self.next_node_channel) - return True - except Exception as e: - logger.warning( - f"Shard node {self.node_id} failed to connect to next node {address}: {e}" - ) - self.next_node_channel = None - self.next_node_stub = None - return False - - async def _reconnect_next_node(self) -> bool: - try: - if self.next_node_channel: - await self.next_node_channel.close() - except Exception: - pass - self.next_node_channel = None - self.next_node_stub = None - return await self._connect_next_node() - - async def _health_check(self): - try: - health_request = dnet_ring_pb2.HealthRequest(requester_id=str(self.node_id)) - response = await self.next_node_stub.HealthCheck(health_request) # type: ignore - logger.info( - "Shard node %s successfully pinged: %s, healthy: %s", - self.node_id, - response.node_id, - response.healthy, - ) - return True - except Exception as e: - logger.warning( - "Shard node %s failed to ping next node %s: %s", - self.node_id, - self.next_node_address, - e, - ) - return False - - async def _measure_latency_to_devices( - self, - devices: Mapping[str, DnetDeviceProperties], - thunderbolts: Mapping[str, ThunderboltConnection], - payload_sizes: List[int], - ) -> LatencyResults: - """Measure latency to all devices except self. - - Args: - devices: Device information mapping - thunderbolts: Thunderbolt connection information - payload_sizes: List of payload sizes to test - - Returns: - Latency measurement results - """ - latency_results_dict: Dict[str, DeviceLatencyResult] = {} - - for service_name, device_info in devices.items(): - # Skip measuring latency to ourselves - if service_name.startswith(self.discovery.instance_name()): - logger.debug("Skipping latency measurement to self: %s", service_name) - continue - - # Skip measuring latency to API (manager) devices - if device_info.is_manager: - logger.debug( - "Skipping latency measurement to manager/API: %s", service_name - ) - continue - - try: - shard_port = device_info.shard_port - - # Check for Thunderbolt connection - if service_name in thunderbolts: - tb_data = thunderbolts[service_name] - service_ip = tb_data.ip_addr - logger.info( - "Using Thunderbolt for %s at %s, connected to instance %s", - service_name, - service_ip, - tb_data.instance, - ) - else: - # No Thunderbolt, use WiFi - service_ip = device_info.local_ip - - if not shard_port or not service_ip: - logger.warning( - "No shard_port or local_ip for device %s", service_name - ) - continue - - # Connect to target shard's gRPC server - target_address = f"{service_ip}:{shard_port}" - channel = aio_grpc.insecure_channel(target_address) - from ...protos.dnet_ring_pb2_grpc import DnetRingServiceStub - - stub = DnetRingServiceStub(channel) - - # Measure latency for each payload size - latency_measurements: List[LatencyMeasurement] = [] - for payload_size in payload_sizes: - # Create dummy payload - dummy_data = b"x" * payload_size - - start_time = time.perf_counter() - timestamp_ms = int(time.time() * 1000) - - request = dnet_ring_pb2.LatencyMeasureRequest( - requester_id=str(self.node_id), - payload_size=payload_size, - dummy_data=dummy_data, - timestamp=timestamp_ms, - ) - - response = await stub.MeasureLatency(request) # type: ignore - end_time = time.perf_counter() - - if response.success: - latency_ms = (end_time - start_time) * 1000 - latency_measurements.append( - LatencyMeasurement( - payload_size=payload_size, - latency_ms=round(latency_ms, 2), - success=True, - error=None, - ) - ) - else: - latency_measurements.append( - LatencyMeasurement( - payload_size=payload_size, - success=False, - error=response.message, - latency_ms=0, - ) - ) - - # Store results - result = DeviceLatencyResult( - target_node_id=response.node_id if response.success else None, - measurements=latency_measurements, - success=True, - error=None, - ) - latency_results_dict[service_name] = result - - # Close channel - await channel.close() - - except Exception as e: - logger.error("Error measuring latency to %s: %s", service_name, e) - result = DeviceLatencyResult( - target_node_id=None, - success=False, - error=str(e), - measurements=[], - ) - latency_results_dict[service_name] = result - - return LatencyResults(results=latency_results_dict) From 240300adbec86ddfe774d9f2480e871dfeba11a2 Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 22 Oct 2025 22:52:27 -0700 Subject: [PATCH 114/172] runtime stats high-level frames --- src/dnet/ring/shard/node.py | 74 +++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index e10c4d65..ba62b0ef 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -376,6 +376,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse resident_windows=self._resident_windows, use_mxload_fastpath=self.config.mxload_fastpath, prefetch_mode=self.config.prefetch_mode, + tracer=self.tracer, ) # Load the model @@ -1159,7 +1160,7 @@ def _compute_worker(self) -> None: activation_msg = self.activation_recv_queue.get(timeout=1.0) # Process the activation - with self.tracer.frame("compute", "forward"): + with self.tracer.frame("compute", "forward"): # NOTE: Symbol hardcoded for runtime stats self._process_activation(activation_msg) except Empty: @@ -1302,9 +1303,7 @@ def _warmup_serialization(self): pass def _warmup_shard(self): - logger.info( - "[WARMUP] Starting shard warmup with window size %s", self.window_size - ) + logger.info("[WARMUP] Starting shard warmup with window size %s", self.window_size) if not self.model or not self.model_metadata or not self.weight_cache: logger.warning("[WARMUP] No model loaded; skipping warmup") return @@ -1326,10 +1325,9 @@ def _warmup_shard(self): max_windows = max(1, self.config.warmup_windows) windows: list[list[int]] = [] for window_start in range(0, len(self._assigned_sorted), self.window_size): - window_end = min( - window_start + self.window_size, len(self._assigned_sorted) - ) + window_end = min(window_start + self.window_size, len(self._assigned_sorted)) windows.append(self._assigned_sorted[window_start:window_end]) + for wi, window_layers in enumerate(windows[:max_windows]): weights_to_bind = {} for layer_id in window_layers: @@ -1337,6 +1335,7 @@ def _warmup_shard(self): if weights: for k, v in weights.items(): weights_to_bind[k] = v + if weights_to_bind: # Serialize MLX parameter binding with self._mlx_lock: @@ -1350,6 +1349,7 @@ def _warmup_shard(self): mx.eval(_s) except Exception: pass + try: for lid in window_layers: self.weight_cache.decrease_reference(lid) @@ -1570,7 +1570,8 @@ async def load_model_endpoint( f"total_layers={req.total_layers}, kv_bits={req.kv_bits or 'default'}, " f"api_callback={req.api_callback_address or 'none'}" ) - result = await self.load_model(req) + with self.tracer.frame("memory", "model.load"): # NOTE: Symbol hardcoded for runtime stats + result = await self.load_model(req) return result except Exception as e: @@ -1587,7 +1588,8 @@ async def unload_model_endpoint() -> ShardUnloadModelResponse: """Unload current model.""" try: logger.info("HTTP /unload_model") - result = await self.unload_model() + with self.tracer.frame("memory", "model.unload"): # NOTE: Symbol hardcoded for runtime stats + result = await self.unload_model() return result except Exception as e: @@ -1601,29 +1603,30 @@ async def unload_model_endpoint() -> ShardUnloadModelResponse: # FIXME: add pydantic type here async def warm(request: Request) -> JSONResponse: try: - body = await request.json() - start = int(body.get("start", -1)) - window = int(body.get("window", self.window_size)) - if start < 0: - return JSONResponse( - status_code=400, content={"error": "missing/invalid start"} - ) - start_idx = 0 - for i, lyr in enumerate(self._assigned_sorted): - if lyr >= start: - start_idx = i - break - else: - return JSONResponse(content={"prefetched": []}) - window_layers = self._assigned_sorted[ - start_idx : start_idx + max(1, window) - ] - for wl in window_layers: - # Prefetch disabled in fit mode; allow only when non-fit and enabled - if self._mode != "fit" and self.config.prefetch_mode != "off": - self._prefetch_to_ram(wl) - self._enqueue_weight_prefetch(wl) - return JSONResponse(content={"prefetched": window_layers}) + with self.tracer.frame("memory", "model.warm"): # NOTE: Symbol hardcoded for runtime stats + body = await request.json() + start = int(body.get("start", -1)) + window = int(body.get("window", self.window_size)) + if start < 0: + return JSONResponse( + status_code=400, content={"error": "missing/invalid start"} + ) + start_idx = 0 + for i, lyr in enumerate(self._assigned_sorted): + if lyr >= start: + start_idx = i + break + else: + return JSONResponse(content={"prefetched": []}) + window_layers = self._assigned_sorted[ + start_idx : start_idx + max(1, window) + ] + for wl in window_layers: + # Prefetch disabled in fit mode; allow only when non-fit and enabled + if self._mode != "fit" and self.config.prefetch_mode != "off": + self._prefetch_to_ram(wl) + self._enqueue_weight_prefetch(wl) + return JSONResponse(content={"prefetched": window_layers}) except Exception as e: logger.error("/warm failed: %s", e) return JSONResponse(status_code=500, content={"error": str(e)}) @@ -1665,9 +1668,10 @@ async def _profile_device(self, repo_id: str, max_batch_exp: int) -> dict: Returns: Device profile information as a plain dict """ - profile_dict = profile_device_via_subprocess( - repo_id, max_batch_exp=max_batch_exp, debug=0 - ) + with self.tracer.frame("startup", "profile.device"): # NOTE: Symbol hardcoded for runtime stats + profile_dict = profile_device_via_subprocess( + repo_id, max_batch_exp=max_batch_exp, debug=0 + ) logger.info("Device profiling completed for node %s", self.node_id) return profile_dict From ca8147f545eed480a32dbf2d2e845d56ee1b5909 Mon Sep 17 00:00:00 2001 From: Octavian Date: Wed, 22 Oct 2025 22:52:54 -0700 Subject: [PATCH 115/172] runtime stats aggregator --- src/dnet/perf/utils/aggregator.py | 64 +++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index 010e6922..34ba7fa8 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -7,6 +7,7 @@ from collections import defaultdict, deque from dnet.utils.logger import logger +from dnet.ring.common import LayerAssignment, TopologyInformation Key = Tuple[str, Optional[int], Optional[int], str] # (node_id, pid, tid, req_id) @@ -205,3 +206,66 @@ def q(p: float) -> float: def roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: # Call-tree storage is disabled to reduce memory; keep API for compatibility. return [] + + +# Runtime statistics +# Use a RunAggregator to get raw frames per request, then transform into _RuntimeStats + +# Track a single request, use multiple for a full benchmark +@dataclass +class _RuntimeStats: + model: str # Model name + tokenizer: str # Tokenizer name + run_id: str # ID of request serviced (for later mapping) + ttft: Dict[str, float] # Time to first token, map: p50 : 0.0 (ms) + itl: Dict[str, float] # Inter-token latency, mapL p50 : 0.0 (ms) + requests: int # Number of requests serviced + failed: int # Number of failed requests + prompt_tokens: int # Number of prompt tokens per request (req_id: #) + generated_tokens: int # Number of generated tokens per request (req_id: #) + + latencys: Dict[List[str, str, str], int] # Map of latencys: [node0, node1, p50]: 0.0 + latency_per_layer: Dict[int, float] # Map of {layer: 0.0} + latency_per_shard: Dict[str, float] # Map of {shard: 0.0} + total_latency: int # Total runtime of requests + throughput: float # aaa + + topo: TopologyInfo = None # Topology information for this request (keep here since it might change) + assignment: LayerAssignment = None # Map of layer to shard IDs + startup_t: float # Time to start shard (ms) + layer_assignment_t: float # Time to layer assignment (ms) + + +# NOTE: Hardcodes some high-level trace frame symbols +def to_runstats(agg: RunAggregator): + pass + +# Process stats + handle per-request data +class StatsAggregator: + def __init__(self) -> None: + self._req: Dict[str, _RuntimeStats] = {} # Map req_id : RuntimeStats obj + self._lock = threading.Lock() + + # Ingest raw data from tracer + def add(self, run: _RuntimeStats) -> bool: + run_id = run.get("run_id") + + # Return data for total, per req, worker or model (maybe add per layer too?) + def stats( + self, + req_id: Optional[str], + worker: Optional[str], + model: Optional[str] + ): + + if req_id: + pass + + elif worker: + pass + + elif model: + pass + + else: # Return stats of all counters + From 4630c41cac6b5b5e5da061c09dd4aea8aed06208 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 00:24:40 -0700 Subject: [PATCH 116/172] track per-nonce in-flight and in-wait times and append to ingress trace frame --- src/dnet/ring/shard/node.py | 38 +++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index ba62b0ef..7f22e475 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -218,6 +218,10 @@ def __init__( self.tracer = Tracer(cfg) self.tracer.start() + # Get in-flight and in-wait time per request + self._rx_ingress_t: Dict[str, float] = {} # Mapping of nonce -> perf_counter() + self._rx_inflight_t: Dict[str, float] = {} # Track per-request inflight + # Per-nonce KV caches (concurrent requests) self._kv_by_nonce: Dict[str, list] = {} self._kv_last_seen: Dict[str, float] = {} @@ -852,6 +856,10 @@ async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: """enqueue protobuf frame to ingress queue""" while self.running: try: + rx_t = time.perf_counter() + self._rx_ingress_t[request.nonce] = rx_t + self._rx_inflight_t[request.nonce] = rx_t - request.timestamp + self.ingress_q.put_nowait(request) logger.debug(f"[ENQUE] Enqueued activation request") return @@ -868,17 +876,22 @@ async def _ingress_worker(self): finally enqueues for compute or forwards to the next shard. """ while self.running: - with self.tracer.frame("grpc", "ingress") as f: - with self.tracer.frame("grpc.ingress", "get.wait"): - try: - req = await self.ingress_q.get() - logger.debug(f"[DEQUE]Dequeued activation for processing") - except asyncio.CancelledError: - break + with self.tracer.frame("network.ingress", "wait"): # NOTE: bad counter + try: + req = await self.ingress_q.get() + logger.debug(f"[DEQUE]Dequeued activation for processing") + except asyncio.CancelledError: + break + + # Trace processing of request, in-flight and in-wait times + with self.tracer.frame("network.ingress", "process") as f: + f.set("inflight", self._rx_inflight_t[req.nonce]) + f.set("inwait", time.perf_counter() - self._rx_ingress_t[req.nonce]) + f.set("nonce", req.nonce) try: - with self.tracer.frame("grpc.ingress", "connect_next_node"): - await self._connect_next_node() + #with self.tracer.frame("grpc.ingress", "connect_next_node"): + await self._connect_next_node() activation = req.activation target_layer = activation.layer_id + 1 @@ -1025,7 +1038,7 @@ def _prepare_activation_message_blocking( activation = request.activation if "|" in activation.dtype: # Compressed path: decompress to MLX array and copy to pool - with self.tracer.frame("grpc.ingress.prepare_activation", "decompress") as f: + with self.tracer.frame("network.ingress.prepare_activation", "decompress") as f: try: deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, @@ -1061,7 +1074,7 @@ def _prepare_activation_message_blocking( return activation_msg elif activation.dtype == "tokens": # Tokens path: parse int32 token IDs and stage them - with self.tracer.frame("grpc.ingress.prepare_activation", "tokens") as f: + with self.tracer.frame("network.ingress.prepare_activation", "tokens") as f: try: tokens = np.frombuffer(activation.data, dtype=np.int32) shp = (int(len(tokens)),) @@ -1090,7 +1103,7 @@ def _prepare_activation_message_blocking( return activation_msg else: # Dense path: validate size and copy raw bytes view into pool buffer - with self.tracer.frame("grpc.ingress.prepare_activation", "default") as f: + with self.tracer.frame("network.ingress.prepare_activation", "default") as f: try: expected = ( int(np.prod(activation.shape)) @@ -1570,6 +1583,7 @@ async def load_model_endpoint( f"total_layers={req.total_layers}, kv_bits={req.kv_bits or 'default'}, " f"api_callback={req.api_callback_address or 'none'}" ) + self.tracer.mark("model", {"model": req.model_path, "ts": time.perf_counter()}) # Record model name with self.tracer.frame("memory", "model.load"): # NOTE: Symbol hardcoded for runtime stats result = await self.load_model(req) return result From 7bf6bbbe140c1eaccf666b55ceee618a30c94d06 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 00:30:44 -0700 Subject: [PATCH 117/172] stop tracking bytes and target, change 'grpc' to 'network' for cleaner frame tagging --- src/dnet/ring/shard/node.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 7f22e475..1f52c2c7 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -851,7 +851,6 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): logger.exception("Error receiving activation: %s", e) - async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: """enqueue protobuf frame to ingress queue""" while self.running: @@ -884,13 +883,13 @@ async def _ingress_worker(self): break # Trace processing of request, in-flight and in-wait times - with self.tracer.frame("network.ingress", "process") as f: + with self.tracer.frame("network", "ingress") as f: f.set("inflight", self._rx_inflight_t[req.nonce]) f.set("inwait", time.perf_counter() - self._rx_ingress_t[req.nonce]) f.set("nonce", req.nonce) try: - #with self.tracer.frame("grpc.ingress", "connect_next_node"): + #with self.tracer.frame("network.ingress", "connect_next_node"): await self._connect_next_node() activation = req.activation @@ -946,7 +945,7 @@ async def _ingress_worker(self): activation_msg.recv_perf_t = t_recv # Enqueue for compute (cancellable back-off) - with self.tracer.frame("grpc.ingress", "queue") as fr: + with self.tracer.frame("network.ingress", "enque") as fr: while self.running: try: self.activation_recv_queue.put_nowait(activation_msg) @@ -958,18 +957,15 @@ async def _ingress_worker(self): except Full: await asyncio.sleep(0) else: - logger.error( - "Failed to queue activation %s (node stopping)", - activation_msg.nonce, - ) + logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce,) try: if self.input_pool: # FIXME: !!! self.input_pool.release(activation_msg.pool_id) except Exception: pass - else: - # Forward to next node (not our layer) + + else: # Forward to next node (not our layer) logger.debug( "Forwarding activation (layer %s) to next node, nonce: %s", target_layer, From 242839f848b464474daa4d4eb76b65119b964b0d Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 00:31:22 -0700 Subject: [PATCH 118/172] remove profiling logs --- src/dnet/ring/shard/compute.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 794fc682..fce05d21 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -65,16 +65,13 @@ def _delta_swap_eviction( return len(evicted) def _process_activation(self, activation_msg: ActivationMessage): - if ( - not self.model + if ( not self.model or not self.model_metadata or not self.weight_cache or not self.input_pool or not self.output_pool ): - logger.error( - "Node %s: Cannot process activation - model not loaded", self.node_id - ) + logger.error("Node %s: Cannot process activation - model not loaded", self.node_id) return try: @@ -172,7 +169,6 @@ def _process_activation(self, activation_msg: ActivationMessage): ) to_bind: Dict[str, mx.array] = {} if not skip_scan: - t_w_ready = time.perf_counter() for wl in window_layers: weights = self.weight_cache.get_weight(wl) if weights is None: @@ -267,7 +263,8 @@ def _process_activation(self, activation_msg: ActivationMessage): """ for lid in window_layers: - self.weight_cache.decrease_reference(lid) + #self.weight_cache.decrease_reference(lid) + pass with self.tracer.frame("compute.thread", "execute.evict_and_unload"): try: From 2c2608bf15217f50741e2859c882c46243f9663a Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 01:06:22 -0700 Subject: [PATCH 119/172] aggregate per-nonce --- src/dnet/perf/utils/aggregator.py | 87 ++++++++++++++++++++++++------- 1 file changed, 68 insertions(+), 19 deletions(-) diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index 34ba7fa8..b9cd78c5 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -7,7 +7,8 @@ from collections import defaultdict, deque from dnet.utils.logger import logger -from dnet.ring.common import LayerAssignment, TopologyInformation +from dnet.ring import LayerAssignment, TopologyInfo +from dnet.perf import _Frame Key = Tuple[str, Optional[int], Optional[int], str] # (node_id, pid, tid, req_id) @@ -119,22 +120,13 @@ def __init__(self) -> None: def enqueue(self, batch: Dict[str, Any]) -> None: run_id = batch.get("run_id") node_id = batch.get("node_id") - logger.debug(f"Enquing trace buffer from {run_id}, {node_id}") - if not run_id or not node_id: - return events = batch.get("events") or [] - batch_seq = int(batch.get("batch_seq") or 0) + logger.debug(f"Enquing trace buffer from {run_id}, {node_id}") + if not run_id or not node_id: return # Drop batch with self._lock: agg = self._req.setdefault(run_id, RunAggregator()) - last = agg.last_batch_seq.get(node_id) - if (last is not None) and (batch_seq != last + 1): - agg.drops += abs(batch_seq - (last + 1)) - agg.last_batch_seq[node_id] = batch_seq for ev in events: - try: - agg.ingest_event(node_id, ev) - except Exception: - continue + agg.ingest_event(node_id, ev) def annotate(self, run_id: str, *, mapping: Optional[Dict[str, str]] = None, repeats: int = 0) -> List[Dict[str, Any]]: with self._lock: @@ -216,7 +208,8 @@ def roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: class _RuntimeStats: model: str # Model name tokenizer: str # Tokenizer name - run_id: str # ID of request serviced (for later mapping) + run_id: str # ID of session (for later mapping) + nonce: List[str] # List of serviced requests ttft: Dict[str, float] # Time to first token, map: p50 : 0.0 (ms) itl: Dict[str, float] # Inter-token latency, mapL p50 : 0.0 (ms) requests: int # Number of requests serviced @@ -229,11 +222,11 @@ class _RuntimeStats: latency_per_shard: Dict[str, float] # Map of {shard: 0.0} total_latency: int # Total runtime of requests throughput: float # aaa + startup_t: float # Time to start shard (ms) + layer_assignment_t: float # Time to layer assignment (ms) topo: TopologyInfo = None # Topology information for this request (keep here since it might change) assignment: LayerAssignment = None # Map of layer to shard IDs - startup_t: float # Time to start shard (ms) - layer_assignment_t: float # Time to layer assignment (ms) # NOTE: Hardcodes some high-level trace frame symbols @@ -243,12 +236,67 @@ def to_runstats(agg: RunAggregator): # Process stats + handle per-request data class StatsAggregator: def __init__(self) -> None: - self._req: Dict[str, _RuntimeStats] = {} # Map req_id : RuntimeStats obj self._lock = threading.Lock() + self._max_resident_rq = 50 # per node FIXME: modify from repl + self._workers: Dict[str, Dict[str, Dict[str, _Frame]]] = {} # Store frames per nonce, per node_id + + self._nonces = [] # Tracked nonces (either in-flight or done) + self._nonce_round_finish: Dict[str, bool] = {} # Track in-flight rounds + self._nonce_prefill: Dict[str, bool] = {} # Track if this round is prefill + self._running_stats: Dict[str, _RuntimeStats] = {} # Unfinished stat frames + self._stats: Dict[str, _RuntimeStats] = {} # Finished stat frames + # Ingest raw data from tracer - def add(self, run: _RuntimeStats) -> bool: - run_id = run.get("run_id") + def add(self, data: Dict[str, Any]) -> None: + run_id = data.run_id + node_id = data.node_id + events = data.events or [] + name = data.name + if not run_id or not node_id: return # Drop the batch + + with self._lock: + for i, ev in enumerate(events): + nonce = ev.attrs["nonce"] or f"ERR_{i}" + + if node_id not in self._workers: + self._workers[node_id] = {} + + if nonce not in self._workers[node_id]: + self._workers[node_id][nonce] = {} + + if name not in self._workers[node_id][nonce]: + self._workers[node_id][nonce][name] = [ev, ] + continue + + if len(self._workers[node_id]) >= self._max_resident_req: # remove oldest entry + del self._workers[self._nonces[0]] + del self._nonces[0] + + self._workers[node_id][name].append(ev) + self._nonces.push(nonce) + + # Construct RuntimeStats + assert "model" in self._frames, "No model found in trace data." + + rt_stat = self._req.setdefault(run_id, _RuntimeStats) + #rt_stat.model = self._workers[0]["model"][-1].attrs["model"] + rt_stat.tokenizer = + rt_stat.run_id = run_id + rt_stat.ttft = {} + + for n in self._nonces: # accumulate new data for each nonce + for shard in self._workers: + + if "final" in self._workers[node_id][nonce] and not self._nonce_round_finish[nonce]: + self._nonce_round_finish[nonce] = True + if not self._nonce_prefill[nonce]: # This is prefill, append ttft + + + acc_ttt = 0 # accumulated time to token + acc_ttt += shard["network.ingress"][-1] + inflight = shard['network.ingress'][] + # Return data for total, per req, worker or model (maybe add per layer too?) def stats( @@ -268,4 +316,5 @@ def stats( pass else: # Return stats of all counters + pass From adf3edfe62ed10415e31c905e80115f42116167a Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 02:08:31 -0700 Subject: [PATCH 120/172] construct new request on embedding event --- src/dnet/perf/utils/aggregator.py | 81 +++++++++++++++++++------------ 1 file changed, 49 insertions(+), 32 deletions(-) diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index b9cd78c5..de1b508f 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -8,7 +8,6 @@ from dnet.utils.logger import logger from dnet.ring import LayerAssignment, TopologyInfo -from dnet.perf import _Frame Key = Tuple[str, Optional[int], Optional[int], str] # (node_id, pid, tid, req_id) @@ -210,8 +209,8 @@ class _RuntimeStats: tokenizer: str # Tokenizer name run_id: str # ID of session (for later mapping) nonce: List[str] # List of serviced requests - ttft: Dict[str, float] # Time to first token, map: p50 : 0.0 (ms) - itl: Dict[str, float] # Inter-token latency, mapL p50 : 0.0 (ms) + ttft: float # Time to first token + itl: float # Inter-token latency requests: int # Number of requests serviced failed: int # Number of failed requests prompt_tokens: int # Number of prompt tokens per request (req_id: #) @@ -229,23 +228,24 @@ class _RuntimeStats: assignment: LayerAssignment = None # Map of layer to shard IDs -# NOTE: Hardcodes some high-level trace frame symbols -def to_runstats(agg: RunAggregator): - pass - # Process stats + handle per-request data +# NOTE: Hardcodes some high-level trace frame symbols +# TODO: Use a bitmap to track the stages for each req and prevent double-count class StatsAggregator: def __init__(self) -> None: self._lock = threading.Lock() - self._max_resident_rq = 50 # per node FIXME: modify from repl - self._workers: Dict[str, Dict[str, Dict[str, _Frame]]] = {} # Store frames per nonce, per node_id + self._max_inflight_rq = 20 # per node FIXME: modify from repl + + # Frames are kept in here while in-flight, then remove the frame objects and append to _stats + self._workers: Dict[str, Dict[str, Dict[str, Any]]] = {} # Store frames per nonce, per node_id self._nonces = [] # Tracked nonces (either in-flight or done) self._nonce_round_finish: Dict[str, bool] = {} # Track in-flight rounds self._nonce_prefill: Dict[str, bool] = {} # Track if this round is prefill self._running_stats: Dict[str, _RuntimeStats] = {} # Unfinished stat frames self._stats: Dict[str, _RuntimeStats] = {} # Finished stat frames + self._open_frames: Dict[str, Dict[str, Any]] # We got 'B' event but not 'E' (per nonce) # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: @@ -256,8 +256,13 @@ def add(self, data: Dict[str, Any]) -> None: if not run_id or not node_id: return # Drop the batch with self._lock: + + # Ensure we register workers and nodes for i, ev in enumerate(events): - nonce = ev.attrs["nonce"] or f"ERR_{i}" + if "nonce" not in ev.attrs: ev.attrs["nonce"] = f"N_{i}" + nonce = ev.attrs["nonce"] + + new_frames.append(ev) if node_id not in self._workers: self._workers[node_id] = {} @@ -265,33 +270,45 @@ def add(self, data: Dict[str, Any]) -> None: if nonce not in self._workers[node_id]: self._workers[node_id][nonce] = {} - if name not in self._workers[node_id][nonce]: - self._workers[node_id][nonce][name] = [ev, ] - continue - if len(self._workers[node_id]) >= self._max_resident_req: # remove oldest entry - del self._workers[self._nonces[0]] - del self._nonces[0] + del self._workers[self._nonces[0]] + del self._nonces[0] - self._workers[node_id][name].append(ev) self._nonces.push(nonce) - # Construct RuntimeStats - assert "model" in self._frames, "No model found in trace data." - - rt_stat = self._req.setdefault(run_id, _RuntimeStats) - #rt_stat.model = self._workers[0]["model"][-1].attrs["model"] - rt_stat.tokenizer = - rt_stat.run_id = run_id - rt_stat.ttft = {} - - for n in self._nonces: # accumulate new data for each nonce - for shard in self._workers: - + # Update in-flight events or register new ones + for e in new_events: + nonce = e.attrs["nonce"] + assert nonce is not None, "" + + if not node_id and nonce: return # Drop invalid frames + stats = self._running_stats[nonce] + + # Register new request + if e.name == "compute.embedding": + #assert "model" in self._frames, "No model found in trace data." + rt_stat = self._running_stats.setdefault(run_id, _RuntimeStats) + #rt_stat.model = self._workers[0]["model"][-1].attrs["model"] + #rt_stat.tokenizer = + rt_stat.run_id = run_id + rt_stat.nonce = nonce + rt_stat.ttft = {} + + if e.name == "network.ingress": + if e.type == "B": self._open_frames[nonce][e.name] = e + n_rt = e.attrs["inflight"] + e.attrs["inwait"] + n_rt += self._open_frames[nonce][e.name].t0 + if self._nonce_prefill[nonce]: + stats.ttft += n_rt + continue + stats.itl += n_rt + + if f.name == "compute.forward": + + # Request is finished, construct _RuntimeStats and remove from memory if "final" in self._workers[node_id][nonce] and not self._nonce_round_finish[nonce]: - self._nonce_round_finish[nonce] = True - if not self._nonce_prefill[nonce]: # This is prefill, append ttft - + self._nonce_round_finish[nonce] = True + if not self._nonce_prefill[nonce]: # This is prefill, append ttft acc_ttt = 0 # accumulated time to token acc_ttt += shard["network.ingress"][-1] From c60624951c994bbb6dd03825a7efc6e3ac0ddbd5 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 11:50:58 -0700 Subject: [PATCH 121/172] handle frame with custom cost function --- src/dnet/perf/utils/aggregator.py | 67 ++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 23 deletions(-) diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index de1b508f..d1419cee 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -210,13 +210,11 @@ class _RuntimeStats: run_id: str # ID of session (for later mapping) nonce: List[str] # List of serviced requests ttft: float # Time to first token - itl: float # Inter-token latency - requests: int # Number of requests serviced - failed: int # Number of failed requests + itl: List[float] # Inter-token latency per round prompt_tokens: int # Number of prompt tokens per request (req_id: #) generated_tokens: int # Number of generated tokens per request (req_id: #) - latencys: Dict[List[str, str, str], int] # Map of latencys: [node0, node1, p50]: 0.0 + latencies: Dict[List[str, str, str], int] # Map of inter-node latencies: [node0, node1, p50]: 0.0 latency_per_layer: Dict[int, float] # Map of {layer: 0.0} latency_per_shard: Dict[str, float] # Map of {shard: 0.0} total_latency: int # Total runtime of requests @@ -246,6 +244,7 @@ def __init__(self) -> None: self._running_stats: Dict[str, _RuntimeStats] = {} # Unfinished stat frames self._stats: Dict[str, _RuntimeStats] = {} # Finished stat frames self._open_frames: Dict[str, Dict[str, Any]] # We got 'B' event but not 'E' (per nonce) + self._model_per_run: Dict[str, str] = {} # Track model per run_id # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: @@ -277,33 +276,43 @@ def add(self, data: Dict[str, Any]) -> None: self._nonces.push(nonce) # Update in-flight events or register new ones - for e in new_events: + for e in events: nonce = e.attrs["nonce"] assert nonce is not None, "" if not node_id and nonce: return # Drop invalid frames - stats = self._running_stats[nonce] - # Register new request - if e.name == "compute.embedding": - #assert "model" in self._frames, "No model found in trace data." - rt_stat = self._running_stats.setdefault(run_id, _RuntimeStats) - #rt_stat.model = self._workers[0]["model"][-1].attrs["model"] - #rt_stat.tokenizer = - rt_stat.run_id = run_id - rt_stat.nonce = nonce - rt_stat.ttft = {} + if e.name == "embedding": # Register new request + rt_stat = self._running_stats.setdefault(run_id, _RuntimeStats( + model="", + tokenizer="", + run_id=run_id, + nonce=nonce, + ttft=0.0, + itl=[0.0], + generated_tokens=0, + prompt_tokens=e.attrs["prompt_tokens"], + latencies={}, + latency_per_layer={}, + latency_per_shard={}, + total_latency=0.0, + assignment=None, + topo=None, + )) + + # FIXME: We might receive other frames then "embed" from shards + # so we need to handle the creation of this better + stats = self._running_stats[nonce] if e.name == "network.ingress": - if e.type == "B": self._open_frames[nonce][e.name] = e - n_rt = e.attrs["inflight"] + e.attrs["inwait"] - n_rt += self._open_frames[nonce][e.name].t0 - if self._nonce_prefill[nonce]: - stats.ttft += n_rt - continue - stats.itl += n_rt + _cost: lambda e: e.attrs["inflight"] + e.attrs["inwait"] + e.attrs["ms"] + self._handle_frame(e, stats, _cost) - if f.name == "compute.forward": + if e.name == "compute.forward": + _cost = lambda e: e.attrs["ms"] + self._handle_frame(e, stats, _cost) + + if e.name == "" # Request is finished, construct _RuntimeStats and remove from memory if "final" in self._workers[node_id][nonce] and not self._nonce_round_finish[nonce]: @@ -314,6 +323,18 @@ def add(self, data: Dict[str, Any]) -> None: acc_ttt += shard["network.ingress"][-1] inflight = shard['network.ingress'][] + # Handle cost aggregation of frames + def _handle_frame(e: Any, stats: _RuntimeStats, _cost_fnc: Any): + if e.type == 'B': + self._open_frames[nonce][e.name] = e + return + elif e.type == 'E': + n_rt = _cost_fnc(e) # Custom cost function for each farme + if self._nonce_prefill[nonce]: + stats.ttft += n_rt + else: + stats.itl[-1] += n_rt + del self._open_frames[nonce][e.name] # Return data for total, per req, worker or model (maybe add per layer too?) def stats( From fe47f6892d3a41b4eff7a0f206d631d728186369 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 14:08:10 -0700 Subject: [PATCH 122/172] update canonical traces for stats, rename ingress to rx and egress to tx --- src/dnet/ring/shard/comms.py | 52 ++++------ src/dnet/ring/shard/compute.py | 14 ++- src/dnet/ring/shard/node.py | 168 +++++++++++---------------------- 3 files changed, 89 insertions(+), 145 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index bcc826c3..b36f8250 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -172,18 +172,21 @@ async def _send_worker(self): ): try: activation_msg = await self.activation_computed_queue.get() - if activation_msg.tx_enq_perf_t and self._profile: - q_wait_ms = ( - time.perf_counter() - activation_msg.tx_enq_perf_t - ) * 1000.0 - logger.info( - "[PROFILE][QUEUE-TX] node=%s nonce=%s wait_ms=%.3f size=%s", - self.node_id, - activation_msg.nonce, - q_wait_ms, - self.activation_computed_queue.qsize(), - ) - await self._send_activation(activation_msg) + with self.tracer.frame("network", "tx") as f: + if activation_msg.tx_enq_perf_t and self._profile: + f.set("inwait", time.perf_counter() - self._rx_enque_t) + f.set("nonce", activation_msg.nonce) + q_wait_ms = ( + time.perf_counter() - activation_msg.tx_enq_perf_t + ) * 1000.0 + logger.info( + "[PROFILE][QUEUE-TX] node=%s nonce=%s wait_ms=%.3f size=%s", + self.node_id, + activation_msg.nonce, + q_wait_ms, + self.activation_computed_queue.qsize(), + ) + await self._send_activation(activation_msg) except asyncio.CancelledError: break except Exception as e: @@ -247,28 +250,22 @@ async def _send_activation(self, activation_msg: ActivationMessage): pass cb = activation_msg.callback_url or "" parsed = urlparse(cb) if cb else None - t_rpc = time.perf_counter() if parsed and parsed.scheme == "grpc": addr = parsed.netloc if not addr: logger.error("Invalid gRPC callback URL for token: %s", cb) return - # Ensure API channel/stub - if (self.api_channel is None) or (addr != self.api_address): - # close existing channel if any - try: + + if (self.api_channel is None) or (addr != self.api_address): # Ensure API channel/stub + try: # close existing channel if any if self.api_channel is not None: await self.api_channel.close() except Exception: pass self.api_address = addr - self.api_channel = aio_grpc.insecure_channel( - addr, options=GRPC_AIO_OPTIONS - ) - self.api_stub = shard_api_comm_pb2_grpc.ShardApiServiceStub( - self.api_channel - ) + self.api_channel = aio_grpc.insecure_channel( addr, options=GRPC_AIO_OPTIONS) + self.api_stub = shard_api_comm_pb2_grpc.ShardApiServiceStub( self.api_channel) f.event("reset_api") with self.tracer.frame("grpc", "token_request") as fr: @@ -279,21 +276,12 @@ async def _send_activation(self, activation_msg: ActivationMessage): timestamp=utc_epoch_now(), ) resp = await self.api_stub.SendToken(req) # type: ignore - rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 if not resp.success: logger.error( "API SendToken failed for %s: %s", activation_msg.nonce, resp.message, ) - if self._profile: - logger.info( - "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", - self.node_id, - activation_msg.nonce, - int(getattr(activation_msg, "token_id", -1)), - rpc_ms, - ) except Exception as e: logger.exception("Error sending token via gRPC: %s", e) else: diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index fce05d21..01825bac 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -65,7 +65,7 @@ def _delta_swap_eviction( return len(evicted) def _process_activation(self, activation_msg: ActivationMessage): - if ( not self.model + if (not self.model or not self.model_metadata or not self.weight_cache or not self.input_pool @@ -88,13 +88,21 @@ def _process_activation(self, activation_msg: ActivationMessage): # Prepare input activation with self.tracer.frame("compute.thread", "activations.process") as f: + f.set("nonce", activation_msg.nonce) if activation_msg.dtype == "tokens": # embed locally on start shard logger.debug(f"Embedding tokens.") - f.event("embed_tokens") numel = int(np.prod(activation_msg.shape)) tok_view = input_buffer[:numel].reshape(activation_msg.shape) toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) + x = self.model.embed(toks[None]) + + # NOTE: Used to track start of request in perf stats + self.tracer.mark("embedding", { + "nonce": actication_msg.nonce, + "prompt_tokens": toks.size, + }) + if x.dtype != self._wire_mx_dtype: x = x.astype(self._wire_mx_dtype) @@ -382,6 +390,8 @@ def _process_activation(self, activation_msg: ActivationMessage): with self._mlx_lock: y = self.model.normalize(x_cast) y = self.model.lm_project(y) + self.tracer.mark("lm_head", {"nonce": actication_msg.nonce}) # NOTE: canonical stats end + # Greedy sample last position if y.ndim == 3: logits_2d = y[:, -1, :] diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 1f52c2c7..d0023d38 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -108,7 +108,7 @@ def __init__( self._assigned_sorted = sorted(self.assigned_layers or []) self._assigned_set = set(self._assigned_sorted) - # Topology (configured later) + # Topology self.next_node: Optional[DnetDeviceProperties] = None self.total_layers: int = 0 # Total layers in model self.api_callback_address: Optional[str] = None @@ -219,8 +219,10 @@ def __init__( self.tracer.start() # Get in-flight and in-wait time per request - self._rx_ingress_t: Dict[str, float] = {} # Mapping of nonce -> perf_counter() - self._rx_inflight_t: Dict[str, float] = {} # Track per-request inflight + self._rx_ingress_t: Dict[str, float] = {} # Timestamp we enqued the request + self._rx_inflight_t: Dict[str, float] = {} # Per-request inflight time + self._ex_enque_t: Dict[str, float] = {} # req is queued for execution + self._tx_enque_t: Dict[str, float] = {} # req is queued for sendoff # Per-nonce KV caches (concurrent requests) self._kv_by_nonce: Dict[str, list] = {} @@ -598,13 +600,11 @@ async def unload_model(self) -> ShardUnloadModelResponse: async def reset_cache(self) -> None: """Reset LLM KV cache.""" if not self.model: - logger.warning( - "Node %s: Cannot reset cache - no model loaded", self.node_id - ) + logger.warning("Node %s: Cannot reset cache - no model loaded", self.node_id) return - try: - with self.tracer.frame("memory", "cache.reset"): + with self.tracer.frame("memory", "cache.reset"): + try: self.cache = make_cache( self.model, # type: ignore[arg-type] kv_mode=self.config.kv_cache.mode, @@ -612,26 +612,19 @@ async def reset_cache(self) -> None: kv_group=self.config.kv_cache.group_size, ) logger.info("Node %s: Cache reset successfully", self.node_id) - except Exception as e: - logger.error("Node %s: Error resetting cache: %s", self.node_id, e) + # FIXME This seems to still be dead code async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): """Receive activation from previous node and queue for local compute or forward.""" if self.input_pool is None: - logger.error( - "Node %s: Cannot receive activation - input pool not initialized", - self.node_id, - ) + logger.error("Node %s: Cannot receive activation - input pool not initialized", self.node_id) return - t_recv = time.perf_counter() - await self._connect_next_node() - - with self.tracer.frame("grpc.receive", "connect_next_node"): + with self.tracer.frame("network.rx", "connect_next_node"): await self._connect_next_node() - with self.tracer.frame("grpc.receive", "process_activation") as f: + with self.tracer.frame("network.rx", "process_activation") as f: try: activation = request.activation target_layer = activation.layer_id + 1 @@ -699,87 +692,62 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, shape=list(activation.shape), - dtype_with_metadata=activation.dtype, - ) + dtype_with_metadata=activation.dtype) except Exception as e: - logger.error( - "Decompression failed for nonce %s: %s", request.nonce, e - ) + logger.error("Decompression failed for nonce %s: %s", request.nonce, e) return - with self.tracer.frame("grpc.receive", "alloc.buffer") as fr: + with self.tracer.frame("network.rx", "alloc.buffer") as fr: pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, dtype=deq.dtype, - shape=cast(tuple[int, ...], tuple(deq.shape)), - ) + shape=cast(tuple[int, ...], tuple(deq.shape))) + if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) + logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) return + buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: flat = deq.reshape(-1) buffer[: flat.size] = flat - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", - self.node_id, - request.nonce, - alloc_copy_ms, - ) # Update activation message with true dtype/shape new_dtype_str = str(deq.dtype) activation_msg = ActivationMessage.from_proto(request, pool_id) activation_msg.dtype = new_dtype_str activation_msg.shape = tuple(deq.shape) - else: - # Special token stream support: dtype='tokens' carries int32 token IDs + + else: # Special token stream support: dtype='tokens' carries int32 token IDs if activation.dtype == "tokens": - with self.tracer.frame("grpc.receive", "token_stream") as fr: + with self.tracer.frame("network.rx", "token_stream") as fr: try: - tokens = np.frombuffer( - request.activation.data, dtype=np.int32 - ) - shp = (int(len(tokens)),) + tokens = np.frombuffer(request.activation.data, dtype=np.int32) + shp = (int(len(tokens)), ) except Exception as e: - logger.error( - "Failed to parse tokens for nonce %s: %s", - request.nonce, - e, - ) + logger.error("Failed to parse tokens for nonce %s: %s", request.nonce, e,) return + pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, dtype=mx.int32, - shape=cast(tuple[int, ...], shp), - ) + shape=cast(tuple[int, ...], shp)) + if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) + logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) return + buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: buffer[: len(tokens)] = tokens - if self._profile: - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (tokens)", - self.node_id, - request.nonce, - alloc_copy_ms, - ) activation_msg = ActivationMessage.from_proto(request, pool_id) + # Ensure dtype reflects token payload for compute path activation_msg.dtype = "tokens" activation_msg.shape = shp + else: - with self.tracer.frame("grpc.receive", "default") as fr: + with self.tracer.frame("network.ex", "default") as fr: # Safety: byte length must match shape*dtype try: expected = ( @@ -793,58 +761,37 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, dtype=mlx_dtype_map[activation.dtype], - shape=cast(tuple[int, ...], activation.shape), - ) + shape=cast(tuple[int, ...], activation.shape)) + if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) + logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) return + buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: data = request.activation.data - input_data = np.frombuffer( - data, dtype=dtype_map[activation.dtype] - ) + input_data = np.frombuffer(data, dtype=dtype_map[activation.dtype]) buffer[: len(input_data)] = input_data - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f", - self.node_id, - request.nonce, - alloc_copy_ms, - ) + activation_msg = ActivationMessage.from_proto(request, pool_id) activation_msg.dtype = new_dtype_str activation_msg.shape = tuple(deq.shape) # Queue for processing — non-blocking back-off loop (cancellable) - if self._profile: - activation_msg.enq_perf_t = time.perf_counter() while self.running: try: self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", - activation_msg.nonce, - ) + self._ex_enque_t[activation_msg.nonce] = time.perf_counter() + logger.debug("Queued activation for processing: nonce %s", activation_msg.nonce) break except Full: await asyncio.sleep(0) else: - logger.error( - "Failed to queue activation %s (node stopping)", - activation_msg.nonce, - ) + logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce) self.input_pool.release(pool_id) - else: - # Forward to next node (not our layer) - logger.debug( - "Forwarding activation (layer %s) to next node, nonce: %s", - target_layer, - request.nonce, - ) + + else: # Forward to next node (not our layer) + logger.debug("Forwarding activation (layer %s) to next node, nonce: %s", target_layer, request.nonce) await self._forward_activation(request) except Exception as e: @@ -872,10 +819,10 @@ async def _ingress_worker(self): Admission (servicer) is lightweight; this worker performs per-frame processing, offloading alloc/copy/decompress to the threadpool, and - finally enqueues for compute or forwards to the next shard. - """ + finally enqueues for compute or forwards to the next shard. """ + while self.running: - with self.tracer.frame("network.ingress", "wait"): # NOTE: bad counter + with self.tracer.frame("network.rx", "wait"): # NOTE: bad counter try: req = await self.ingress_q.get() logger.debug(f"[DEQUE]Dequeued activation for processing") @@ -883,9 +830,9 @@ async def _ingress_worker(self): break # Trace processing of request, in-flight and in-wait times - with self.tracer.frame("network", "ingress") as f: - f.set("inflight", self._rx_inflight_t[req.nonce]) + with self.tracer.frame("network", "rx") as f: f.set("inwait", time.perf_counter() - self._rx_ingress_t[req.nonce]) + f.set("inflight", self._rx_inflight_t[req.nonce]) f.set("nonce", req.nonce) try: @@ -941,11 +888,9 @@ async def _ingress_worker(self): continue if activation_msg is None: continue - if self._profile: - activation_msg.recv_perf_t = t_recv - # Enqueue for compute (cancellable back-off) - with self.tracer.frame("network.ingress", "enque") as fr: + # Enqueue for compute + with self.tracer.frame("network.rx", "enque") as fr: while self.running: try: self.activation_recv_queue.put_nowait(activation_msg) @@ -1034,7 +979,7 @@ def _prepare_activation_message_blocking( activation = request.activation if "|" in activation.dtype: # Compressed path: decompress to MLX array and copy to pool - with self.tracer.frame("network.ingress.prepare_activation", "decompress") as f: + with self.tracer.frame("network.rx.prepare_activation", "decompress") as f: try: deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, @@ -1070,7 +1015,7 @@ def _prepare_activation_message_blocking( return activation_msg elif activation.dtype == "tokens": # Tokens path: parse int32 token IDs and stage them - with self.tracer.frame("network.ingress.prepare_activation", "tokens") as f: + with self.tracer.frame("network.rx.prepare_activation", "tokens") as f: try: tokens = np.frombuffer(activation.data, dtype=np.int32) shp = (int(len(tokens)),) @@ -1099,7 +1044,7 @@ def _prepare_activation_message_blocking( return activation_msg else: # Dense path: validate size and copy raw bytes view into pool buffer - with self.tracer.frame("network.ingress.prepare_activation", "default") as f: + with self.tracer.frame("network.rx.prepare_activation", "default") as f: try: expected = ( int(np.prod(activation.shape)) @@ -1169,7 +1114,8 @@ def _compute_worker(self) -> None: activation_msg = self.activation_recv_queue.get(timeout=1.0) # Process the activation - with self.tracer.frame("compute", "forward"): # NOTE: Symbol hardcoded for runtime stats + with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("inwait", time.perf_counter() - self._ex_enque_t) self._process_activation(activation_msg) except Empty: From 4e7e12d05bf84c6c265557559226f8ee48b85f2c Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 14:08:45 -0700 Subject: [PATCH 123/172] filter canonical frames --- src/dnet/perf/utils/aggregator.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregator.py index d1419cee..6807d47e 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregator.py @@ -304,20 +304,27 @@ def add(self, data: Dict[str, Any]) -> None: # so we need to handle the creation of this better stats = self._running_stats[nonce] - if e.name == "network.ingress": - _cost: lambda e: e.attrs["inflight"] + e.attrs["inwait"] + e.attrs["ms"] + if e.name == "network.rx": + # Time in transport, ingress queue and ingress_worker + _cost = lambda e: e.attrs["inflight"] + e.attrs["inwait"] + e.attrs["ms"] self._handle_frame(e, stats, _cost) + #TODO: change shard in metadata if e.name == "compute.forward": - _cost = lambda e: e.attrs["ms"] + _cost = lambda e: e.attrs["inwait"] + e.attrs["ms"] # compute queue + execution self._handle_frame(e, stats, _cost) - if e.name == "" + if e.name == "network.tx.send": + _cost = lambda e: e.attrs["inwait"] + e.attrs["ms"] # tx queue + sendoff + self._handle_frame(e, stats, _cost) - # Request is finished, construct _RuntimeStats and remove from memory - if "final" in self._workers[node_id][nonce] and not self._nonce_round_finish[nonce]: + if e.name = "lm_head" and not self._nonce.round_finish[nonce]: # Finish request self._nonce_round_finish[nonce] = True - if not self._nonce_prefill[nonce]: # This is prefill, append ttft + + # TODO: Remove frame and stsats from working and append + st_obj = self._running_stats[nonce] + del self._running_stats[nonce] + self._stats.append(st_obj) acc_ttt = 0 # accumulated time to token acc_ttt += shard["network.ingress"][-1] From df240c512b9884f4d34f540b6e5bf7689d2bea44 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 21:27:02 -0700 Subject: [PATCH 124/172] move to on-request vars to avoid race conditions --- src/dnet/protos/dnet_ring.proto | 6 ++++-- src/dnet/ring/data_types.py | 5 +++++ src/dnet/ring/shard/compute.py | 5 ++--- src/dnet/ring/shard/node.py | 20 ++++++++------------ 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/dnet/protos/dnet_ring.proto b/src/dnet/protos/dnet_ring.proto index 0b46c5be..5452a559 100644 --- a/src/dnet/protos/dnet_ring.proto +++ b/src/dnet/protos/dnet_ring.proto @@ -32,8 +32,10 @@ message ActivationRequest { string nonce = 1; Activation activation = 2; int64 timestamp = 3; - string node_origin = 4; - string callback_url = 5; + float rx_enq_t = 4; + float rx_inflight_t = 5; + string node_origin = 6; + string callback_url = 7; } // Response message for activation sending diff --git a/src/dnet/ring/data_types.py b/src/dnet/ring/data_types.py index 5da3e65a..db9438e5 100644 --- a/src/dnet/ring/data_types.py +++ b/src/dnet/ring/data_types.py @@ -25,7 +25,12 @@ class ActivationMessage: recv_perf_t: float = 0.0 enq_perf_t: float = 0.0 # TX queue enqueue time (perf_counter seconds) + rx_enq_perf_t: float = 0.0 tx_enq_perf_t: float = 0.0 + rx_ingress_t: float = 0.0 + rx_inflight_t: float = 0.0 + ex_enq_t: float = 0.0 + tx_enq_t: float = 0.0 # Final token path (end-shard sampling) is_final: bool = False token_id: int = -1 diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 01825bac..d51343a3 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -94,12 +94,11 @@ def _process_activation(self, activation_msg: ActivationMessage): numel = int(np.prod(activation_msg.shape)) tok_view = input_buffer[:numel].reshape(activation_msg.shape) toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) - x = self.model.embed(toks[None]) # NOTE: Used to track start of request in perf stats self.tracer.mark("embedding", { - "nonce": actication_msg.nonce, + "nonce": activation_msg.nonce, "prompt_tokens": toks.size, }) @@ -390,7 +389,7 @@ def _process_activation(self, activation_msg: ActivationMessage): with self._mlx_lock: y = self.model.normalize(x_cast) y = self.model.lm_project(y) - self.tracer.mark("lm_head", {"nonce": actication_msg.nonce}) # NOTE: canonical stats end + #self.tracer.mark("lm_head", {"nonce": actication_msg.nonce}) # NOTE: canonical stats end # Greedy sample last position if y.ndim == 3: diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index d0023d38..3ed08470 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -218,12 +218,6 @@ def __init__( self.tracer = Tracer(cfg) self.tracer.start() - # Get in-flight and in-wait time per request - self._rx_ingress_t: Dict[str, float] = {} # Timestamp we enqued the request - self._rx_inflight_t: Dict[str, float] = {} # Per-request inflight time - self._ex_enque_t: Dict[str, float] = {} # req is queued for execution - self._tx_enque_t: Dict[str, float] = {} # req is queued for sendoff - # Per-nonce KV caches (concurrent requests) self._kv_by_nonce: Dict[str, list] = {} self._kv_last_seen: Dict[str, float] = {} @@ -781,7 +775,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): while self.running: try: self.activation_recv_queue.put_nowait(activation_msg) - self._ex_enque_t[activation_msg.nonce] = time.perf_counter() + activatino_msg.ex_enq_t = time.perf_counter() logger.debug("Queued activation for processing: nonce %s", activation_msg.nonce) break except Full: @@ -803,8 +797,8 @@ async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: while self.running: try: rx_t = time.perf_counter() - self._rx_ingress_t[request.nonce] = rx_t - self._rx_inflight_t[request.nonce] = rx_t - request.timestamp + request.rx_enq_t = rx_t + request.rx_inflight_t = rx_t - request.timestamp self.ingress_q.put_nowait(request) logger.debug(f"[ENQUE] Enqueued activation request") @@ -831,8 +825,8 @@ async def _ingress_worker(self): # Trace processing of request, in-flight and in-wait times with self.tracer.frame("network", "rx") as f: - f.set("inwait", time.perf_counter() - self._rx_ingress_t[req.nonce]) - f.set("inflight", self._rx_inflight_t[req.nonce]) + f.set("inwait", time.perf_counter() - req.rx_enq_t) + f.set("inflight", req.rx_inflight_t) f.set("nonce", req.nonce) try: @@ -1115,7 +1109,9 @@ def _compute_worker(self) -> None: # Process the activation with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats - f.set("inwait", time.perf_counter() - self._ex_enque_t) + f.set("inwait", time.perf_counter() - activation_msg.ex_enq_t) + if (self.model_metadata.num_layers - 1) in self.assigned_layers: + f.set("lm_head", True) self._process_activation(activation_msg) except Empty: From df9705ae91a5126a9a72c1605608246179243b5d Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 21:27:25 -0700 Subject: [PATCH 125/172] auto topo and load model --- src/repl.py | 146 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 108 insertions(+), 38 deletions(-) diff --git a/src/repl.py b/src/repl.py index c2f081fb..f9cb363b 100644 --- a/src/repl.py +++ b/src/repl.py @@ -29,8 +29,17 @@ logger = get_api_logger() from dnet.perf.trace import TraceConfig, Tracer -from dnet.perf.utils import TraceAggregator +from dnet.perf.utils import TraceAggregator, StatsAggregator #from dnet.perf.bench import +from dnet.ring.common import TopologyInfo + +from dnet.ring.api.models import ( + PrepareTopologyManualRequest, + PrepareTopologyRequest, + PrepareTopologyResponse, + APILoadModelRequest, + APILoadModelResponse, +) # Handle restricted repos from importlib import import_module @@ -58,6 +67,7 @@ class REPLState: api_http_port: int = 8080 api_grpc_port: int = 50500 window_size = 2 # Number of layers per node per visit (also number resident in cache) + topo: TopologyInfo = None class REPL(cmd.Cmd): @@ -92,13 +102,17 @@ def __init__(self, model="NULL", nodes=1): aggregate=True, agg_max_events=50, ) - self._tracing = threading.Event() self._tracer = None self._trace_file = f"trace-{time.strftime("%Y%m%d-%H%M%S")}" self._trace_cursor = 0 # keep track of registered events in buffer self._trace_agg = TraceAggregator() + # Runtime stats (ingests data from tracer) + self._stats_agg = StatsAggregator() + self._stats = threading.Event() + self._stats.set() # Trace runtime information by default + def loop(self): # Main tty loop sys.stdout.write(self.WELCOME) @@ -123,9 +137,15 @@ def loop(self): # Main tty loop elif cmd.startswith("nodes"): self.print_mdns_nodes() continue + elif cmd.startswith("load"): + self.load_model() + continue elif cmd.startswith(("trace", ".trace")): self.do_trace(cmd.split(" ")) continue + elif cmd.startswith(("perf", ".perf")): + self.do_perf(cmd.split(" ")) + continue elif cmd.startswith(("topo", ".topo")): self.do_topo(cmd.split(" ")) continue @@ -187,8 +207,10 @@ def do_topo(self, cmd: List[str]) -> None: dprint("Invalid topology command. Type 'help' for a list of valid commands.\n") return if cmd[1] == "search": + self.print_mdns_nodes() pass - elif cmd[1] == "auto": + elif cmd[1] == "auto" or cmd[1] == "build": + self.prepare_topo() pass elif cmd[1] == "setup": pass @@ -238,7 +260,8 @@ def _print_hf(cmd, desc, examples=[""]): _print_hf("trace focus [SUBSYSTEM] ", "Focus the trace on [SUBSYSTEM]. Do 'trace focus' for a list of available subsystems.") _print_hf("trace stream [ON|OFF] ", "Stream the trace spans to current terminal.") _print_hf("trace set [BUDGET] ", "Set the maximum amount of recoded events.") - _print_hf("profile [REPO] ", "Estimate the total FLOPS of the model from [REPO]") + _print_hf("perf ", "Prints the current state of runtime performance tracking.") + _print_hf("perf stat [REQ_ID | WORKER_ID | MODEL] ", "Prints the runtime statistics of target system.") _print_hf("bench [REPO]", "Benchmark the system using the model from [REPO]") _print_hf("bench [KERNEL]", "Behcnmark the system using base kernel [KERNEL]") _print_hf("bench [NODE]", "Behcnmark the network latency between the current system and [NODE]") @@ -502,43 +525,65 @@ async def _await_then_set(): def do_trace(self, cmd): if len(cmd) < 2: dprint(f"Tracing is currently {"ON" if self._trace_cfg.enabled else "OFF"}\n") - elif cmd[1] in ("on", "ON"): - self._trace_cfg.enabled = True - if self._api_running: - self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards - dprint("Tracing is now ON\n") - elif cmd[1] in ("off", "OFF"): - self._trace_cfg.enabled = False - if self._api_running: - self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards - dprint("Tracing is not OFF\n") - elif cmd[1] == "focus": - #self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards - dprint("Subsystems not yet implemented.\n") - elif cmd[1] == "stream": - if len(cmd) == 2: - dprint(f"Trace is {"streaming to file: "+str(self._trace_cfg.file) if self._trace_cfg.streaming else "not streaming."}\n") - elif cmd[2] == "on": - self._trace_cfg.streaming = True - dprint(f"Streaming trace frames to {self._trace_cfg.file}\n") - elif cmd[2] == "off": - self._trace_cfg.streaming = False - dprint("Trace streaming is OFF.\n") - elif cmd[1] == "set": - if len(cmd) == 2: - dprint("Use: trace set [BUDGET], eg. 2000\n") - else: - dprint("Not implemented yet\n") - # FIXME: Implement - elif cmd[1] == "status": - dprint(f"Frames: {len(self._trace_agg._req)}\n") - - elif cmd[1] == "annotate": - self.print_trace_annotate("NONE") + return + + match cmd[1]: + case s if s in ["on", "ON"]: + self._trace_cfg.enabled = True + dprint("Tracing is now ON\n") + + case s if s in ["off", "OFF"]: + self._trace_cfg.enabled = False + dprint("Tracing is now OFF\n") + + case s if s == "focus": + dprint("Subsystems not yet implemented.\n") + + case s if s == "stream": + if len(cmd) == 2: + dprint(f"Trace is {"streaming to file: "+str(self._trace_cfg.file) if self._trace_cfg.streaming else "not streaming."}\n") + elif cmd[2] == "on": + self._trace_cfg.streaming = True + dprint(f"Streaming trace frames to {self._trace_cfg.file}\n") + elif cmd[2] == "off": + self._trace_cfg.streaming = False + dprint("Trace streaming is OFF.\n") + + case s if s == "set": + if len(cmd) == 2: + dprint("Use: trace set [BUDGET], eg. 2000\n") + else: + dprint("Not implemented yet\n") + + case s if s == "annotate": + self.print_trace_annotate("NONE") + + case _: + dprint("Unknown trace command. Type 'help' for a list of available commands.\n") + + if self._api_running.is_set() and self._api_ready.is_set(): + self.api_call("_forward_trace_config", self._trace_cfg) # Send trace config to all shards + + # Performance trackers + def do_perf(self, cmd): + if len(cmd) < 2 or cmd[1] == "stat": + dprint("Runtime performance metrics are ON by default.\n") + dprint("Turn tracking off with 'perf off'. Do 'perf stat' for statistics on previous requests or 'help' for more commands.\n\n") + return + + match cmd[1]: + case s if s in "...": + pass + case _: + pass # Trace callback registered with API Thread + # This forwards the tracer frames back to the REPL for printing def __trace_cb(self, data): - self._trace_agg.enqueue(data) + if self._tracing.is_set(): + self._trace_agg.enqueue(data) + if self._stats.is_set(): + self._stats_agg.add(data) def __print_tr(self, row): sym = " " + symbol.ljust(40, ' ') @@ -638,6 +683,31 @@ def print_mdns_nodes(self) -> None: except Exception as e: dprint(f"Could not list nodes: {e}\n") + def print_topo(self, topo): + line = "="*20+" Topology " + "="*20 + sys.stdout.write(f"{line}\nModel: {topo.model}\nLayers: {topo.num_layers}\n") + sys.stdout.write(f"Devices: {topo.devices}\n\n") + # TODO: Better print here + + def prepare_topo(self): + req = PrepareTopologyRequest(model="Qwen/Qwen3-4B-MLX-4bit") + try: + topo = self.api_call("_handle_prepare_topology", req, timeout=30) + except Exception as e: + dprint(f"Unable to create topology: {e}\n\n") + return + self.state.topo = topo + self.print_topo(topo) + + def load_model(self): + req = APILoadModelRequest(model="Qwen/Qwen3-4B-MLX-4bit") + try: + res = self.api_call("_handle_load_model", req, timeout=30) + except Exception as e: + dprint(f"Failed to load model: {e}\n\n") + return + + # ===== Handle shutdown def handle_shutdown(self): From 37a63b9d1a7d3052e6e4219b7e8a656ca9468d03 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 21:27:50 -0700 Subject: [PATCH 126/172] wrap in trace frames --- src/dnet/ring/weight_cache.py | 128 +++++++++++++++++++++++----------- 1 file changed, 89 insertions(+), 39 deletions(-) diff --git a/src/dnet/ring/weight_cache.py b/src/dnet/ring/weight_cache.py index f7c9ae93..fcb5f53a 100644 --- a/src/dnet/ring/weight_cache.py +++ b/src/dnet/ring/weight_cache.py @@ -26,19 +26,19 @@ def __init__( model_metadata: ModelMetadata, window_size: Optional[int] = None, prefetch_threads: int = 2, + tracer=None, *, resident_windows: int = 2, use_mxload_fastpath: bool = False, prefetch_mode: str = "off", ): self.assigned_layers = assigned_layers - # Resident budget: enforce up to N windows resident - resident_windows = max(1, int(resident_windows)) + resident_windows = max(1, int(resident_windows)) # Resident budget if window_size is not None and window_size > 0: self.max_weights = min( - len(self.assigned_layers), max(1, resident_windows * int(window_size)) - ) + len(self.assigned_layers), + max(1, resident_windows * int(window_size))) else: self.max_weights = len(self.assigned_layers) self.cache = {} # layer_id -> (data, access_time) @@ -51,14 +51,18 @@ def __init__( prefetch_mode=prefetch_mode, ) self.lock = threading.Lock() + + if not tracer: + logger.error("Invalid tracer object passed to WeightCache.") + self.tracer = tracer + # Track in-flight materializations so compute can wait on prefetch self.loading_futures: Dict[int, Future] = {} self.prefetch_futures: Dict[int, Future] = {} logger.info("WeightCache resident budget: max_weights=%d", self.max_weights) - def get_weight( - self, layer_id: int, *, inc_ref: bool = True - ) -> Optional[Dict[str, mx.array]]: + + def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[str, mx.array]]: """Get weight from cache""" # First, fast path under lock for cache hit or in-flight load with self.lock: @@ -72,41 +76,85 @@ def get_weight( ) return data - # If a load is in-flight, wait on it outside the lock - inflight = self.loading_futures.get(layer_id) - if inflight is None: - # Prepare eviction decision now to avoid overfilling once loaded - need_evict = len(self.cache) >= self.max_weights - if need_evict: - # Evict under lock, then proceed to load - self._evict_lru() - # Install a new future marker so others wait - fut = Future() - self.loading_futures[layer_id] = fut - inflight = fut - creator = True - else: - creator = False - - if creator: - # Perform the blocking load without holding the cache lock - try: - t0 = time.perf_counter() - data = self.layer_manager.load_layer_to_gpu(layer_id) - dt_ms = (time.perf_counter() - t0) * 1000.0 - # Estimate bytes by summing tensor sizes for the layer + with self.tracer.frame("weights.cache", "search") as f: + with self.lock: + if layer_id in self.cache: + data, _ = self.cache[layer_id] + self.cache[layer_id] = (data, time.time()) # refresh LRU timestamp + if inc_ref: + self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) + logger.debug("Cache hit for layer %s, ref=%d inc=%d", + layer_id, self.reference_counts.get(layer_id, 0), int(inc_ref)) + return data + + inflight = self.loading_futures.get(layer_id) # If a load is in-flight, wait on it outside the lock + if inflight is None: + need_evict = len(self.cache) >= self.max_weights + if need_evict: # Prepare eviction decision now to avoid overfilling once loaded + self._evict_lru() # Evict under lock, then proceed to load + fut = Future() # Install a new future marker so others wait + self.loading_futures[layer_id] = fut + inflight = fut + creator = True + else: + creator = False + + if creator: # Perform the blocking load without holding the cache lock + with self.tracer.frame("weights.cache", "load") as f: try: - winfo = self.layer_manager.weight_info.get(layer_id, {}) - total_bytes = sum(w.size_bytes for w in winfo.values()) - except Exception: - total_bytes = 0 - # Commit to cache under lock - with self.lock: + data = self.layer_manager.load_layer_to_gpu(layer_id) + f.event("load") + + try: # Estimate bytes by summing tensor sizes for the layer + winfo = self.layer_manager.weight_info.get(layer_id, {}) + total_bytes = sum(w.size_bytes for w in winfo.values()) + f.set("bytes", total_bytes) + except Exception: + total_bytes = 0 + + with self.lock: # Commit to cache under lock + self.cache[layer_id] = (data, time.time()) + if inc_ref: + self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) + else: + self.reference_counts.setdefault(layer_id, 0) + + try: # Resolve future and clear from tracking + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_result(True) + except Exception: + self.loading_futures.pop(layer_id, None) + return data + + except Exception as e: + with self.lock: # Signal failure to any waiters + try: + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_exception(e) + except Exception: + self.loading_futures.pop(layer_id, None) + logger.exception("Failed to load weight %s: %s", layer_id, e) + return None + + else: # Not the creator: wait for the in-flight load to complete + with self.tracer.frame("weights.cache", "wait") as f: + try: + inflight.result() # block until the creator completes + except Exception as e: + logger.error("Wait for layer %s load failed: %s", layer_id, e) + return None + + with self.lock: # Return from cache + data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] + if data is None: + logger.error("Wait for layer %s load failed: data not in cache", layer_id) + return None + self.cache[layer_id] = (data, time.time()) if inc_ref: - self.reference_counts[layer_id] = ( - self.reference_counts.get(layer_id, 0) + 1 - ) + self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) else: self.reference_counts.setdefault(layer_id, 0) # Resolve future and clear from tracking @@ -184,6 +232,7 @@ def prefetch_to_ram(self, layer_id: int): except Exception: return None + def cancel_all_prefetch(self): """Cancel any in-flight prefetch tasks and clear tracking.""" with self.lock: @@ -195,6 +244,7 @@ def cancel_all_prefetch(self): pass self.prefetch_futures.clear() + def _evict_lru(self): """Evict least recently used weight with zero references""" candidates = [ From e0a9fab343fca69912a21d3b86a14d6ba7e7d245 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 21:28:49 -0700 Subject: [PATCH 127/172] change file name --- src/dnet/perf/utils/__init__.py | 2 +- .../utils/{aggregator.py => aggregators.py} | 85 +++++++++---------- 2 files changed, 42 insertions(+), 45 deletions(-) rename src/dnet/perf/utils/{aggregator.py => aggregators.py} (82%) diff --git a/src/dnet/perf/utils/__init__.py b/src/dnet/perf/utils/__init__.py index 7228627d..0ee2f5e1 100644 --- a/src/dnet/perf/utils/__init__.py +++ b/src/dnet/perf/utils/__init__.py @@ -1 +1 @@ -from .aggregator import TraceAggregator +from .aggregators import TraceAggregator, StatsAggregator diff --git a/src/dnet/perf/utils/aggregator.py b/src/dnet/perf/utils/aggregators.py similarity index 82% rename from src/dnet/perf/utils/aggregator.py rename to src/dnet/perf/utils/aggregators.py index 6807d47e..78d5bbd6 100644 --- a/src/dnet/perf/utils/aggregator.py +++ b/src/dnet/perf/utils/aggregators.py @@ -200,21 +200,21 @@ def roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: # Runtime statistics -# Use a RunAggregator to get raw frames per request, then transform into _RuntimeStats +# Use a RunAggregator to get raw frames per request, then transform into ReqStats # Track a single request, use multiple for a full benchmark @dataclass -class _RuntimeStats: +class ReqStats: model: str # Model name tokenizer: str # Tokenizer name run_id: str # ID of session (for later mapping) - nonce: List[str] # List of serviced requests + nonce: str # List of serviced requests ttft: float # Time to first token itl: List[float] # Inter-token latency per round prompt_tokens: int # Number of prompt tokens per request (req_id: #) generated_tokens: int # Number of generated tokens per request (req_id: #) - latencies: Dict[List[str, str, str], int] # Map of inter-node latencies: [node0, node1, p50]: 0.0 + latencies: List[List[str, str, str, int]] # List of inter-node latencies: [node0, node1, p50, 0.0] latency_per_layer: Dict[int, float] # Map of {layer: 0.0} latency_per_shard: Dict[str, float] # Map of {shard: 0.0} total_latency: int # Total runtime of requests @@ -233,47 +233,47 @@ class StatsAggregator: def __init__(self) -> None: self._lock = threading.Lock() - self._max_inflight_rq = 20 # per node FIXME: modify from repl + self._max_inflight_req = 20 # per node FIXME: modify from repl # Frames are kept in here while in-flight, then remove the frame objects and append to _stats - self._workers: Dict[str, Dict[str, Dict[str, Any]]] = {} # Store frames per nonce, per node_id + self._frames: Dict[str, Dict[str, Dict[str, Any]]] = {} # Store frames per nonce, per node_id self._nonces = [] # Tracked nonces (either in-flight or done) self._nonce_round_finish: Dict[str, bool] = {} # Track in-flight rounds self._nonce_prefill: Dict[str, bool] = {} # Track if this round is prefill - self._running_stats: Dict[str, _RuntimeStats] = {} # Unfinished stat frames - self._stats: Dict[str, _RuntimeStats] = {} # Finished stat frames - self._open_frames: Dict[str, Dict[str, Any]] # We got 'B' event but not 'E' (per nonce) + self._running_stats: Dict[str, ReqStats] = {} # Unfinished stat frames + self._stats: Dict[str, ReqStats] = {} # Finished stat frames + self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per nonce) self._model_per_run: Dict[str, str] = {} # Track model per run_id # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: - run_id = data.run_id - node_id = data.node_id - events = data.events or [] - name = data.name + run_id = data["run_id"] + node_id = data["node_id"] + events = data["events"] or [] + name = data["name"] if not run_id or not node_id: return # Drop the batch with self._lock: # Ensure we register workers and nodes for i, ev in enumerate(events): - if "nonce" not in ev.attrs: ev.attrs["nonce"] = f"N_{i}" - nonce = ev.attrs["nonce"] + if "nonce" not in ev["attrs"]: ev["attrs"]["nonce"] = f"N_{i}" + nonce = ev["attrs"]["nonce"] new_frames.append(ev) - if node_id not in self._workers: - self._workers[node_id] = {} + if node_id not in self._frames: + self._frames[node_id] = {} - if nonce not in self._workers[node_id]: - self._workers[node_id][nonce] = {} + if nonce not in self._frames[node_id]: + self._frames[node_id][nonce] = {} - if len(self._workers[node_id]) >= self._max_resident_req: # remove oldest entry - del self._workers[self._nonces[0]] + if len(self._frames[node_id]) >= self._max_resident_req: # remove oldest entry + del self._frames[self._nonces[0]] del self._nonces[0] - self._nonces.push(nonce) + self._nonces.append(nonce) # Update in-flight events or register new ones for e in events: @@ -282,8 +282,8 @@ def add(self, data: Dict[str, Any]) -> None: if not node_id and nonce: return # Drop invalid frames - if e.name == "embedding": # Register new request - rt_stat = self._running_stats.setdefault(run_id, _RuntimeStats( + if e["name"] == "embedding": # Register new request + rt_stat = self._running_stats.setdefault(nonce, ReqStats( model="", tokenizer="", run_id=run_id, @@ -304,34 +304,31 @@ def add(self, data: Dict[str, Any]) -> None: # so we need to handle the creation of this better stats = self._running_stats[nonce] - if e.name == "network.rx": + if e["name"] == "network.rx": # Time in transport, ingress queue and ingress_worker - _cost = lambda e: e.attrs["inflight"] + e.attrs["inwait"] + e.attrs["ms"] - self._handle_frame(e, stats, _cost) + _cost = lambda e: e["attrs"]["inflight"] + e["attrs"]["inwait"] + e["attrs"]["ms"] + self._handle_frame(e, nonce, stats, _cost) #TODO: change shard in metadata - if e.name == "compute.forward": - _cost = lambda e: e.attrs["inwait"] + e.attrs["ms"] # compute queue + execution - self._handle_frame(e, stats, _cost) + if e["name"] == "compute.forward": + _cost = lambda e: e["attrs"]["inwait"] + e.attrs["ms"] # compute queue + execution + self._handle_frame(e, nonce, stats, _cost) - if e.name == "network.tx.send": - _cost = lambda e: e.attrs["inwait"] + e.attrs["ms"] # tx queue + sendoff - self._handle_frame(e, stats, _cost) - - if e.name = "lm_head" and not self._nonce.round_finish[nonce]: # Finish request + # Finish request + if "lm_head" in e.attrs and not self._nonce_round_finish[nonce]: self._nonce_round_finish[nonce] = True - - # TODO: Remove frame and stsats from working and append st_obj = self._running_stats[nonce] + self._stats[nonce] = st_obj del self._running_stats[nonce] - self._stats.append(st_obj) - - acc_ttt = 0 # accumulated time to token - acc_ttt += shard["network.ingress"][-1] - inflight = shard['network.ingress'][] + #del self._frames[node_id][nonce] + # TODO: Handle latency of transfer back to API + + if e["name"] == "network.tx.send": + _cost = lambda e: e["attrs"]["inwait"] + e["attrs"]["ms"] # tx queue + sendoff + self._handle_frame(e, nonce, stats, _cost) # Handle cost aggregation of frames - def _handle_frame(e: Any, stats: _RuntimeStats, _cost_fnc: Any): + def _handle_frame(self, e: Any, nonce, stats: ReqStats, _cost_fnc: Any): if e.type == 'B': self._open_frames[nonce][e.name] = e return @@ -360,6 +357,6 @@ def stats( elif model: pass - else: # Return stats of all counters + else: # Sort per model, per request (node info only when requested) pass From 1e16b46857e42ef7ce04f10a46ee0c406029687d Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 23 Oct 2025 21:48:57 -0700 Subject: [PATCH 128/172] various small stuff --- src/dnet/perf/trace.py | 6 +++++- src/dnet/ring/__init__.py | 2 ++ src/dnet/ring/api/api_logging.py | 13 +++---------- src/dnet/ring/api/node.py | 5 +++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 7b5d79c1..65accca1 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -224,8 +224,12 @@ def frame(self, scope: str, name: str, attrs: Optional[Dict[str, Any]] = None): return _NoopFrame() return _Frame(self, f"{scope}.{name}", attrs) + # Same as normal frame but signals that this trace is a cannon event (required for runtime stats) + def canonical(self, scope: str, name: str, attrs: Optional[Dict[str, Any]] = None): + return self.frame(scope, name, attrs) + # Mark an event outside of a frame - def mark(self, name: str, **attrs: Any) -> None: + def mark(self, name: str, attrs: Any) -> None: self._emit({"type": "I", "name": name, "args": attrs}) # Helpers diff --git a/src/dnet/ring/__init__.py b/src/dnet/ring/__init__.py index e69de29b..adb40d20 100644 --- a/src/dnet/ring/__init__.py +++ b/src/dnet/ring/__init__.py @@ -0,0 +1,2 @@ + +from .common import LayerAssignment, TopologyInfo diff --git a/src/dnet/ring/api/api_logging.py b/src/dnet/ring/api/api_logging.py index d90c526c..48999dac 100644 --- a/src/dnet/ring/api/api_logging.py +++ b/src/dnet/ring/api/api_logging.py @@ -5,17 +5,9 @@ from logging.handlers import RotatingFileHandler from pathlib import Path - _CONFIGURED_FLAG = "_dnet_api_logger_configured" - def get_api_logger() -> logging.Logger: - """Return a process‑local logger for the API server. - - - Does not propagate to the root logger (so it won't spam the REPL TTY). - - Writes to logs/api.log with rotation. - - Level is controlled by DNET_API_LOG (default: INFO). - """ log = logging.getLogger("dnet.api") if getattr(log, _CONFIGURED_FLAG, False): return log @@ -23,7 +15,8 @@ def get_api_logger() -> logging.Logger: # Level from env, fallback INFO level_name = (os.getenv("DNET_API_LOG", "INFO") or "INFO").strip().upper() level = getattr(logging, level_name, logging.INFO) - log.setLevel(level) + #log.setLevel(level) + log.setLevel(logging.DEBUG) # Do not bubble to root (console) log.propagate = False @@ -36,7 +29,7 @@ def get_api_logger() -> logging.Logger: # Attach a rotating file handler try: - fh = RotatingFileHandler("logs/api.log", maxBytes=10_000_000, backupCount=5) + fh = RotatingFileHandler("logs/api.log", maxBytes=10000000, backupCount=5) fmt = logging.Formatter( "%(asctime)s %(levelname)s [%(threadName)s] %(name)s: %(message)s" ) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 18cf4059..7978e7d0 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -425,6 +425,7 @@ async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: return TraceIngestResponse(ok=False, accepted=0, message=str(e)) async def _forward_trace_config(self, cfg: Any) -> bool: + logger.debug("Forwarding Trace config") shards = self._get_shards_from_discovery() this = self.discovery.get_own_properties() api_endpoint = f"http://{this.local_ip}:{this.server_port}/trace/ingest" @@ -450,9 +451,9 @@ async def _forward_trace_config(self, cfg: Any) -> bool: try: res = await client.post(url, json=dict(payload)) if res.status_code != 200: - logger.warning(f"Failed to POST tracer config to node {name}.") + logger.error(f"Failed to POST tracer config to {url}.: {res.text}") except Exception as e: - logger.warning(f"Failed to POST tracer config: {e}") + logger.error(f"Failed to POST tracer config: {e}") return False return True From ffa89a2a32f871844f06f89b393fa1c43dab754f Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 00:04:41 -0700 Subject: [PATCH 129/172] add tracer to api and send frames back to repl. emit special frames when api starts and ends a chat request --- src/dnet/perf/trace.py | 12 +++---- src/dnet/perf/utils/aggregators.py | 51 ++++++++++++++++-------------- src/dnet/ring/api/node.py | 23 ++++++++++++++ 3 files changed, 57 insertions(+), 29 deletions(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 65accca1..48f6e49b 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -229,7 +229,7 @@ def canonical(self, scope: str, name: str, attrs: Optional[Dict[str, Any]] = Non return self.frame(scope, name, attrs) # Mark an event outside of a frame - def mark(self, name: str, attrs: Any) -> None: + def mark(self, name: str, attrs: Any = {}) -> None: self._emit({"type": "I", "name": name, "args": attrs}) # Helpers @@ -250,7 +250,7 @@ def profile_block(self, outfile: Optional[str] = None, sort: str = "cumtime", li with open(outfile, "w", encoding="utf-8") as f: f.write(out) else: - self._emit({"type": "PROFILE", "name": "cprofile", "args": {"sort": sort, "limit": limit, "report": out}}) + self._emit({"type": "PROFILE", "name": "cprofile", "attrs": {"sort": sort, "limit": limit, "report": out}}) @contextmanager def callgraph( @@ -284,13 +284,13 @@ def prof(frame, event, arg): key = f"{filename}:{code.co_firstlineno}:{name}" if event == "call": stack.append((key, time.perf_counter())) - self._emit({"type": "B", "name": f"py.{name}", "args": {"file": filename, "line": code.co_firstlineno}}) + self._emit({"type": "B", "name": f"py.{name}", "attrs": {"file": filename, "line": code.co_firstlineno}}) emitted += 1 else: if stack and stack[-1][0] == key: _, t0 = stack.pop() dt_ms = (time.perf_counter() - t0) * 1000.0 - self._emit({"type": "E", "name": f"py.{name}", "args": {"ms": round(dt_ms, 3)}}) + self._emit({"type": "E", "name": f"py.{name}", "attrs": {"ms": round(dt_ms, 3)}}) emitted += 1 elif inc_c and event in ("c_call", "c_return"): func = getattr(arg, "__name__", None) @@ -298,10 +298,10 @@ def prof(frame, event, arg): if not func: return if event == "c_call": - self._emit({"type": "B", "name": f"c.{mod}.{func}", "args": {}}) + self._emit({"type": "B", "name": f"c.{mod}.{func}", "attrs": {}}) emitted += 1 else: - self._emit({"type": "E", "name": f"c.{mod}.{func}", "args": {}}) + self._emit({"type": "E", "name": f"c.{mod}.{func}", "attrs": {}}) emitted += 1 prev = sys.getprofile() diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index 78d5bbd6..fad29a7d 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -238,30 +238,27 @@ def __init__(self) -> None: # Frames are kept in here while in-flight, then remove the frame objects and append to _stats self._frames: Dict[str, Dict[str, Dict[str, Any]]] = {} # Store frames per nonce, per node_id - self._nonces = [] # Tracked nonces (either in-flight or done) + self._nonces: List[str] = [] # Tracked nonces (either in-flight or done) self._nonce_round_finish: Dict[str, bool] = {} # Track in-flight rounds self._nonce_prefill: Dict[str, bool] = {} # Track if this round is prefill - self._running_stats: Dict[str, ReqStats] = {} # Unfinished stat frames - self._stats: Dict[str, ReqStats] = {} # Finished stat frames + self._running_stats: Dict[str, ReqStats] = {} # Unfinished stat frames + self._stats: Dict[str, ReqStats] = {} # Finished stat frames self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per nonce) self._model_per_run: Dict[str, str] = {} # Track model per run_id # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: - run_id = data["run_id"] - node_id = data["node_id"] + run_id = data["run_id"] or "NONE" + node_id = data["node_id"] or "NONE" events = data["events"] or [] - name = data["name"] if not run_id or not node_id: return # Drop the batch with self._lock: # Ensure we register workers and nodes for i, ev in enumerate(events): - if "nonce" not in ev["attrs"]: ev["attrs"]["nonce"] = f"N_{i}" - nonce = ev["attrs"]["nonce"] - - new_frames.append(ev) + if "nonce" not in ev["args"]: ev["args"]["nonce"] = f"N_" + nonce = ev["args"]["nonce"] if node_id not in self._frames: self._frames[node_id] = {} @@ -269,21 +266,24 @@ def add(self, data: Dict[str, Any]) -> None: if nonce not in self._frames[node_id]: self._frames[node_id][nonce] = {} - if len(self._frames[node_id]) >= self._max_resident_req: # remove oldest entry + if len(self._frames[node_id]) >= self._max_inflight_req: # remove oldest entry del self._frames[self._nonces[0]] del self._nonces[0] - - self._nonces.append(nonce) + if nonce not in self._nonces: + self._nonces.append(nonce) # Update in-flight events or register new ones for e in events: - nonce = e.attrs["nonce"] + nonce = e["args"]["nonce"] assert nonce is not None, "" - if not node_id and nonce: return # Drop invalid frames + if not node_id or not nonce: return # Drop invalid frames - if e["name"] == "embedding": # Register new request - rt_stat = self._running_stats.setdefault(nonce, ReqStats( + if e["name"] == "chat.request.end": + print(e) + if e["name"] == "chat.request.start": + print(e) + self._running_stats[nonce] = ReqStats( model="", tokenizer="", run_id=run_id, @@ -291,31 +291,36 @@ def add(self, data: Dict[str, Any]) -> None: ttft=0.0, itl=[0.0], generated_tokens=0, - prompt_tokens=e.attrs["prompt_tokens"], + prompt_tokens=e["args"]["prompt_tokens"], latencies={}, latency_per_layer={}, latency_per_shard={}, total_latency=0.0, assignment=None, topo=None, - )) + ) + if e["name"] == "embedding": # Register new request + pass # FIXME: We might receive other frames then "embed" from shards # so we need to handle the creation of this better + if nonce not in self._running_stats: + continue + stats = self._running_stats[nonce] if e["name"] == "network.rx": # Time in transport, ingress queue and ingress_worker - _cost = lambda e: e["attrs"]["inflight"] + e["attrs"]["inwait"] + e["attrs"]["ms"] + _cost = lambda e: e["args"]["inflight"] + e["args"]["inwait"] + e["args"]["ms"] self._handle_frame(e, nonce, stats, _cost) #TODO: change shard in metadata if e["name"] == "compute.forward": - _cost = lambda e: e["attrs"]["inwait"] + e.attrs["ms"] # compute queue + execution + _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # compute queue + execution self._handle_frame(e, nonce, stats, _cost) # Finish request - if "lm_head" in e.attrs and not self._nonce_round_finish[nonce]: + if "lm_head" in e["args"] and not self._nonce_round_finish[nonce]: self._nonce_round_finish[nonce] = True st_obj = self._running_stats[nonce] self._stats[nonce] = st_obj @@ -324,7 +329,7 @@ def add(self, data: Dict[str, Any]) -> None: # TODO: Handle latency of transfer back to API if e["name"] == "network.tx.send": - _cost = lambda e: e["attrs"]["inwait"] + e["attrs"]["ms"] # tx queue + sendoff + _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # tx queue + sendoff self._handle_frame(e, nonce, stats, _cost) # Handle cost aggregation of frames diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 7978e7d0..ff37d34b 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -89,6 +89,7 @@ from .servicer import ShardApiServicer from ..common import TopologyInfo, LayerAssignment +from dnet.perf import Tracer, TraceConfig async def arange(count: int): """Async range generator.""" @@ -151,12 +152,27 @@ def __init__( except Exception: pass + cfg = TraceConfig( + file="./trace.json", + streaming=False, + include_prefixes = ("src/dnet/"), + include_c_calls = False, + budget = 10000, + enabled = True, + record_pid_tid = True, + aggregate=False, + aggregate_url=None, + ) + self.tracer = Tracer(cfg) + self.tracer.start() + logger.info( "API node initialized on HTTP port %s, gRPC port %s", self.http_port, self.grpc_port, ) + async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()) -> None: """Start the API node. @@ -404,6 +420,11 @@ async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: if self._trace_ingest_cb is not None: logger.debug("Forwarding trace batch to REPL.") self._trace_ingest_cb(batch.model_dump()) + + _t_batch = { "run_id": "NONE", "node_id": "API", "events": list(self.tracer._events) } + #self.tracer._events.clear() + self._trace_ingest_cb(_t_batch) # FIXME: Move this + return TraceIngestResponse(ok=True, accepted=len(batch.events)) try: @@ -1315,6 +1336,7 @@ async def _handle_chat_completion(self, req: ChatRequestModel) -> ChatResponseMo Returns: Chat response """ + self.tracer.mark("chat.request.start") stop_id_sequences: List[List[int]] = [ self.tokenizer.encode(stop_word, add_special_tokens=False) # type: ignore for stop_word in req.stop # type: ignore @@ -1445,6 +1467,7 @@ async def _handle_completion( ) # Build optional metrics + self.tracer.mark("chat.request.end") metrics = None if profile_enabled: t_end = time.perf_counter() From 6fb77406889040b023060f502b8dae535c7e46fa Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 00:57:28 -0700 Subject: [PATCH 130/172] track prompt tokens --- src/dnet/ring/api/node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index ff37d34b..cd74defe 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -422,7 +422,7 @@ async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: self._trace_ingest_cb(batch.model_dump()) _t_batch = { "run_id": "NONE", "node_id": "API", "events": list(self.tracer._events) } - #self.tracer._events.clear() + self.tracer._events.clear() self._trace_ingest_cb(_t_batch) # FIXME: Move this return TraceIngestResponse(ok=True, accepted=len(batch.events)) @@ -1336,7 +1336,7 @@ async def _handle_chat_completion(self, req: ChatRequestModel) -> ChatResponseMo Returns: Chat response """ - self.tracer.mark("chat.request.start") + self.tracer.mark("chat.request.start", {"prompt_tokens": len(req.messages[0].content)}) stop_id_sequences: List[List[int]] = [ self.tokenizer.encode(stop_word, add_special_tokens=False) # type: ignore for stop_word in req.stop # type: ignore From b4580396f0026d929a9269b7fd7b0ea87c208175 Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 01:25:49 -0700 Subject: [PATCH 131/172] move trace frame and add correct nonce and other metadata --- src/dnet/ring/api/node.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index cd74defe..92384336 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -1336,7 +1336,6 @@ async def _handle_chat_completion(self, req: ChatRequestModel) -> ChatResponseMo Returns: Chat response """ - self.tracer.mark("chat.request.start", {"prompt_tokens": len(req.messages[0].content)}) stop_id_sequences: List[List[int]] = [ self.tokenizer.encode(stop_word, add_special_tokens=False) # type: ignore for stop_word in req.stop # type: ignore @@ -1410,6 +1409,13 @@ async def _handle_completion( t_start = time.perf_counter() t_first_token = None nonce = f"chatcmpl-{uuid.uuid4()}" + + self.tracer.mark("chat.request.start", { + "tokenizer": None, + "prompt_tokens": prompt.size, + "nonce": nonce, + }) + detokenizer = self.tokenizer.detokenizer # type: ignore detokenizer.reset() tokens: List[int] = [] @@ -1466,8 +1472,12 @@ async def _handle_completion( else detokenizer.text[: -len(stop_sequence_suffix)] ) + self.tracer.mark("chat.request.end", { + "generated_tokens": len(tokens), + "nonce": nonce, + }) + # Build optional metrics - self.tracer.mark("chat.request.end") metrics = None if profile_enabled: t_end = time.perf_counter() From 4600a5f7c1a66ddb9bc6eb2e3f7ae13cfea3054f Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 03:26:35 -0700 Subject: [PATCH 132/172] track nonce on all frames --- src/dnet/ring/shard/node.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 3ed08470..d231c907 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -615,10 +615,12 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): logger.error("Node %s: Cannot receive activation - input pool not initialized", self.node_id) return - with self.tracer.frame("network.rx", "connect_next_node"): + with self.tracer.frame("network.rx", "connect_next_node") as f: + f.set("nonce", request.nonce) await self._connect_next_node() with self.tracer.frame("network.rx", "process_activation") as f: + f.set("nonce", request.nonce) try: activation = request.activation target_layer = activation.layer_id + 1 @@ -692,6 +694,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): return with self.tracer.frame("network.rx", "alloc.buffer") as fr: + fr.set("nonce", request.nonce) pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, dtype=deq.dtype, @@ -715,6 +718,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): else: # Special token stream support: dtype='tokens' carries int32 token IDs if activation.dtype == "tokens": with self.tracer.frame("network.rx", "token_stream") as fr: + fr.set("nonce", request.nonce) try: tokens = np.frombuffer(request.activation.data, dtype=np.int32) shp = (int(len(tokens)), ) @@ -742,6 +746,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): else: with self.tracer.frame("network.ex", "default") as fr: + fr.set("nonce", request.nonce) # Safety: byte length must match shape*dtype try: expected = ( @@ -852,7 +857,8 @@ async def _ingress_worker(self): if target_layer in self._assigned_set: # Heavy prep in executor (alloc/copy/decompress) - with self.tracer.frame("grpc.ingress", "prepare"): + with self.tracer.frame("grpc.ingress", "prepare") as fr: + fr.set("nonce", req.nonce) loop = asyncio.get_running_loop() try: activation_msg = await loop.run_in_executor( @@ -885,6 +891,7 @@ async def _ingress_worker(self): # Enqueue for compute with self.tracer.frame("network.rx", "enque") as fr: + fr.set("nonce", req.nonce) while self.running: try: self.activation_recv_queue.put_nowait(activation_msg) @@ -974,6 +981,7 @@ def _prepare_activation_message_blocking( if "|" in activation.dtype: # Compressed path: decompress to MLX array and copy to pool with self.tracer.frame("network.rx.prepare_activation", "decompress") as f: + f.set("nonce", request.nonce) try: deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, @@ -1010,6 +1018,7 @@ def _prepare_activation_message_blocking( elif activation.dtype == "tokens": # Tokens path: parse int32 token IDs and stage them with self.tracer.frame("network.rx.prepare_activation", "tokens") as f: + f.set("nonce", request.nonce) try: tokens = np.frombuffer(activation.data, dtype=np.int32) shp = (int(len(tokens)),) @@ -1039,6 +1048,7 @@ def _prepare_activation_message_blocking( else: # Dense path: validate size and copy raw bytes view into pool buffer with self.tracer.frame("network.rx.prepare_activation", "default") as f: + f.set("nonce", request.nonce) try: expected = ( int(np.prod(activation.shape)) @@ -1109,6 +1119,7 @@ def _compute_worker(self) -> None: # Process the activation with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("nonce", activation_msg.nonce) f.set("inwait", time.perf_counter() - activation_msg.ex_enq_t) if (self.model_metadata.num_layers - 1) in self.assigned_layers: f.set("lm_head", True) @@ -1522,7 +1533,7 @@ async def load_model_endpoint( f"api_callback={req.api_callback_address or 'none'}" ) self.tracer.mark("model", {"model": req.model_path, "ts": time.perf_counter()}) # Record model name - with self.tracer.frame("memory", "model.load"): # NOTE: Symbol hardcoded for runtime stats + with self.tracer.frame("memory", "model.load") as f: # NOTE: Symbol hardcoded for runtime stats result = await self.load_model(req) return result @@ -1540,7 +1551,7 @@ async def unload_model_endpoint() -> ShardUnloadModelResponse: """Unload current model.""" try: logger.info("HTTP /unload_model") - with self.tracer.frame("memory", "model.unload"): # NOTE: Symbol hardcoded for runtime stats + with self.tracer.frame("memory", "model.unload") as f: # NOTE: Symbol hardcoded for runtime stats result = await self.unload_model() return result @@ -1555,7 +1566,9 @@ async def unload_model_endpoint() -> ShardUnloadModelResponse: # FIXME: add pydantic type here async def warm(request: Request) -> JSONResponse: try: - with self.tracer.frame("memory", "model.warm"): # NOTE: Symbol hardcoded for runtime stats + # FIXME: Append warmup config? Or something to distinguish + with self.tracer.frame("memory", "model.warm") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("nonce", request.nonce) body = await request.json() start = int(body.get("start", -1)) window = int(body.get("window", self.window_size)) @@ -1620,7 +1633,7 @@ async def _profile_device(self, repo_id: str, max_batch_exp: int) -> dict: Returns: Device profile information as a plain dict """ - with self.tracer.frame("startup", "profile.device"): # NOTE: Symbol hardcoded for runtime stats + with self.tracer.frame("startup", "profile.device") as f: # NOTE: Symbol hardcoded for runtime stats profile_dict = profile_device_via_subprocess( repo_id, max_batch_exp=max_batch_exp, debug=0 ) From d941191aa5eb2e970d3377913edfade0d7e7a39c Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 14:02:50 -0700 Subject: [PATCH 133/172] better track in-wait time and default to 0 for single shard --- src/dnet/protos/dnet_ring.proto | 7 ++++--- src/dnet/protos/shard_api_comm.proto | 3 ++- src/dnet/ring/data_types.py | 1 + src/dnet/ring/shard/comms.py | 5 ++++- src/dnet/ring/shard/node.py | 10 ++++++++-- src/repl.py | 19 ++++++++++++++----- 6 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/dnet/protos/dnet_ring.proto b/src/dnet/protos/dnet_ring.proto index 5452a559..8009601f 100644 --- a/src/dnet/protos/dnet_ring.proto +++ b/src/dnet/protos/dnet_ring.proto @@ -33,9 +33,10 @@ message ActivationRequest { Activation activation = 2; int64 timestamp = 3; float rx_enq_t = 4; - float rx_inflight_t = 5; - string node_origin = 6; - string callback_url = 7; + float tx_enq_prev_t = 5; + float rx_inflight_t = 6; + string node_origin = 7; + string callback_url = 8; } // Response message for activation sending diff --git a/src/dnet/protos/shard_api_comm.proto b/src/dnet/protos/shard_api_comm.proto index ec80cf18..6540f65f 100644 --- a/src/dnet/protos/shard_api_comm.proto +++ b/src/dnet/protos/shard_api_comm.proto @@ -35,6 +35,7 @@ message TokenRequest { string nonce = 1; int32 token_id = 2; int64 timestamp = 3; + float tx_enq_prev_t = 4; } // Response for token reception @@ -49,4 +50,4 @@ message RingError { string failed_node = 2; string error_code = 3; string error = 4; -} \ No newline at end of file +} diff --git a/src/dnet/ring/data_types.py b/src/dnet/ring/data_types.py index db9438e5..ee6f81e6 100644 --- a/src/dnet/ring/data_types.py +++ b/src/dnet/ring/data_types.py @@ -27,6 +27,7 @@ class ActivationMessage: # TX queue enqueue time (perf_counter seconds) rx_enq_perf_t: float = 0.0 tx_enq_perf_t: float = 0.0 + tx_enq_prev_t: float = 0.0 rx_ingress_t: float = 0.0 rx_inflight_t: float = 0.0 ex_enq_t: float = 0.0 diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index b36f8250..a8eff0ac 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -274,6 +274,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): nonce=activation_msg.nonce, token_id=int(getattr(activation_msg, "token_id", -1)), timestamp=utc_epoch_now(), + tx_enq_prev_t=time.perf_counter(), ) resp = await self.api_stub.SendToken(req) # type: ignore if not resp.success: @@ -355,7 +356,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): if (nxt < self.model_metadata.num_layers) and (nxt not in self._assigned_set): if self.next_node_stub: - with self.tracer.frame("grpc", "send_activation.next") as f: + with self.tracer.frame("network", "send_activation.next") as f: request = activation_msg.to_proto(data) request.timestamp = utc_epoch_now() if self._mode == "offload" and self.window_size > 0: @@ -416,6 +417,8 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) ctx.disabled = True + request.tx_enq_prev_t = time.perf_counter() + # Prefer streaming if enabled/available; fallback to unary stream_used = False ctx = await self._ensure_stream(activation_msg.nonce) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index d231c907..d498dd89 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -803,7 +803,8 @@ async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: try: rx_t = time.perf_counter() request.rx_enq_t = rx_t - request.rx_inflight_t = rx_t - request.timestamp + request.rx_inflight_t = 0.0 if request.tx_enq_prev_t == 0.0 else rx_t - request_enq_prev_t + logger.error(f"rx_t {rx_t} --- tx_enq {request.tx_enq_prev_t}") self.ingress_q.put_nowait(request) logger.debug(f"[ENQUE] Enqueued activation request") @@ -1120,9 +1121,14 @@ def _compute_worker(self) -> None: # Process the activation with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats f.set("nonce", activation_msg.nonce) - f.set("inwait", time.perf_counter() - activation_msg.ex_enq_t) + if activation_msg.ex_enq_t == 0.0: # FIXME float comparison + f.set("inwait", 0.0) + else: + f.set("inwait", time.perf_counter() - activation_msg.ex_enq_t) + if (self.model_metadata.num_layers - 1) in self.assigned_layers: f.set("lm_head", True) + self._process_activation(activation_msg) except Empty: diff --git a/src/repl.py b/src/repl.py index f9cb363b..dc629123 100644 --- a/src/repl.py +++ b/src/repl.py @@ -536,6 +536,9 @@ def do_trace(self, cmd): self._trace_cfg.enabled = False dprint("Tracing is now OFF\n") + case s if s in ["status"]: + dprint(f"Captured {len(self._trace_agg._req)} frames.\n") + case s if s == "focus": dprint("Subsystems not yet implemented.\n") @@ -572,7 +575,10 @@ def do_perf(self, cmd): return match cmd[1]: - case s if s in "...": + case s if s in "stats": + print(f"{self._stats_agg._nonces}") + print(f"{self._stats_agg._running_stats}") + print(f"{self._stats_agg._stats}") pass case _: pass @@ -580,10 +586,13 @@ def do_perf(self, cmd): # Trace callback registered with API Thread # This forwards the tracer frames back to the REPL for printing def __trace_cb(self, data): - if self._tracing.is_set(): - self._trace_agg.enqueue(data) - if self._stats.is_set(): - self._stats_agg.add(data) + try: + if self._tracing.is_set(): + self._trace_agg.enqueue(data) + if self._stats.is_set(): + self._stats_agg.add(data) + except Exception as e: + print(f"Unable to ingest trace buffer into REPL: {e}") def __print_tr(self, row): sym = " " + symbol.ljust(40, ' ') From 3ac46e6d41bf6eb0dca9506a91eca979c042a72e Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 15:55:13 -0700 Subject: [PATCH 134/172] started filtering counters and printing --- src/dnet/perf/utils/aggregators.py | 145 ++++++++++++++++++++++------- 1 file changed, 109 insertions(+), 36 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index fad29a7d..6a0653d0 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -1,6 +1,7 @@ from __future__ import annotations +import sys import threading from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple, Optional, DefaultDict @@ -213,6 +214,7 @@ class ReqStats: itl: List[float] # Inter-token latency per round prompt_tokens: int # Number of prompt tokens per request (req_id: #) generated_tokens: int # Number of generated tokens per request (req_id: #) + total_tokens: int # Total number of tokens processed latencies: List[List[str, str, str, int]] # List of inter-node latencies: [node0, node1, p50, 0.0] latency_per_layer: Dict[int, float] # Map of {layer: 0.0} @@ -248,10 +250,13 @@ def __init__(self) -> None: # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: - run_id = data["run_id"] or "NONE" - node_id = data["node_id"] or "NONE" + run_id = data["run_id"] + node_id = data["node_id"] events = data["events"] or [] - if not run_id or not node_id: return # Drop the batch + + if not run_id or not node_id: + print("Dropped batch") + return # Drop the batch with self._lock: @@ -269,6 +274,7 @@ def add(self, data: Dict[str, Any]) -> None: if len(self._frames[node_id]) >= self._max_inflight_req: # remove oldest entry del self._frames[self._nonces[0]] del self._nonces[0] + if nonce not in self._nonces: self._nonces.append(nonce) @@ -277,82 +283,120 @@ def add(self, data: Dict[str, Any]) -> None: nonce = e["args"]["nonce"] assert nonce is not None, "" + if not node_id or not nonce: return # Drop invalid frames - if e["name"] == "chat.request.end": - print(e) if e["name"] == "chat.request.start": - print(e) + print(e["args"]) + self._open_frames[nonce] = {} + self._nonce_prefill[nonce] = True self._running_stats[nonce] = ReqStats( - model="", - tokenizer="", + model=e["args"]["model"], + tokenizer=e["args"]["tokenizer"], run_id=run_id, nonce=nonce, - ttft=0.0, + ttft=e["args"]["t0"], # set to initial timestamp then compute itl=[0.0], generated_tokens=0, prompt_tokens=e["args"]["prompt_tokens"], + total_tokens=e["args"]["prompt_tokens"], latencies={}, latency_per_layer={}, latency_per_shard={}, total_latency=0.0, assignment=None, topo=None, + layer_assignment_t=None, + throughput=0.0, + startup_t=0.0, ) + + if e["name"] == "embedding": # Register new request pass # FIXME: We might receive other frames then "embed" from shards # so we need to handle the creation of this better - if nonce not in self._running_stats: - continue + if nonce not in self._running_stats: + continue stats = self._running_stats[nonce] - if e["name"] == "network.rx": - # Time in transport, ingress queue and ingress_worker + if e["name"] == "network.rx": # Time in transport, ingress queue and ingress_worker + print(f"\n{e["name"]}\n{e["args"]["inflight"]}\n{e["args"]["inwait"]}\n{e["args"]["ms"]}") _cost = lambda e: e["args"]["inflight"] + e["args"]["inwait"] + e["args"]["ms"] self._handle_frame(e, nonce, stats, _cost) #TODO: change shard in metadata if e["name"] == "compute.forward": + print(f"\n{e["name"]}\n{e["args"]["inwait"]}\n{e["args"]["ms"]}") _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # compute queue + execution self._handle_frame(e, nonce, stats, _cost) + self._nonce_round_finish[nonce] = False + + # End a cycle on compute done (inter-node queue wait is computed in next) + if self._nonce_prefill[nonce]: + stats.ttft = e["args"]["t0"] - stats.ttft + else: + stats.itl[-1] = e["args"]["t0"] - stats.itl[-1] + stats.itl.append(e["args"]["t0"]) + + if e["name"] == "chat.request.end": + if self._nonce_round_finish[nonce]: + self._nonce_round_finish[nonce] = True + pass + self._nonce_round_finish[nonce] = True + st_obj = self._running_stats[nonce] + st_obj.generated_tokens = e["args"]["generated_tokens"] + st_obj.total_tokens += e["args"]["generated_tokens"] + self._stats[nonce] = st_obj + del self._running_stats[nonce] + #del self._frames[node_id][nonce] + # TODO: Handle latency of transfer back to API - # Finish request - if "lm_head" in e["args"] and not self._nonce_round_finish[nonce]: - self._nonce_round_finish[nonce] = True - st_obj = self._running_stats[nonce] - self._stats[nonce] = st_obj - del self._running_stats[nonce] - #del self._frames[node_id][nonce] - # TODO: Handle latency of transfer back to API - if e["name"] == "network.tx.send": _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # tx queue + sendoff self._handle_frame(e, nonce, stats, _cost) # Handle cost aggregation of frames def _handle_frame(self, e: Any, nonce, stats: ReqStats, _cost_fnc: Any): - if e.type == 'B': - self._open_frames[nonce][e.name] = e - return - elif e.type == 'E': - n_rt = _cost_fnc(e) # Custom cost function for each farme - if self._nonce_prefill[nonce]: - stats.ttft += n_rt - else: - stats.itl[-1] += n_rt - del self._open_frames[nonce][e.name] + try: + if e["type"] == 'B': + self._open_frames[nonce][e["name"]] = e + return + elif e["type"] == 'E': + n_rt = _cost_fnc(e) # Custom cost function for each farme + if self._nonce_prefill[nonce]: + stats.ttft += n_rt + else: + stats.itl[-1] += n_rt + if e["name"] in self._open_frames[nonce]: + del self._open_frames[nonce][e["name"]] + except Exception as ex: + print(f"{ex}") # Return data for total, per req, worker or model (maybe add per layer too?) def stats( self, - req_id: Optional[str], - worker: Optional[str], - model: Optional[str] + req_id: Optional[str] = None, + worker: Optional[str] = None, + model: Optional[str] = None ): + fields = [ # 0 is native, 1 compound + (0, "prompt_tokens", ""), + (0, "generated_tokens", ""), + (0, "total_tokens", ""), + (0, -1, ""), # special for empty line + (0, "ttft", "ms"), + (1, "tokens_per_second", "ms"), + (1, "inter_token_latency", "ms"), + (0, "throughput", "ms"), + (0, -1, ""), + (1, "workers", "ms"), + (1, "estimated_compute", "GFLOPs") + ] + if req_id: pass @@ -363,5 +407,34 @@ def stats( pass else: # Sort per model, per request (node info only when requested) - pass + if len(self._stats) < 1: + print("No tracked stats in memory. Track a request first.\n") + return + stats = self._stats[list(self._stats.keys())[-1]] + sys.stdout.write(f"\n Performance counters stats for model '{stats.model}':\n\n") + for tag, n, unit in fields: + if tag == 0: # Native counter + if n == -1: + sys.stdout.write("\n") + continue + nr = getattr(stats, n) + if isinstance(nr, int): + nr_str = f"{nr}" + elif isinstance(nr, float): + nr_str = f"{nr:15.5}" + elif isinstance(nr, str): + if len(nr) > 20: + nr_str = nr[:15] + "..." + else: + nr_str = nr + elif tag == 1: + match n: + case "tokens_per_second": + case "inter_token_latency": + case _: + pass + sys.stdout.write(f"{nr_str.rjust(20)} {unit.ljust(4)}\t{n}\n") + sys.stdout.write("\n\n") + return + From 20af919a557000f3998d107bda4ec1460fc3cf0a Mon Sep 17 00:00:00 2001 From: Octavian Date: Fri, 24 Oct 2025 23:20:26 -0700 Subject: [PATCH 135/172] basic counters working, ttft, tps, itl, token_count --- src/dnet/perf/utils/aggregators.py | 238 ++++++++++++++++------------- src/dnet/ring/api/node.py | 8 +- src/repl.py | 11 +- 3 files changed, 143 insertions(+), 114 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index 6a0653d0..51b4494d 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -3,13 +3,17 @@ import sys import threading +import statistics from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple, Optional, DefaultDict from collections import defaultdict, deque -from dnet.utils.logger import logger +#from dnet.utils.logger import logger +from dnet.ring.api.api_logging import get_api_logger from dnet.ring import LayerAssignment, TopologyInfo +logger = get_api_logger() + Key = Tuple[str, Optional[int], Optional[int], str] # (node_id, pid, tid, req_id) @dataclass @@ -206,23 +210,31 @@ def roots(self, run_id: str, req_id: str) -> List[Dict[str, Any]]: # Track a single request, use multiple for a full benchmark @dataclass class ReqStats: - model: str # Model name - tokenizer: str # Tokenizer name - run_id: str # ID of session (for later mapping) - nonce: str # List of serviced requests - ttft: float # Time to first token - itl: List[float] # Inter-token latency per round - prompt_tokens: int # Number of prompt tokens per request (req_id: #) - generated_tokens: int # Number of generated tokens per request (req_id: #) - total_tokens: int # Total number of tokens processed - - latencies: List[List[str, str, str, int]] # List of inter-node latencies: [node0, node1, p50, 0.0] - latency_per_layer: Dict[int, float] # Map of {layer: 0.0} - latency_per_shard: Dict[str, float] # Map of {shard: 0.0} - total_latency: int # Total runtime of requests - throughput: float # aaa - startup_t: float # Time to start shard (ms) - layer_assignment_t: float # Time to layer assignment (ms) + model: str = "" # Model name + tokenizer: str = "" # Tokenizer name + run_id: str = "" # ID of session (for later mapping) + nonce: str = "" # List of serviced requests + ttft: float = 0.0 # Time to first token + itl: List[float] = None # Inter-token latency per round + prompt_tokens: int = -1 # Number of prompt tokens per request (req_id: #) + generated_tokens: int = -1 # Number of generated tokens per request (req_id: #) + total_tokens: int = -1 # Total number of tokens processed + + latencies: List[List[str, str, str, int]] = None # List of inter-node latencies: [node0, node1, p50, 0.0] + latency_per_layer: Dict[int, float] = None # Map of {layer: 0.0} + latency_per_shard: Dict[str, float] = None # Map of {shard: 0.0} + total_latency: int = -1 # Total runtime of requests + startup_t: float = 0.0 # Time to start shard (ms) + layer_assignment_t: float = 0.0 # Time to layer assignment (ms) + + # Per-worker data + compute_per_worker: Dict[str, float] = None + inwait_per_worker: Dict[str, float] = None + inflight_per_worker: Dict[str, float] = None + + # Network info + tx_bytes_per_node: Dict[str, int] = None # Volume of trafic per node + rx_bytes_per_node: Dict[str, int] = None topo: TopologyInfo = None # Topology information for this request (keep here since it might change) assignment: LayerAssignment = None # Map of layer to shard IDs @@ -259,7 +271,6 @@ def add(self, data: Dict[str, Any]) -> None: return # Drop the batch with self._lock: - # Ensure we register workers and nodes for i, ev in enumerate(events): if "nonce" not in ev["args"]: ev["args"]["nonce"] = f"N_" @@ -283,11 +294,9 @@ def add(self, data: Dict[str, Any]) -> None: nonce = e["args"]["nonce"] assert nonce is not None, "" - if not node_id or not nonce: return # Drop invalid frames if e["name"] == "chat.request.start": - print(e["args"]) self._open_frames[nonce] = {} self._nonce_prefill[nonce] = True self._running_stats[nonce] = ReqStats( @@ -295,23 +304,19 @@ def add(self, data: Dict[str, Any]) -> None: tokenizer=e["args"]["tokenizer"], run_id=run_id, nonce=nonce, - ttft=e["args"]["t0"], # set to initial timestamp then compute - itl=[0.0], - generated_tokens=0, + ttft= e["args"]["t0"], + itl=[ e["args"]["t0"], ], prompt_tokens=e["args"]["prompt_tokens"], total_tokens=e["args"]["prompt_tokens"], latencies={}, latency_per_layer={}, latency_per_shard={}, - total_latency=0.0, assignment=None, - topo=None, - layer_assignment_t=None, - throughput=0.0, - startup_t=0.0, + compute_per_worker={}, + inwait_per_worker={}, + inflight_per_worker={}, ) - if e["name"] == "embedding": # Register new request pass @@ -323,55 +328,45 @@ def add(self, data: Dict[str, Any]) -> None: stats = self._running_stats[nonce] if e["name"] == "network.rx": # Time in transport, ingress queue and ingress_worker - print(f"\n{e["name"]}\n{e["args"]["inflight"]}\n{e["args"]["inwait"]}\n{e["args"]["ms"]}") _cost = lambda e: e["args"]["inflight"] + e["args"]["inwait"] + e["args"]["ms"] - self._handle_frame(e, nonce, stats, _cost) #TODO: change shard in metadata if e["name"] == "compute.forward": - print(f"\n{e["name"]}\n{e["args"]["inwait"]}\n{e["args"]["ms"]}") - _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # compute queue + execution - self._handle_frame(e, nonce, stats, _cost) - self._nonce_round_finish[nonce] = False - - # End a cycle on compute done (inter-node queue wait is computed in next) - if self._nonce_prefill[nonce]: - stats.ttft = e["args"]["t0"] - stats.ttft - else: - stats.itl[-1] = e["args"]["t0"] - stats.itl[-1] - stats.itl.append(e["args"]["t0"]) - - if e["name"] == "chat.request.end": - if self._nonce_round_finish[nonce]: - self._nonce_round_finish[nonce] = True - pass - self._nonce_round_finish[nonce] = True - st_obj = self._running_stats[nonce] - st_obj.generated_tokens = e["args"]["generated_tokens"] - st_obj.total_tokens += e["args"]["generated_tokens"] - self._stats[nonce] = st_obj - del self._running_stats[nonce] - #del self._frames[node_id][nonce] - # TODO: Handle latency of transfer back to API - - if e["name"] == "network.tx.send": - _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # tx queue + sendoff - self._handle_frame(e, nonce, stats, _cost) + try: + _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # compute queue + execution + self._handle_round(e, nonce, stats, _cost) + except Exception as e: + print(f"{e}") + + try: + if e["name"] == "chat.request.end": + st_obj = self._running_stats[nonce] + st_obj.generated_tokens = e["args"]["generated_tokens"] + st_obj.total_tokens += e["args"]["generated_tokens"] + print("Adding to stats") + self._stats[nonce] = st_obj + del self._running_stats[nonce] + #del self._frames[node_id][nonce] + # TODO: Handle latency of transfer back to API + + + if e["name"] == "network.tx.send": + _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # tx queue + sendoff + + except Exception as e: + print(f"{e}") # Handle cost aggregation of frames - def _handle_frame(self, e: Any, nonce, stats: ReqStats, _cost_fnc: Any): + def _handle_round(self, e: Any, nonce, stats: ReqStats, _cost_fnc: Any): try: - if e["type"] == 'B': - self._open_frames[nonce][e["name"]] = e - return - elif e["type"] == 'E': - n_rt = _cost_fnc(e) # Custom cost function for each farme - if self._nonce_prefill[nonce]: - stats.ttft += n_rt - else: - stats.itl[-1] += n_rt - if e["name"] in self._open_frames[nonce]: - del self._open_frames[nonce][e["name"]] + logger.error(f"TTFT: {e["args"]["t0"]} - {stats.ttft}") + if self._nonce_prefill[nonce]: + stats.ttft = (e["args"]["t0"] - stats.ttft) * 1000.0 + self._nonce_prefill[nonce] = False + else: + if e["args"]["t0"] > 0.0: + stats.itl[-1] = (e["args"]["t0"] - stats.itl[-1]) + stats.itl.append(e["args"]["t0"]) except Exception as ex: print(f"{ex}") @@ -383,6 +378,7 @@ def stats( model: Optional[str] = None ): + # FIXME: Allow manual selection of counters (and push to tracer) fields = [ # 0 is native, 1 compound (0, "prompt_tokens", ""), (0, "generated_tokens", ""), @@ -391,50 +387,78 @@ def stats( (0, "ttft", "ms"), (1, "tokens_per_second", "ms"), (1, "inter_token_latency", "ms"), - (0, "throughput", "ms"), (0, -1, ""), - (1, "workers", "ms"), - (1, "estimated_compute", "GFLOPs") + (1, "estimated_compute", "GFLOPs"), + (1, "compute_time_per_worker", "ms"), + (1, "inwait_time_per_worker", "ms"), + (1, "inflight_time_per_worker", "ms"), + (0, -1, ""), + (1, "network_latency", "ms"), ] - if req_id: - pass - - elif worker: - pass - - elif model: - pass + # FIXME: Allow filtering by these + if req_id: pass + elif worker: pass + elif model: pass else: # Sort per model, per request (node info only when requested) if len(self._stats) < 1: print("No tracked stats in memory. Track a request first.\n") return stats = self._stats[list(self._stats.keys())[-1]] - sys.stdout.write(f"\n Performance counters stats for model '{stats.model}':\n\n") - for tag, n, unit in fields: - if tag == 0: # Native counter - if n == -1: - sys.stdout.write("\n") - continue - nr = getattr(stats, n) - if isinstance(nr, int): - nr_str = f"{nr}" - elif isinstance(nr, float): - nr_str = f"{nr:15.5}" - elif isinstance(nr, str): - if len(nr) > 20: - nr_str = nr[:15] + "..." - else: - nr_str = nr - elif tag == 1: - match n: - case "tokens_per_second": - case "inter_token_latency": - case _: - pass - sys.stdout.write(f"{nr_str.rjust(20)} {unit.ljust(4)}\t{n}\n") - sys.stdout.write("\n\n") + #sys.stdout.write(f"\n Loaded model '{stats.model}'.\n") + sys.stdout.write(f"Performance stats for request '{stats.nonce}':\n\n") + try: + for tag, n, unit in fields: + if tag == 0: # Native counter + if n == -1: + sys.stdout.write("\n") + continue + nr = getattr(stats, n) + if isinstance(nr, int): + nr_str = f"{nr}" + elif isinstance(nr, float): + nr_str = f"{nr:.2f}" + elif isinstance(nr, str): + if len(nr) > 20: + nr_str = nr[:15] + "..." + else: + nr_str = nr + sys.stdout.write(f"{nr_str.rjust(20)} {unit.ljust(5)}\t{n}\n") + + # Compound trackers + elif tag == 1: + match n: + case "tokens_per_second": + tps = [ 1 / rt for rt in stats.itl ] + #median = statistics.median(tps) + mean = sum(tps) / len(tps) + sys.stdout.write(f"{mean:.2f}".rjust(20)+" tok/s".rjust(5)+" \ttokens_per_second") + sys.stdout.write(f"\t# {statistics.median(stats.itl):.3f} s/tok\n") + + case "inter_token_latency": + itl = stats.itl[:-1] # FIXME: last element is super big + median = statistics.median(itl) + p90 = statistics.quantiles(itl, n=100)[89] + p99 = statistics.quantiles(itl, n=100)[98] + sys.stdout.write(f"{median:.4f}".rjust(20) + " ms".ljust(5) + "\tmean_inter_token_latency (ITL)\n") + sys.stdout.write(" "*35 + f"{p90:.3f} (p90), {p99:.3f} (p99)\n") + sys.stdout.write(" "*35 +f"{min(itl):.3f} (min), {max(itl):.3f} (max)\n") + + case "estimated_compute": + sys.stdout.write(f"UNKNOWN":rjust(20)+" GFLOPs".ljust(5)+"\testimated_flops\n") + + case "compute_time_per_worker": + pass + + case _: + pass + + except Exception as e: + logger.error(f"{e}") + + # Per-node information + sys.stdout.write("\n") return diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 92384336..36f590f7 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -422,8 +422,8 @@ async def trace_ingest(batch: TraceIngestBatch) -> TraceIngestResponse: # type: self._trace_ingest_cb(batch.model_dump()) _t_batch = { "run_id": "NONE", "node_id": "API", "events": list(self.tracer._events) } - self.tracer._events.clear() self._trace_ingest_cb(_t_batch) # FIXME: Move this + self.tracer._events.clear() return TraceIngestResponse(ok=True, accepted=len(batch.events)) @@ -1411,9 +1411,12 @@ async def _handle_completion( nonce = f"chatcmpl-{uuid.uuid4()}" self.tracer.mark("chat.request.start", { - "tokenizer": None, + "tokenizer": "", + "model": req.model, + "temperature": req.temperature, "prompt_tokens": prompt.size, "nonce": nonce, + "t0": time.perf_counter(), }) detokenizer = self.tokenizer.detokenizer # type: ignore @@ -1475,6 +1478,7 @@ async def _handle_completion( self.tracer.mark("chat.request.end", { "generated_tokens": len(tokens), "nonce": nonce, + "t0": time.perf_counter(), }) # Build optional metrics diff --git a/src/repl.py b/src/repl.py index dc629123..f0e9fb6d 100644 --- a/src/repl.py +++ b/src/repl.py @@ -1,4 +1,5 @@ +import io import os import sys import logging @@ -6,6 +7,7 @@ import time import argparse import subprocess +import contextlib from dataclasses import dataclass from typing import Optional, List, Any, Dict @@ -146,7 +148,7 @@ def loop(self): # Main tty loop elif cmd.startswith(("perf", ".perf")): self.do_perf(cmd.split(" ")) continue - elif cmd.startswith(("topo", ".topo")): + elif cmd.startswith(("topo", ".topo", "t ")): self.do_topo(cmd.split(" ")) continue elif cmd.startswith((".model", "model", "m ")): @@ -209,7 +211,7 @@ def do_topo(self, cmd: List[str]) -> None: if cmd[1] == "search": self.print_mdns_nodes() pass - elif cmd[1] == "auto" or cmd[1] == "build": + elif cmd[1] in ("auto", "build", "b"): self.prepare_topo() pass elif cmd[1] == "setup": @@ -391,6 +393,7 @@ def handle_start_worker(self): # ===== Handle API server async def _api_main(self) -> None: # main thread loop + logging.disable(logging.CRITICAL) self._api_loop = asyncio.get_running_loop() self._api_shutdown_e = asyncio.Event() self._node = RingApiNode( @@ -576,9 +579,7 @@ def do_perf(self, cmd): match cmd[1]: case s if s in "stats": - print(f"{self._stats_agg._nonces}") - print(f"{self._stats_agg._running_stats}") - print(f"{self._stats_agg._stats}") + self._stats_agg.stats() pass case _: pass From 6674083d35a3d35b9773965321492290ab54cb09 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sat, 25 Oct 2025 00:51:11 -0700 Subject: [PATCH 136/172] track node_id for every frame --- src/dnet/ring/shard/comms.py | 6 +++++- src/dnet/ring/shard/compute.py | 20 +++++++++++++------ src/dnet/ring/shard/node.py | 36 +++++++++++++++++++++++++++++----- 3 files changed, 50 insertions(+), 12 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index a8eff0ac..b827bcf0 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -174,8 +174,9 @@ async def _send_worker(self): activation_msg = await self.activation_computed_queue.get() with self.tracer.frame("network", "tx") as f: if activation_msg.tx_enq_perf_t and self._profile: - f.set("inwait", time.perf_counter() - self._rx_enque_t) + f.set("inwait", time.perf_counter() - activation_msg.tx_enq_t) f.set("nonce", activation_msg.nonce) + f.set("node", self._instance_name) q_wait_ms = ( time.perf_counter() - activation_msg.tx_enq_perf_t ) * 1000.0 @@ -232,6 +233,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): logger.debug(f"Sending activation") if activation_msg.is_final: with self.tracer.frame("grpc", "send_activation.final") as f: + f.set("node", self._instance_name) try: if self._mode == "offload" and self.window_size > 0: first_window = self._assigned_sorted[: self.window_size] @@ -269,6 +271,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): f.event("reset_api") with self.tracer.frame("grpc", "token_request") as fr: + fr.set("node", self._instance_name) try: req = shard_api_comm_pb2.TokenRequest( nonce=activation_msg.nonce, @@ -357,6 +360,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): if self.next_node_stub: with self.tracer.frame("network", "send_activation.next") as f: + f.set("node", self._instance_name) request = activation_msg.to_proto(data) request.timestamp = utc_epoch_now() if self._mode == "offload" and self.window_size > 0: diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index d51343a3..50744b55 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -76,11 +76,13 @@ def _process_activation(self, activation_msg: ActivationMessage): try: # per-nonce kvcache for concurrent requests - with self.tracer.frame("compute.thread", "kvcache.init"): + with self.tracer.frame("compute.thread", "kvcache.init") as f: + f.set("node", self._instance_name) kv = self._get_or_make_kv(activation_msg.nonce) # Get input activation from pool - with self.tracer.frame("compute.thread", "activations.load"): + with self.tracer.frame("compute.thread", "activations.load") as f: + f.set("node", self._instance_name) input_buffer = self.input_pool.get_buffer(activation_msg.pool_id) if input_buffer is None: logger.error("Failed to get input buffer %s", activation_msg.pool_id) @@ -89,6 +91,7 @@ def _process_activation(self, activation_msg: ActivationMessage): # Prepare input activation with self.tracer.frame("compute.thread", "activations.process") as f: f.set("nonce", activation_msg.nonce) + f.set("node", self._instance_name) if activation_msg.dtype == "tokens": # embed locally on start shard logger.debug(f"Embedding tokens.") numel = int(np.prod(activation_msg.shape)) @@ -124,6 +127,7 @@ def _process_activation(self, activation_msg: ActivationMessage): did_early_swap = False with self.tracer.frame("compute.thread", "weights.prepare") as f: + f.set("node", self._instance_name) # Determine contiguous local window starting at current_layer window_layers: List[int] = [] @@ -217,7 +221,8 @@ def _process_activation(self, activation_msg: ActivationMessage): pass # Execute the window - with self.tracer.frame("compute.thread", "execute"): + with self.tracer.frame("compute.thread", "execute") as f: + f.set("node", self._instance_name) self._beyond_cursor = window_layers[-1] if window_layers else (activation_msg.layer_id) try: # Prevent prefetch touching during encode/compute to minimize UMA pressure @@ -273,7 +278,8 @@ def _process_activation(self, activation_msg: ActivationMessage): #self.weight_cache.decrease_reference(lid) pass - with self.tracer.frame("compute.thread", "execute.evict_and_unload"): + with self.tracer.frame("compute.thread", "execute.evict_and_unload") as f: + f.set("node", self._instance_name) try: # Sliding-fit delta swap: maintain a single resident set by evicting # only what's needed to fit the next window into the budget. Prefer @@ -368,7 +374,8 @@ def _process_activation(self, activation_msg: ActivationMessage): continue # Boundary reached — directly pass tensor to TX to avoid extra copy/sync - with self.tracer.frame("compute.thread", "execute.enqueue_prefetch"): + with self.tracer.frame("compute.thread", "execute.enqueue_prefetch") as f: + f.set("node", self._instance_name) x_cast = x if x.dtype == self._wire_mx_dtype else x.astype(self._wire_mx_dtype) try: self._compute_busy.clear() @@ -382,7 +389,8 @@ def _process_activation(self, activation_msg: ActivationMessage): pass # Create and enqueue output message: either forward activations or finalize on end role - with self.tracer.frame("compute.thread", "grpc.send"): + with self.tracer.frame("compute.thread", "grpc.send") as f: + f.set("node", self._instance_name) nxt = last_layer + 1 if nxt >= self.model_metadata.num_layers: # End of model try: diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index d498dd89..1f5d4b94 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -193,6 +193,7 @@ def __init__( # Discovery self.discovery = AsyncDnetP2P("lib/dnet-p2p/lib") + self._instance_name = "" # Background tasks self.background_tasks: List[asyncio.Task] = [] @@ -259,7 +260,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse and (self.model_path != req.model_path or self.assigned_layers != req.layers) ): logger.info("Node %s: Unloading current model to load new configuration", self.node_id) - with self.tracer.frame("memory", "model.unload"): + with self.tracer.frame("memory", "model.unload") as f: + f.set("node", self._instance_name) await self.unload_model() # Load model metadata @@ -367,7 +369,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # Initialize weight cache - with self.tracer.frame("memory", "weight_cache.init"): + with self.tracer.frame("memory", "weight_cache.init") as f: + f.set("node", self._instance_name) self.weight_cache = WeightCache( self.assigned_layers, self.model_metadata, @@ -380,7 +383,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # Load the model - with self.tracer.frame("memory", "model.load"): + with self.tracer.frame("memory", "model.load") as f: + f.set("node", self._instance_name) self.model = get_ring_model( self.model_metadata.model_type, self.model_metadata.model_config, @@ -404,7 +408,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse logger.warning("[QUANT] apply failed: %s", e) self.model.eval() - with self.tracer.frame("memory", "make_cache"): + with self.tracer.frame("memory", "make_cache") as f: + f.set("node", self._instance_name) self.cache = make_cache( self.model, kv_mode=self.config.kv_cache.mode, @@ -444,7 +449,8 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse self.total_layers = req.total_layers self.api_callback_address = req.api_callback_address - with self.tracer.frame("network", "connect.next_node"): + with self.tracer.frame("network", "connect.next_node") as f: + f.set("node", self._instance_name) if self.next_node: await self._connect_next_node() else: @@ -617,6 +623,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): with self.tracer.frame("network.rx", "connect_next_node") as f: f.set("nonce", request.nonce) + f.set("node", self._instance_name) await self._connect_next_node() with self.tracer.frame("network.rx", "process_activation") as f: @@ -639,6 +646,8 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): t_alloc = time.perf_counter() if "|" in activation.dtype: with self.tracer.frame("grpc.receive", "decompress") as fr: + fr.set("nonce", request.nonce) + fr.set("node", self._instance_name) try: deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, @@ -695,6 +704,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): with self.tracer.frame("network.rx", "alloc.buffer") as fr: fr.set("nonce", request.nonce) + fr.set("node", self._instance_name) pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, dtype=deq.dtype, @@ -719,6 +729,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): if activation.dtype == "tokens": with self.tracer.frame("network.rx", "token_stream") as fr: fr.set("nonce", request.nonce) + fr.set("node", self._instance_name) try: tokens = np.frombuffer(request.activation.data, dtype=np.int32) shp = (int(len(tokens)), ) @@ -746,6 +757,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): else: with self.tracer.frame("network.ex", "default") as fr: + fr.set("node", self._instance_name) fr.set("nonce", request.nonce) # Safety: byte length must match shape*dtype try: @@ -822,6 +834,7 @@ async def _ingress_worker(self): finally enqueues for compute or forwards to the next shard. """ while self.running: + logger.error(f"NODE_ID {self.node_id}") with self.tracer.frame("network.rx", "wait"): # NOTE: bad counter try: req = await self.ingress_q.get() @@ -833,6 +846,7 @@ async def _ingress_worker(self): with self.tracer.frame("network", "rx") as f: f.set("inwait", time.perf_counter() - req.rx_enq_t) f.set("inflight", req.rx_inflight_t) + f.set("node", self._instance_name) f.set("nonce", req.nonce) try: @@ -859,6 +873,7 @@ async def _ingress_worker(self): if target_layer in self._assigned_set: # Heavy prep in executor (alloc/copy/decompress) with self.tracer.frame("grpc.ingress", "prepare") as fr: + fr.set("node", self._instance_name) fr.set("nonce", req.nonce) loop = asyncio.get_running_loop() try: @@ -892,6 +907,7 @@ async def _ingress_worker(self): # Enqueue for compute with self.tracer.frame("network.rx", "enque") as fr: + fr.set("node", self._instance_name) fr.set("nonce", req.nonce) while self.running: try: @@ -982,6 +998,7 @@ def _prepare_activation_message_blocking( if "|" in activation.dtype: # Compressed path: decompress to MLX array and copy to pool with self.tracer.frame("network.rx.prepare_activation", "decompress") as f: + f.set("node", self._instance_name) f.set("nonce", request.nonce) try: deq = decompress_tensor_from_protobuf_data( @@ -1019,6 +1036,7 @@ def _prepare_activation_message_blocking( elif activation.dtype == "tokens": # Tokens path: parse int32 token IDs and stage them with self.tracer.frame("network.rx.prepare_activation", "tokens") as f: + f.set("node", self._instance_name) f.set("nonce", request.nonce) try: tokens = np.frombuffer(activation.data, dtype=np.int32) @@ -1049,6 +1067,7 @@ def _prepare_activation_message_blocking( else: # Dense path: validate size and copy raw bytes view into pool buffer with self.tracer.frame("network.rx.prepare_activation", "default") as f: + f.set("node", self._instance_name) f.set("nonce", request.nonce) try: expected = ( @@ -1121,6 +1140,7 @@ def _compute_worker(self) -> None: # Process the activation with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats f.set("nonce", activation_msg.nonce) + f.set("node", self._instance_name) if activation_msg.ex_enq_t == 0.0: # FIXME float comparison f.set("inwait", 0.0) else: @@ -1130,6 +1150,7 @@ def _compute_worker(self) -> None: f.set("lm_head", True) self._process_activation(activation_msg) + f.set("t0", time.perf_counter()) except Empty: continue @@ -1227,6 +1248,7 @@ async def _start_discovery(self) -> None: hostname = gethostname() # TODO: optionally take shard name from CLI instance = f"shard-{token_hex(4)}-{hostname}" + self._instance_name = instance self.discovery.create_instance( instance, self.http_port, @@ -1540,6 +1562,7 @@ async def load_model_endpoint( ) self.tracer.mark("model", {"model": req.model_path, "ts": time.perf_counter()}) # Record model name with self.tracer.frame("memory", "model.load") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("node", self._instance_name) result = await self.load_model(req) return result @@ -1558,6 +1581,7 @@ async def unload_model_endpoint() -> ShardUnloadModelResponse: try: logger.info("HTTP /unload_model") with self.tracer.frame("memory", "model.unload") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("node", self._instance_name) result = await self.unload_model() return result @@ -1574,6 +1598,7 @@ async def warm(request: Request) -> JSONResponse: try: # FIXME: Append warmup config? Or something to distinguish with self.tracer.frame("memory", "model.warm") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("node", self._instance_name) f.set("nonce", request.nonce) body = await request.json() start = int(body.get("start", -1)) @@ -1640,6 +1665,7 @@ async def _profile_device(self, repo_id: str, max_batch_exp: int) -> dict: Device profile information as a plain dict """ with self.tracer.frame("startup", "profile.device") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("node", self._instance_name) profile_dict = profile_device_via_subprocess( repo_id, max_batch_exp=max_batch_exp, debug=0 ) From c47a351925e063ed3232b1cbd6757cdb0ad8b22c Mon Sep 17 00:00:00 2001 From: Octavian Date: Sat, 25 Oct 2025 01:23:35 -0700 Subject: [PATCH 137/172] fix trace annotate --- src/repl.py | 68 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/src/repl.py b/src/repl.py index f0e9fb6d..231cecb1 100644 --- a/src/repl.py +++ b/src/repl.py @@ -393,7 +393,7 @@ def handle_start_worker(self): # ===== Handle API server async def _api_main(self) -> None: # main thread loop - logging.disable(logging.CRITICAL) + #logging.disable(logging.CRITICAL) self._api_loop = asyncio.get_running_loop() self._api_shutdown_e = asyncio.Event() self._node = RingApiNode( @@ -532,10 +532,12 @@ def do_trace(self, cmd): match cmd[1]: case s if s in ["on", "ON"]: + self._tracing.set() self._trace_cfg.enabled = True dprint("Tracing is now ON\n") case s if s in ["off", "OFF"]: + self._tracing.clear() self._trace_cfg.enabled = False dprint("Tracing is now OFF\n") @@ -562,7 +564,11 @@ def do_trace(self, cmd): dprint("Not implemented yet\n") case s if s == "annotate": - self.print_trace_annotate("NONE") + if len(self._trace_agg._req) < 1: + dprint("No trace frames captured. Is tracing enabled?\n") + return + last = list(self._trace_agg._req.keys())[-1] + self.print_trace_annotate(last) case _: dprint("Unknown trace command. Type 'help' for a list of available commands.\n") @@ -608,33 +614,37 @@ def print_trace_annotate( repeats: int = 0, ) -> List[Dict[str, Any]]: - rows = self._trace_agg.annotate(run_id) - headers = ["name", "total","max","mean","p50","p90","p99","samples"] - limits = {"name": 50,} - w = {h: max(len(h), min(limits.get(h, 8), max(len(str(r[h])) for r in rows))) for h in headers} - w["name"] = max(w["name"], 35) - - line = " ".join(h.ljust(w[h]) for h in headers); sys.stdout.write("\n") - sys.stdout.write(line + "\n") - sys.stdout.write(" ".join("."*w[h] for h in headers)); sys.stdout.write("\n") - for r in rows: - name = str(r["name"]) - if len(name) > w["name"]: name = name[:w["name"]-1] + "..." - vals = { - "name": r["name"], - "total": r["total"], - "max": r["max"], - "mean": r["mean"], - "p50": r["p50"], - "p90": r["p90"], - "p99": r["p99"], - "samples": r["samples"], - } - sys.stdout.write(" " + str(vals[headers[0]]).ljust(w[headers[0]])) - sys.stdout.write(" ".join(f"{vals[h]:8.2f}".rjust(w[h]) for h in headers[1:])) - sys.stdout.write("\n") - sys.stdout.write("\n\n") - sys.stdout.flush() + try: + rows = self._trace_agg.annotate(run_id) + logger.debug(f"rows") + headers = ["name", "total","max","mean","p50","p90","p99","samples"] + limits = {"name": 50,} + w = {h: max(len(h), min(limits.get(h, 8), max(len(str(r[h])) for r in rows))) for h in headers} + w["name"] = max(w["name"], 35) + + line = " ".join(h.ljust(w[h]) for h in headers); sys.stdout.write("\n") + sys.stdout.write(line + "\n") + sys.stdout.write(" ".join("."*w[h] for h in headers)); sys.stdout.write("\n") + for r in rows: + name = str(r["name"]) + if len(name) > w["name"]: name = name[:w["name"]-1] + "..." + vals = { + "name": r["name"], + "total": r["total"], + "max": r["max"], + "mean": r["mean"], + "p50": r["p50"], + "p90": r["p90"], + "p99": r["p99"], + "samples": r["samples"], + } + sys.stdout.write(" " + str(vals[headers[0]]).ljust(w[headers[0]])) + sys.stdout.write(" ".join(f"{vals[h]:8.2f}".rjust(w[h]) for h in headers[1:])) + sys.stdout.write("\n") + sys.stdout.write("\n\n") + sys.stdout.flush() + except Exception as e: + logger.error(f"{e}") def _print_nodes_table(self, rows: List[Any]) -> None: headers = ["name", "role", "addr", "http", "grpc", "status", "head"] From f5057908f5db42d482f3940c231b9cdf70d965af Mon Sep 17 00:00:00 2001 From: Octavian Date: Sat, 25 Oct 2025 01:24:23 -0700 Subject: [PATCH 138/172] aggreagate frame symbols into sub-groups for focusing and compute total time per subsection --- src/dnet/perf/utils/aggregators.py | 42 +++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index 51b4494d..50d7ddd5 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -122,6 +122,7 @@ def __init__(self) -> None: self._lock = threading.Lock() def enqueue(self, batch: Dict[str, Any]) -> None: + try: run_id = batch.get("run_id") node_id = batch.get("node_id") events = batch.get("events") or [] @@ -131,6 +132,8 @@ def enqueue(self, batch: Dict[str, Any]) -> None: agg = self._req.setdefault(run_id, RunAggregator()) for ev in events: agg.ingest_event(node_id, ev) + except Exception as e: + logger.error(f"Trace aggregator enque error: {e}") def annotate(self, run_id: str, *, mapping: Optional[Dict[str, str]] = None, repeats: int = 0) -> List[Dict[str, Any]]: with self._lock: @@ -260,6 +263,43 @@ def __init__(self) -> None: self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per nonce) self._model_per_run: Dict[str, str] = {} # Track model per run_id + # Maps of frames to higher-level sub-systems + self._compute_set = [ + "compute.forward", + "compute.thread.kvcache.init", + "compute.thread.weights.prepare", + "compute.thread.activations.process", + "compute.thread.activations.load", + "compute.thread.execute", + "compute.thread.execute.enqueue_prefetch", + "compute.thread.execute.evict_and_unload", + "compute.thread.cleanup", + "compute.thread.mdns.send", + ] + + self._network_set = [ + "network.tx", + "network.token_request", + "network.rx.prepare", + "network.rx.prepare_activation.tokens", + "network.rx.enque", + "network.send_activation.final", + "network.rx", + "network.connect.next_node", + "network.rx.prefetch", + ] + + self._memory_set = [ + "memory.model.load", + "memory.model.load_metadata", + "memory.warmup", + "memory.weight_cache.init", + "memory.prefetch", + "memory.memory_pools.init", + "memory.cache.reset", + "memory.make_cache", + ] + # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: run_id = data["run_id"] @@ -446,7 +486,7 @@ def stats( sys.stdout.write(" "*35 +f"{min(itl):.3f} (min), {max(itl):.3f} (max)\n") case "estimated_compute": - sys.stdout.write(f"UNKNOWN":rjust(20)+" GFLOPs".ljust(5)+"\testimated_flops\n") + sys.stdout.write(f"UNKNOWN".rjust(20)+" GFLOPs".ljust(5)+"\testimated_flops\n") case "compute_time_per_worker": pass From f509b55f65b50e6d968b9103337bccbfd7c35900 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sat, 25 Oct 2025 02:57:30 -0700 Subject: [PATCH 139/172] typo in activation timestamp --- src/dnet/ring/shard/node.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 1f5d4b94..589b4c42 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -815,8 +815,7 @@ async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: try: rx_t = time.perf_counter() request.rx_enq_t = rx_t - request.rx_inflight_t = 0.0 if request.tx_enq_prev_t == 0.0 else rx_t - request_enq_prev_t - logger.error(f"rx_t {rx_t} --- tx_enq {request.tx_enq_prev_t}") + request.rx_inflight_t = 0.0 if request.tx_enq_prev_t == 0.0 else rx_t - request.tx_enq_prev_t self.ingress_q.put_nowait(request) logger.debug(f"[ENQUE] Enqueued activation request") @@ -834,7 +833,6 @@ async def _ingress_worker(self): finally enqueues for compute or forwards to the next shard. """ while self.running: - logger.error(f"NODE_ID {self.node_id}") with self.tracer.frame("network.rx", "wait"): # NOTE: bad counter try: req = await self.ingress_q.get() From 47274367dba01584f688bb977cd02b8f0c809d17 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sat, 25 Oct 2025 03:51:27 -0700 Subject: [PATCH 140/172] per-node info and restructured counting --- src/dnet/perf/utils/aggregators.py | 80 ++++++++++++++++++------------ 1 file changed, 49 insertions(+), 31 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index 50d7ddd5..b5753589 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -232,8 +232,8 @@ class ReqStats: # Per-worker data compute_per_worker: Dict[str, float] = None - inwait_per_worker: Dict[str, float] = None - inflight_per_worker: Dict[str, float] = None + network_per_worker: Dict[str, float] = None + memory_per_worker: Dict[str, float] = None # Network info tx_bytes_per_node: Dict[str, int] = None # Volume of trafic per node @@ -263,6 +263,8 @@ def __init__(self) -> None: self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per nonce) self._model_per_run: Dict[str, str] = {} # Track model per run_id + self.nodes = [] # Keep track of active nodes + # Maps of frames to higher-level sub-systems self._compute_set = [ "compute.forward", @@ -353,13 +355,10 @@ def add(self, data: Dict[str, Any]) -> None: latency_per_shard={}, assignment=None, compute_per_worker={}, - inwait_per_worker={}, - inflight_per_worker={}, + network_per_worker={}, + memory_per_worker={}, ) - if e["name"] == "embedding": # Register new request - pass - # FIXME: We might receive other frames then "embed" from shards # so we need to handle the creation of this better if nonce not in self._running_stats: @@ -367,6 +366,31 @@ def add(self, data: Dict[str, Any]) -> None: stats = self._running_stats[nonce] + if "node" not in e["args"]: + if e["name"] == "chat.request.end": + print(f"{e}") + st_obj = self._running_stats[nonce] + st_obj.generated_tokens = e["args"]["generated_tokens"] + st_obj.total_tokens += e["args"]["generated_tokens"] + print("Adding to stats") + self._stats[nonce] = st_obj + del self._running_stats[nonce] + #del self._frames[node_id][nonce] + # TODO: Handle latency of transfer back to API + + else: + continue # Drop frames without "node" + + node_id = e["args"]["node"] + if node_id not in self.nodes: + self.nodes.append(node_id) + stats.compute_per_worker[node_id] = 0.0 + stats.network_per_worker[node_id] = 0.0 + stats.memory_per_worker[node_id] = 0.0 + + if e["name"] in self._memory_set: + stats.memory_per_worker[node_id] += e["args"]["ms"] + if e["name"] == "network.rx": # Time in transport, ingress queue and ingress_worker _cost = lambda e: e["args"]["inflight"] + e["args"]["inwait"] + e["args"]["ms"] #TODO: change shard in metadata @@ -378,29 +402,20 @@ def add(self, data: Dict[str, Any]) -> None: except Exception as e: print(f"{e}") - try: - if e["name"] == "chat.request.end": - st_obj = self._running_stats[nonce] - st_obj.generated_tokens = e["args"]["generated_tokens"] - st_obj.total_tokens += e["args"]["generated_tokens"] - print("Adding to stats") - self._stats[nonce] = st_obj - del self._running_stats[nonce] - #del self._frames[node_id][nonce] - # TODO: Handle latency of transfer back to API - + if e["name"] in self._compute_set: # Aggregate for compute total + stats.compute_per_worker[node_id] += e["args"]["ms"] - if e["name"] == "network.tx.send": - _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # tx queue + sendoff + if e["name"] in self._network_set: + stats.network_per_worker[node_id] += e["args"]["ms"] - except Exception as e: - print(f"{e}") + if e["name"] in self._memory_set: + stats.memory_per_worker[node_id] += e["args"]["ms"] # Handle cost aggregation of frames def _handle_round(self, e: Any, nonce, stats: ReqStats, _cost_fnc: Any): try: - logger.error(f"TTFT: {e["args"]["t0"]} - {stats.ttft}") if self._nonce_prefill[nonce]: + logger.error(f"TTFT: {stats.ttft}") stats.ttft = (e["args"]["t0"] - stats.ttft) * 1000.0 self._nonce_prefill[nonce] = False else: @@ -429,11 +444,6 @@ def stats( (1, "inter_token_latency", "ms"), (0, -1, ""), (1, "estimated_compute", "GFLOPs"), - (1, "compute_time_per_worker", "ms"), - (1, "inwait_time_per_worker", "ms"), - (1, "inflight_time_per_worker", "ms"), - (0, -1, ""), - (1, "network_latency", "ms"), ] # FIXME: Allow filtering by these @@ -477,6 +487,7 @@ def stats( sys.stdout.write(f"\t# {statistics.median(stats.itl):.3f} s/tok\n") case "inter_token_latency": + assert len(stats.itl) > 1, "Not enough trace frames" itl = stats.itl[:-1] # FIXME: last element is super big median = statistics.median(itl) p90 = statistics.quantiles(itl, n=100)[89] @@ -488,12 +499,19 @@ def stats( case "estimated_compute": sys.stdout.write(f"UNKNOWN".rjust(20)+" GFLOPs".ljust(5)+"\testimated_flops\n") - case "compute_time_per_worker": - pass - case _: pass + for i, n in enumerate(self.nodes): + comp = stats.compute_per_worker[n] + net = stats.network_per_worker[n] + mem = stats.memory_per_worker[n] + total = comp + net + mem + sys.stdout.write(f"\n node{i} [{n}]\n") + sys.stdout.write(f"{comp:.2f}".rjust(20)+"ms".ljust(5)+f"\tcompute_time # [{(comp/total)*100:0.2f}%]\n") + sys.stdout.write(f"{net:.2f}".rjust(20)+"ms".ljust(5)+f"\tnetwork_time # [{(net/total)*100:0.2f}%]\n") + sys.stdout.write(f"{mem:.2f}".rjust(20)+"ms".ljust(5)+f"\tmemory_time # [{(mem/total)*100:0.2f}%]\n") + except Exception as e: logger.error(f"{e}") From f8495a0458e1d80cb109b1c9cdd196e7b5b6ce8b Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 03:59:25 -0700 Subject: [PATCH 141/172] reformat frames --- src/dnet/ring/api/node.py | 8 +++---- src/dnet/ring/shard/comms.py | 5 ++++- src/dnet/ring/shard/compute.py | 28 ++++++++++++++++++++++- src/dnet/ring/shard/node.py | 41 +++++++++++++++++----------------- src/dnet/ring/weight_cache.py | 6 ++--- 5 files changed, 58 insertions(+), 30 deletions(-) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 36f590f7..c32a8f65 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -1410,12 +1410,12 @@ async def _handle_completion( t_first_token = None nonce = f"chatcmpl-{uuid.uuid4()}" - self.tracer.mark("chat.request.start", { + self.tracer.mark("request.start", { "tokenizer": "", "model": req.model, "temperature": req.temperature, "prompt_tokens": prompt.size, - "nonce": nonce, + "req_id": nonce, "t0": time.perf_counter(), }) @@ -1475,9 +1475,9 @@ async def _handle_completion( else detokenizer.text[: -len(stop_sequence_suffix)] ) - self.tracer.mark("chat.request.end", { + self.tracer.mark("request.end", { "generated_tokens": len(tokens), - "nonce": nonce, + "req_id": nonce, "t0": time.perf_counter(), }) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index b827bcf0..bca9bdc6 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -175,7 +175,7 @@ async def _send_worker(self): with self.tracer.frame("network", "tx") as f: if activation_msg.tx_enq_perf_t and self._profile: f.set("inwait", time.perf_counter() - activation_msg.tx_enq_t) - f.set("nonce", activation_msg.nonce) + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) q_wait_ms = ( time.perf_counter() - activation_msg.tx_enq_perf_t @@ -233,6 +233,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): logger.debug(f"Sending activation") if activation_msg.is_final: with self.tracer.frame("grpc", "send_activation.final") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) try: if self._mode == "offload" and self.window_size > 0: @@ -271,6 +272,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): f.event("reset_api") with self.tracer.frame("grpc", "token_request") as fr: + fr.set("req_id", activation_msg.nonce) fr.set("node", self._instance_name) try: req = shard_api_comm_pb2.TokenRequest( @@ -360,6 +362,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): if self.next_node_stub: with self.tracer.frame("network", "send_activation.next") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) request = activation_msg.to_proto(data) request.timestamp = utc_epoch_now() diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 50744b55..f7554de0 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -77,11 +77,13 @@ def _process_activation(self, activation_msg: ActivationMessage): try: # per-nonce kvcache for concurrent requests with self.tracer.frame("compute.thread", "kvcache.init") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) kv = self._get_or_make_kv(activation_msg.nonce) # Get input activation from pool with self.tracer.frame("compute.thread", "activations.load") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) input_buffer = self.input_pool.get_buffer(activation_msg.pool_id) if input_buffer is None: @@ -90,7 +92,7 @@ def _process_activation(self, activation_msg: ActivationMessage): # Prepare input activation with self.tracer.frame("compute.thread", "activations.process") as f: - f.set("nonce", activation_msg.nonce) + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) if activation_msg.dtype == "tokens": # embed locally on start shard logger.debug(f"Embedding tokens.") @@ -127,6 +129,7 @@ def _process_activation(self, activation_msg: ActivationMessage): did_early_swap = False with self.tracer.frame("compute.thread", "weights.prepare") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) # Determine contiguous local window starting at current_layer @@ -222,6 +225,7 @@ def _process_activation(self, activation_msg: ActivationMessage): # Execute the window with self.tracer.frame("compute.thread", "execute") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) self._beyond_cursor = window_layers[-1] if window_layers else (activation_msg.layer_id) @@ -279,6 +283,7 @@ def _process_activation(self, activation_msg: ActivationMessage): pass with self.tracer.frame("compute.thread", "execute.evict_and_unload") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) try: # Sliding-fit delta swap: maintain a single resident set by evicting @@ -375,6 +380,7 @@ def _process_activation(self, activation_msg: ActivationMessage): # Boundary reached — directly pass tensor to TX to avoid extra copy/sync with self.tracer.frame("compute.thread", "execute.enqueue_prefetch") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) x_cast = x if x.dtype == self._wire_mx_dtype else x.astype(self._wire_mx_dtype) try: @@ -390,6 +396,7 @@ def _process_activation(self, activation_msg: ActivationMessage): # Create and enqueue output message: either forward activations or finalize on end role with self.tracer.frame("compute.thread", "grpc.send") as f: + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) nxt = last_layer + 1 if nxt >= self.model_metadata.num_layers: # End of model @@ -506,10 +513,29 @@ def _process_activation(self, activation_msg: ActivationMessage): # Clean up input resources self.input_pool.release(activation_msg.pool_id) +<<<<<<< HEAD # Optional unload/evict after stage with self.tracer.frame("compute.thread", "cleanup"): if self._mode != "sliding_fit": if self._defer_unload: +======= + # Clean up input resources + with self.tracer.frame("compute.thread", "cleanup") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + self.input_pool.release(activation_msg.pool_id) + # After queuing TX, schedule prefetch and eviction in the background + # to avoid stalling the handoff to the next shard. + try: + self._prefetch_pause.set() + except Exception: + pass + next_window = self._next_local_layers(last_layer, self.window_size) + for nl in next_window: + self._prefetch_to_ram(nl) + self._enqueue_weight_prefetch(nl) + if getattr(self, "_defer_unload", False): +>>>>>>> 6c40e99 (reformat frames) try: while len(self._recent_windows) > max( 1, int(self._resident_windows) diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 589b4c42..7cf02fdc 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -260,12 +260,12 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse and (self.model_path != req.model_path or self.assigned_layers != req.layers) ): logger.info("Node %s: Unloading current model to load new configuration", self.node_id) - with self.tracer.frame("memory", "model.unload") as f: + with self.tracer.frame("memory.model", "unload") as f: f.set("node", self._instance_name) await self.unload_model() # Load model metadata - with self.tracer.frame("memory", "model.load_metadata"): + with self.tracer.frame("memory.model", "load_metadata"): self.model_metadata = get_model_metadata(req.model_path) self.assigned_layers = req.layers @@ -369,7 +369,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # Initialize weight cache - with self.tracer.frame("memory", "weight_cache.init") as f: + with self.tracer.frame("memory.weights", "cache.init") as f: f.set("node", self._instance_name) self.weight_cache = WeightCache( self.assigned_layers, @@ -383,7 +383,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse ) # Load the model - with self.tracer.frame("memory", "model.load") as f: + with self.tracer.frame("memory.model", "load") as f: f.set("node", self._instance_name) self.model = get_ring_model( self.model_metadata.model_type, @@ -408,7 +408,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse logger.warning("[QUANT] apply failed: %s", e) self.model.eval() - with self.tracer.frame("memory", "make_cache") as f: + with self.tracer.frame("memory.cache", "make_cache") as f: f.set("node", self._instance_name) self.cache = make_cache( self.model, @@ -449,7 +449,7 @@ async def load_model(self, req: ShardLoadModelRequest) -> ShardLoadModelResponse self.total_layers = req.total_layers self.api_callback_address = req.api_callback_address - with self.tracer.frame("network", "connect.next_node") as f: + with self.tracer.frame("network.connect", "next_node") as f: f.set("node", self._instance_name) if self.next_node: await self._connect_next_node() @@ -622,12 +622,12 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): return with self.tracer.frame("network.rx", "connect_next_node") as f: - f.set("nonce", request.nonce) + f.set("req_id", request.nonce) f.set("node", self._instance_name) await self._connect_next_node() with self.tracer.frame("network.rx", "process_activation") as f: - f.set("nonce", request.nonce) + f.set("req_id", request.nonce) try: activation = request.activation target_layer = activation.layer_id + 1 @@ -646,7 +646,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): t_alloc = time.perf_counter() if "|" in activation.dtype: with self.tracer.frame("grpc.receive", "decompress") as fr: - fr.set("nonce", request.nonce) + fr.set("req_id", request.nonce) fr.set("node", self._instance_name) try: deq = decompress_tensor_from_protobuf_data( @@ -703,7 +703,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): return with self.tracer.frame("network.rx", "alloc.buffer") as fr: - fr.set("nonce", request.nonce) + fr.set("req_id", request.nonce) fr.set("node", self._instance_name) pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, @@ -728,7 +728,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): else: # Special token stream support: dtype='tokens' carries int32 token IDs if activation.dtype == "tokens": with self.tracer.frame("network.rx", "token_stream") as fr: - fr.set("nonce", request.nonce) + fr.set("req_id", request.nonce) fr.set("node", self._instance_name) try: tokens = np.frombuffer(request.activation.data, dtype=np.int32) @@ -758,7 +758,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): else: with self.tracer.frame("network.ex", "default") as fr: fr.set("node", self._instance_name) - fr.set("nonce", request.nonce) + fr.set("req_id", request.nonce) # Safety: byte length must match shape*dtype try: expected = ( @@ -845,10 +845,9 @@ async def _ingress_worker(self): f.set("inwait", time.perf_counter() - req.rx_enq_t) f.set("inflight", req.rx_inflight_t) f.set("node", self._instance_name) - f.set("nonce", req.nonce) + f.set("req_id", req.nonce) try: - #with self.tracer.frame("network.ingress", "connect_next_node"): await self._connect_next_node() activation = req.activation @@ -863,7 +862,7 @@ async def _ingress_worker(self): logger.error(f"Unable to read length of data for {req.nonce}") payload_bytes = -1 - f.set("nonce", req.nonce) + fr.set("req_id", req.nonce) f.set("target", target_layer) f.set("payload_bytes", payload_bytes) f.event("received") @@ -905,8 +904,8 @@ async def _ingress_worker(self): # Enqueue for compute with self.tracer.frame("network.rx", "enque") as fr: + fr.set("req_id", req.nonce) fr.set("node", self._instance_name) - fr.set("nonce", req.nonce) while self.running: try: self.activation_recv_queue.put_nowait(activation_msg) @@ -996,8 +995,8 @@ def _prepare_activation_message_blocking( if "|" in activation.dtype: # Compressed path: decompress to MLX array and copy to pool with self.tracer.frame("network.rx.prepare_activation", "decompress") as f: + f.set("req_id", request.nonce) f.set("node", self._instance_name) - f.set("nonce", request.nonce) try: deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, @@ -1034,8 +1033,8 @@ def _prepare_activation_message_blocking( elif activation.dtype == "tokens": # Tokens path: parse int32 token IDs and stage them with self.tracer.frame("network.rx.prepare_activation", "tokens") as f: + f.set("req_id", request.nonce) f.set("node", self._instance_name) - f.set("nonce", request.nonce) try: tokens = np.frombuffer(activation.data, dtype=np.int32) shp = (int(len(tokens)),) @@ -1065,8 +1064,8 @@ def _prepare_activation_message_blocking( else: # Dense path: validate size and copy raw bytes view into pool buffer with self.tracer.frame("network.rx.prepare_activation", "default") as f: + f.set("req_id", request.nonce) f.set("node", self._instance_name) - f.set("nonce", request.nonce) try: expected = ( int(np.prod(activation.shape)) @@ -1137,7 +1136,7 @@ def _compute_worker(self) -> None: # Process the activation with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats - f.set("nonce", activation_msg.nonce) + f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) if activation_msg.ex_enq_t == 0.0: # FIXME float comparison f.set("inwait", 0.0) @@ -1596,8 +1595,8 @@ async def warm(request: Request) -> JSONResponse: try: # FIXME: Append warmup config? Or something to distinguish with self.tracer.frame("memory", "model.warm") as f: # NOTE: Symbol hardcoded for runtime stats + f.set("req_id", request.nonce) f.set("node", self._instance_name) - f.set("nonce", request.nonce) body = await request.json() start = int(body.get("start", -1)) window = int(body.get("window", self.window_size)) diff --git a/src/dnet/ring/weight_cache.py b/src/dnet/ring/weight_cache.py index fcb5f53a..2a9a22b0 100644 --- a/src/dnet/ring/weight_cache.py +++ b/src/dnet/ring/weight_cache.py @@ -76,7 +76,7 @@ def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[s ) return data - with self.tracer.frame("weights.cache", "search") as f: + with self.tracer.frame("memory.weights", "cache.search") as f: with self.lock: if layer_id in self.cache: data, _ = self.cache[layer_id] @@ -100,7 +100,7 @@ def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[s creator = False if creator: # Perform the blocking load without holding the cache lock - with self.tracer.frame("weights.cache", "load") as f: + with self.tracer.frame("memory.weights", "cache.load") as f: try: data = self.layer_manager.load_layer_to_gpu(layer_id) f.event("load") @@ -139,7 +139,7 @@ def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[s return None else: # Not the creator: wait for the in-flight load to complete - with self.tracer.frame("weights.cache", "wait") as f: + with self.tracer.frame("memory.weights", "cache.wait") as f: try: inflight.result() # block until the creator completes except Exception as e: From ba431132234e6dd9f6e892bb91f71451df0e9d6e Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 04:00:02 -0700 Subject: [PATCH 142/172] better sorting --- src/dnet/perf/utils/aggregators.py | 258 ++++++++++++----------------- 1 file changed, 108 insertions(+), 150 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index b5753589..be627a0d 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -216,7 +216,7 @@ class ReqStats: model: str = "" # Model name tokenizer: str = "" # Tokenizer name run_id: str = "" # ID of session (for later mapping) - nonce: str = "" # List of serviced requests + req_id: str = "" # List of serviced requests ttft: float = 0.0 # Time to first token itl: List[float] = None # Inter-token latency per round prompt_tokens: int = -1 # Number of prompt tokens per request (req_id: #) @@ -253,171 +253,129 @@ def __init__(self) -> None: self._max_inflight_req = 20 # per node FIXME: modify from repl # Frames are kept in here while in-flight, then remove the frame objects and append to _stats - self._frames: Dict[str, Dict[str, Dict[str, Any]]] = {} # Store frames per nonce, per node_id + self._frames: Dict[str, Dict[str, Dict[str, Any]]] = {} # Store frames per req_id, per node_id - self._nonces: List[str] = [] # Tracked nonces (either in-flight or done) - self._nonce_round_finish: Dict[str, bool] = {} # Track in-flight rounds - self._nonce_prefill: Dict[str, bool] = {} # Track if this round is prefill + self._req: List[str] = [] # Tracked requests (in-flight or done) + self._req_round_finish: Dict[str, bool] = {} # Track in-flight requests + self._req_prefill: Dict[str, bool] = {} # Track if this request round is prefill + self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per request) + + # Staging environment for events that arrive before + # the request.start of the request they belong to + self._staging: Dict[str, Any] = {} + + # Finished and running requests self._running_stats: Dict[str, ReqStats] = {} # Unfinished stat frames self._stats: Dict[str, ReqStats] = {} # Finished stat frames - self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per nonce) - self._model_per_run: Dict[str, str] = {} # Track model per run_id - - self.nodes = [] # Keep track of active nodes - - # Maps of frames to higher-level sub-systems - self._compute_set = [ - "compute.forward", - "compute.thread.kvcache.init", - "compute.thread.weights.prepare", - "compute.thread.activations.process", - "compute.thread.activations.load", - "compute.thread.execute", - "compute.thread.execute.enqueue_prefetch", - "compute.thread.execute.evict_and_unload", - "compute.thread.cleanup", - "compute.thread.mdns.send", - ] - self._network_set = [ - "network.tx", - "network.token_request", - "network.rx.prepare", - "network.rx.prepare_activation.tokens", - "network.rx.enque", - "network.send_activation.final", - "network.rx", - "network.connect.next_node", - "network.rx.prefetch", - ] - - self._memory_set = [ - "memory.model.load", - "memory.model.load_metadata", - "memory.warmup", - "memory.weight_cache.init", - "memory.prefetch", - "memory.memory_pools.init", - "memory.cache.reset", - "memory.make_cache", - ] + self.nodes = [] # Keep track of active nodes in the network # Ingest raw data from tracer def add(self, data: Dict[str, Any]) -> None: - run_id = data["run_id"] - node_id = data["node_id"] events = data["events"] or [] - - if not run_id or not node_id: - print("Dropped batch") - return # Drop the batch + if not events: return # Nothing to do with self._lock: - # Ensure we register workers and nodes - for i, ev in enumerate(events): - if "nonce" not in ev["args"]: ev["args"]["nonce"] = f"N_" - nonce = ev["args"]["nonce"] - - if node_id not in self._frames: - self._frames[node_id] = {} - - if nonce not in self._frames[node_id]: - self._frames[node_id][nonce] = {} - - if len(self._frames[node_id]) >= self._max_inflight_req: # remove oldest entry - del self._frames[self._nonces[0]] - del self._nonces[0] - - if nonce not in self._nonces: - self._nonces.append(nonce) # Update in-flight events or register new ones - for e in events: - nonce = e["args"]["nonce"] - assert nonce is not None, "" - - if not node_id or not nonce: return # Drop invalid frames - - if e["name"] == "chat.request.start": - self._open_frames[nonce] = {} - self._nonce_prefill[nonce] = True - self._running_stats[nonce] = ReqStats( - model=e["args"]["model"], - tokenizer=e["args"]["tokenizer"], - run_id=run_id, - nonce=nonce, - ttft= e["args"]["t0"], - itl=[ e["args"]["t0"], ], - prompt_tokens=e["args"]["prompt_tokens"], - total_tokens=e["args"]["prompt_tokens"], - latencies={}, - latency_per_layer={}, - latency_per_shard={}, - assignment=None, - compute_per_worker={}, - network_per_worker={}, - memory_per_worker={}, - ) - - # FIXME: We might receive other frames then "embed" from shards - # so we need to handle the creation of this better - if nonce not in self._running_stats: - continue - - stats = self._running_stats[nonce] - - if "node" not in e["args"]: - if e["name"] == "chat.request.end": - print(f"{e}") - st_obj = self._running_stats[nonce] - st_obj.generated_tokens = e["args"]["generated_tokens"] - st_obj.total_tokens += e["args"]["generated_tokens"] - print("Adding to stats") - self._stats[nonce] = st_obj - del self._running_stats[nonce] - #del self._frames[node_id][nonce] - # TODO: Handle latency of transfer back to API - - else: - continue # Drop frames without "node" - - node_id = e["args"]["node"] - if node_id not in self.nodes: - self.nodes.append(node_id) - stats.compute_per_worker[node_id] = 0.0 - stats.network_per_worker[node_id] = 0.0 - stats.memory_per_worker[node_id] = 0.0 - - if e["name"] in self._memory_set: - stats.memory_per_worker[node_id] += e["args"]["ms"] - - if e["name"] == "network.rx": # Time in transport, ingress queue and ingress_worker - _cost = lambda e: e["args"]["inflight"] + e["args"]["inwait"] + e["args"]["ms"] - #TODO: change shard in metadata - - if e["name"] == "compute.forward": - try: - _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # compute queue + execution - self._handle_round(e, nonce, stats, _cost) - except Exception as e: - print(f"{e}") - - if e["name"] in self._compute_set: # Aggregate for compute total - stats.compute_per_worker[node_id] += e["args"]["ms"] - - if e["name"] in self._network_set: - stats.network_per_worker[node_id] += e["args"]["ms"] + for i, e in enumerate(events): + symbol = e["name"].split(".") + + req_id = e["args"].get("req_id") + if not req_id: + print(f"Dropping {e["name"]}: {e["args"]}") + continue # Drop anonymous frames + + if symbol[0] == "request": + if symbol[1] == "start": # Start request, setup buffers and ingest staged frames + self._req.append(req_id) + self._open_frames[req_id] = {} + self._req_prefill[req_id] = True + stats = ReqStats( + model=e["args"]["model"], + tokenizer=e["args"]["tokenizer"], + req_id=req_id, + ttft= e["args"]["t0"], + itl=[ e["args"]["t0"], ], + prompt_tokens=e["args"]["prompt_tokens"], + total_tokens=e["args"]["prompt_tokens"], + latencies={}, + latency_per_layer={}, + latency_per_shard={}, + assignment=None, + compute_per_worker={}, + network_per_worker={}, + memory_per_worker={}, + ) + self._running_stats[req_id] = stats + + # Process all events in staging + if req_id in self._staging: + for pe in self._staging[req_id]: + node_id = e["args"].get("node_id") + if not node_id: continue + self._process_frame(pe, req_id, node_id, stats) + + del self._staging[req_id] + continue + + elif symbol[1] == "end": # Finish processing request + st_obj = self._running_stats[req_id] + st_obj.generated_tokens = e["args"]["generated_tokens"] + st_obj.total_tokens += e["args"]["generated_tokens"] + self._stats[req_id] = st_obj + del self._running_stats[req_id] + # TODO: Handle latency of transfer back to API + continue + + if req_id in self._stats: continue # Already finidhed processing request + if req_id not in self._req: # If unknown request, stage frames + if req_id not in self._staging: + self._staging[req_id] = [] + self._staging[req_id].append(e) + continue + + node_id = e["args"].get("node_id") + if not node_id: return # Drop unknown node + + stats = self._running_stats[req_id] + self._process_frame(e, req_id, node_id, stats) + + + def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): + if node_id not in self.nodes: + self.nodes.append(node_id) + stats.compute_per_worker[node_id] = 0.0 + stats.network_per_worker[node_id] = 0.0 + stats.memory_per_worker[node_id] = 0.0 + + if symbol[0] == "compute": + if symbol[1] == "forward": + try: + _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] + self._handle_round(e, req_id, stats, _cost) # compute queue + execution + print(f"TTFT: {stats.ttft}") + except Exception as e: + print(f"{e}") + stats.compute_per_worker[node_id] += e["args"]["ms"] + + elif symbol[0] == "network": + if symbol[1] == "rx": # Time in transport, ingress queue and ingress_worker + _cost = lambda e: e["args"]["inflight"] + e["args"]["inwait"] + e["args"]["ms"] + #TODO: change shard in metadata + stats.network_per_worker[node_id] += e["args"]["ms"] + + elif symbol[0] == "memory": + stats.memory_per_worker[node_id] += e["args"]["ms"] + return - if e["name"] in self._memory_set: - stats.memory_per_worker[node_id] += e["args"]["ms"] # Handle cost aggregation of frames - def _handle_round(self, e: Any, nonce, stats: ReqStats, _cost_fnc: Any): + def _handle_round(self, e: Any, req_id, stats: ReqStats, _cost_fnc: Any): try: - if self._nonce_prefill[nonce]: - logger.error(f"TTFT: {stats.ttft}") + if self._req_prefill[req_id]: stats.ttft = (e["args"]["t0"] - stats.ttft) * 1000.0 - self._nonce_prefill[nonce] = False + self._req_prefill[req_id] = False else: if e["args"]["t0"] > 0.0: stats.itl[-1] = (e["args"]["t0"] - stats.itl[-1]) @@ -457,7 +415,7 @@ def stats( return stats = self._stats[list(self._stats.keys())[-1]] #sys.stdout.write(f"\n Loaded model '{stats.model}'.\n") - sys.stdout.write(f"Performance stats for request '{stats.nonce}':\n\n") + sys.stdout.write(f"Performance stats for request '{stats.req_id}':\n\n") try: for tag, n, unit in fields: if tag == 0: # Native counter From 9d6e4fffff48c9148ef2953804f6ea672c55e35b Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 11:56:39 -0700 Subject: [PATCH 143/172] use epoch for t0 --- src/dnet/perf/trace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 48f6e49b..4037d987 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -51,7 +51,7 @@ def __init__(self, tracer: "Tracer", name: str, attrs: Optional[Dict[str, Any]]) self.attrs = dict(attrs or {}) self._t0 = 0.0 def __enter__(self): - self._t0 = time.perf_counter() + self._t0 = time.time_ns() # cross-node timekeeping self.t._emit({"type": "B", "name": self.name, "args": dict(self.attrs)}) return self def __exit__(self, ex_type, ex, tb): @@ -96,7 +96,7 @@ def stop_aggregator(self, *, flush: bool = True, timeout: float = 5.0) -> None: if flush and self._events: try: self._agg_q.put_nowait({ - "run_id": (self._req_id or "run"), + "req_id": (self._req_id or "run"), "node_id": (self.config.node_id or "node"), "events": list(self._events), }) except queue.Full: From a05eb75a95c95d054de70419e62f6907c6b6fd96 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 11:57:18 -0700 Subject: [PATCH 144/172] add stats nodes and fix node registration --- src/dnet/perf/utils/aggregators.py | 33 ++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index be627a0d..e64deb88 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -216,12 +216,13 @@ class ReqStats: model: str = "" # Model name tokenizer: str = "" # Tokenizer name run_id: str = "" # ID of session (for later mapping) - req_id: str = "" # List of serviced requests + req_id: str = "" # List of serviced requests ttft: float = 0.0 # Time to first token itl: List[float] = None # Inter-token latency per round prompt_tokens: int = -1 # Number of prompt tokens per request (req_id: #) generated_tokens: int = -1 # Number of generated tokens per request (req_id: #) total_tokens: int = -1 # Total number of tokens processed + nodes: List[str] = None # Nodes that participated in computation latencies: List[List[str, str, str, int]] = None # List of inter-node latencies: [node0, node1, p50, 0.0] latency_per_layer: Dict[int, float] = None # Map of {layer: 0.0} @@ -258,7 +259,7 @@ def __init__(self) -> None: self._req: List[str] = [] # Tracked requests (in-flight or done) self._req_round_finish: Dict[str, bool] = {} # Track in-flight requests self._req_prefill: Dict[str, bool] = {} # Track if this request round is prefill - self._open_frames: Dict[str, Dict[str, Any]] = {} # We got 'B' event but not 'E' (per request) + self._open_frames: Dict[str, Dict[str, Dict[str, Any]]] = {} # Staging environment for events that arrive before # the request.start of the request they belong to @@ -275,15 +276,24 @@ def add(self, data: Dict[str, Any]) -> None: events = data["events"] or [] if not events: return # Nothing to do + node_id = data.get("node_id") + if not node_id: return # Drop unknown node + with self._lock: # Update in-flight events or register new ones for i, e in enumerate(events): symbol = e["name"].split(".") + if e["type"] == 'B': + req_id = data.get("req_id") + if not req_id or not node_id: continue + self._open_frames[req_id][node_id][e["name"]] = e + continue + req_id = e["args"].get("req_id") if not req_id: - print(f"Dropping {e["name"]}: {e["args"]}") + #print(f"Dropping {e}") continue # Drop anonymous frames if symbol[0] == "request": @@ -306,6 +316,7 @@ def add(self, data: Dict[str, Any]) -> None: compute_per_worker={}, network_per_worker={}, memory_per_worker={}, + nodes=[], ) self._running_stats[req_id] = stats @@ -335,26 +346,27 @@ def add(self, data: Dict[str, Any]) -> None: self._staging[req_id].append(e) continue - node_id = e["args"].get("node_id") - if not node_id: return # Drop unknown node + #node_id = e["args"].get("node_id") + #if not node_id: return # Drop unknown node stats = self._running_stats[req_id] self._process_frame(e, req_id, node_id, stats) def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): - if node_id not in self.nodes: - self.nodes.append(node_id) + symbol = e["name"].split(".") + if node_id not in self.nodes: self.nodes.append(node_id) + if node_id not in stats.nodes: + stats.nodes.append(node_id) stats.compute_per_worker[node_id] = 0.0 stats.network_per_worker[node_id] = 0.0 stats.memory_per_worker[node_id] = 0.0 - + if symbol[0] == "compute": if symbol[1] == "forward": try: _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] self._handle_round(e, req_id, stats, _cost) # compute queue + execution - print(f"TTFT: {stats.ttft}") except Exception as e: print(f"{e}") stats.compute_per_worker[node_id] += e["args"]["ms"] @@ -375,11 +387,14 @@ def _handle_round(self, e: Any, req_id, stats: ReqStats, _cost_fnc: Any): try: if self._req_prefill[req_id]: stats.ttft = (e["args"]["t0"] - stats.ttft) * 1000.0 + print(f"TTFT: {stats.ttft}") self._req_prefill[req_id] = False else: if e["args"]["t0"] > 0.0: stats.itl[-1] = (e["args"]["t0"] - stats.itl[-1]) + print(f"ITL: {e["args"]["t0"]} - {stats.itl[-1]}") stats.itl.append(e["args"]["t0"]) + print(f"ITL: {stats.itl[-1]}") except Exception as ex: print(f"{ex}") From d8ce1756e448f37dcf9055abbaabc6ebbefed9ae Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 12:01:06 -0700 Subject: [PATCH 145/172] always register t0 --- src/dnet/perf/trace.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 4037d987..84e90a60 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -52,11 +52,12 @@ def __init__(self, tracer: "Tracer", name: str, attrs: Optional[Dict[str, Any]]) self._t0 = 0.0 def __enter__(self): self._t0 = time.time_ns() # cross-node timekeeping + self.attrs.update({"t0": self._t0}) self.t._emit({"type": "B", "name": self.name, "args": dict(self.attrs)}) return self def __exit__(self, ex_type, ex, tb): - dt_ms = (time.perf_counter() - self._t0) * 1000.0 - self.attrs.update({"ms": round(dt_ms, 3), "exc": bool(ex)}) + dt_ms = (time.time_ns() - self._t0) + self.attrs.update({"ms": round(dt_ms, 3), "exc": bool(ex), "t0": time.time_ns()}) self.t._emit({"type": "E", "name": self.name, "args": self.attrs}) return False def event(self, name: str, **attrs): From 2b44848b035a6769bcc58de47b70f64290580dae Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 13:20:21 -0700 Subject: [PATCH 146/172] mark round --- src/dnet/ring/api/node.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index c32a8f65..6794492a 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -1437,6 +1437,8 @@ async def _handle_completion( ), # type: ignore arange(req.max_tokens or 0), ): + self.tracer.mark("request.round", {"req_id": nonce,"t0": time.time_ns()}) + if profile_enabled and t_first_token is None: t_first_token = time.perf_counter() detokenizer.add_token(token) From 563160c02217806d49a66c06a583752e05ca6345 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:09:16 -0700 Subject: [PATCH 147/172] don't track queue wait --- src/dnet/ring/shard/comms.py | 2 +- src/dnet/ring/shard/node.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index bca9bdc6..11a53fe5 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -174,7 +174,7 @@ async def _send_worker(self): activation_msg = await self.activation_computed_queue.get() with self.tracer.frame("network", "tx") as f: if activation_msg.tx_enq_perf_t and self._profile: - f.set("inwait", time.perf_counter() - activation_msg.tx_enq_t) + f.set("inwait", (time.perf_counter() - activation_msg.tx_enq_t)*1000) f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) q_wait_ms = ( diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 7cf02fdc..48c49ed8 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -1131,8 +1131,7 @@ def _compute_worker(self) -> None: while self.running: try: # Get activation from queue (blocks until available) - with self.tracer.frame("compute", "deque.wait"): - activation_msg = self.activation_recv_queue.get(timeout=1.0) + activation_msg = self.activation_recv_queue.get(timeout=1.0) # Process the activation with self.tracer.frame("compute", "forward") as f: # NOTE: Symbol hardcoded for runtime stats From 013ed18af744f3bf5006d6b8f18b939e517f900b Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:10:00 -0700 Subject: [PATCH 148/172] break down tx.enque into correct compute-network frames --- src/dnet/ring/shard/compute.py | 77 +++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 30 deletions(-) diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index f7554de0..67f035aa 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -395,11 +395,13 @@ def _process_activation(self, activation_msg: ActivationMessage): pass # Create and enqueue output message: either forward activations or finalize on end role - with self.tracer.frame("compute.thread", "grpc.send") as f: + with self.tracer.frame("network.tx", "send") as f: f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) - nxt = last_layer + 1 - if nxt >= self.model_metadata.num_layers: # End of model + + nxt = last_layer + 1 + if nxt >= self.model_metadata.num_layers: # End of model + with self.tracer.frame("compute.thread", "sampling") as f: try: with self._mlx_lock: y = self.model.normalize(x_cast) @@ -492,33 +494,35 @@ def _process_activation(self, activation_msg: ActivationMessage): except Exception: output_msg.tx_enq_perf_t = 0.0 # Enqueue to appropriate asyncio TX queue from compute thread - try: - if self._loop is not None: - target_q = ( - self.activation_token_queue - if output_msg.is_final - else self.activation_computed_queue - ) - fut = asyncio.run_coroutine_threadsafe( - target_q.put(output_msg), self._loop + with self.tracer.frame("network.tx", "enque") as f: + output_msg.tx_enq_t = time.perf_counter() + try: + if self._loop is not None: + target_q = ( + self.activation_token_queue + if output_msg.is_final + else self.activation_computed_queue + ) + fut = asyncio.run_coroutine_threadsafe( + target_q.put(output_msg), self._loop + ) + fut.result() + else: + raise RuntimeError("Event loop not available for TX queue") + except Exception as e: + logger.error( + "Failed to queue computed activation for sending: %s", e ) - fut.result() - else: - raise RuntimeError("Event loop not available for TX queue") - except Exception as e: - logger.error( - "Failed to queue computed activation for sending: %s", e - ) - # Clean up input resources - self.input_pool.release(activation_msg.pool_id) + # Clean up input resources + self.input_pool.release(activation_msg.pool_id) -<<<<<<< HEAD # Optional unload/evict after stage with self.tracer.frame("compute.thread", "cleanup"): + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) if self._mode != "sliding_fit": if self._defer_unload: -======= # Clean up input resources with self.tracer.frame("compute.thread", "cleanup") as f: f.set("req_id", activation_msg.nonce) @@ -527,15 +531,28 @@ def _process_activation(self, activation_msg: ActivationMessage): # After queuing TX, schedule prefetch and eviction in the background # to avoid stalling the handoff to the next shard. try: - self._prefetch_pause.set() + while len(self._recent_windows) > max(1, int(getattr(self, "_resident_windows", 2))): + old = self._recent_windows.pop(0) + try: + evicted_cnt = self.weight_cache.evict_layers(old) + except Exception: + evicted_cnt = 0 + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) + except Exception: + pass + except Exception: + pass + if getattr(self, "_resident_windows", 2) <= 1: + try: + evicted = self.weight_cache.evict_layers(window_layers) + if hasattr(self.mode, "unload_layers"): + self.model.unload_layers(window_layers) except Exception: pass - next_window = self._next_local_layers(last_layer, self.window_size) - for nl in next_window: - self._prefetch_to_ram(nl) - self._enqueue_weight_prefetch(nl) - if getattr(self, "_defer_unload", False): ->>>>>>> 6c40e99 (reformat frames) try: while len(self._recent_windows) > max( 1, int(self._resident_windows) From d7149c059df046e08ac742ffd8db16949fb4005b Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:10:49 -0700 Subject: [PATCH 149/172] aggregate compound subsytem metrics correctly per node --- src/dnet/perf/utils/aggregators.py | 51 +++++++++++++++--------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index e64deb88..44ae1b60 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -223,6 +223,7 @@ class ReqStats: generated_tokens: int = -1 # Number of generated tokens per request (req_id: #) total_tokens: int = -1 # Total number of tokens processed nodes: List[str] = None # Nodes that participated in computation + _rounds_t0: List[int] = None # Keep round times for post-processing latencies: List[List[str, str, str, int]] = None # List of inter-node latencies: [node0, node1, p50, 0.0] latency_per_layer: Dict[int, float] = None # Map of {layer: 0.0} @@ -305,8 +306,8 @@ def add(self, data: Dict[str, Any]) -> None: model=e["args"]["model"], tokenizer=e["args"]["tokenizer"], req_id=req_id, - ttft= e["args"]["t0"], - itl=[ e["args"]["t0"], ], + ttft=0.0, + itl=[], prompt_tokens=e["args"]["prompt_tokens"], total_tokens=e["args"]["prompt_tokens"], latencies={}, @@ -317,6 +318,7 @@ def add(self, data: Dict[str, Any]) -> None: network_per_worker={}, memory_per_worker={}, nodes=[], + _rounds_t0=[], ) self._running_stats[req_id] = stats @@ -339,6 +341,10 @@ def add(self, data: Dict[str, Any]) -> None: # TODO: Handle latency of transfer back to API continue + elif symbol[1] == "round": + stats = self._running_stats[req_id] + stats._rounds_t0.append(e["args"]["t0"]) + if req_id in self._stats: continue # Already finidhed processing request if req_id not in self._req: # If unknown request, stage frames if req_id not in self._staging: @@ -364,12 +370,11 @@ def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): if symbol[0] == "compute": if symbol[1] == "forward": - try: - _cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] - self._handle_round(e, req_id, stats, _cost) # compute queue + execution - except Exception as e: - print(f"{e}") + pass + #_cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] + # Defer stats compute until after we sort the times (async is kil) stats.compute_per_worker[node_id] += e["args"]["ms"] + print(f"COMPUTE_PER_WORKER: {e["name"]} : {stats.compute_per_worker}") elif symbol[0] == "network": if symbol[1] == "rx": # Time in transport, ingress queue and ingress_worker @@ -381,22 +386,17 @@ def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): stats.memory_per_worker[node_id] += e["args"]["ms"] return - - # Handle cost aggregation of frames - def _handle_round(self, e: Any, req_id, stats: ReqStats, _cost_fnc: Any): - try: - if self._req_prefill[req_id]: - stats.ttft = (e["args"]["t0"] - stats.ttft) * 1000.0 - print(f"TTFT: {stats.ttft}") - self._req_prefill[req_id] = False - else: - if e["args"]["t0"] > 0.0: - stats.itl[-1] = (e["args"]["t0"] - stats.itl[-1]) - print(f"ITL: {e["args"]["t0"]} - {stats.itl[-1]}") - stats.itl.append(e["args"]["t0"]) - print(f"ITL: {stats.itl[-1]}") - except Exception as ex: - print(f"{ex}") + def _compute_round_stats(self, stats): + rounds = stats._rounds_t0 + #rounds.sort() + assert len(rounds) > 2, "Not enough data." + stats.ttft = (rounds[1] - rounds[0]) * 1e-6 + stats.itl.append(rounds[1]) + for i in range(1, len(rounds)): + stats.itl[-1] = (rounds[i] - rounds[i-1]) * 1e-6 + stats.itl.append(rounds[i]) + stats.itl = stats.itl[:-1] + print(stats.itl) # Return data for total, per req, worker or model (maybe add per layer too?) def stats( @@ -429,6 +429,7 @@ def stats( print("No tracked stats in memory. Track a request first.\n") return stats = self._stats[list(self._stats.keys())[-1]] + self._compute_round_stats(stats) #sys.stdout.write(f"\n Loaded model '{stats.model}'.\n") sys.stdout.write(f"Performance stats for request '{stats.req_id}':\n\n") try: @@ -453,11 +454,11 @@ def stats( elif tag == 1: match n: case "tokens_per_second": - tps = [ 1 / rt for rt in stats.itl ] + tps = [ 1 / (rt/1000) for rt in stats.itl ] #median = statistics.median(tps) mean = sum(tps) / len(tps) sys.stdout.write(f"{mean:.2f}".rjust(20)+" tok/s".rjust(5)+" \ttokens_per_second") - sys.stdout.write(f"\t# {statistics.median(stats.itl):.3f} s/tok\n") + sys.stdout.write(f"\t# {statistics.median(stats.itl)/1000:.3f} s/tok\n") case "inter_token_latency": assert len(stats.itl) > 1, "Not enough trace frames" From a0780b12eb320ccc585791c4b1c7a606d4433c80 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:11:23 -0700 Subject: [PATCH 150/172] fix ms scaling --- src/dnet/perf/trace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dnet/perf/trace.py b/src/dnet/perf/trace.py index 84e90a60..8afcca47 100644 --- a/src/dnet/perf/trace.py +++ b/src/dnet/perf/trace.py @@ -56,7 +56,7 @@ def __enter__(self): self.t._emit({"type": "B", "name": self.name, "args": dict(self.attrs)}) return self def __exit__(self, ex_type, ex, tb): - dt_ms = (time.time_ns() - self._t0) + dt_ms = (time.time_ns() - self._t0) * 1e-6 self.attrs.update({"ms": round(dt_ms, 3), "exc": bool(ex), "t0": time.time_ns()}) self.t._emit({"type": "E", "name": self.name, "args": self.attrs}) return False From 547543d4e8c2eb3dfdfaac687b6cd0270be1db0d Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:37:34 -0700 Subject: [PATCH 151/172] rename old grpc frames to network --- src/dnet/ring/shard/comms.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index 11a53fe5..1725fd4f 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -232,7 +232,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): try: logger.debug(f"Sending activation") if activation_msg.is_final: - with self.tracer.frame("grpc", "send_activation.final") as f: + with self.tracer.frame("network", "send_activation.final") as f: f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) try: @@ -271,7 +271,7 @@ async def _send_activation(self, activation_msg: ActivationMessage): self.api_stub = shard_api_comm_pb2_grpc.ShardApiServiceStub( self.api_channel) f.event("reset_api") - with self.tracer.frame("grpc", "token_request") as fr: + with self.tracer.frame("network", "token_request") as fr: fr.set("req_id", activation_msg.nonce) fr.set("node", self._instance_name) try: @@ -301,13 +301,16 @@ async def _send_activation(self, activation_msg: ActivationMessage): # FIXME: shaped var is a bit weird (is it np_array or mlx_array), @andthattoo shall check shaped = activation_msg.tensor - if shaped is None: - output_buffer = self.output_pool.get_buffer(activation_msg.pool_id) - if output_buffer is None: - logger.error( - "Failed to get output buffer %s", activation_msg.pool_id - ) - return + with self.tracer.frame("gprc.send_activations.default", "get_buffer") as fr: + fr.set("req_id", activation_msg.nonce) + fr.set("node", self._instance_name) + if shaped is None: + output_buffer = self.output_pool.get_buffer(activation_msg.pool_id) + if output_buffer is None: + logger.error( + "Failed to get output buffer %s", activation_msg.pool_id + ) + return if self._profile: logger.info( From 59243839903f7b1bd039030caa0fd9236cc8f297 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:38:09 -0700 Subject: [PATCH 152/172] correctly aggregate global memory use per node --- src/dnet/perf/utils/aggregators.py | 31 ++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index 44ae1b60..23d8c5ee 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -261,6 +261,7 @@ def __init__(self) -> None: self._req_round_finish: Dict[str, bool] = {} # Track in-flight requests self._req_prefill: Dict[str, bool] = {} # Track if this request round is prefill self._open_frames: Dict[str, Dict[str, Dict[str, Any]]] = {} + self._global_memory_per_worker: Dict[str, float] = {} # Staging environment for events that arrive before # the request.start of the request they belong to @@ -294,7 +295,10 @@ def add(self, data: Dict[str, Any]) -> None: req_id = e["args"].get("req_id") if not req_id: - #print(f"Dropping {e}") + if symbol[0] == "memory": # Global memory frames are not request-based + if node_id not in self._global_memory_per_worker: + self._global_memory_per_worker[node_id] = 0.0 + self._global_memory_per_worker[node_id] += e["args"]["ms"] continue # Drop anonymous frames if symbol[0] == "request": @@ -374,7 +378,6 @@ def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): #_cost = lambda e: e["args"]["inwait"] + e["args"]["ms"] # Defer stats compute until after we sort the times (async is kil) stats.compute_per_worker[node_id] += e["args"]["ms"] - print(f"COMPUTE_PER_WORKER: {e["name"]} : {stats.compute_per_worker}") elif symbol[0] == "network": if symbol[1] == "rx": # Time in transport, ingress queue and ingress_worker @@ -383,7 +386,10 @@ def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): stats.network_per_worker[node_id] += e["args"]["ms"] elif symbol[0] == "memory": + print(f"MEMORY_PER_WORKER: {e["name"]} : {stats.memory_per_worker}") stats.memory_per_worker[node_id] += e["args"]["ms"] + else: + print(f"UNTRACKED: {e["name"]}") return def _compute_round_stats(self, stats): @@ -477,14 +483,19 @@ def stats( pass for i, n in enumerate(self.nodes): - comp = stats.compute_per_worker[n] - net = stats.network_per_worker[n] - mem = stats.memory_per_worker[n] - total = comp + net + mem - sys.stdout.write(f"\n node{i} [{n}]\n") - sys.stdout.write(f"{comp:.2f}".rjust(20)+"ms".ljust(5)+f"\tcompute_time # [{(comp/total)*100:0.2f}%]\n") - sys.stdout.write(f"{net:.2f}".rjust(20)+"ms".ljust(5)+f"\tnetwork_time # [{(net/total)*100:0.2f}%]\n") - sys.stdout.write(f"{mem:.2f}".rjust(20)+"ms".ljust(5)+f"\tmemory_time # [{(mem/total)*100:0.2f}%]\n") + try: + comp = stats.compute_per_worker[n] + net = stats.network_per_worker[n] + req_mem = stats.memory_per_worker[n] + g_mem = self._global_memory_per_worker[n] + mem = req_mem + g_mem + total = comp + net + mem + sys.stdout.write(f"\n node{i} [{n}]\n") + sys.stdout.write(f"{comp:.2f}".rjust(20)+"ms".ljust(5)+f"\tcompute_time # [{(comp/total)*100:0.2f}%]\n") + sys.stdout.write(f"{net:.2f}".rjust(20)+"ms".ljust(5)+f"\tnetwork_time # [{(net/total)*100:0.2f}%]\n") + sys.stdout.write(f"{mem:.2f}".rjust(20)+"ms".ljust(5)+f"\tmemory_time # [{(mem/total)*100:0.2f}%]\n") + except Exception as e: + print(f"{e}") except Exception as e: logger.error(f"{e}") From 561accddc836cab80acd53b6f6b07f9c5b8b3d38 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 17:41:31 -0700 Subject: [PATCH 153/172] request.round continue --- src/dnet/perf/utils/aggregators.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index 23d8c5ee..a75518a1 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -348,6 +348,7 @@ def add(self, data: Dict[str, Any]) -> None: elif symbol[1] == "round": stats = self._running_stats[req_id] stats._rounds_t0.append(e["args"]["t0"]) + continue if req_id in self._stats: continue # Already finidhed processing request if req_id not in self._req: # If unknown request, stage frames From f13c0aad0b12bc7d41159d2c0602c8c92219cb07 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 21:04:09 -0700 Subject: [PATCH 154/172] update signature and unload --- src/dnet/ring/model/llama3.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/dnet/ring/model/llama3.py b/src/dnet/ring/model/llama3.py index 81626d51..978a7fe5 100644 --- a/src/dnet/ring/model/llama3.py +++ b/src/dnet/ring/model/llama3.py @@ -18,7 +18,8 @@ def __init__( self, model_config: Any, assigned_layers: Optional[List[int]] = [], - is_api_layer: bool = False + is_api_layer: bool = False, + shard_config: Optional[Any] = None, ): super().__init__() @@ -154,7 +155,6 @@ def apply_single_layer( layer = self.layers[local_idx] ret = self.layers[local_idx](x, mask, cache[local_idx] if local_idx < len(cache) else None) - logger.debug(f"Executed layer:{layer_idx} with output shape: {ret.shape}") return ret def load_weights(self, weights, strict=False): @@ -199,3 +199,15 @@ def load_weights(self, weights, strict=False): logger.error(f"Failed to load weights: {e}") logger.error(f"Weight keys: {list(shard_weights.keys())}") raise + + def unload_layers(self, layers: List[int]): + for l in layers: + local = self.abs_to_local[l] + for name, mod in self.layers[local].named_modules(): + if name in ['self_attn', 'mlp']: + for pname in mod.parameters(): + setattr(mod, pname, None) + logger.debug(f"Unloaded {pname}") + elif name in ['input_layernorm', 'post_attention_layernorm']: + mod.weight = None + logger.debug(f"Unloaded {name}") From 88ce17f1158059d669e02931e8f62a4584bbd9b8 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 26 Oct 2025 21:41:40 -0700 Subject: [PATCH 155/172] force quantization field in model config (mlx_lm doesn't have it) --- src/dnet/ring/model/llama3.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/dnet/ring/model/llama3.py b/src/dnet/ring/model/llama3.py index 978a7fe5..0ee26280 100644 --- a/src/dnet/ring/model/llama3.py +++ b/src/dnet/ring/model/llama3.py @@ -27,6 +27,7 @@ def __init__( raise RuntimeError(f"API Service doesn't handle layers") self.config = ModelArgs.from_dict(model_config) + self.config.quantization = model_config["quantization"] # lmao self.is_api_layer = is_api_layer self._converted_to_quantized = False @@ -87,13 +88,15 @@ def lm_project(self, x: mx.array): def quantize_layers(self): self.quantization = None + logger.debug(f"{self.config}") if hasattr(self.config, "quantization"): self.quantization = getattr(self.config, "quantization") elif hasattr(self.config, "quantization_config"): self.quantization = getattr(self.config, "quantization_config") + logger.debug(f"QUANTIZING {self.quantization}") if self.quantization is not None: - bits = int(self.quantization.get("bits", 8)) + bits = int(self.quantization.get("bits", 4)) group = int(self.quantization.get("group_size", 64)) try: from mlx.nn.layers.quantized import QuantizedEmbedding From 994ccfa90122f024c229fd15eb34907a6a90a647 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 27 Oct 2025 00:50:29 -0700 Subject: [PATCH 156/172] add ShardConfig to __init__ --- src/dnet/ring/shard/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dnet/ring/shard/__init__.py b/src/dnet/ring/shard/__init__.py index ddc69862..00b76b81 100644 --- a/src/dnet/ring/shard/__init__.py +++ b/src/dnet/ring/shard/__init__.py @@ -2,5 +2,6 @@ from .node import RingShardNode from .servicer import ShardServicer +from .config import ShardConfig __all__ = ["RingShardNode", "ShardServicer"] From 9091031b7fb973606a73422e81f50e879edc2d04 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 27 Oct 2025 01:15:30 -0700 Subject: [PATCH 157/172] rm old bench framework --- src/dnet/perf/bench.py | 144 ----------------------------------------- 1 file changed, 144 deletions(-) delete mode 100644 src/dnet/perf/bench.py diff --git a/src/dnet/perf/bench.py b/src/dnet/perf/bench.py deleted file mode 100644 index 0cdbadd2..00000000 --- a/src/dnet/perf/bench.py +++ /dev/null @@ -1,144 +0,0 @@ - -from __future__ import annotations - -import json -import os -import statistics -import time -from dataclasses import dataclass, field -from typing import Any, Dict, Iterable, List, Optional - -from dnet.perf.trace import Tracer - - -def _percentile(xs: List[float], q: float) -> float: - if not xs: - return 0.0 - ys = sorted(xs) - k = int(round(q * (len(ys) - 1))) - k = max(0, min(k, len(ys) - 1)) - return ys[k] - -def collect_stats(times_ms: List[float], *, bytes_total: float = 0.0, tokens_total: float = 0.0) -> Dict[str, Any]: - if not times_ms: - return { - "mean": 0.0, - "std": 0.0, - "min": 0.0, - "p50": 0.0, - "p90": 0.0, - "p99": 0.0, - "max": 0.0, - "samples": 0, - "mb_s": 0.0, - "tok_s": 0.0, - } - total_ms = sum(times_ms) - mean = total_ms / len(times_ms) - std = statistics.pstdev(times_ms) if len(times_ms) > 1 else 0.0 - total_s = max(total_ms / 1000.0, 1e-12) - return { - "mean": mean, - "std": std, - "min": min(times_ms), - "p50": _percentile(times_ms, 0.5), - "p90": _percentile(times_ms, 0.9), - "p99": _percentile(times_ms, 0.99), - "max": max(times_ms), - "samples": len(times_ms), - "mb_per_s": (bytes_total / 1_000_000.0) / total_s if bytes_total else 0.0, - "tokens_per_s": (tokens_total / total_s) if tokens_total else 0.0, - } - - -def _ensure_dir(path: str) -> None: - d = os.path.dirname(path) or "." - os.makedirs(d, exist_ok=True) - - -@dataclass -class BenchCounters: - values: Dict[str, float] = field(default_factory=dict) - - def add_time(self, key: str, dt_ms: float) -> None: - self.values[key] = self.values.get(key, 0.0) + float(dt_ms) - - def add_bytes(self, *, direction: str, n: int) -> None: - k = "bytes_in" if direction == "in" else "bytes_out" - self.values[k] = self.values.get(k, 0.0) + float(n) - - def inc(self, key: str, delta: float = 1.0) -> None: - self.values[key] = self.values.get(key, 0.0) + float(delta) - - def snapshot(self, *, run_id: str, node: str, role: str = "shard") -> Dict[str, Any]: - snap = { - "run_id": run_id, - "node": node, - "role": role, - "counters": dict(self.values), - } - return snap - - -class TimedSpan: - __slots__ = ("_tracer", "_name", "_attrs", "_t0", "_frame", "_counters", "_counter_key") - - def __init__( - self, - tracer: Optional[Tracer], - name: str, - counters: Optional[BenchCounters] = None, - counter_key: Optional[str] = None, - attrs: Optional[Dict[str, Any]] = None, - ) -> None: - self._tracer = tracer - self._name = name - self._attrs = attrs or {} - self._t0 = 0.0 - self._frame = None - self._counters = counters - self._counter_key = counter_key - - def __enter__(self): - self._t0 = time.perf_counter() - if self._tracer is not None: - self._frame = self._tracer.frame("bench", self._name, self._attrs) - self._frame.__enter__() - return self - - def __exit__(self, ex_type, ex, tb) -> bool: - dt_ms = (time.perf_counter() - self._t0) * 1000.0 - if self._frame is not None: - try: - self._frame.__exit__(ex_type, ex, tb) - except Exception: - pass - if self._counters is not None and self._counter_key: - self._counters.add_time(self._counter_key, dt_ms) - return False - - -def aggregate_annotate( - snapshots: Iterable[Dict[str, Any]], - *, - mapping: Optional[Dict[str, str]] = None, - repeats: int = 0, -) -> List[Dict[str, Any]]: - - sums: Dict[str, float] = {} - for snap in snapshots: - ctr = snap.get("counters") if isinstance(snap, dict) else None - if not isinstance(ctr, dict): - continue - for k, v in ctr.items(): - name = mapping.get(k, k) if mapping else k - try: - sums[name] = sums.get(name, 0.0) + float(v) - except Exception: - continue - - rows = [ {"name": name, "self_ms": val, "total_ms": val, "count": repeats or 0, "max_ms": None} - for name, val in sums.items() if val > 0.0] - rows.sort(key=lambda r: r["self_ms"], reverse=True) - return rows - From 1896d24be2f596f2f42b053edd131a5e1b1b5cfe Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 27 Oct 2025 01:22:50 -0700 Subject: [PATCH 158/172] remove old memory frame path --- src/dnet/perf/utils/aggregators.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/dnet/perf/utils/aggregators.py b/src/dnet/perf/utils/aggregators.py index a75518a1..86e07f41 100644 --- a/src/dnet/perf/utils/aggregators.py +++ b/src/dnet/perf/utils/aggregators.py @@ -386,11 +386,6 @@ def _process_frame(self, e: Any, req_id: str, node_id: str, stats: ReqStats): #TODO: change shard in metadata stats.network_per_worker[node_id] += e["args"]["ms"] - elif symbol[0] == "memory": - print(f"MEMORY_PER_WORKER: {e["name"]} : {stats.memory_per_worker}") - stats.memory_per_worker[node_id] += e["args"]["ms"] - else: - print(f"UNTRACKED: {e["name"]}") return def _compute_round_stats(self, stats): @@ -403,7 +398,6 @@ def _compute_round_stats(self, stats): stats.itl[-1] = (rounds[i] - rounds[i-1]) * 1e-6 stats.itl.append(rounds[i]) stats.itl = stats.itl[:-1] - print(stats.itl) # Return data for total, per req, worker or model (maybe add per layer too?) def stats( From 20e39e14ef40534c96bf195bba15c1d606a1c282 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 27 Oct 2025 02:08:31 -0700 Subject: [PATCH 159/172] not-working chat interface --- src/repl.py | 58 +++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 11 deletions(-) diff --git a/src/repl.py b/src/repl.py index 231cecb1..f62d68e5 100644 --- a/src/repl.py +++ b/src/repl.py @@ -140,7 +140,8 @@ def loop(self): # Main tty loop self.print_mdns_nodes() continue elif cmd.startswith("load"): - self.load_model() + model = "mlx-community/llama-3.3-70b-instruct-4bit" + self.load_model(model) continue elif cmd.startswith(("trace", ".trace")): self.do_trace(cmd.split(" ")) @@ -212,7 +213,8 @@ def do_topo(self, cmd: List[str]) -> None: self.print_mdns_nodes() pass elif cmd[1] in ("auto", "build", "b"): - self.prepare_topo() + model = "mlx-community/llama-3.3-70b-instruct-4bit" + self.prepare_topo(model) pass elif cmd[1] == "setup": pass @@ -520,8 +522,11 @@ async def _await_then_set(): f.set_result(ret) except BaseException as e: f.set_exception(e) - self._api_loop.call_soon_threadsafe(runner) - return f.result(timeout) + try: + self._api_loop.call_soon_threadsafe(runner) + return f.result(timeout) + except Exception as e: + raise # ------- Trace aggregation helpers @@ -709,24 +714,55 @@ def print_topo(self, topo): sys.stdout.write(f"Devices: {topo.devices}\n\n") # TODO: Better print here - def prepare_topo(self): - req = PrepareTopologyRequest(model="Qwen/Qwen3-4B-MLX-4bit") + def prepare_topo(self, model): + req = PrepareTopologyRequest(model=model) try: - topo = self.api_call("_handle_prepare_topology", req, timeout=30) + topo = self.api_call("_handle_prepare_topology", req, timeout=120) except Exception as e: dprint(f"Unable to create topology: {e}\n\n") - return + return False self.state.topo = topo self.print_topo(topo) + return True - def load_model(self): - req = APILoadModelRequest(model="Qwen/Qwen3-4B-MLX-4bit") + def load_model(self, model): + req = APILoadModelRequest(model=model) try: res = self.api_call("_handle_load_model", req, timeout=30) + return True except Exception as e: dprint(f"Failed to load model: {e}\n\n") - return + return False + # ===== Handle chat + + def do_chat(self, cmd): + model = "mlx-community/llama-3.3-70b-instruct-4bit" + if len(cmd) < 2: + if not self.state.model or self.state.model == "": + self.prompt_model() + if not self.state.topology: + if not self._prepare_topo(self.state.model): + raise RuntimeError("Unable to create topology.") + if not self.load_model(self.state.model): + raise RuntimeError("Unable to load model.") + + while True: + prompt = input("\n> ") + prompt = self.format_prompt(prompt) + messages = prompt + req = ChatRequest( + messages=messages, + max_tokens=100, + temperature=0.7, + stream=True, + ) + + self.api_call("_handle_completion", req) + + # Start default chat with selected model + pass + pass # ===== Handle shutdown From ba93e0e7fda9efa3a4b531a81428d49f0403402b Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 2 Nov 2025 18:25:17 -0800 Subject: [PATCH 160/172] fix indent after rebase --- src/dnet/ring/shard/compute.py | 242 +++++++++------------ src/dnet/ring/shard/node.py | 372 +++++++++++++++++---------------- 2 files changed, 285 insertions(+), 329 deletions(-) diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 67f035aa..e0cc1a64 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -276,7 +276,6 @@ def _process_activation(self, activation_msg: ActivationMessage): window_layers, (t_comp_done - t_comp) * 1000.0, ) - """ for lid in window_layers: #self.weight_cache.decrease_reference(lid) @@ -338,37 +337,37 @@ def _process_activation(self, activation_msg: ActivationMessage): self._bound_versions.pop(lid, None) except Exception: pass - else: - if not self._defer_unload: - while len(self._recent_windows) > max( - 1, int(self._resident_windows) - ): - old = self._recent_windows.pop(0) - try: - evicted_cnt = self.weight_cache.evict_layers( - old - ) - except Exception: - evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass - if self._profile: + else: + if not self._defer_unload: + while len(self._recent_windows) > max( + 1, int(self._resident_windows) + ): + old = self._recent_windows.pop(0) try: - logger.info( - "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s", - self.node_id, - activation_msg.nonce, - old, - evicted_cnt, - self._resident_windows, + evicted_cnt = self.weight_cache.evict_layers( + old ) + except Exception: + evicted_cnt = 0 + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) except Exception: pass + if self._profile: + try: + logger.info( + "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s", + self.node_id, + activation_msg.nonce, + old, + evicted_cnt, + self._resident_windows, + ) + except Exception: + pass except Exception: pass @@ -434,48 +433,6 @@ def _process_activation(self, activation_msg: ActivationMessage): is_final=True, token_id=token_id, ) - else: - output_msg = ActivationMessage( - nonce=activation_msg.nonce, - layer_id=last_layer, - pool_id=-1, - shape=cast(tuple[int, ...], x.shape), - batch_size=activation_msg.batch_size, - timestamp=utc_epoch_now(), - node_origin=f"node_{self.node_id}", - dtype=str(self._wire_mx_dtype), - callback_url=activation_msg.callback_url, - tensor=x_cast, - ) - try: - with self._mlx_lock: - y = self.model.normalize(x_cast) - y = self.model.lm_project(y) - # Greedy sample last position - if y.ndim == 3: - logits_2d = y[:, -1, :] - elif y.ndim == 2: - logits_2d = y[-1:, :] - else: - logits_2d = y.reshape(1, -1) - tok = mx.argmax(logits_2d, axis=-1) - token_id = int(tok.item()) - except Exception as e: - logger.error("End-shard sampling failed: %s", e) - return - output_msg = ActivationMessage( - nonce=activation_msg.nonce, - layer_id=last_layer, - pool_id=-1, - shape=cast(tuple[int, ...], x.shape), - batch_size=activation_msg.batch_size, - timestamp=utc_epoch_now(), - node_origin=f"node_{self.node_id}", - dtype=str(self._wire_mx_dtype), - callback_url=activation_msg.callback_url, - is_final=True, - token_id=token_id, - ) else: output_msg = ActivationMessage( nonce=activation_msg.nonce, @@ -490,69 +447,75 @@ def _process_activation(self, activation_msg: ActivationMessage): tensor=x_cast, ) try: - output_msg.tx_enq_perf_t = time.perf_counter() - except Exception: - output_msg.tx_enq_perf_t = 0.0 - # Enqueue to appropriate asyncio TX queue from compute thread - with self.tracer.frame("network.tx", "enque") as f: - output_msg.tx_enq_t = time.perf_counter() - try: - if self._loop is not None: - target_q = ( - self.activation_token_queue - if output_msg.is_final - else self.activation_computed_queue - ) - fut = asyncio.run_coroutine_threadsafe( - target_q.put(output_msg), self._loop - ) - fut.result() - else: - raise RuntimeError("Event loop not available for TX queue") - except Exception as e: - logger.error( - "Failed to queue computed activation for sending: %s", e + with self._mlx_lock: + y = self.model.normalize(x_cast) + y = self.model.lm_project(y) + # Greedy sample last position + if y.ndim == 3: + logits_2d = y[:, -1, :] + elif y.ndim == 2: + logits_2d = y[-1:, :] + else: + logits_2d = y.reshape(1, -1) + tok = mx.argmax(logits_2d, axis=-1) + token_id = int(tok.item()) + except Exception as e: + logger.error("End-shard sampling failed: %s", e) + return + output_msg = ActivationMessage( + nonce=activation_msg.nonce, + layer_id=last_layer, + pool_id=-1, + shape=cast(tuple[int, ...], x.shape), + batch_size=activation_msg.batch_size, + timestamp=utc_epoch_now(), + node_origin=f"node_{self.node_id}", + dtype=str(self._wire_mx_dtype), + callback_url=activation_msg.callback_url, + is_final=True, + token_id=token_id, + ) + else: + output_msg = ActivationMessage( + nonce=activation_msg.nonce, + layer_id=last_layer, + pool_id=-1, + shape=cast(tuple[int, ...], x.shape), + batch_size=activation_msg.batch_size, + timestamp=utc_epoch_now(), + node_origin=f"node_{self.node_id}", + dtype=str(self._wire_mx_dtype), + callback_url=activation_msg.callback_url, + tensor=x_cast, + ) + try: + output_msg.tx_enq_perf_t = time.perf_counter() + except Exception: + output_msg.tx_enq_perf_t = 0.0 + # Enqueue to appropriate asyncio TX queue from compute thread + with self.tracer.frame("network.tx", "enque") as f: + output_msg.tx_enq_t = time.perf_counter() + try: + if self._loop is not None: + target_q = ( + self.activation_token_queue + if output_msg.is_final + else self.activation_computed_queue ) - # Clean up input resources - self.input_pool.release(activation_msg.pool_id) + # Clean up input resources + self.input_pool.release(activation_msg.pool_id) - # Optional unload/evict after stage - with self.tracer.frame("compute.thread", "cleanup"): - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) - if self._mode != "sliding_fit": - if self._defer_unload: - # Clean up input resources - with self.tracer.frame("compute.thread", "cleanup") as f: - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) - self.input_pool.release(activation_msg.pool_id) - # After queuing TX, schedule prefetch and eviction in the background - # to avoid stalling the handoff to the next shard. - try: - while len(self._recent_windows) > max(1, int(getattr(self, "_resident_windows", 2))): - old = self._recent_windows.pop(0) - try: - evicted_cnt = self.weight_cache.evict_layers(old) - except Exception: - evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass - except Exception: - pass - if getattr(self, "_resident_windows", 2) <= 1: - try: - evicted = self.weight_cache.evict_layers(window_layers) - if hasattr(self.mode, "unload_layers"): - self.model.unload_layers(window_layers) - except Exception: - pass + # Optional unload/evict after stage + with self.tracer.frame("compute.thread", "cleanup"): + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + if self._mode != "sliding_fit": + if self._defer_unload: + # Clean up input resources + with self.tracer.frame("compute.thread", "cleanup") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) try: while len(self._recent_windows) > max( 1, int(self._resident_windows) @@ -569,23 +532,14 @@ def _process_activation(self, activation_msg: ActivationMessage): self._bound_versions.pop(lid, None) except Exception: pass - if self._profile: - logger.info( - "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s (post-stage)", - self.node_id, - activation_msg.nonce, - old, - evicted_cnt, - self._resident_windows, - ) except Exception: pass - if self._resident_windows <= 1: - try: - evicted = self.weight_cache.evict_layers(window_layers) - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(window_layers) # type: ignore[attr-defined] + if self._resident_windows <= 1: + try: + evicted = self.weight_cache.evict_layers(window_layers) + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(window_layers) # type: ignore[attr-defined] if self._profile: logger.info( "[PROFILE][EVICT] node=%s nonce=%s layers=%s evicted=%s", diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 48c49ed8..4f899f9d 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -611,7 +611,9 @@ async def reset_cache(self) -> None: kv_bits=self.config.kv_cache.bits, kv_group=self.config.kv_cache.group_size, ) - logger.info("Node %s: Cache reset successfully", self.node_id) + logger.info("Node %s: Cache reset successfully", self.node_id) + except Exception as e: + logger.error("Node %s: Error resetting cache: %s", self.node_id, e) # FIXME This seems to still be dead code @@ -631,148 +633,84 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): try: activation = request.activation target_layer = activation.layer_id + 1 + + # Detect new sequence per node: initialize per-nonce KV + if request.nonce != self._active_nonce: + self._active_nonce = request.nonce + try: + payload_bytes = len(activation.data) + except Exception: + payload_bytes = -1 + f.event("process_payload") - # Detect new sequence per node: initialize per-nonce KV - if request.nonce != self._active_nonce: - self._active_nonce = request.nonce - try: - payload_bytes = len(activation.data) - except Exception: - payload_bytes = -1 - f.event("process_payload") - - if target_layer in self._assigned_set: - # Allocate input pool and copy payload (with optional decompression) - t_alloc = time.perf_counter() - if "|" in activation.dtype: - with self.tracer.frame("grpc.receive", "decompress") as fr: - fr.set("req_id", request.nonce) - fr.set("node", self._instance_name) - try: - deq = decompress_tensor_from_protobuf_data( - tensor_data=activation.data, - shape=list(activation.shape), - dtype_with_metadata=activation.dtype, - ) - except Exception as e: - logger.error( - "Decompression failed for nonce %s: %s", request.nonce, e - ) - return - - with self.tracer.frame("grpc.receive", "alloc.buffer") as fr: - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=deq.dtype, - shape=cast(tuple[int, ...], tuple(deq.shape)), - ) - if pool_id is None: - logger.warning( - "Failed to allocate input pool buffer for nonce %s", - request.nonce, - ) - return - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - flat = deq.reshape(-1) - buffer[: flat.size] = flat - alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 - logger.info( - "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", - self.node_id, - request.nonce, - alloc_copy_ms, - ) - - # Update activation message with true dtype/shape - new_dtype_str = str(deq.dtype) - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = new_dtype_str - activation_msg.shape = tuple(deq.shape) - else: - # Special token stream support: dtype='tokens' carries int32 token IDs - if activation.dtype == "tokens": - with self.tracer.frame("grpc.receive", "token_stream") as fr: + if target_layer in self._assigned_set: + # Allocate input pool and copy payload (with optional decompression) + t_alloc = time.perf_counter() + if "|" in activation.dtype: + with self.tracer.frame("grpc.receive", "decompress") as fr: + fr.set("req_id", request.nonce) + fr.set("node", self._instance_name) try: deq = decompress_tensor_from_protobuf_data( tensor_data=activation.data, shape=list(activation.shape), - dtype_with_metadata=activation.dtype) + dtype_with_metadata=activation.dtype, + ) except Exception as e: - logger.error("Decompression failed for nonce %s: %s", request.nonce, e) + logger.error( + "Decompression failed for nonce %s: %s", request.nonce, e + ) return - with self.tracer.frame("network.rx", "alloc.buffer") as fr: - fr.set("req_id", request.nonce) - fr.set("node", self._instance_name) + with self.tracer.frame("grpc.receive", "alloc.buffer") as fr: pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, dtype=deq.dtype, - shape=cast(tuple[int, ...], tuple(deq.shape))) - + shape=cast(tuple[int, ...], tuple(deq.shape)), + ) if pool_id is None: - logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) + logger.warning( + "Failed to allocate input pool buffer for nonce %s", + request.nonce, + ) return - buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: flat = deq.reshape(-1) buffer[: flat.size] = flat + alloc_copy_ms = (time.perf_counter() - t_alloc) * 1000.0 + logger.info( + "[PROFILE][RX] node=%s nonce=%s alloc_copy_ms=%.3f (decompressed)", + self.node_id, + request.nonce, + alloc_copy_ms, + ) # Update activation message with true dtype/shape new_dtype_str = str(deq.dtype) activation_msg = ActivationMessage.from_proto(request, pool_id) activation_msg.dtype = new_dtype_str activation_msg.shape = tuple(deq.shape) - - else: # Special token stream support: dtype='tokens' carries int32 token IDs + else: + # Special token stream support: dtype='tokens' carries int32 token IDs if activation.dtype == "tokens": - with self.tracer.frame("network.rx", "token_stream") as fr: - fr.set("req_id", request.nonce) - fr.set("node", self._instance_name) + with self.tracer.frame("grpc.receive", "token_stream") as fr: try: - tokens = np.frombuffer(request.activation.data, dtype=np.int32) - shp = (int(len(tokens)), ) + deq = decompress_tensor_from_protobuf_data( + tensor_data=activation.data, + shape=list(activation.shape), + dtype_with_metadata=activation.dtype) except Exception as e: - logger.error("Failed to parse tokens for nonce %s: %s", request.nonce, e,) - return - - pool_id = self.input_pool.allocate_for_layer( - layer_id=activation.layer_id, - dtype=mx.int32, - shape=cast(tuple[int, ...], shp)) - - if pool_id is None: - logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) + logger.error("Decompression failed for nonce %s: %s", request.nonce, e) return - buffer = self.input_pool.get_buffer(pool_id) - if buffer is not None: - buffer[: len(tokens)] = tokens - activation_msg = ActivationMessage.from_proto(request, pool_id) - - # Ensure dtype reflects token payload for compute path - activation_msg.dtype = "tokens" - activation_msg.shape = shp - - else: - with self.tracer.frame("network.ex", "default") as fr: - fr.set("node", self._instance_name) + with self.tracer.frame("network.rx", "alloc.buffer") as fr: fr.set("req_id", request.nonce) - # Safety: byte length must match shape*dtype - try: - expected = ( - int(np.prod(activation.shape)) - * np.dtype(dtype_map[activation.dtype]).itemsize - ) - actual = len(request.activation.data) - except Exception: - pass - + fr.set("node", self._instance_name) pool_id = self.input_pool.allocate_for_layer( layer_id=activation.layer_id, - dtype=mlx_dtype_map[activation.dtype], - shape=cast(tuple[int, ...], activation.shape)) + dtype=deq.dtype, + shape=cast(tuple[int, ...], tuple(deq.shape))) if pool_id is None: logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) @@ -780,26 +718,90 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): buffer = self.input_pool.get_buffer(pool_id) if buffer is not None: - data = request.activation.data - input_data = np.frombuffer(data, dtype=dtype_map[activation.dtype]) - buffer[: len(input_data)] = input_data + flat = deq.reshape(-1) + buffer[: flat.size] = flat + + # Update activation message with true dtype/shape + new_dtype_str = str(deq.dtype) + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + + else: # Special token stream support: dtype='tokens' carries int32 token IDs + if activation.dtype == "tokens": + with self.tracer.frame("network.rx", "token_stream") as fr: + fr.set("req_id", request.nonce) + fr.set("node", self._instance_name) + try: + tokens = np.frombuffer(request.activation.data, dtype=np.int32) + shp = (int(len(tokens)), ) + except Exception as e: + logger.error("Failed to parse tokens for nonce %s: %s", request.nonce, e,) + return + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mx.int32, + shape=cast(tuple[int, ...], shp)) + + if pool_id is None: + logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) + return + + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + buffer[: len(tokens)] = tokens + activation_msg = ActivationMessage.from_proto(request, pool_id) + + # Ensure dtype reflects token payload for compute path + activation_msg.dtype = "tokens" + activation_msg.shape = shp - activation_msg = ActivationMessage.from_proto(request, pool_id) - activation_msg.dtype = new_dtype_str - activation_msg.shape = tuple(deq.shape) - - # Queue for processing — non-blocking back-off loop (cancellable) - while self.running: - try: - self.activation_recv_queue.put_nowait(activation_msg) - activatino_msg.ex_enq_t = time.perf_counter() - logger.debug("Queued activation for processing: nonce %s", activation_msg.nonce) - break - except Full: - await asyncio.sleep(0) - else: - logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce) - self.input_pool.release(pool_id) + else: + with self.tracer.frame("network.ex", "default") as fr: + fr.set("node", self._instance_name) + fr.set("req_id", request.nonce) + # Safety: byte length must match shape*dtype + try: + expected = ( + int(np.prod(activation.shape)) + * np.dtype(dtype_map[activation.dtype]).itemsize + ) + actual = len(request.activation.data) + except Exception: + pass + + pool_id = self.input_pool.allocate_for_layer( + layer_id=activation.layer_id, + dtype=mlx_dtype_map[activation.dtype], + shape=cast(tuple[int, ...], activation.shape)) + + if pool_id is None: + logger.warning("Failed to allocate input pool buffer for nonce %s", request.nonce) + return + + buffer = self.input_pool.get_buffer(pool_id) + if buffer is not None: + data = request.activation.data + input_data = np.frombuffer(data, dtype=dtype_map[activation.dtype]) + buffer[: len(input_data)] = input_data + + activation_msg = ActivationMessage.from_proto(request, pool_id) + activation_msg.dtype = new_dtype_str + activation_msg.shape = tuple(deq.shape) + + # Queue for processing — non-blocking back-off loop (cancellable) + while self.running: + try: + self.activation_recv_queue.put_nowait(activation_msg) + activatino_msg.ex_enq_t = time.perf_counter() + logger.debug("Queued activation for processing: nonce %s", activation_msg.nonce) + break + except Full: + await asyncio.sleep(0) + else: + logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce) + self.input_pool.release(pool_id) else: # Forward to next node (not our layer) logger.debug("Forwarding activation (layer %s) to next node, nonce: %s", target_layer, request.nonce) @@ -853,43 +855,26 @@ async def _ingress_worker(self): activation = req.activation target_layer = activation.layer_id + 1 - # Detect new sequence per node: initialize per-nonce KV - if req.nonce != self._active_nonce: - self._active_nonce = req.nonce - try: - payload_bytes = len(activation.data) - except Exception: - logger.error(f"Unable to read length of data for {req.nonce}") - payload_bytes = -1 + # Detect new sequence per node: initialize per-nonce KV + if req.nonce != self._active_nonce: + self._active_nonce = req.nonce + try: + payload_bytes = len(activation.data) + except Exception: + logger.error(f"Unable to read length of data for {req.nonce}") + payload_bytes = -1 - fr.set("req_id", req.nonce) - f.set("target", target_layer) - f.set("payload_bytes", payload_bytes) - f.event("received") + fr.set("req_id", req.nonce) + f.set("target", target_layer) + f.set("payload_bytes", payload_bytes) + f.event("received") - if target_layer in self._assigned_set: - # Heavy prep in executor (alloc/copy/decompress) - with self.tracer.frame("grpc.ingress", "prepare") as fr: - fr.set("node", self._instance_name) - fr.set("nonce", req.nonce) - loop = asyncio.get_running_loop() - try: - activation_msg = await loop.run_in_executor( - self.executor, - self._prepare_activation_message_blocking, - req, - ) - except Exception as e: - logger.error("Activation prepare failed for nonce %s: %s", req.nonce, e) - continue - if activation_msg is None: - continue - if self._profile: - activation_msg.recv_perf_t = t_recv - - # Enqueue for compute (cancellable back-off) - with self.tracer.frame("grpc.ingress", "queue") as fr: - while self.running: + if target_layer in self._assigned_set: + # Heavy prep in executor (alloc/copy/decompress) + with self.tracer.frame("network.ingress", "prepare") as fr: + fr.set("node", self._instance_name) + fr.set("nonce", req.nonce) + loop = asyncio.get_running_loop() try: activation_msg = await loop.run_in_executor( self.executor, @@ -901,29 +886,46 @@ async def _ingress_worker(self): continue if activation_msg is None: continue + if self._profile: + activation_msg.recv_perf_t = t_recv - # Enqueue for compute - with self.tracer.frame("network.rx", "enque") as fr: - fr.set("req_id", req.nonce) - fr.set("node", self._instance_name) + # Enqueue for compute (cancellable back-off) + with self.tracer.frame("network.ingress", "queue") as fr: while self.running: try: - self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", - activation_msg.nonce, + activation_msg = await loop.run_in_executor( + self.executor, + self._prepare_activation_message_blocking, + req, ) - break - except Full: - await asyncio.sleep(0) - else: - logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce,) - try: - if self.input_pool: - # FIXME: !!! - self.input_pool.release(activation_msg.pool_id) - except Exception: - pass + except Exception as e: + logger.error("Activation prepare failed for nonce %s: %s", req.nonce, e) + continue + if activation_msg is None: + continue + + # Enqueue for compute + with self.tracer.frame("network.rx", "enque") as fr: + fr.set("req_id", req.nonce) + fr.set("node", self._instance_name) + while self.running: + try: + self.activation_recv_queue.put_nowait(activation_msg) + logger.debug( + "Queued activation for processing: nonce %s", + activation_msg.nonce, + ) + break + except Full: + await asyncio.sleep(0) + else: + logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce,) + try: + if self.input_pool: + # FIXME: !!! + self.input_pool.release(activation_msg.pool_id) + except Exception: + pass else: # Forward to next node (not our layer) logger.debug( From af49e361ee7f8740be38a1be921f6040c5c8dcc7 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 2 Nov 2025 19:15:05 -0800 Subject: [PATCH 161/172] cleanup weight_cache --- src/dnet/ring/shard/models.py | 2 +- src/dnet/ring/weight_cache.py | 125 +++------------------------------- 2 files changed, 12 insertions(+), 115 deletions(-) diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index 6a5efcc5..a44ad608 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -53,7 +53,7 @@ class ShardUnloadModelResponse(BaseModel): class ShardProfileRequest(BaseModel): """Request to profile device and measure latencies.""" - api_address: Optional[str] = Field( ..., description="API Address" ) + #api_address: Optional[str] = Field( ..., description="API Address" ) devices: Dict[str, DnetDeviceProperties] = Field( ..., description="Device information mapping" ) diff --git a/src/dnet/ring/weight_cache.py b/src/dnet/ring/weight_cache.py index 2a9a22b0..251834aa 100644 --- a/src/dnet/ring/weight_cache.py +++ b/src/dnet/ring/weight_cache.py @@ -64,40 +64,19 @@ def __init__( def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[str, mx.array]]: """Get weight from cache""" - # First, fast path under lock for cache hit or in-flight load - with self.lock: - if layer_id in self.cache: - data, _ = self.cache[layer_id] - # refresh LRU timestamp - self.cache[layer_id] = (data, time.time()) - if inc_ref: - self.reference_counts[layer_id] = ( - self.reference_counts.get(layer_id, 0) + 1 - ) - return data with self.tracer.frame("memory.weights", "cache.search") as f: - with self.lock: - if layer_id in self.cache: - data, _ = self.cache[layer_id] - self.cache[layer_id] = (data, time.time()) # refresh LRU timestamp - if inc_ref: - self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) - logger.debug("Cache hit for layer %s, ref=%d inc=%d", - layer_id, self.reference_counts.get(layer_id, 0), int(inc_ref)) - return data - - inflight = self.loading_futures.get(layer_id) # If a load is in-flight, wait on it outside the lock - if inflight is None: - need_evict = len(self.cache) >= self.max_weights - if need_evict: # Prepare eviction decision now to avoid overfilling once loaded - self._evict_lru() # Evict under lock, then proceed to load - fut = Future() # Install a new future marker so others wait - self.loading_futures[layer_id] = fut - inflight = fut - creator = True - else: - creator = False + inflight = self.loading_futures.get(layer_id) # If a load is in-flight, wait on it outside the lock + if inflight is None: + need_evict = len(self.cache) >= self.max_weights + if need_evict: # Prepare eviction decision now to avoid overfilling once loaded + self._evict_lru() # Evict under lock, then proceed to load + fut = Future() # Install a new future marker so others wait + self.loading_futures[layer_id] = fut + inflight = fut + creator = True + else: + creator = False if creator: # Perform the blocking load without holding the cache lock with self.tracer.frame("memory.weights", "cache.load") as f: @@ -118,93 +97,11 @@ def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[s self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) else: self.reference_counts.setdefault(layer_id, 0) - - try: # Resolve future and clear from tracking - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_result(True) - except Exception: - self.loading_futures.pop(layer_id, None) return data except Exception as e: - with self.lock: # Signal failure to any waiters - try: - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_exception(e) - except Exception: - self.loading_futures.pop(layer_id, None) - logger.exception("Failed to load weight %s: %s", layer_id, e) - return None - - else: # Not the creator: wait for the in-flight load to complete - with self.tracer.frame("memory.weights", "cache.wait") as f: - try: - inflight.result() # block until the creator completes - except Exception as e: - logger.error("Wait for layer %s load failed: %s", layer_id, e) return None - with self.lock: # Return from cache - data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] - if data is None: - logger.error("Wait for layer %s load failed: data not in cache", layer_id) - return None - - self.cache[layer_id] = (data, time.time()) - if inc_ref: - self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) - else: - self.reference_counts.setdefault(layer_id, 0) - # Resolve future and clear from tracking - try: - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_result(True) - except Exception: - self.loading_futures.pop(layer_id, None) - logger.info( - "[PROFILE][MATERIALIZE] layer=%s ms=%.2f bytes=%.2fMB", - layer_id, - dt_ms, - (total_bytes / 1_048_576), - ) - return data - except Exception as e: - # Signal failure to any waiters - with self.lock: - try: - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_exception(e) - except Exception: - self.loading_futures.pop(layer_id, None) - logger.exception("Failed to load weight %s: %s", layer_id, e) - return None - else: - # Not the creator: wait for the in-flight load to complete - t0w = time.perf_counter() - try: - inflight.result() # block until the creator completes - except Exception as e: - logger.error("Wait for layer %s load failed: %s", layer_id, e) - return None - wait_ms = (time.perf_counter() - t0w) * 1000.0 - logger.info("[PROFILE][WAIT-WEIGHT] layer=%s ms=%.2f", layer_id, wait_ms) - # Return from cache (now populated) and update ref/LRU - with self.lock: - data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] - if data is None: - return None - self.cache[layer_id] = (data, time.time()) - if inc_ref: - self.reference_counts[layer_id] = ( - self.reference_counts.get(layer_id, 0) + 1 - ) - else: - self.reference_counts.setdefault(layer_id, 0) - return data def decrease_reference(self, layer_id: int): """Decrease reference count for layer""" From cc977c564737d799c535792a6f9f5b33afc61eb9 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 2 Nov 2025 19:16:04 -0800 Subject: [PATCH 162/172] Revert "cleanup weight_cache" This reverts commit a12eefb6f3807ff9f1812cd755743bc4664a8714. --- src/dnet/ring/shard/models.py | 2 +- src/dnet/ring/weight_cache.py | 125 +++++++++++++++++++++++++++++++--- 2 files changed, 115 insertions(+), 12 deletions(-) diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index a44ad608..6a5efcc5 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -53,7 +53,7 @@ class ShardUnloadModelResponse(BaseModel): class ShardProfileRequest(BaseModel): """Request to profile device and measure latencies.""" - #api_address: Optional[str] = Field( ..., description="API Address" ) + api_address: Optional[str] = Field( ..., description="API Address" ) devices: Dict[str, DnetDeviceProperties] = Field( ..., description="Device information mapping" ) diff --git a/src/dnet/ring/weight_cache.py b/src/dnet/ring/weight_cache.py index 251834aa..2a9a22b0 100644 --- a/src/dnet/ring/weight_cache.py +++ b/src/dnet/ring/weight_cache.py @@ -64,19 +64,40 @@ def __init__( def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[str, mx.array]]: """Get weight from cache""" + # First, fast path under lock for cache hit or in-flight load + with self.lock: + if layer_id in self.cache: + data, _ = self.cache[layer_id] + # refresh LRU timestamp + self.cache[layer_id] = (data, time.time()) + if inc_ref: + self.reference_counts[layer_id] = ( + self.reference_counts.get(layer_id, 0) + 1 + ) + return data with self.tracer.frame("memory.weights", "cache.search") as f: - inflight = self.loading_futures.get(layer_id) # If a load is in-flight, wait on it outside the lock - if inflight is None: - need_evict = len(self.cache) >= self.max_weights - if need_evict: # Prepare eviction decision now to avoid overfilling once loaded - self._evict_lru() # Evict under lock, then proceed to load - fut = Future() # Install a new future marker so others wait - self.loading_futures[layer_id] = fut - inflight = fut - creator = True - else: - creator = False + with self.lock: + if layer_id in self.cache: + data, _ = self.cache[layer_id] + self.cache[layer_id] = (data, time.time()) # refresh LRU timestamp + if inc_ref: + self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) + logger.debug("Cache hit for layer %s, ref=%d inc=%d", + layer_id, self.reference_counts.get(layer_id, 0), int(inc_ref)) + return data + + inflight = self.loading_futures.get(layer_id) # If a load is in-flight, wait on it outside the lock + if inflight is None: + need_evict = len(self.cache) >= self.max_weights + if need_evict: # Prepare eviction decision now to avoid overfilling once loaded + self._evict_lru() # Evict under lock, then proceed to load + fut = Future() # Install a new future marker so others wait + self.loading_futures[layer_id] = fut + inflight = fut + creator = True + else: + creator = False if creator: # Perform the blocking load without holding the cache lock with self.tracer.frame("memory.weights", "cache.load") as f: @@ -97,11 +118,93 @@ def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[s self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) else: self.reference_counts.setdefault(layer_id, 0) + + try: # Resolve future and clear from tracking + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_result(True) + except Exception: + self.loading_futures.pop(layer_id, None) return data except Exception as e: + with self.lock: # Signal failure to any waiters + try: + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_exception(e) + except Exception: + self.loading_futures.pop(layer_id, None) + logger.exception("Failed to load weight %s: %s", layer_id, e) + return None + + else: # Not the creator: wait for the in-flight load to complete + with self.tracer.frame("memory.weights", "cache.wait") as f: + try: + inflight.result() # block until the creator completes + except Exception as e: + logger.error("Wait for layer %s load failed: %s", layer_id, e) return None + with self.lock: # Return from cache + data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] + if data is None: + logger.error("Wait for layer %s load failed: data not in cache", layer_id) + return None + + self.cache[layer_id] = (data, time.time()) + if inc_ref: + self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) + else: + self.reference_counts.setdefault(layer_id, 0) + # Resolve future and clear from tracking + try: + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_result(True) + except Exception: + self.loading_futures.pop(layer_id, None) + logger.info( + "[PROFILE][MATERIALIZE] layer=%s ms=%.2f bytes=%.2fMB", + layer_id, + dt_ms, + (total_bytes / 1_048_576), + ) + return data + except Exception as e: + # Signal failure to any waiters + with self.lock: + try: + fut = self.loading_futures.pop(layer_id, None) + if fut is not None and not fut.done(): + fut.set_exception(e) + except Exception: + self.loading_futures.pop(layer_id, None) + logger.exception("Failed to load weight %s: %s", layer_id, e) + return None + else: + # Not the creator: wait for the in-flight load to complete + t0w = time.perf_counter() + try: + inflight.result() # block until the creator completes + except Exception as e: + logger.error("Wait for layer %s load failed: %s", layer_id, e) + return None + wait_ms = (time.perf_counter() - t0w) * 1000.0 + logger.info("[PROFILE][WAIT-WEIGHT] layer=%s ms=%.2f", layer_id, wait_ms) + # Return from cache (now populated) and update ref/LRU + with self.lock: + data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] + if data is None: + return None + self.cache[layer_id] = (data, time.time()) + if inc_ref: + self.reference_counts[layer_id] = ( + self.reference_counts.get(layer_id, 0) + 1 + ) + else: + self.reference_counts.setdefault(layer_id, 0) + return data def decrease_reference(self, layer_id: int): """Decrease reference count for layer""" From 0bc3253778fbf72b43ff0a6d915de6a63b69fa09 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 2 Nov 2025 19:22:31 -0800 Subject: [PATCH 163/172] remove double code in weight_cache from rebase --- src/dnet/ring/weight_cache.py | 90 +++++++---------------------------- 1 file changed, 18 insertions(+), 72 deletions(-) diff --git a/src/dnet/ring/weight_cache.py b/src/dnet/ring/weight_cache.py index 2a9a22b0..3749522d 100644 --- a/src/dnet/ring/weight_cache.py +++ b/src/dnet/ring/weight_cache.py @@ -65,26 +65,16 @@ def __init__( def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[str, mx.array]]: """Get weight from cache""" # First, fast path under lock for cache hit or in-flight load - with self.lock: - if layer_id in self.cache: - data, _ = self.cache[layer_id] - # refresh LRU timestamp - self.cache[layer_id] = (data, time.time()) - if inc_ref: - self.reference_counts[layer_id] = ( - self.reference_counts.get(layer_id, 0) + 1 - ) - return data - with self.tracer.frame("memory.weights", "cache.search") as f: - with self.lock: - if layer_id in self.cache: + with self.lock: + if layer_id in self.cache: data, _ = self.cache[layer_id] - self.cache[layer_id] = (data, time.time()) # refresh LRU timestamp + # refresh LRU timestamp + self.cache[layer_id] = (data, time.time()) if inc_ref: - self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) - logger.debug("Cache hit for layer %s, ref=%d inc=%d", - layer_id, self.reference_counts.get(layer_id, 0), int(inc_ref)) + self.reference_counts[layer_id] = ( + self.reference_counts.get(layer_id, 0) + 1 + ) return data inflight = self.loading_futures.get(layer_id) # If a load is in-flight, wait on it outside the lock @@ -137,74 +127,30 @@ def get_weight(self, layer_id: int, *, inc_ref: bool = False) -> Optional[Dict[s self.loading_futures.pop(layer_id, None) logger.exception("Failed to load weight %s: %s", layer_id, e) return None - - else: # Not the creator: wait for the in-flight load to complete + else: + # Not the creator: wait for the in-flight load to complete with self.tracer.frame("memory.weights", "cache.wait") as f: + t0w = time.perf_counter() try: inflight.result() # block until the creator completes except Exception as e: logger.error("Wait for layer %s load failed: %s", layer_id, e) return None - - with self.lock: # Return from cache + wait_ms = (time.perf_counter() - t0w) * 1000.0 + logger.info("[PROFILE][WAIT-WEIGHT] layer=%s ms=%.2f", layer_id, wait_ms) + # Return from cache (now populated) and update ref/LRU + with self.lock: data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] if data is None: - logger.error("Wait for layer %s load failed: data not in cache", layer_id) return None - self.cache[layer_id] = (data, time.time()) if inc_ref: - self.reference_counts[layer_id] = (self.reference_counts.get(layer_id, 0) + 1) + self.reference_counts[layer_id] = ( + self.reference_counts.get(layer_id, 0) + 1 + ) else: self.reference_counts.setdefault(layer_id, 0) - # Resolve future and clear from tracking - try: - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_result(True) - except Exception: - self.loading_futures.pop(layer_id, None) - logger.info( - "[PROFILE][MATERIALIZE] layer=%s ms=%.2f bytes=%.2fMB", - layer_id, - dt_ms, - (total_bytes / 1_048_576), - ) - return data - except Exception as e: - # Signal failure to any waiters - with self.lock: - try: - fut = self.loading_futures.pop(layer_id, None) - if fut is not None and not fut.done(): - fut.set_exception(e) - except Exception: - self.loading_futures.pop(layer_id, None) - logger.exception("Failed to load weight %s: %s", layer_id, e) - return None - else: - # Not the creator: wait for the in-flight load to complete - t0w = time.perf_counter() - try: - inflight.result() # block until the creator completes - except Exception as e: - logger.error("Wait for layer %s load failed: %s", layer_id, e) - return None - wait_ms = (time.perf_counter() - t0w) * 1000.0 - logger.info("[PROFILE][WAIT-WEIGHT] layer=%s ms=%.2f", layer_id, wait_ms) - # Return from cache (now populated) and update ref/LRU - with self.lock: - data, _ = self.cache.get(layer_id, (None, 0.0)) # type: ignore[assignment] - if data is None: - return None - self.cache[layer_id] = (data, time.time()) - if inc_ref: - self.reference_counts[layer_id] = ( - self.reference_counts.get(layer_id, 0) + 1 - ) - else: - self.reference_counts.setdefault(layer_id, 0) - return data + return data def decrease_reference(self, layer_id: int): """Decrease reference count for layer""" From d186593fe6bd2d4aa92aedf1f5381cfda6544e37 Mon Sep 17 00:00:00 2001 From: Octavian Date: Sun, 2 Nov 2025 19:24:53 -0800 Subject: [PATCH 164/172] comment message field for compatibility --- src/dnet/ring/shard/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dnet/ring/shard/models.py b/src/dnet/ring/shard/models.py index 6a5efcc5..a44ad608 100644 --- a/src/dnet/ring/shard/models.py +++ b/src/dnet/ring/shard/models.py @@ -53,7 +53,7 @@ class ShardUnloadModelResponse(BaseModel): class ShardProfileRequest(BaseModel): """Request to profile device and measure latencies.""" - api_address: Optional[str] = Field( ..., description="API Address" ) + #api_address: Optional[str] = Field( ..., description="API Address" ) devices: Dict[str, DnetDeviceProperties] = Field( ..., description="Device information mapping" ) From 4ce0405032cf97a83bf2bef28279fe2593dea53c Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 3 Nov 2025 01:02:38 -0800 Subject: [PATCH 165/172] fix indent and duplicates from rebase --- src/dnet/ring/shard/compute.py | 481 ++++++++++++++------------------- 1 file changed, 209 insertions(+), 272 deletions(-) diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index e0cc1a64..6b64a48a 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -94,15 +94,14 @@ def _process_activation(self, activation_msg: ActivationMessage): with self.tracer.frame("compute.thread", "activations.process") as f: f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) + if activation_msg.dtype == "tokens": # embed locally on start shard - logger.debug(f"Embedding tokens.") numel = int(np.prod(activation_msg.shape)) tok_view = input_buffer[:numel].reshape(activation_msg.shape) toks = mx.array(np.array(tok_view, dtype=np.int32), dtype=mx.int32) x = self.model.embed(toks[None]) - # NOTE: Used to track start of request in perf stats - self.tracer.mark("embedding", { + self.tracer.mark("embedding", { # NOTE: Used to track start of request in perf stats "nonce": activation_msg.nonce, "prompt_tokens": toks.size, }) @@ -125,12 +124,13 @@ def _process_activation(self, activation_msg: ActivationMessage): current_layer = activation_msg.layer_id + 1 last_layer = current_layer - 1 while True: + start_time = time.perf_counter() processed = 0 did_early_swap = False - with self.tracer.frame("compute.thread", "weights.prepare") as f: - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) + with self.tracer.frame("compute.thread", "weights.prepare") as fr: + fr.set("req_id", activation_msg.nonce) + fr.set("node", self._instance_name) # Determine contiguous local window starting at current_layer window_layers: List[int] = [] @@ -174,23 +174,19 @@ def _process_activation(self, activation_msg: ActivationMessage): # Ensure weights for the window are resident and bind only if arrays changed # if model fits and we've already bound these layers, skip the scan entirely. - fast_fit = ( - self._mode == "fit" - and len(self._assigned_sorted) <= self.window_size - ) - skip_scan = fast_fit and all( - (wl in self._bound_versions) for wl in window_layers - ) + fast_fit = (self._mode == "fit" and len(self._assigned_sorted) <= self.window_size) + skip_scan = fast_fit and all( (wl in self._bound_versions) for wl in window_layers) + to_bind: Dict[str, mx.array] = {} if not skip_scan: + t_w_ready = time.perf_counter() for wl in window_layers: weights = self.weight_cache.get_weight(wl) if weights is None: logger.error("Failed to load weights for layer %s", wl) self.input_pool.release(activation_msg.pool_id) return - try: - # Use identity of first array as a cheap version/fingerprint + try: # Use identity of first array as a cheap version/fingerprint first_arr = next(iter(weights.values())) version = id(first_arr) except StopIteration: @@ -211,118 +207,84 @@ def _process_activation(self, activation_msg: ActivationMessage): t_w_ms, ) - # Opportunistically schedule prefetch for the next window to overlap with compute - try: - next_win_pre = self._next_local_layers( - (window_layers[-1] if window_layers else (activation_msg.layer_id)), - self.window_size, - ) - for nl in next_win_pre: - self._prefetch_to_ram(nl) - self._enqueue_weight_prefetch(nl) - except Exception: - pass - # Execute the window with self.tracer.frame("compute.thread", "execute") as f: f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) - self._beyond_cursor = window_layers[-1] if window_layers else (activation_msg.layer_id) - try: # Prevent prefetch touching during encode/compute to minimize UMA pressure + if to_bind: # Block prefetch-touch during binding and serialize MLX ops self._compute_busy.set() - except Exception: - pass - t_bind = time.perf_counter() - with self._mlx_lock: - self.model.load_weights(list(to_bind.items()), strict=False) - bind_ms = (time.perf_counter() - t_bind) * 1000.0 + t_bind = time.perf_counter() + with self._mlx_lock: + self.model.load_weights(list(to_bind.items()), strict=False) + bind_ms = (time.perf_counter() - t_bind) * 1000.0 + if self._profile: + logger.info( + "[PROFILE][BIND] node=%s nonce=%s layers=%s tensors=%s bind_ms=%.3f", + self.node_id, + activation_msg.nonce, + window_layers, + len(to_bind), + bind_ms, + ) + t_comp = time.perf_counter() + self._compute_busy.set() + for i, lyr in enumerate(window_layers): + with self._mlx_lock: + x = self.model.apply_single_layer(lyr, x, cache=kv) + try: + if str(x.dtype) != str(self._wire_mx_dtype): + x = x.astype(self._wire_mx_dtype) + except Exception: + pass + + last_layer = ( window_layers[-1] if window_layers else activation_msg.layer_id) + mx.eval(x) + if self._profile: + t_comp_done = time.perf_counter() logger.info( - "[PROFILE][BIND] node=%s nonce=%s layers=%s tensors=%s bind_ms=%.3f", + "[PROFILE][WINDOW] node=%s nonce=%s layers=%s compute_ms=%.3f", self.node_id, activation_msg.nonce, window_layers, - len(to_bind), - bind_ms, + (t_comp_done - t_comp) * 1000.0, ) - t_comp = time.perf_counter() - try: - self._compute_busy.set() - except Exception: - pass - for i, lyr in enumerate(window_layers): - with self._mlx_lock: - x = self.model.apply_single_layer(lyr, x, cache=kv) - try: - if str(x.dtype) != str(self._wire_mx_dtype): - x = x.astype(self._wire_mx_dtype) - except Exception: - pass - - last_layer = ( - window_layers[-1] if window_layers else activation_msg.layer_id - ) - try: - mx.eval(x) - except Exception: - pass - if self._profile: - t_comp_done = time.perf_counter() - logger.info( - "[PROFILE][WINDOW] node=%s nonce=%s layers=%s compute_ms=%.3f", - self.node_id, - activation_msg.nonce, - window_layers, - (t_comp_done - t_comp) * 1000.0, - ) - for lid in window_layers: - #self.weight_cache.decrease_reference(lid) - pass with self.tracer.frame("compute.thread", "execute.evict_and_unload") as f: f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) - try: - # Sliding-fit delta swap: maintain a single resident set by evicting + + for lid in window_layers: + self.weight_cache.decrease_reference(lid) + + try: # Sliding-fit delta swap: maintain a single resident set by evicting # only what's needed to fit the next window into the budget. Prefer # keeping the tail of the previous window so we don't thrash weights # that are likely to be reused. if self._mode == "sliding_fit": if int(self._resident_windows) <= 1: - if did_early_swap: - # Early delta-swap already trimmed resident set for this window + if did_early_swap: # Early delta-swap already trimmed resident set for this window pass - elif not self._recent_windows: - # First window in token: seed resident set + elif not self._recent_windows: # First window in token: seed resident set self._recent_windows.append(list(window_layers)) else: prev = self._recent_windows.pop(0) - self._delta_swap_eviction(window_layers, prev, activation_msg, early=False) + self._delta_swap_eviction( + window_layers, prev, activation_msg, early=False + ) budget = max(1, int(self.window_size or 1)) curr = list(window_layers) prev_only = [x for x in prev if x not in curr] keep_quota = max(0, budget - len(curr)) - keep_tail = prev_only[-keep_quota:] if keep_quota > 0 else [] + keep_tail = ( + prev_only[-keep_quota:] if keep_quota > 0 else [] + ) combined = list(keep_tail) + curr self._recent_windows.append(combined) - else: - prev = self._recent_windows.pop(0) - self._delta_swap_eviction( - window_layers, prev, activation_msg, early=False - ) - budget = max(1, int(self.window_size or 1)) - curr = list(window_layers) - prev_only = [x for x in prev if x not in curr] - keep_quota = max(0, budget - len(curr)) - keep_tail = ( - prev_only[-keep_quota:] if keep_quota > 0 else [] - ) - combined = list(keep_tail) + curr - self._recent_windows.append(combined) - else: - # Original eviction policy for other modes + + else: # Original eviction policy for other modes self._recent_windows.append(list(window_layers)) if int(self._resident_windows) <= 1: old = self._recent_windows.pop(0) @@ -330,13 +292,11 @@ def _process_activation(self, activation_msg: ActivationMessage): evicted_cnt = self.weight_cache.evict_layers(old) except Exception: evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass + + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) else: if not self._defer_unload: while len(self._recent_windows) > max( @@ -349,76 +309,85 @@ def _process_activation(self, activation_msg: ActivationMessage): ) except Exception: evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) if self._profile: - try: - logger.info( - "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s", - self.node_id, - activation_msg.nonce, - old, - evicted_cnt, - self._resident_windows, - ) - except Exception: - pass + logger.info( + "[PROFILE][UNLOAD-WINDOW] node=%s nonce=%s old_layers=%s evicted=%s keep_windows=%s", + self.node_id, + activation_msg.nonce, + old, + evicted_cnt, + self._resident_windows, + ) except Exception: pass - # If next layer is still local, continue without staging/tx - nxt = last_layer + 1 - if nxt in self._assigned_set: - current_layer = nxt - continue + computation_time = time.perf_counter() - start_time + self._prof.info( + "[PROFILE][COMPUTE] node=%s nonce=%s window_layers=%s total_ms=%.3f", + self.node_id, + activation_msg.nonce, + window_layers, + computation_time * 1000.0, + ) + self._prof.info( + "Completed layers up to %s in %.3fs, nonce: %s, result: %s %s", + last_layer, + computation_time, + activation_msg.nonce, + x.shape, + x.dtype, + ) + + # If next layer is still local, continue without staging/tx + nxt = last_layer + 1 + if nxt in self._assigned_set: + current_layer = nxt + continue # Boundary reached — directly pass tensor to TX to avoid extra copy/sync - with self.tracer.frame("compute.thread", "execute.enqueue_prefetch") as f: + with self.tracer.frame("compute.thread", "staging") as f: f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) - x_cast = x if x.dtype == self._wire_mx_dtype else x.astype(self._wire_mx_dtype) - try: - self._compute_busy.clear() - except Exception: - pass - try: - for lid in list(self._prefetch_pending): - self._prefetch_pending.discard(lid) - self._enqueue_weight_prefetch(lid) - except Exception: - pass - # Create and enqueue output message: either forward activations or finalize on end role - with self.tracer.frame("network.tx", "send") as f: - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) + t_stage = time.perf_counter() + x_cast = ( x if x.dtype == self._wire_mx_dtype else x.astype(self._wire_mx_dtype)) + self._compute_busy.clear() - nxt = last_layer + 1 - if nxt >= self.model_metadata.num_layers: # End of model - with self.tracer.frame("compute.thread", "sampling") as f: - try: - with self._mlx_lock: - y = self.model.normalize(x_cast) - y = self.model.lm_project(y) - #self.tracer.mark("lm_head", {"nonce": actication_msg.nonce}) # NOTE: canonical stats end - - # Greedy sample last position - if y.ndim == 3: - logits_2d = y[:, -1, :] - elif y.ndim == 2: - logits_2d = y[-1:, :] - else: - logits_2d = y.reshape(1, -1) - tok = mx.argmax(logits_2d, axis=-1) - token_id = int(tok.item()) - except Exception as e: - logger.error("End-shard sampling failed: %s", e) - return + if self._profile: + logger.info( + "[PROFILE][STAGE-DIRECT] node=%s nonce=%s layer_tail=%s stage_ms=%.3f shape=%s dtype=%s", + self.node_id, + activation_msg.nonce, + last_layer, + (time.perf_counter() - t_stage) * 1000.0, + tuple(x_cast.shape), + str(self._wire_mx_dtype), + ) + + if nxt >= self.model_metadata.num_layers: # End of model + with self.tracer.frame("compute.thread", "sampling") as fr: + try: + with self._mlx_lock: + y = self.model.normalize(x_cast) + y = self.model.lm_project(y) + self.tracer.mark("lm_head", {"nonce": activation_msg.nonce}) # NOTE: canonical stats end + + # Greedy sample last position + if y.ndim == 3: + logits_2d = y[:, -1, :] + elif y.ndim == 2: + logits_2d = y[-1:, :] + else: + logits_2d = y.reshape(1, -1) + tok = mx.argmax(logits_2d, axis=-1) + token_id = int(tok.item()) + except Exception as e: + logger.error("End-shard sampling failed: %s", e) + return output_msg = ActivationMessage( nonce=activation_msg.nonce, @@ -433,123 +402,91 @@ def _process_activation(self, activation_msg: ActivationMessage): is_final=True, token_id=token_id, ) - else: - output_msg = ActivationMessage( - nonce=activation_msg.nonce, - layer_id=last_layer, - pool_id=-1, - shape=cast(tuple[int, ...], x.shape), - batch_size=activation_msg.batch_size, - timestamp=utc_epoch_now(), - node_origin=f"node_{self.node_id}", - dtype=str(self._wire_mx_dtype), - callback_url=activation_msg.callback_url, - tensor=x_cast, - ) - try: - with self._mlx_lock: - y = self.model.normalize(x_cast) - y = self.model.lm_project(y) - # Greedy sample last position - if y.ndim == 3: - logits_2d = y[:, -1, :] - elif y.ndim == 2: - logits_2d = y[-1:, :] else: - logits_2d = y.reshape(1, -1) - tok = mx.argmax(logits_2d, axis=-1) - token_id = int(tok.item()) - except Exception as e: - logger.error("End-shard sampling failed: %s", e) - return - output_msg = ActivationMessage( - nonce=activation_msg.nonce, - layer_id=last_layer, - pool_id=-1, - shape=cast(tuple[int, ...], x.shape), - batch_size=activation_msg.batch_size, - timestamp=utc_epoch_now(), - node_origin=f"node_{self.node_id}", - dtype=str(self._wire_mx_dtype), - callback_url=activation_msg.callback_url, - is_final=True, - token_id=token_id, - ) - else: - output_msg = ActivationMessage( - nonce=activation_msg.nonce, - layer_id=last_layer, - pool_id=-1, - shape=cast(tuple[int, ...], x.shape), - batch_size=activation_msg.batch_size, - timestamp=utc_epoch_now(), - node_origin=f"node_{self.node_id}", - dtype=str(self._wire_mx_dtype), - callback_url=activation_msg.callback_url, - tensor=x_cast, - ) - try: - output_msg.tx_enq_perf_t = time.perf_counter() - except Exception: - output_msg.tx_enq_perf_t = 0.0 - # Enqueue to appropriate asyncio TX queue from compute thread - with self.tracer.frame("network.tx", "enque") as f: - output_msg.tx_enq_t = time.perf_counter() - try: - if self._loop is not None: - target_q = ( - self.activation_token_queue - if output_msg.is_final - else self.activation_computed_queue + output_msg = ActivationMessage( + nonce=activation_msg.nonce, + layer_id=last_layer, + pool_id=-1, + shape=cast(tuple[int, ...], x.shape), + batch_size=activation_msg.batch_size, + timestamp=utc_epoch_now(), + node_origin=f"node_{self.node_id}", + dtype=str(self._wire_mx_dtype), + callback_url=activation_msg.callback_url, + tensor=x_cast, ) - # Clean up input resources - self.input_pool.release(activation_msg.pool_id) + # Clean up input resources + self.input_pool.release(activation_msg.pool_id) - # Optional unload/evict after stage - with self.tracer.frame("compute.thread", "cleanup"): - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) - if self._mode != "sliding_fit": - if self._defer_unload: - # Clean up input resources - with self.tracer.frame("compute.thread", "cleanup") as f: - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) - try: - while len(self._recent_windows) > max( - 1, int(self._resident_windows) - ): - old = self._recent_windows.pop(0) - try: - evicted_cnt = self.weight_cache.evict_layers(old) - except Exception: - evicted_cnt = 0 - try: - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(old) # type: ignore[attr-defined] - for lid in old: - self._bound_versions.pop(lid, None) - except Exception: - pass - except Exception: - pass + try: + output_msg.tx_enq_perf_t = time.perf_counter() + except Exception: + output_msg.tx_enq_perf_t = 0.0 - if self._resident_windows <= 1: + # Enqueue to appropriate asyncio TX queue from compute thread + with self.tracer.frame("network.tx", "enque") as fr: + output_msg.tx_enq_t = time.perf_counter() try: - evicted = self.weight_cache.evict_layers(window_layers) - if hasattr(self.model, "unload_layers"): - self.model.unload_layers(window_layers) # type: ignore[attr-defined] - if self._profile: - logger.info( - "[PROFILE][EVICT] node=%s nonce=%s layers=%s evicted=%s", - self.node_id, - activation_msg.nonce, - window_layers, - evicted, + if self._loop is not None: + target_q = ( + self.activation_token_queue + if output_msg.is_final + else self.activation_computed_queue ) - except Exception: - pass + fut = asyncio.run_coroutine_threadsafe( + target_q.put(output_msg), self._loop + ) + fut.result() + else: + raise RuntimeError("Event loop not available for TX queue") + except Exception as e: + logger.error( + "Failed to queue computed activation for sending: %s", e + ) + + # Clean up input resources + with self.tracer.frame("compute.thread", "cleanup") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) + try: + while len(self._recent_windows) > max( + 1, int(self._resident_windows) + ): + old = self._recent_windows.pop(0) + try: + while len(self._recent_windows) > max(1, int(getattr(self, "_resident_windows", 2))): + old = self._recent_windows.pop(0) + try: + evicted_cnt = self.weight_cache.evict_layers(old) + except Exception: + evicted_cnt = 0 + try: + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(old) # type: ignore[attr-defined] + for lid in old: + self._bound_versions.pop(lid, None) + except Exception: + pass + except Exception: + pass + + if self._resident_windows <= 1: + try: + evicted = self.weight_cache.evict_layers(window_layers) + if hasattr(self.model, "unload_layers"): + self.model.unload_layers(window_layers) # type: ignore[attr-defined] + if self._profile: + logger.info( + "[PROFILE][EVICT] node=%s nonce=%s layers=%s evicted=%s", + self.node_id, + activation_msg.nonce, + window_layers, + evicted, + ) + except Exception: + pass return except Exception as e: logger.exception("Error processing activation: %s", e) + From a1765622f066dc20d605d6380fcf3af4718e27eb Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 3 Nov 2025 02:14:56 -0800 Subject: [PATCH 166/172] small rebase fixes --- src/dnet/ring/shard/comms.py | 232 +++++++++++++++++--------------- src/dnet/ring/shard/compute.py | 1 + src/dnet/ring/shard/node.py | 82 +++++------ src/dnet/ring/shard/servicer.py | 1 + 4 files changed, 159 insertions(+), 157 deletions(-) diff --git a/src/dnet/ring/shard/comms.py b/src/dnet/ring/shard/comms.py index 1725fd4f..8d55fe7e 100644 --- a/src/dnet/ring/shard/comms.py +++ b/src/dnet/ring/shard/comms.py @@ -251,8 +251,10 @@ async def _send_activation(self, activation_msg: ActivationMessage): ) except Exception: pass + cb = activation_msg.callback_url or "" parsed = urlparse(cb) if cb else None + t_rpc = time.perf_counter() if parsed and parsed.scheme == "grpc": addr = parsed.netloc if not addr: @@ -282,15 +284,25 @@ async def _send_activation(self, activation_msg: ActivationMessage): tx_enq_prev_t=time.perf_counter(), ) resp = await self.api_stub.SendToken(req) # type: ignore + rpc_ms = (time.perf_counter() - t_rpc) * 1000.0 if not resp.success: logger.error( "API SendToken failed for %s: %s", activation_msg.nonce, resp.message, ) + if self._profile: + logger.info( + "[PROFILE][TX-TOKEN][gRPC] node=%s nonce=%s token=%s rpc_ms=%.2f", + self.node_id, + activation_msg.nonce, + int(getattr(activation_msg, "token_id", -1)), + rpc_ms, + ) except Exception as e: logger.exception("Error sending token via gRPC: %s", e) else: + logger.error(activation_msg) logger.error( "No valid gRPC callback for token; cannot deliver nonce=%s", activation_msg.nonce, @@ -301,51 +313,49 @@ async def _send_activation(self, activation_msg: ActivationMessage): # FIXME: shaped var is a bit weird (is it np_array or mlx_array), @andthattoo shall check shaped = activation_msg.tensor - with self.tracer.frame("gprc.send_activations.default", "get_buffer") as fr: - fr.set("req_id", activation_msg.nonce) - fr.set("node", self._instance_name) + with self.tracer.frame("network.send_activations.default", "get_buffer") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) if shaped is None: output_buffer = self.output_pool.get_buffer(activation_msg.pool_id) if output_buffer is None: - logger.error( - "Failed to get output buffer %s", activation_msg.pool_id - ) + logger.error("Failed to get output buffer %s", activation_msg.pool_id) return - if self._profile: - logger.info( - "[PROFILE][SER-START] node=%s nonce=%s", - self.node_id, - activation_msg.nonce, - ) - t_ser = time.perf_counter() - t_cast = t_ser - _len_bytes = int(getattr(shaped, "nbytes", 0)) - _do_compress = bool( - self._compress and _len_bytes >= self._compress_min_bytes - ) - if _do_compress: - # Skip compression for decode. - _do_compress = False - try: - wire_np_dtype = dtype_map[self._wire_dtype_str] - except Exception: - wire_np_dtype = np.float16 # reasonable default fallback + if self._profile: + logger.info( + "[PROFILE][SER-START] node=%s nonce=%s", + self.node_id, + activation_msg.nonce, + ) - if isinstance(shaped, np.ndarray): - logger.warning("Activation tensor is a numpy array!!!") - if shaped.dtype != wire_np_dtype: - # FIXME: numpy vs mx array here - shaped = shaped.astype(wire_np_dtype, copy=False) + with self.tracer.frame("network.tx", "cast") as f: + t_ser = time.perf_counter() + t_cast = t_ser + _len_bytes = int(getattr(shaped, "nbytes", 0)) + _do_compress = bool( + self._compress and _len_bytes >= self._compress_min_bytes + ) + if _do_compress: + # Skip compression for decode. + _do_compress = False + try: + wire_np_dtype = dtype_map[self._wire_dtype_str] + except Exception: + wire_np_dtype = np.float16 # reasonable default fallback - else: # MLX array -> cast to desired wire dtype - if str(shaped.dtype) != self._wire_dtype_str: - shaped = shaped.astype(self._wire_mx_dtype) + if isinstance(shaped, np.ndarray): + logger.warning("Activation tensor is a numpy array!!!") + if shaped.dtype != wire_np_dtype: + # FIXME: numpy vs mx array here + shaped = shaped.astype(wire_np_dtype, copy=False) - activation_msg.dtype = self._wire_dtype_str - t_cast = time.perf_counter() + else: # MLX array -> cast to desired wire dtype + if str(shaped.dtype) != self._wire_dtype_str: + shaped = shaped.astype(self._wire_mx_dtype) - with self.tracer.frame("grpc", "send_activations.cast_to_dtype") as f: + activation_msg.dtype = self._wire_dtype_str + t_cast = time.perf_counter() if isinstance(shaped, np.ndarray): # Cast to target dtype if shaped.dtype != wire_np_dtype: @@ -358,15 +368,15 @@ async def _send_activation(self, activation_msg: ActivationMessage): f.event("mxarray.cast") data = tensor_to_bytes(shaped) - activation_msg.dtype = self._wire_dtype_str + activation_msg.dtype = self._wire_dtype_str - nxt = activation_msg.layer_id + 1 - if (nxt < self.model_metadata.num_layers) and (nxt not in self._assigned_set): - if self.next_node_stub: + with self.tracer.frame("memory", "prepare.window") as f: + f.set("req_id", activation_msg.nonce) + f.set("node", self._instance_name) - with self.tracer.frame("network", "send_activation.next") as f: - f.set("req_id", activation_msg.nonce) - f.set("node", self._instance_name) + nxt = activation_msg.layer_id + 1 + if (nxt < self.model_metadata.num_layers) and (nxt not in self._assigned_set): + if self.next_node_stub: request = activation_msg.to_proto(data) request.timestamp = utc_epoch_now() if self._mode == "offload" and self.window_size > 0: @@ -476,82 +486,84 @@ async def _send_activation(self, activation_msg: ActivationMessage): ring_timeout, ring_retries, ) - t0 = time.perf_counter() - last_exc: Optional[Exception] = None - for attempt in range(1, ring_retries + 1): - try: - # FIXME: use response here? - _ = await self.next_node_stub.SendActivation( - request, timeout=ring_timeout - ) # type: ignore - break - except grpc.aio.AioRpcError as e: # type: ignore - last_exc = e - code = e.code() - if code in { - grpc.StatusCode.UNAVAILABLE, - grpc.StatusCode.CANCELLED, - grpc.StatusCode.DEADLINE_EXCEEDED, - }: - logger.warning( - "SendActivation attempt %s/%s failed (%s); reconnecting...", - attempt, - ring_retries, - code.name, - ) - await self._reconnect_next_node() - await asyncio.sleep(min(0.25 * attempt, 1.0)) - continue - raise - else: - raise last_exc # type: ignore - rpc_ms = (time.perf_counter() - t0) * 1000.0 - logger.info( - "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", - self.node_id, - activation_msg.nonce, - activation_msg.layer_id + 1, - (len(data) / 1024), - ser_ms, - rpc_ms, - cast_ms, - ) - else: - logger.error("Cannot forward activation - no next node configured; end shard should sample inline.") + t0 = time.perf_counter() + last_exc: Optional[Exception] = None + for attempt in range(1, ring_retries + 1): + try: + # FIXME: use response here? + _ = await self.next_node_stub.SendActivation( + request, timeout=ring_timeout + ) # type: ignore + break + except grpc.aio.AioRpcError as e: # type: ignore + last_exc = e + code = e.code() + if code in { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.CANCELLED, + grpc.StatusCode.DEADLINE_EXCEEDED, + }: + logger.warning( + "SendActivation attempt %s/%s failed (%s); reconnecting...", + attempt, + ring_retries, + code.name, + ) + await self._reconnect_next_node() + await asyncio.sleep(min(0.25 * attempt, 1.0)) + continue + raise + else: + raise last_exc # type: ignore + rpc_ms = (time.perf_counter() - t0) * 1000.0 + logger.info( + "[PROFILE][TX] node=%s nonce=%s next_layer=%s payload_kb=%.1f serialize_ms=%.3f rpc_ms=%.2f cast_ms=%.3f", + self.node_id, + activation_msg.nonce, + activation_msg.layer_id + 1, + (len(data) / 1024), + ser_ms, + rpc_ms, + cast_ms, + ) + else: + logger.error("Cannot forward activation - no next node configured; end shard should sample inline.") - # Final layer not annotated with 'is_final' - else: - logger.error( - "Final activation reached send path unexpectedly; sampling should occur on end shard." - ) - # Clear scheduling at request end - # Sequential offload: prefetch state is unused + # Final layer not annotated with 'is_final' + else: + logger.error( + "Final activation reached send path unexpectedly; sampling should occur on end shard." + ) + # Clear scheduling at request end + # Sequential offload: prefetch state is unused - # Optional: explicitly end the per-nonce stream on request completion - # Enable by setting RING_EXPLICIT_EOR=1 when you emit a true end-of-request signal. - try: - if self._explicit_eor: - if ( - hasattr(self, "_streams") - and activation_msg.nonce in self._streams - ): - await self._end_stream(activation_msg.nonce, eor=True) - except Exception: - pass + # Optional: explicitly end the per-nonce stream on request completion + # Enable by setting RING_EXPLICIT_EOR=1 when you emit a true end-of-request signal. + try: + if self._explicit_eor: + if ( + hasattr(self, "_streams") + and activation_msg.nonce in self._streams + ): + await self._end_stream(activation_msg.nonce, eor=True) + except Exception: + pass - # Release resources at end of send - try: - activation_msg.tensor = None - except Exception: - pass - if used_pool: + # Release resources at end of send try: - self.output_pool.release(activation_msg.pool_id) + activation_msg.tensor = None except Exception: pass + if used_pool: + try: + self.output_pool.release(activation_msg.pool_id) + except Exception: + pass except Exception as e: logger.exception("Error sending activation: %s", e) + + async def _connect_next_node(self) -> bool: """Connect to next node in ring. diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 6b64a48a..489e6b24 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -74,6 +74,7 @@ def _process_activation(self, activation_msg: ActivationMessage): logger.error("Node %s: Cannot process activation - model not loaded", self.node_id) return + logger.error(f"PROCESS_ACTIVATION: {activation_msg.callback_url}") try: # per-nonce kvcache for concurrent requests with self.tracer.frame("compute.thread", "kvcache.init") as f: diff --git a/src/dnet/ring/shard/node.py b/src/dnet/ring/shard/node.py index 4f899f9d..3f551a62 100644 --- a/src/dnet/ring/shard/node.py +++ b/src/dnet/ring/shard/node.py @@ -619,6 +619,7 @@ async def reset_cache(self) -> None: # FIXME This seems to still be dead code async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): """Receive activation from previous node and queue for local compute or forward.""" + logger.debug("RECEIVE ACTIVATION") if self.input_pool is None: logger.error("Node %s: Cannot receive activation - input pool not initialized", self.node_id) return @@ -793,6 +794,7 @@ async def receive_activation(self, request: dnet_ring_pb2.ActivationRequest): # Queue for processing — non-blocking back-off loop (cancellable) while self.running: try: + logger.error(f"NETWORK RX: {activation_msg.callback_url}") self.activation_recv_queue.put_nowait(activation_msg) activatino_msg.ex_enq_t = time.perf_counter() logger.debug("Queued activation for processing: nonce %s", activation_msg.nonce) @@ -819,6 +821,7 @@ async def admit_frame(self, request: dnet_ring_pb2.ActivationRequest) -> None: request.rx_enq_t = rx_t request.rx_inflight_t = 0.0 if request.tx_enq_prev_t == 0.0 else rx_t - request.tx_enq_prev_t + logger.error(f"ADMIT_FRAME: {request.callback_url}") self.ingress_q.put_nowait(request) logger.debug(f"[ENQUE] Enqueued activation request") return @@ -840,6 +843,7 @@ async def _ingress_worker(self): req = await self.ingress_q.get() logger.debug(f"[DEQUE]Dequeued activation for processing") except asyncio.CancelledError: + logger.error("Error while waiting ingress worker.") break # Trace processing of request, in-flight and in-wait times @@ -850,30 +854,27 @@ async def _ingress_worker(self): f.set("req_id", req.nonce) try: - await self._connect_next_node() - activation = req.activation target_layer = activation.layer_id + 1 + try: + payload_bytes = len(activation.data) + except Exception: + payload_bytes = -1 + # Detect new sequence per node: initialize per-nonce KV if req.nonce != self._active_nonce: self._active_nonce = req.nonce try: - payload_bytes = len(activation.data) + self._get_or_make_kv(req.nonce) except Exception: - logger.error(f"Unable to read length of data for {req.nonce}") - payload_bytes = -1 - - fr.set("req_id", req.nonce) - f.set("target", target_layer) - f.set("payload_bytes", payload_bytes) - f.event("received") + pass if target_layer in self._assigned_set: # Heavy prep in executor (alloc/copy/decompress) with self.tracer.frame("network.ingress", "prepare") as fr: - fr.set("node", self._instance_name) - fr.set("nonce", req.nonce) + #fr.set("node", self._instance_name) + #fr.set("nonce", req.nonce) loop = asyncio.get_running_loop() try: activation_msg = await loop.run_in_executor( @@ -886,46 +887,32 @@ async def _ingress_worker(self): continue if activation_msg is None: continue - if self._profile: - activation_msg.recv_perf_t = t_recv + #if self._profile: + # activation_msg.recv_perf_t = t_recv # Enqueue for compute (cancellable back-off) - with self.tracer.frame("network.ingress", "queue") as fr: + with self.tracer.frame("network.rx", "enque") as fr: + fr.set("req_id", req.nonce) + fr.set("node", self._instance_name) while self.running: try: - activation_msg = await loop.run_in_executor( - self.executor, - self._prepare_activation_message_blocking, - req, + logger.error(f"NETWORK RX: {activation_msg.callback_url}") + self.activation_recv_queue.put_nowait(activation_msg) + logger.debug( + "Queued activation for processing: nonce %s", + activation_msg.nonce, ) - except Exception as e: - logger.error("Activation prepare failed for nonce %s: %s", req.nonce, e) - continue - if activation_msg is None: - continue - - # Enqueue for compute - with self.tracer.frame("network.rx", "enque") as fr: - fr.set("req_id", req.nonce) - fr.set("node", self._instance_name) - while self.running: - try: - self.activation_recv_queue.put_nowait(activation_msg) - logger.debug( - "Queued activation for processing: nonce %s", - activation_msg.nonce, - ) - break - except Full: - await asyncio.sleep(0) - else: - logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce,) - try: - if self.input_pool: - # FIXME: !!! - self.input_pool.release(activation_msg.pool_id) - except Exception: - pass + break + except Full: + await asyncio.sleep(0) + else: + logger.error("Failed to queue activation %s (node stopping)", activation_msg.nonce,) + try: + if self.input_pool: + # FIXME: !!! + self.input_pool.release(activation_msg.pool_id) + except Exception: + logger.error("Unable to release from input pool") else: # Forward to next node (not our layer) logger.debug( @@ -1497,6 +1484,7 @@ async def profile(req: ShardProfileRequest) -> ShardProfileResponse: device_profile = await self._profile_device( req.repo_id, req.max_batch_exp ) + logger.debug(device_profile) return ShardProfileResponse(profile=device_profile) except Exception as e: diff --git a/src/dnet/ring/shard/servicer.py b/src/dnet/ring/shard/servicer.py index 93e8627e..006f38a8 100644 --- a/src/dnet/ring/shard/servicer.py +++ b/src/dnet/ring/shard/servicer.py @@ -35,6 +35,7 @@ async def SendActivation( request.activation.layer_id, ) + logger.error(f"SERVICER: {request.callback_url}") await self.node.admit_frame(request) return ActivationResponse( From c6b1f5875ef804b4c6f10bb75732b0262f4b5191 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 3 Nov 2025 02:15:22 -0800 Subject: [PATCH 167/172] change order of elements so callback_url is position 4 again --- src/dnet/protos/dnet_ring.proto | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dnet/protos/dnet_ring.proto b/src/dnet/protos/dnet_ring.proto index 8009601f..d1b3b33a 100644 --- a/src/dnet/protos/dnet_ring.proto +++ b/src/dnet/protos/dnet_ring.proto @@ -32,11 +32,11 @@ message ActivationRequest { string nonce = 1; Activation activation = 2; int64 timestamp = 3; - float rx_enq_t = 4; - float tx_enq_prev_t = 5; - float rx_inflight_t = 6; - string node_origin = 7; - string callback_url = 8; + string node_origin = 4; + string callback_url = 5; + float rx_enq_t = 6; + float tx_enq_prev_t = 7; + float rx_inflight_t = 8; } // Response message for activation sending From 41b932c4893ca8725db93302ffe0e7440ee53710 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 10 Nov 2025 11:27:57 -0800 Subject: [PATCH 168/172] fix missing if conditions in compute thread --- src/dnet/ring/api/node.py | 2 +- src/dnet/ring/shard/compute.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/dnet/ring/api/node.py b/src/dnet/ring/api/node.py index 6794492a..2c08c627 100644 --- a/src/dnet/ring/api/node.py +++ b/src/dnet/ring/api/node.py @@ -6,7 +6,7 @@ import json from dataclasses import asdict from io import StringIO -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Callable import httpx import mlx.core as mx diff --git a/src/dnet/ring/shard/compute.py b/src/dnet/ring/shard/compute.py index 489e6b24..c92bee94 100644 --- a/src/dnet/ring/shard/compute.py +++ b/src/dnet/ring/shard/compute.py @@ -450,13 +450,10 @@ def _process_activation(self, activation_msg: ActivationMessage): with self.tracer.frame("compute.thread", "cleanup") as f: f.set("req_id", activation_msg.nonce) f.set("node", self._instance_name) - try: - while len(self._recent_windows) > max( - 1, int(self._resident_windows) - ): - old = self._recent_windows.pop(0) - try: - while len(self._recent_windows) > max(1, int(getattr(self, "_resident_windows", 2))): + if self._mode != "sliding_fit": + if self._defer_unload: + try: + while len(self._recent_windows) > max(1, int(self._resident_windows)): old = self._recent_windows.pop(0) try: evicted_cnt = self.weight_cache.evict_layers(old) From de8e5eee115f66fffa4ff87378dc1e709086c1a8 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 10 Nov 2025 11:28:13 -0800 Subject: [PATCH 169/172] remove prepare topology response object --- src/repl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/repl.py b/src/repl.py index f62d68e5..6e6d8654 100644 --- a/src/repl.py +++ b/src/repl.py @@ -38,7 +38,6 @@ from dnet.ring.api.models import ( PrepareTopologyManualRequest, PrepareTopologyRequest, - PrepareTopologyResponse, APILoadModelRequest, APILoadModelResponse, ) From c7a13563322a14d0fd294d26ac79afbef1f4d867 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 10 Nov 2025 20:56:56 -0800 Subject: [PATCH 170/172] refactor the model management system --- src/repl.py | 257 +++++++++++++++++++++++++++------------------------- 1 file changed, 135 insertions(+), 122 deletions(-) diff --git a/src/repl.py b/src/repl.py index 6e6d8654..44217673 100644 --- a/src/repl.py +++ b/src/repl.py @@ -52,6 +52,7 @@ hf_errors = import_module("huggingface_hub.utils") GatedRepoError = getattr(hf_errors, "GatedRepoError", Exception) HfHubHTTPError = getattr(hf_errors, "HfHubHTTPError", Exception) +LocalEntryNotFoundError = getattr(hf_errors, "LocalEntryNotFoundError", Exception) def dprint(msg): @@ -60,7 +61,7 @@ def dprint(msg): @dataclass class REPLState: - model: str = "NULL" + model: str = None model_info: ModelMetadata = None, num_local_nodes: int = 1 running_port = 50501 @@ -75,7 +76,7 @@ class REPL(cmd.Cmd): PS1 = "dnet > " WELCOME = "\nDNET Distributed Inference Engine, v0.1\nExperimental software. Type 'help' for usage hints.\n\n" - def __init__(self, model="NULL", nodes=1): + def __init__(self, model=None, nodes=1): assert nodes >= 1 and nodes < 10, "Invalid number of local nodes. Must be 0 < num < 10." super().__init__() @@ -119,109 +120,69 @@ def loop(self): # Main tty loop sys.stdout.write(self.WELCOME) while True: dprint(self.PS1) - cmd = sys.stdin.readline().strip() - - if cmd == "": - #self.print_state() - continue - elif cmd in [".exit", "exit", "quit"]: - self.handle_terminate_signal() - elif cmd in [".help", "help", "h"]: - self.print_help() - - elif cmd.startswith(("api", ".api")): - self.do_api(cmd.split(" ")) - continue - elif cmd.startswith("search"): - self.do_search(cmd.split(" ")) - continue - elif cmd.startswith("nodes"): - self.print_mdns_nodes() - continue - elif cmd.startswith("load"): - model = "mlx-community/llama-3.3-70b-instruct-4bit" - self.load_model(model) - continue - elif cmd.startswith(("trace", ".trace")): - self.do_trace(cmd.split(" ")) - continue - elif cmd.startswith(("perf", ".perf")): - self.do_perf(cmd.split(" ")) - continue - elif cmd.startswith(("topo", ".topo", "t ")): - self.do_topo(cmd.split(" ")) - continue - elif cmd.startswith((".model", "model", "m ")): - cmd.split(" ") - path = self._handle_model_pull(cmd[1]) - if path: - self.state.model = path - + cmd = sys.stdin.readline().strip().split(" ") + + match cmd[0]: + case "": continue + case s if s in ["exit", "quit", "q"]: self.handle_terminate_signal() + case s if s in ["help", "h"]: self.print_help() + case s if s in ["api", "server"]: self.do_api(cmd) + case s if s in ["search", "s"]: self.do_search(cmd) + case s if s in ["nodes", "n"]: self.print_mdns_nodes() + case s if s in ["load", "l"]: self.load_model(self.state.model) + case s if s in ["trace", ".trace", "t"]: self.do_trace(cmd) + case s if s in ["perf", ".perf", "p"]: self.do_perf(cmd) + case s if s in ["topo", "topology", "t"]: self.do_topo(cmd) + case s if s in ["model", "m"]: self.do_model(cmd) + def do_api(self, cmd: List[str]) -> None: if len(cmd) < 2: dprint("Invalid API command. Type 'help' for a list of valid commands.\n") return - if cmd[1] in ["start", "run"]: - http_port, grpc_port = None, None - try: - http_port = cmd[2]; - grpc_port = cmd[3] - except: - pass - self.start_api( - http_port or self.state.api_http_port, - grpc_port or self.state.api_grpc_port - ) - self.api_call("set_trace_ingest_callback", self.__trace_cb, timeout=2.0) - - elif cmd[1] == "stop": - self.stop_api() - elif cmd[1] == "status": - dprint("Running\n" if self._api_running else "Stopped.\n") - elif cmd[1] == "log": - dprint("Log print is not yet supported.\n") - else: - dprint("Invalid API command. Type 'help' for a list of valid commands.\n") + match cmd[1]: # TODO Maybe allow kwargs here? > api start grpc_port=8080 + case s if s in ["start", "run"]: + http_port, grpc_port = None, None + if len(cmd) > 2: + try: + http_port = cmd[2]; + grpc_port = cmd[3] + except: + pass + self.start_api( http_port or self.state.api_http_port, grpc_port or self.state.api_grpc_port) + self.api_call("set_trace_ingest_callback", self.__trace_cb, timeout=2.0) + + case "stop": self.stop_api() + case "status": dprint("Running\n" if self._api_running else "Stopped.\n") + case "log": dprint("Log print is not yet supported.\n") + case _: dprint("Invalid API command. Type 'help' for a list of valid commands.\n") return def do_search(self, cmd: List[str]) -> None: if len(cmd) != 2: dprint("mDNS search is " + ("ON\n\n" if self._api_searching else "OFF\n\n")) return - if cmd[1] == "on": - if self._api_searching: - return - if not self._api_ready and self._api_running: - dprint("Starting API Server thread.\n") - self.start_api() - self.api_call("_start_discovery", timeout=10) - self._api_searching.set() - dprint("Starting mDNS search for worker nodes.\n") - elif cmd[1] == "off": - dprint("Stop discovery not yet implemented in the API node.\n") - pass - else: - dprint("Invalid topology command. Start searchign with 'search on'.\n") - return + match cmd[1]: # NOTE on by default + case "on": + if self._api_searching: + return + if not self._api_ready and self._api_running: + dprint("Starting API Server thread.\n") + self.start_api() + self.api_call("_start_discovery", timeout=10) + self._api_searching.set() + dprint("Starting mDNS search for worker nodes.\n") + case "off": dprint("Stop discovery not yet implemented in the API node.\n") + case _: dprint("Invalid topology command. Start searchign with 'search on'.\n") def do_topo(self, cmd: List[str]) -> None: if len(cmd) < 2: dprint("Invalid topology command. Type 'help' for a list of valid commands.\n") return - if cmd[1] == "search": - self.print_mdns_nodes() - pass - elif cmd[1] in ("auto", "build", "b"): - model = "mlx-community/llama-3.3-70b-instruct-4bit" - self.prepare_topo(model) - pass - elif cmd[1] == "setup": - pass - elif cmd[1] == "add": - pass - elif cmd[1] in ["remove", "rm"]: - pass - return + match cmd[1]: + case "search": self.print_mdns_nodes() + case s if s in ["auto", "build", "b"]: self.prepare_topo(self.state.model) + case "add": dprint("Not implemented.\n") + case s if s in ["remove", "rm"]: dprint("Not implemented.\n") # TODO: standardize ANSI escape codes for easy use def print_help(self): @@ -234,15 +195,15 @@ def _print_hf(cmd, desc, examples=[""]): sys.stdout.write("\033[1m\nAvailable commands:\n\033[0m") dprint("\033[1m\n Common:\n\033[0m") - _print_hf("model [REPO]", "Set the target model. [REPO] must be a valid repository", - ["Examples > model meta-llama/Meta-Llama-3-8B"]) + _print_hf("model list ", "List locally available models.") + _print_hf("model [REPO]", "Set the target model. [REPO] must be a valid repository") _print_hf("nodes list ", "List mDNS discovered nodes.") _print_hf("log [LEVEL]", "Set the logging level.") dprint("\033[1m\n Controlling the API Server:\n\033[0m") - _print_hf("api start [http_port=8080] [grpc_port=50500]", "Start the API server in a separate thread. Use provided ports if given.") - _print_hf("api stop ", "Signal clean shutdown of the API server.") + _print_hf("api start [HTTP] [GRPC]", "Start the API server in a separate thread. Use provided ports if given.") + _print_hf("api stop ", "Stop the API server.") _print_hf("api status ", "Prints the status of the API server.") - _print_hf("api log ", "Print latest logs to the current terminal.") + _print_hf("api log ", "Output live logs to current terminal.") dprint("\033[1m\n Topology construction:\n\033[0m") _print_hf("search ", "Returns the current state of mDNS search.") _print_hf("search [on/off] ", "Toggle mDNS search across the local network.") @@ -297,39 +258,98 @@ def prompt_model(self): except Exception as e: dprint(f"Unable to load model {model}. Target needs to be a valid HF repository. Try again:{e}\n") - # Read HF access token - def _resolve_hf_token(self): - dprint("Ener the HuggingFace access token > ") - tok = sys.stdin.readline().strip() - return tok + def do_model(self, cmd): + if len(cmd) < 2: + if self.state.model is None: dprint("No target model.\n") + else: dprint(f"Target model: {self.state.model}.\n") + return + + match cmd[1]: + case "list": # List locally available models + lists = self._list_local_models() + dprint("\nLocally available weights:\n") + for x in lists[0]: + dprint(f" {x}\n") + dprint("\nMetadata only:\n") + for x in lists[1]: + dprint(f" {x}\n") + dprint("\n") + case _: # Treat unknown commands as model repos + self._handle_model_pull(cmd[1]) + + def _are_weights_local(self, repo_id, revision="main"): + try: + import re, os + from pathlib import Path + root = Path.home() / ".cache/huggingface/hub" + if os.getenv("HF_HOME") is not None: root = Path(f"{os.getenv("HF_HOME")}/huggingface/hub") + if os.getenv("HF_HUB_CACHE") is not None: root = Path(f"{os.getenv("HF_HUB_CACHE")}/huggingface/hub") + files = os.scandir(root / f"models--{repo_id}/snapshots") + commit_hash = next(files).name + files = os.scandir(root / f"models--{repo_id}/snapshots/{commit_hash}") + files = [x.name for x in files] + except Exception as e: + dprint(f"Failed to search files in local model cache: {e}\n") + return False + PATTERNS = [ + re.compile(r".*?model.*\.safetensors$"), + re.compile(r"^(model|pytorch_model)([._-]\d+([-_]of[-_]\d+)?)?[._-]?\.safetensors$"), + re.compile(r"pytorch_model[-_]\d{3,}[.]bin$"), + re.compile(r"model[.]safetensors$"), + re.compile(r"pytorch_model[.]bin$"), + re.compile(r"tf_model[.]h5$"), + re.compile(r"flax_model[.]msgpack$"), + ] + for pat in PATTERNS: + if any(pat.search(f) for f in files): + return True + return False + + def _list_local_models(self): + import os + from pathlib import Path + root = Path.home() / ".cache/huggingface/hub" + if os.getenv("HF_HOME") is not None: root = Path(f"{os.getenv("HF_HOME")}/huggingface/hub") + if os.getenv("HF_HUB_CACHE") is not None: root = Path(f"{os.getenv("HF_HUB_CACHE")}/huggingface/hub") + models = [x.name.replace("models--", "") for x in os.scandir(root) if x.name.startswith("models--")] + weights_local = sorted([x for x in models if self._are_weights_local(x)], key=len) + config_local = sorted([x for x in models if x not in weights_local], key=len) + return [weights_local, config_local] # Require a HF access token for restricted repositories # Ask user for HF access token until they have a valid one def _handle_model_pull(self, repo_path): + local = self._are_weights_local(repo_path) try: - path = try_to_load_from_cache(repo_path) - if path is None: - dprint(f"Model {repo_path} not found in local cache\n") - path = get_model_path(repo_path) + if not local: + dprint(f"Weights for {repo_path} not found in local cache. Downloading.\n") + path = snapshot_download(repo_path) + dprint("Download complete\n") + else: + dprint("Target model found in local registry.\n") + self.state.model = repo_path - return path - except hb.errors.HTTPError: - dprint(f"Repository {repo_path} not found in Hugging Face registry.") - return Null + return repo_path + except GatedRepoError as e: - dprint("Restricted model.\n") tok = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") while True: - tok = self._resolve_hf_token() - print(tok) + dprint("\nRestricted model. Ener the HuggingFace access token > ") + tok = sys.stdin.readline().strip() try: ret = snapshot_download(repo_id=repo_path, token=tok) + self.state.model = path return ret except GatedRepoError as e: print(e) continue except Exception as e: raise RuntimeError(f"Unknown error during HF snapshot_download") + + except HfHubHTTPError as e: + dprint(f"Repository {repo_path} not found in Hugging Face registry: {e}") + return None + except Exception as e: raise RuntimeError(f"Unable to pull model {repo_path} locally") @@ -480,13 +500,8 @@ def stop_api(self, timeout: float = 5.0) -> None: self._api_running.clear() self._api_ready.clear() - def api_call( # Call an API function from the REPL thread - self, - method: str, - *args: Any, - timeout: float=30.0, - **kwargs: Any - ) -> Any: + # Call an API function from the REPL thread + def api_call( self, method: str, *args: Any, timeout: float=30.0, **kwargs: Any) -> Any: if not self._api_loop or not self._node: raise RuntimeError("API Thread not set up correctly.") @@ -500,10 +515,8 @@ def api_call( # Call an API function from the REPL thread f = asyncio.run_coroutine_threadsafe(coroutine, self._api_loop) return f.result(timeout) - # method is sync - f = concurrent.futures.Future() + f = concurrent.futures.Future() # method is sync - # TODO: this is a mess lol def runner(): try: ret = target(*args, **kwargs) @@ -736,7 +749,7 @@ def load_model(self, model): # ===== Handle chat def do_chat(self, cmd): - model = "mlx-community/llama-3.3-70b-instruct-4bit" + model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit" if len(cmd) < 2: if not self.state.model or self.state.model == "": self.prompt_model() From 6415a13b3ca1592312ca34a25331fd37e2946b95 Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 10 Nov 2025 22:11:45 -0800 Subject: [PATCH 171/172] refactor help --- src/repl.py | 64 +++++++++++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/src/repl.py b/src/repl.py index 44217673..155832c2 100644 --- a/src/repl.py +++ b/src/repl.py @@ -190,52 +190,58 @@ def _print_hf(cmd, desc, examples=[""]): pcmd = " " + cmd.ljust(30, '.') sys.stdout.write(f"{pcmd} {desc}\n") for e in examples: - pex = e.rjust(len(e)+35)+"\n" if e != "" else "" + pex = e.rjust(len(e)+37)+"\n" if e != "" else "" sys.stdout.write(f"{pex}") sys.stdout.write("\033[1m\nAvailable commands:\n\033[0m") dprint("\033[1m\n Common:\n\033[0m") _print_hf("model list ", "List locally available models.") _print_hf("model [REPO]", "Set the target model. [REPO] must be a valid repository") - _print_hf("nodes list ", "List mDNS discovered nodes.") - _print_hf("log [LEVEL]", "Set the logging level.") + _print_hf("api start [HTTP] [GRPC]", "Start the API server in a separate thread. Use provided ports if given.") + _print_hf("topo auto ", "Automatically optimize topology from available nodes.") + _print_hf("load ", "Load model into memory on all nodes.") + _print_hf("api log [LEVEL]", "Output live logs to current terminal.") dprint("\033[1m\n Controlling the API Server:\n\033[0m") _print_hf("api start [HTTP] [GRPC]", "Start the API server in a separate thread. Use provided ports if given.") _print_hf("api stop ", "Stop the API server.") - _print_hf("api status ", "Prints the status of the API server.") - _print_hf("api log ", "Output live logs to current terminal.") - dprint("\033[1m\n Topology construction:\n\033[0m") - _print_hf("search ", "Returns the current state of mDNS search.") - _print_hf("search [on/off] ", "Toggle mDNS search across the local network.") + _print_hf("api log [LEVEL]", "Output live logs to current terminal.") + dprint("\033[1m\n Topology management:\n\033[0m") + _print_hf("search [on/off] ", "Toggle mDNS search across the local network. Default is ON.") _print_hf("nodes list ", "List all nodes in the current topology (including local ones).") - _print_hf("nodes all ", "List all nodes (including local ones).") - _print_hf("nodes ", "List mDNS discovered nodes.") - _print_hf("topo [AUTO/SETUP]", "Toggle between automatic and manual topology creation.") + _print_hf("nodes all ", "List all mDNS discovered nodes (including local ones).") + _print_hf("topo auto ", "Automatically optimize topology from available nodes.") _print_hf("topo add [NODE]", "Add [NODE] to the topology.") - _print_hf("topo remove [NODE]", "Remove [NODE] from the topology.") + _print_hf("topo [remove|rm] [NODE]", "Remove [NODE] from the topology.") + _print_hf("topo assign [NODE] [LAYERS] [ROUNDS] ", "Assign [LAYERS] to [NODE]. e.g:", + ["> topo assign dnet0 1-10"]) sys.stdout.write("\033[1m\n Scheduling:\n\033[0m") _print_hf("sched auto ", "Automatic search for best schedule given the active topology and the loaded model.") - _print_hf("sched assign [INDEX] [NODE]", "Assign the layer range between [START] and [END] to [NODE].", - ["Example: > sched assign 10 benny_234", - " > sched assign 0-12 benny_234"]) + _print_hf("sched assign [INDEX] [NODE]", "Assign the layer range between [START] and [END] to [NODE]. e.g:", + ["> sched assign 10 benny_234", + "> sched assign 0-12 benny_234"]) sys.stdout.write("\033[1m\n Benchmarking, Tracing and Profiling:\n\033[0m") - _print_hf("trace [ON|OFF][PATH][SYSTEM] ", "Trace [SYSTEM] and output to file at [PATH].") - _print_hf("trace status ", "See status of the trace, eg. number of frames captured") - _print_hf("trace focus [SUBSYSTEM] ", "Focus the trace on [SUBSYSTEM]. Do 'trace focus' for a list of available subsystems.") - _print_hf("trace stream [ON|OFF] ", "Stream the trace spans to current terminal.") - _print_hf("trace set [BUDGET] ", "Set the maximum amount of recoded events.") _print_hf("perf ", "Prints the current state of runtime performance tracking.") - _print_hf("perf stat [REQ_ID | WORKER_ID | MODEL] ", "Prints the runtime statistics of target system.") - _print_hf("bench [REPO]", "Benchmark the system using the model from [REPO]") - _print_hf("bench [KERNEL]", "Behcnmark the system using base kernel [KERNEL]") - _print_hf("bench [NODE]", "Behcnmark the network latency between the current system and [NODE]") - _print_hf("bench [NODE0] [NODE1]", "Behcnmark the network latency between [NODE0] and [NODE11") + _print_hf("perf stat [-n node] [-r req] ", "Prints the runtime statistics of target system.") + _print_hf("perf vm [interval]", "Prints virtual memory information for worker thread at [INTERVAL].") + _print_hf("bench [repo_id]", "Benchmark the system using the model from [REPO]") + _print_hf("bench [kernel]", "Behcnmark the system using base kernel [KERNEL]") + _print_hf("bench [node]", "Behcnmark the network latency between the current system and [NODE]") + _print_hf("bench [node0, node1]", "Behcnmark the network latency between [NODE0] and [NODE11") _print_hf("bench ", "Behcnmark the system using base library kernels") + _print_hf("trace [on|off] ", "Enable system tracing.") + _print_hf("trace [-o path] [-p [probes]]", "Trace [probes] and output to file at [path]. e.g:", + ["> trace on -p memory,network", + "> trace on -o ./trace_output.txt"]) + _print_hf("trace list ", "List available trace probes.") + _print_hf("trace add [probe] ", "Activate [probe].") + _print_hf("trace stream [on|off] ", "Stream the trace events to the current terminal.") + _print_hf("trace [budget|b] [limit]", "Set the maximum budget for recoded events. Default is 1000.") + _print_hf("trace stat ", "See status of the trace, eg. number of frames captured") sys.stdout.write("\033[1m\n System control:\n\033[0m") - _print_hf("limit [RESOURCE] [VALUE]", "Set a higher limit for a system resource.", - ["Examples > limit memory 12000 (MB)", - " > limit CPU_CORE_COUNT 4", - " > limit GPU_SM 128"]) + _print_hf("limit [RESOURCE] [VALUE]", "Set a higher limit for a system resource. e.g:", + ["> limit SYSMEM 12000 (MB)", + "> limit CPU_CORE_COUNT 4", + "> limit GPU_SM 128"]) sys.stdout.write("\n") sys.stdout.flush() From bef2b1232a26b8fd70b18ce4c7d0021c004fecb0 Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 11 Nov 2025 02:32:00 -0800 Subject: [PATCH 172/172] basic chat interface --- src/repl.py | 119 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 92 insertions(+), 27 deletions(-) diff --git a/src/repl.py b/src/repl.py index 155832c2..c033ccb9 100644 --- a/src/repl.py +++ b/src/repl.py @@ -2,6 +2,7 @@ import io import os import sys +import json import logging import cmd import time @@ -40,6 +41,9 @@ PrepareTopologyRequest, APILoadModelRequest, APILoadModelResponse, + ChatParams, + ChatMessage, + ChatRequestModel, ) # Handle restricted repos @@ -59,6 +63,13 @@ def dprint(msg): sys.stdout.write(msg) sys.stdout.flush() +@dataclass +class ChatInterface: + id: str = "" + model: str = "" + params: ChatParams = None + messages: List[ChatMessage] = None + @dataclass class REPLState: model: str = None @@ -134,6 +145,7 @@ def loop(self): # Main tty loop case s if s in ["perf", ".perf", "p"]: self.do_perf(cmd) case s if s in ["topo", "topology", "t"]: self.do_topo(cmd) case s if s in ["model", "m"]: self.do_model(cmd) + case "chat": self.do_chat(cmd) def do_api(self, cmd: List[str]) -> None: if len(cmd) < 2: @@ -155,7 +167,6 @@ def do_api(self, cmd: List[str]) -> None: case "status": dprint("Running\n" if self._api_running else "Stopped.\n") case "log": dprint("Log print is not yet supported.\n") case _: dprint("Invalid API command. Type 'help' for a list of valid commands.\n") - return def do_search(self, cmd: List[str]) -> None: if len(cmd) != 2: @@ -244,12 +255,6 @@ def _print_hf(cmd, desc, examples=[""]): "> limit GPU_SM 128"]) sys.stdout.write("\n") sys.stdout.flush() - - def print_state(self): - dprint("Network state:\n") - dprint(f"{("Model".ljust(20)): >10}: {self.state.model}\n") - dprint(f"{("Local workers".ljust(20)): >10}: {self.state.num_local_nodes}\n") - # ===== Handle Model input and pull from server @@ -275,10 +280,10 @@ def do_model(self, cmd): lists = self._list_local_models() dprint("\nLocally available weights:\n") for x in lists[0]: - dprint(f" {x}\n") + dprint(f" {x.replace("--", "/")}\n") dprint("\nMetadata only:\n") for x in lists[1]: - dprint(f" {x}\n") + dprint(f" {x.replace("--", "/")}\n") dprint("\n") case _: # Treat unknown commands as model repos self._handle_model_pull(cmd[1]) @@ -325,7 +330,7 @@ def _list_local_models(self): # Require a HF access token for restricted repositories # Ask user for HF access token until they have a valid one def _handle_model_pull(self, repo_path): - local = self._are_weights_local(repo_path) + local = self._are_weights_local(repo_path.replace("/", "--")) try: if not local: dprint(f"Weights for {repo_path} not found in local cache. Downloading.\n") @@ -740,7 +745,7 @@ def prepare_topo(self, model): dprint(f"Unable to create topology: {e}\n\n") return False self.state.topo = topo - self.print_topo(topo) + #self.print_topo(topo) return True def load_model(self, model): @@ -753,31 +758,91 @@ def load_model(self, model): return False # ===== Handle chat + + def _p_msg(self, msg: str, role: str): + match role: + case "user": + ps = f"\n@\033[97;1m{role}\033[0m".rjust(10) + sys.stdout.write(f"{ps} > ") + sys.stdout.flush() + return + case "llm" | "think": + ps = f"@\033[97;1m{role}\033[0m".rjust(10) + sys.stdout.write(f"{ps} > ") + sys.stdout.flush() + return + case _: + ps = f"@\033[1m{role}\033[0m".rjust(10) + sys.stdout.write(f"{ps} > {msg}\n") + + def _chat_loop(self, ci: ChatInterface): + while True: + self._p_msg(" ", "user") + prompt = sys.stdin.readline().strip() + + match prompt: # Match meta commands, else prompt + case s if s in [".quit", ".q"]: + break + case _: + msg = ChatMessage(role="user", content=prompt) + ci.messages.append(msg) + + req = ChatRequestModel( + messages=ci.messages, + model=self.state.model, + max_tokens=5000, + ) + # eval async generator in api thread + agen = self.api_call("_stream_chat", req) + msg = [] + #self._p_msg("", "think") + self._p_msg("", "llm") + while True: + fut = asyncio.run_coroutine_threadsafe(agen.__anext__(), self._api_loop) + line = fut.result() + if not line.startswith("data: "): + sys.stdout.write(f"[stream error] Invalid data returned from API thread.\n") + break + payload = line[6:].strip() + if payload == "[DONE]": + dprint("\n") + break + obj = json.loads(payload) + choices = obj.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + text = delta.get("content", "") + match text: + case "": + sys.stdout.write("\033[90;3m") + case "": + sys.stdout.write("\033[0m") + case "\n": + pass + case _: + sys.stdout.write(text) + sys.stdout.flush() + msg.append(text) + + # TODO: Capture usage and metrics def do_chat(self, cmd): - model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit" if len(cmd) < 2: if not self.state.model or self.state.model == "": self.prompt_model() - if not self.state.topology: - if not self._prepare_topo(self.state.model): + sys.stdout.write(""+"-"*80+"\n\n") + if not hasattr(self.state, "topology"): + self._p_msg("Initializing topology", "system") + if not self.prepare_topo(self.state.model): raise RuntimeError("Unable to create topology.") + self._p_msg("Loading weights into memory", "system") if not self.load_model(self.state.model): raise RuntimeError("Unable to load model.") - while True: - prompt = input("\n> ") - prompt = self.format_prompt(prompt) - messages = prompt - req = ChatRequest( - messages=messages, - max_tokens=100, - temperature=0.7, - stream=True, - ) - - self.api_call("_handle_completion", req) - + self._p_msg("New session initialized. Welcome :3", "system") + ci = ChatInterface(messages=[]) + self._chat_loop(ci) + # Start default chat with selected model pass pass