diff --git a/test/python/nixl_ep_perf/README.md b/test/python/nixl_ep_perf/README.md new file mode 100644 index 0000000000..ffa91645d9 --- /dev/null +++ b/test/python/nixl_ep_perf/README.md @@ -0,0 +1,164 @@ +# NIXL EP Performance Tests + +Performance tests for NIXL EP Buffer: +- **Data Plane**: dispatch/combine throughput and latency +- **Control Plane**: init/connect/disconnect/destroy latency + +## Prerequisites + +- etcd running locally (`etcd &` or `source /workspace/nixl/examples/device/ep/scripts/reset_etcd.sh`) +- CUDA device with RDMA support + +## Environment Setup + +```bash +# For RDMA performance (recommended) +export UCX_TLS=rc_mlx5,dc_mlx5,tcp +export UCX_IB_AR_ENABLE=no # Disable Adaptive Routing for consistent performance +``` + +## Usage + +```bash +cd test/python/nixl_ep_perf + +# IPC/NVLink backend (default) +python3 test_data_plane.py --num-processes=8 --mode=e2e + +# RDMA-only (disable NVLink) +UCX_TLS=rc_mlx5,dc_mlx5,tcp UCX_IB_AR_ENABLE=no \ + python3 test_data_plane.py --num-processes=8 --mode=e2e --nvlink-backend none + +# Dispatch only (measures dispatch throughput) +python3 test_data_plane.py --num-processes=8 --mode=dispatch + +# Combine only (one dispatch, many combines) +python3 test_data_plane.py --num-processes=8 --mode=combine +``` + +## Options + +| Flag | Default | Description | +|------|---------|-------------| +| `--num-processes` | 8 | Number of ranks/GPUs | +| `--mode` | e2e | Test mode: dispatch, combine, e2e | +| `--tokens` | 512 | Number of tokens | +| `--hidden` | 4096 | Hidden dimension | +| `--experts-per-rank` | 8 | Experts per rank | +| `--topk` | 2 | TopK value | +| `--nvlink-backend` | ipc | Backend: ipc, nixl, none (RDMA only) | +| `--warmup` | 10 | Warmup iterations | +| `--iters` | 100 | Measurement iterations | +| `--discover-nics` | false | Enable GPU-NIC topology discovery (default: disabled, UCX auto-selects) | +| `--use-etcd` | false | Use etcd for metadata exchange (default: TCPStore) | + +## Example Output + +``` +====================================================================== +NIXL EP Data Plane Performance Test +====================================================================== +Mode: e2e +Ranks: 8, Tokens: 128, Hidden: 7168 +Experts: 36/rank (288 total), TopK: 8 +Backend: none (RDMA forced) +Warmup: 10, Measure: 100 iterations +====================================================================== + +====================================================================== +Data Plane (e2e): 8/8 ranks passed +====================================================================== +Bandwidth (GB/s): avg=42.88, min=42.86, max=42.89 +Latency (μs): avg=519.3, min=519.1, max=519.5 +``` + +## Expected Performance (DFW cluster, RDMA, AR=no) + +| Mode | Bandwidth | Latency | +|------|-----------|---------| +| E2E | ~42.8 GB/s | ~520 μs | +| Dispatch | ~42.1 GB/s | ~180 μs | +| Combine | ~43.3 GB/s | ~340 μs | + +*Config: 128 tokens, 7168 hidden, topk=8, 288 experts (36/rank), 8 GPUs* + +## Control Plane Tests + +Measures latency of control plane operations (init, connect, disconnect, destroy). + +### Single-Node (Default) + +```bash +# Full cycle (init → connect → disconnect → reconnect → destroy) +python3 test_control_plane.py --num-processes=8 + +# Specific expert counts +python3 test_control_plane.py --num-processes=8 --experts-per-rank=8,32 + +# Single operation +python3 test_control_plane.py --num-processes=8 --test=connect + +# Use etcd instead of TCPStore (if needed) +python3 test_control_plane.py --num-processes=8 --use-etcd +``` + +### Multi-Node Setup + +Use environment variables `WORLD_SIZE`, `RANK`, and `MASTER_ADDR` for multi-node testing: + +**Master Node (RANK=0):** +```bash +WORLD_SIZE=2 RANK=0 MASTER_ADDR=node0.example.com \ + python3 test_control_plane.py --num-processes=8 +``` + +**Worker Node (RANK=1):** +```bash +WORLD_SIZE=2 RANK=1 MASTER_ADDR=node0.example.com \ + python3 test_control_plane.py --num-processes=8 +``` + +**Or use CLI flags:** +```bash +# Master +python3 test_control_plane.py --num-processes=8 --world-size=2 --rank=0 --master-addr=node0 + +# Worker +python3 test_control_plane.py --num-processes=8 --world-size=2 --rank=1 --master-addr=node0 +``` + +**Key points:** +- `WORLD_SIZE` = number of nodes (not total ranks) +- `--num-processes` = GPUs per node +- Total ranks = WORLD_SIZE × num-processes +- RANK=0 is master (runs TCPStore/rank server) +- RANK>0 are workers (connect to master) +- TCPStore is used by default (no etcd dependency); use `--use-etcd` to switch to etcd + +### Example Output + +``` +====================================================================== +Control Plane: 8 experts/rank x 8 ranks = 64 total +====================================================================== +Operation Avg (ms) Min (ms) Max (ms) +---------------------------------------------------------------------- +init 150.23 148.15 152.31 +connect 245.67 242.33 248.91 +disconnect 12.45 11.23 13.67 +reconnect 198.34 195.12 201.56 +destroy 85.12 83.45 86.79 +---------------------------------------------------------------------- +TOTAL 691.81 +====================================================================== +``` + +## Files + +| File | Description | +|------|-------------| +| `test_data_plane.py` | Data plane test (dispatch/combine/e2e) | +| `test_control_plane.py` | Control plane test (init/connect/disconnect/destroy) | +| `mp_runner.py` | Multi-process test runner | +| `rank_server.py` | Coordination server for distributed tests | + diff --git a/test/python/nixl_ep_perf/__init__.py b/test/python/nixl_ep_perf/__init__.py new file mode 100644 index 0000000000..158379e9cc --- /dev/null +++ b/test/python/nixl_ep_perf/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""NIXL EP data plane performance tests.""" diff --git a/test/python/nixl_ep_perf/mp_runner.py b/test/python/nixl_ep_perf/mp_runner.py new file mode 100644 index 0000000000..521beac473 --- /dev/null +++ b/test/python/nixl_ep_perf/mp_runner.py @@ -0,0 +1,607 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Multi-process test runner for NIXL EP performance tests. + +Spawns worker processes with proper GPU assignment and UCX configuration. +""" + +import logging +import os +import re +import subprocess +import time +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +import store_group +import torch +import torch.multiprocessing as mp +from rank_server import RankClient, start_server + +logger = logging.getLogger(__name__) + + +@dataclass +class TestResult: + """Result from a single rank's test execution.""" + + rank: int + test_name: str + passed: bool + error: Optional[str] = None + metrics: Optional[Dict[str, Any]] = None + duration_ms: float = 0.0 + + +# Cached topology (discovered once per process) +_GPU_NIC_TOPOLOGY: Optional[Dict[int, str]] = None +_RANK_SERVER_ADDR: Optional[str] = None +_RANK_SERVER_PORT: int = 9998 + + +def discover_gpu_nic_topology() -> Optional[Dict[int, str]]: + """Discover GPU-NIC topology using nvidia-smi topo -m.""" + try: + result = subprocess.run( + ["nvidia-smi", "topo", "-m"], capture_output=True, text=True, timeout=30 + ) + if result.returncode != 0: + return None + + lines = result.stdout.strip().split("\n") + + # Parse NIC legend (e.g., "NIC0: mlx5_0") + nic_legend = {} + for line in lines: + match = re.match(r"\s*(NIC\d+):\s*(\S+)", line) + if match: + nic_legend[match.group(1)] = match.group(2) + + if not nic_legend: + return None + + # Find header line with GPU0 and NIC0 + header_idx = None + for i, line in enumerate(lines): + if "GPU0" in line and "NIC0" in line: + header_idx = i + break + + if header_idx is None: + return None + + header = lines[header_idx].split() + nic_columns = {col: i for i, col in enumerate(header) if col.startswith("NIC")} + + if not nic_columns: + return None + + # Connection priority (best to worst) + priority = {"PIX": 0, "PXB": 1, "PHB": 2, "NODE": 3, "SYS": 4, "X": 99} + gpu_to_nic = {} + + for line in lines[header_idx + 1 :]: + parts = line.split() + if not parts or not parts[0].startswith("GPU"): + continue + if parts[0].startswith("NIC") or "Legend" in line: + break + + match = re.match(r"GPU(\d+)", parts[0]) + if not match: + continue + gpu_idx = int(match.group(1)) + + best_nic, best_priority = None, 100 + for nic_name, col_idx in nic_columns.items(): + data_col_idx = col_idx + 1 + if data_col_idx < len(parts): + p = priority.get(parts[data_col_idx], 50) + if p < best_priority: + best_priority = p + best_nic = nic_legend.get(nic_name) + + if best_nic: + gpu_to_nic[gpu_idx] = best_nic + + return gpu_to_nic if gpu_to_nic else None + + except Exception as e: + logger.warning("Failed to discover GPU-NIC topology: %s", e) + return None + + +def get_gpu_nic_mapping(local_rank: int) -> Optional[str]: + """Get UCX_NET_DEVICES string for a GPU. + + Format matches elastic.py: RDMA NIC + TCP fallback interfaces + """ + if _GPU_NIC_TOPOLOGY is None: + return None # Topology not set - let UCX auto-select + + if local_rank in _GPU_NIC_TOPOLOGY: + rdma_nic = f"cuda0-{_GPU_NIC_TOPOLOGY[local_rank]}:1" + + # Add TCP fallback interfaces (like elastic.py) for cross-node communication + # These are IPoIB (InfiniBand) interfaces used as TCP fallback + tcp_nics = ( + ",ibp26s0,ibp44s0,ibp64s0,ibp101s0,ibp156s0,ibp173s0,ibp192s0,ibp227s0" + ) + + return rdma_nic + tcp_nics + return None + + +def setup_worker_environment( + local_rank: int, + etcd_server: str = "http://127.0.0.1:2379", + use_tcp_store: bool = False, +): + """Set up GPU, UCX, and NIXL environment for a worker process. + + Args: + local_rank: Local GPU index on this node (0-7), like elastic.py + etcd_server: etcd server URL (only used if not use_tcp_store) + use_tcp_store: If True, use TCPStore instead of etcd + """ + cuda_device = local_rank % 8 + os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_device) + + # Set UCX_NET_DEVICES using local_rank (like elastic.py) + # Maps to the optimal RDMA NIC for this GPU + TCP fallback interfaces + ucx_devices = get_gpu_nic_mapping(local_rank) + if ucx_devices: + os.environ["UCX_NET_DEVICES"] = ucx_devices + + # Don't set UCX_TLS here - buffer.py will set it to "^cuda_ipc" when nvlink_backend != "nixl" + # which tells UCX to auto-detect all transports except cuda_ipc (including RDMA) + + # Only set NIXL_ETCD_ENDPOINTS when NOT using TCPStore (copy elastic.py pattern) + # This prevents C++ code from activating etcd path when we want TCPStore + if not use_tcp_store: + os.environ["NIXL_ETCD_ENDPOINTS"] = etcd_server + logger.info( + f"Worker local_rank={local_rank}: Set NIXL_ETCD_ENDPOINTS={etcd_server}" + ) + + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device("cuda") + torch.cuda.set_device(0) + + +def worker_fn( + torch_rank: int, + num_processes: int, + test_fn: Callable, + result_queue: mp.Queue, + etcd_server: str, + rank_server_addr: str, + gpu_nic_topology: Dict[int, str], + extra_kwargs: Optional[Dict[Any, Any]], + rank_server_port: int, + use_tcp_store: bool, + world_size: int = 1, + node_rank: int = 0, +): + """Worker function executed by each spawned process.""" + global _GPU_NIC_TOPOLOGY, _RANK_SERVER_ADDR, _RANK_SERVER_PORT + + _GPU_NIC_TOPOLOGY = gpu_nic_topology + _RANK_SERVER_ADDR = rank_server_addr + _RANK_SERVER_PORT = rank_server_port + + if extra_kwargs is None: + extra_kwargs = {} + + # Pass node_rank to test function for logging prefix + extra_kwargs["node_rank"] = node_rank + + total_ranks = num_processes * world_size + + # Compute ranks deterministically based on node_rank and process index + # This ensures predictable assignment: + # Node 0: global ranks 0-7 + # Node 1: global ranks 8-15 + # etc. + local_rank = torch_rank # Process index within this node (0-7) + global_rank = node_rank * num_processes + local_rank + + try: + # Setup environment using local_rank for GPU/NIC selection + setup_worker_environment(local_rank, etcd_server, use_tcp_store) + + start_time = time.perf_counter() + result = test_fn( + rank=global_rank, # Global rank for Buffer + world_size=total_ranks, + local_rank=local_rank, # Local rank for GPU index + **extra_kwargs, + ) + duration_ms = (time.perf_counter() - start_time) * 1000 + + if isinstance(result, bool): + test_result = TestResult( + rank=global_rank, + test_name=test_fn.__name__, + passed=result, + duration_ms=duration_ms, + ) + elif isinstance(result, dict): + test_result = TestResult( + rank=global_rank, + test_name=test_fn.__name__, + passed=result.get("passed", True), + error=result.get("error"), + metrics=result.get("metrics"), + duration_ms=duration_ms, + ) + else: + test_result = TestResult( + rank=global_rank, + test_name=test_fn.__name__, + passed=True, + metrics={"result": result}, + duration_ms=duration_ms, + ) + + result_queue.put(test_result) + + except Exception as e: + import traceback + + result_queue.put( + TestResult( + rank=global_rank, + test_name=test_fn.__name__, + passed=False, + error=f"{type(e).__name__}: {e}\n{traceback.format_exc()}", + ) + ) + + +def wait_for_tcp_port( + host: str, + port: int, + timeout: float = 60.0, + poll_interval: float = 0.5, +) -> bool: + """Wait for a TCP port to accept connections. + + Args: + host: Hostname or IP to connect to + port: Port number + timeout: Maximum time to wait in seconds + poll_interval: Initial interval between connection attempts + + Returns: + True if port is ready, raises TimeoutError otherwise + """ + import socket + + start_time = time.time() + attempt = 0 + current_interval = poll_interval + + while time.time() - start_time < timeout: + attempt += 1 + try: + s = socket.create_connection((host, port), timeout=2.0) + s.close() + logger.info( + f"TCP port {host}:{port} is ready " + f"(attempt {attempt}, waited {time.time() - start_time:.1f}s)" + ) + return True + except (ConnectionRefusedError, socket.timeout, OSError): + if attempt == 1: + logger.info(f"Waiting for TCP port {host}:{port}...") + elif attempt % 10 == 0: + logger.info( + f"Still waiting for {host}:{port}... " + f"(attempt {attempt}, {time.time() - start_time:.1f}s)" + ) + time.sleep(current_interval) + current_interval = min(current_interval * 1.2, 2.0) + + raise TimeoutError(f"TCP port {host}:{port} not ready after {timeout}s") + + +def check_etcd_running(etcd_endpoints: str = "http://127.0.0.1:2379") -> bool: + """Check if etcd is running.""" + try: + result = subprocess.run( + ["pgrep", "-x", "etcd"], capture_output=True, text=True, timeout=2 + ) + if result.returncode == 0 and result.stdout.strip(): + return True + except Exception: + pass + + try: + env = os.environ.copy() + env["ETCDCTL_API"] = "3" + result = subprocess.run( + ["etcdctl", "--endpoints", etcd_endpoints, "endpoint", "health"], + capture_output=True, + text=True, + timeout=5, + env=env, + ) + if result.returncode == 0 and "is healthy" in result.stdout: + return True + except Exception: + pass + + return False + + +def clean_etcd_state(etcd_endpoints: str = "http://127.0.0.1:2379"): + """Clean all keys from etcd.""" + try: + env = os.environ.copy() + env["ETCDCTL_API"] = "3" + + # Delete all keys (empty prefix = all keys) + result = subprocess.run( + ["etcdctl", "--endpoints", etcd_endpoints, "del", "--prefix", ""], + capture_output=True, + text=True, + timeout=10, + env=env, + ) + if result.returncode == 0: + time.sleep(1.0) + except Exception: + pass + + +def run_multiprocess_test( + test_fn: Callable, + num_processes: int = 8, + etcd_server: str = "http://127.0.0.1:2379", + timeout: float = 120.0, + clean_etcd: bool = True, + rank_server_port: int = 9998, + tcp_store_port: int = 9999, + skip_nic_discovery: bool = True, + use_tcp_store: bool = True, + world_size: int = 1, + rank: int = 0, + master_addr: str = "127.0.0.1", + **kwargs, +) -> List[TestResult]: + """ + Run a test function across multiple GPU processes (single or multi-node). + + Args: + test_fn: Function receiving (rank, world_size, local_rank, **kwargs) + num_processes: Number of processes to spawn per node + timeout: Timeout in seconds + use_tcp_store: If True (default), use TCPStore; if False, use etcd + tcp_store_port: Port for TCPStore server (default: 9999) + world_size: Total number of nodes (env: WORLD_SIZE, default: 1 for single-node) + rank: This node's rank 0=master (env: RANK, default: 0) + master_addr: Master node address (env: MASTER_ADDR, for TCPStore and rank server) + **kwargs: Passed to test_fn + + Returns: + List of TestResult, one per local rank on this node + """ + # Always use master_addr for etcd (works for both single-node and multi-node) + etcd_server = f"http://{master_addr}:2379" + + # Configure logger with node prefix for multi-node debugging + for handler in logging.root.handlers: + handler.setFormatter(logging.Formatter(f"[Node {rank}] %(message)s")) + + # Calculate total ranks and set master address + total_ranks = num_processes * world_size + os.environ["MASTER_ADDR"] = master_addr + os.environ["WORLD_SIZE"] = str(total_ranks) # Total ranks, not nodes + os.environ["RANK"] = str(rank) # This node's rank + is_master = rank == 0 + + logger.info( + f"etcd_server={etcd_server}, master_addr={master_addr}, " + f"world_size={world_size}, num_processes={num_processes}" + ) + + if world_size > 1: + logger.info( + "Multi-node mode: This is %s node (RANK=%d/%d, MASTER_ADDR=%s)", + "MASTER" if is_master else "WORKER", + rank, + world_size - 1, + master_addr, + ) + + # Start TCPStore server if requested (master node only) + tcp_store_process = None + if use_tcp_store: + if is_master: + logger.info(f"Starting TCPStore server on port {tcp_store_port}") + + def run_tcp_store_server(): + # Keep reference to prevent garbage collection + _store = store_group.create_master_store( # noqa: F841 + port=tcp_store_port + ) + # Keep server alive + import signal + + signal.pause() + + tcp_store_process = mp.Process(target=run_tcp_store_server, daemon=True) + tcp_store_process.start() + + # Wait for TCPStore to be ready (both master and worker nodes) + logger.info(f"Waiting for TCPStore at {master_addr}:{tcp_store_port}...") + wait_for_tcp_port(master_addr, tcp_store_port, timeout=60.0) + logger.info(f"✓ TCPStore ready at {master_addr}:{tcp_store_port}") + kwargs["tcp_store_port"] = tcp_store_port + else: + # Only check/clean etcd on master node when not using TCPStore + if is_master: + if not check_etcd_running(etcd_server): + raise RuntimeError(f"etcd is not running at {etcd_server}") + + if clean_etcd: + clean_etcd_state(etcd_server) + logger.info("Cleaned etcd state") + else: + logger.info("Skipping etcd check (master handles it)") + + # Pass use_tcp_store to the test function via kwargs + kwargs["use_tcp_store"] = use_tcp_store + + # Discover topology once (skipped by default unless --discover-nics is set) + gpu_nic_topology = None + if skip_nic_discovery: + logger.info("Skipping GPU-NIC discovery (default), UCX will auto-select") + else: + gpu_nic_topology = discover_gpu_nic_topology() + if gpu_nic_topology is None: + raise RuntimeError( + "Failed to discover GPU-NIC topology. " + "Ensure nvidia-smi is available and GPUs are present. " + "Or omit --discover-nics to let UCX auto-select (default)." + ) + logger.info(f"Discovered GPU-NIC topology: {gpu_nic_topology}") + + # Start rank server (master node only) + server_process = None + if is_master: + logger.info(f"Starting rank server on port {rank_server_port}") + server_process = start_server(port=rank_server_port) + time.sleep(1.0) + + try: + client = RankClient(master_addr, rank_server_port) + client.clear_barriers() + client.reset() + except Exception as e: + raise RuntimeError(f"Failed to connect to rank server: {e}") + else: + # Worker node: wait for master's rank server to be ready + # NOTE: Do NOT call clear_barriers() here - only master should do that + # to avoid clearing barriers that master's processes are already using + logger.info(f"Waiting for rank server at {master_addr}:{rank_server_port}...") + client = RankClient(master_addr, rank_server_port) + client.wait_for_server(timeout=60.0) + logger.info( + f"✓ Master is alive! " + f"Connected to rank server at {master_addr}:{rank_server_port}" + ) + + spawn_ctx = mp.get_context("spawn") + result_queue = spawn_ctx.Queue() + + try: + ctx = mp.spawn( + worker_fn, + args=( + num_processes, + test_fn, + result_queue, + etcd_server, + master_addr, + gpu_nic_topology, + kwargs, + rank_server_port, + use_tcp_store, + world_size, + rank, + ), + nprocs=num_processes, + join=False, + daemon=False, + start_method="spawn", + ) + + deadline = time.time() + timeout + for p in ctx.processes: + remaining = max(0.1, deadline - time.time()) + p.join(timeout=remaining) + if p.is_alive(): + p.terminate() + + results = [] + while not result_queue.empty(): + try: + results.append(result_queue.get_nowait()) + except Exception: + break + + # Calculate expected global rank range for this node + start_rank = rank * num_processes + end_rank = start_rank + num_processes + expected_ranks = set(range(start_rank, end_rank)) + + result_ranks = {r.rank for r in results} + for expected_rank in expected_ranks: + if expected_rank not in result_ranks: + results.append( + TestResult( + rank=expected_rank, + test_name=test_fn.__name__, + passed=False, + error="Timeout or process died", + ) + ) + + results.sort(key=lambda r: r.rank) + return results + + finally: + if server_process and server_process.is_alive(): + server_process.terminate() + server_process.join(timeout=2) + if tcp_store_process and tcp_store_process.is_alive(): + tcp_store_process.terminate() + tcp_store_process.join(timeout=2) + + +# ============================================================================ +# Synchronization +# ============================================================================ + + +class DistributedBarrier: + """TCP-based barrier using rank_server.""" + + def __init__( + self, + world_size: int, + barrier_id: str, + server_addr: str = "127.0.0.1", + port: int = 9998, + ): + self.world_size = world_size + self.barrier_id = barrier_id + self.server_addr = server_addr + self.port = port + + def wait(self, rank: int, timeout: float = 60.0): + """Wait for all ranks to reach this barrier.""" + client = RankClient(self.server_addr, self.port) + return client.barrier_wait(self.barrier_id, rank, self.world_size, timeout) + + +def sync_all_ranks( + rank: int, + world_size: int, + barrier_name: str, + timeout: float = 60.0, + server_addr: Optional[str] = None, + port: Optional[int] = None, +): + """Synchronize all ranks at a named barrier point.""" + if server_addr is None: + server_addr = _RANK_SERVER_ADDR or os.environ.get("MASTER_ADDR", "127.0.0.1") + if port is None: + port = _RANK_SERVER_PORT + + assert server_addr is not None # Ensured by default above + barrier = DistributedBarrier(world_size, barrier_name, server_addr, port) + barrier.wait(rank, timeout) diff --git a/test/python/nixl_ep_perf/rank_server.py b/test/python/nixl_ep_perf/rank_server.py new file mode 100644 index 0000000000..fe56bc3084 --- /dev/null +++ b/test/python/nixl_ep_perf/rank_server.py @@ -0,0 +1,290 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Rank server for multi-process test coordination. + +Provides rank assignment and distributed barriers. +""" + +import multiprocessing as mp +import os +import socket +import time +from collections import defaultdict +from socketserver import StreamRequestHandler, ThreadingTCPServer +from threading import Lock +from typing import Any, Dict, Optional, Set, Tuple + + +class RankServerHandler(StreamRequestHandler): + """Handles GET_RANK, RELEASE_RANK, BARRIER, CLEAR_BARRIERS, RESET.""" + + _lock: Lock = Lock() + _counts: Dict[str, list] = defaultdict(list) + _rank_to_host: Dict[int, Tuple[str, int]] = {} + _all_global_ranks: Set[int] = set() + _removed_global_ranks: Set[int] = set() + _barriers: Dict[str, Dict[str, Any]] = {} + _completed_barriers: Set[str] = set() + + def handle(self): + try: + line = self.rfile.readline().strip().decode() + + with self._lock: + if line.startswith("BARRIER"): + self._handle_barrier(line) + elif line.startswith("RELEASE_RANK"): + self._handle_release(line) + elif line.startswith("CLEAR_BARRIERS"): + self._handle_clear_barriers() + elif line.startswith("RESET"): + self._handle_reset() + elif line.startswith("GET_RANK") or line: + self._handle_get_rank(line) + + except Exception as e: + try: + self.wfile.write(f"ERROR: {e}\n".encode()) + except Exception: + pass + + def _handle_barrier(self, line: str): + """BARRIER """ + parts = line.split() + if len(parts) < 4: + self.wfile.write(b"ERROR: BARRIER requires barrier_id rank world_size\n") + return + + barrier_id, rank, world_size = parts[1], int(parts[2]), int(parts[3]) + + if barrier_id in self._completed_barriers: + self.wfile.write(b"BARRIER_DONE\n") + return + + if barrier_id not in self._barriers: + self._barriers[barrier_id] = {"expected": world_size, "arrived": set()} + + self._barriers[barrier_id]["arrived"].add(rank) + + if len(self._barriers[barrier_id]["arrived"]) >= world_size: + self._completed_barriers.add(barrier_id) + del self._barriers[barrier_id] + self.wfile.write(b"BARRIER_DONE\n") + else: + arrived = len(self._barriers[barrier_id]["arrived"]) + self.wfile.write(f"BARRIER_WAIT {arrived}/{world_size}\n".encode()) + + def _handle_release(self, line: str): + """RELEASE_RANK """ + parts = line.split() + if len(parts) < 2: + self.wfile.write(b"ERROR: RELEASE_RANK requires rank\n") + return + + rank = int(parts[1]) + if rank in self._all_global_ranks: + self._all_global_ranks.discard(rank) + self._removed_global_ranks.add(rank) + if rank in self._rank_to_host: + host, local_rank = self._rank_to_host[rank] + if local_rank in self._counts[host]: + self._counts[host].remove(local_rank) + del self._rank_to_host[rank] + self.wfile.write(b"OK\n") + + def _handle_clear_barriers(self): + """CLEAR_BARRIERS""" + count = len(self._barriers) + len(self._completed_barriers) + self._barriers.clear() + self._completed_barriers.clear() + self.wfile.write(f"OK {count}\n".encode()) + + def _handle_reset(self): + """RESET""" + self._counts.clear() + self._rank_to_host.clear() + self._all_global_ranks.clear() + self._removed_global_ranks.clear() + self._barriers.clear() + self._completed_barriers.clear() + self.wfile.write(b"OK\n") + + def _handle_get_rank(self, line: str): + """GET_RANK [hostname]""" + if line.startswith("GET_RANK"): + parts = line.split(maxsplit=1) + host = parts[1] if len(parts) > 1 else os.uname().nodename + else: + host = line if line else os.uname().nodename + + used = set(self._counts[host]) + local = 0 + while local in used: + local += 1 + self._counts[host].append(local) + + if self._removed_global_ranks: + global_rank = min(self._removed_global_ranks) + self._removed_global_ranks.remove(global_rank) + else: + global_rank = len(self._all_global_ranks) + + self._all_global_ranks.add(global_rank) + self._rank_to_host[global_rank] = (host, local) + self.wfile.write(f"{local} {global_rank}\n".encode()) + + +class ReusableTCPServer(ThreadingTCPServer): + """TCP server with port reuse.""" + + allow_reuse_address = True + daemon_threads = True + + +class RankClient: + """Client for rank server communication.""" + + def __init__(self, server: str = "127.0.0.1", port: int = 9998): + self.server = server + self.port = port + self.global_rank: Optional[int] = None + self.local_rank: Optional[int] = None + + def wait_for_server( + self, + timeout: float = 60.0, + poll_interval: float = 0.5, + ) -> bool: + """Wait for the rank server to be ready. + + Polls the server until it responds, with exponential backoff. + + Args: + timeout: Maximum time to wait in seconds + poll_interval: Initial interval between connection attempts + + Returns: + True if server is ready, raises TimeoutError otherwise + """ + import logging + + logger = logging.getLogger(__name__) + + start_time = time.time() + attempt = 0 + current_interval = poll_interval + + while time.time() - start_time < timeout: + attempt += 1 + try: + # Try to connect and send a simple command + s = socket.create_connection((self.server, self.port), timeout=2.0) + s.close() + logger.info( + f"Rank server at {self.server}:{self.port} is ready " + f"(attempt {attempt}, waited {time.time() - start_time:.1f}s)" + ) + return True + except (ConnectionRefusedError, socket.timeout, OSError): + if attempt == 1: + logger.info( + f"Waiting for rank server at {self.server}:{self.port}..." + ) + elif attempt % 10 == 0: + logger.info( + f"Still waiting for rank server... " + f"(attempt {attempt}, {time.time() - start_time:.1f}s)" + ) + time.sleep(current_interval) + # Exponential backoff up to 2 seconds + current_interval = min(current_interval * 1.2, 2.0) + + raise TimeoutError( + f"Rank server at {self.server}:{self.port} not ready after {timeout}s" + ) + + def _send(self, command: str, timeout: float = 10.0) -> str: + """Send command and return response.""" + s = socket.create_connection((self.server, self.port), timeout=timeout) + try: + s.sendall(f"{command}\n".encode()) + return s.recv(4096).decode().strip() + finally: + s.close() + + def get_rank(self) -> Tuple[int, int]: + """Get (local_rank, global_rank) from server.""" + if self.global_rank is not None and self.local_rank is not None: + return (self.local_rank, self.global_rank) + + response = self._send(f"GET_RANK {os.uname().nodename}") + parts = response.split() + if len(parts) >= 2: + self.local_rank = int(parts[0]) + self.global_rank = int(parts[1]) + return (self.local_rank, self.global_rank) + raise RuntimeError(f"Unexpected response: {response}") + + def release_rank(self) -> bool: + """Release assigned rank.""" + if self.global_rank is None: + return True + response = self._send(f"RELEASE_RANK {self.global_rank}") + self.global_rank = None + self.local_rank = None + return response == "OK" + + def barrier_wait( + self, barrier_id: str, rank: int, world_size: int, timeout: float = 60.0 + ) -> bool: + """Wait at barrier until all ranks arrive.""" + deadline = time.time() + timeout + + while time.time() < deadline: + try: + response = self._send( + f"BARRIER {barrier_id} {rank} {world_size}", + timeout=min(5.0, deadline - time.time()), + ) + if response == "BARRIER_DONE": + return True + elif response.startswith("BARRIER_WAIT"): + time.sleep(0.05) + elif response.startswith("ERROR"): + raise RuntimeError(f"Barrier error: {response}") + except socket.timeout: + continue + except ConnectionRefusedError: + time.sleep(0.1) + continue + + raise TimeoutError(f"Barrier {barrier_id} timeout after {timeout}s") + + def reset(self) -> bool: + """Reset server state.""" + return self._send("RESET") == "OK" + + def clear_barriers(self) -> int: + """Clear pending barriers.""" + response = self._send("CLEAR_BARRIERS") + if response.startswith("OK"): + parts = response.split() + return int(parts[1]) if len(parts) > 1 else 0 + return 0 + + +def start_server(port: int = 9998) -> mp.Process: + """Start rank server in background process.""" + + def run(): + try: + server = ReusableTCPServer(("0.0.0.0", port), RankServerHandler) + server.serve_forever() + except OSError: + pass # Already running + + process = mp.Process(target=run, daemon=True) + process.start() + time.sleep(0.5) + return process diff --git a/test/python/nixl_ep_perf/store_group.py b/test/python/nixl_ep_perf/store_group.py new file mode 100644 index 0000000000..881c427441 --- /dev/null +++ b/test/python/nixl_ep_perf/store_group.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import timedelta + +import torch.distributed as dist + + +def create_master_store( + port: int = 9999, + timeout_sec: float = 300.0, +) -> dist.TCPStore: + return dist.TCPStore( + host_name="0.0.0.0", + port=port, + is_master=True, + wait_for_workers=False, + timeout=timedelta(seconds=timeout_sec), + ) + + +def create_client_store( + master_addr: str = "127.0.0.1", + port: int = 9999, + timeout_sec: float = 300.0, +) -> dist.TCPStore: + return dist.TCPStore( + host_name=master_addr, + port=port, + is_master=False, + wait_for_workers=False, + timeout=timedelta(seconds=timeout_sec), + ) diff --git a/test/python/nixl_ep_perf/test_data_plane.py b/test/python/nixl_ep_perf/test_data_plane.py new file mode 100644 index 0000000000..303cda1ab0 --- /dev/null +++ b/test/python/nixl_ep_perf/test_data_plane.py @@ -0,0 +1,441 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Data plane performance test for NIXL EP Buffer. + +Measures throughput and latency of dispatch/combine operations. + +Usage: + # Dispatch only (measure dispatch BW/latency) + python3 test_data_plane.py --num-processes=8 --mode=dispatch + + # Combine only (one dispatch, many combines) + python3 test_data_plane.py --num-processes=8 --mode=combine + + # End-to-end (dispatch + combine cycles) + python3 test_data_plane.py --num-processes=8 --mode=e2e + + # Custom configuration + python3 test_data_plane.py --num-processes=8 --mode=e2e \ + --tokens=2048 --hidden=7168 --experts-per-rank=32 --topk=8 +""" + +import argparse +import logging +import os +import sys +from typing import Any, Dict, List + +import store_group +from mp_runner import TestResult, run_multiprocess_test, sync_all_ranks + +# Setup logging +logging.basicConfig(level=logging.INFO, format="%(message)s") +logger = logging.getLogger(__name__) + +# Defaults +DEFAULT_WARMUP = 10 +DEFAULT_ITERS = 100 + + +def _run_data_plane_test( + rank: int, + world_size: int, + local_rank: int = 0, + mode: str = "e2e", + num_experts_per_rank: int = 8, + num_tokens: int = 512, + hidden: int = 4096, + topk: int = 2, + nvlink_backend: str = "ipc", + warmup_iters: int = DEFAULT_WARMUP, + measure_iters: int = DEFAULT_ITERS, + use_tcp_store: bool = False, + node_rank: int = 0, + **kwargs, +) -> Dict[str, Any]: + """ + Run data plane performance test. + + Args: + mode: "dispatch" (only dispatch), "combine" (1 dispatch + N combines), + or "e2e" (N dispatch+combine cycles) + use_tcp_store: Use TCPStore for metadata exchange instead of etcd + node_rank: Node rank for log message prefix + """ + import nixl_ep + import numpy as np + import torch + + # Configure logger with node prefix + for handler in logging.root.handlers: + handler.setFormatter(logging.Formatter(f"[Node {node_rank}] %(message)s")) + + total_experts = num_experts_per_rank * world_size + + # Setup TCPStore if requested + tcp_store = None + if use_tcp_store: + master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") + tcp_store_port = kwargs.get("tcp_store_port", 9999) + tcp_store = store_group.create_client_store( + master_addr=master_addr, + port=tcp_store_port, + timeout_sec=60.0, + ) + + # Create buffer + num_rdma_bytes = nixl_ep.Buffer.get_rdma_size_hint( + num_tokens, hidden, world_size, total_experts + ) + buffer = nixl_ep.Buffer( + rank=rank, + nvlink_backend=nvlink_backend, + explicitly_destroy=True, + enable_shrink=True, + tcp_store_group=tcp_store, + ) + buffer.update_memory_buffers( + num_ranks=world_size, + num_experts_per_rank=num_experts_per_rank, + num_rdma_bytes=num_rdma_bytes, + ) + + sync_all_ranks(rank, world_size, f"{mode}_init") + + # Connect to all other ranks + other_ranks = [r for r in range(world_size) if r != rank] + if other_ranks: + buffer.connect_ranks(other_ranks) + + sync_all_ranks(rank, world_size, f"{mode}_connected") + + # Create test data + x = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device="cuda") + topk_idx = torch.randint( + 0, total_experts, (num_tokens, topk), dtype=torch.int64, device="cuda" + ) + topk_weights = torch.rand(num_tokens, topk, dtype=torch.float32, device="cuda") + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + # Calculate bytes for BW measurement (FP8 dispatch, BF16 combine) + num_fp8_bytes = hidden + hidden // 128 * 4 + 16 + num_combine_bytes = hidden * 2 # BF16 + num_dispatch_comm_bytes = 0 + num_combine_comm_bytes = 0 + for i in range(num_tokens): + num_selections = (topk_idx[i] != -1).sum().item() + num_dispatch_comm_bytes += num_fp8_bytes * num_selections + num_combine_comm_bytes += num_combine_bytes * num_selections + + # Initial dispatch to get shape for combine + recv_x, recv_count, handle_init, event, hook = buffer.dispatch( + x=x, + topk_idx=topk_idx, + num_max_dispatch_tokens_per_rank=num_tokens, + num_experts=total_experts, + use_fp8=True, + async_finish=False, + ) + simulated_gemm_x = recv_x[0].to(torch.bfloat16).clone() + + # Define test functions based on mode + def dispatch_fn(): + return buffer.dispatch( + x=x, + topk_idx=topk_idx, + num_max_dispatch_tokens_per_rank=num_tokens, + num_experts=total_experts, + use_fp8=True, + async_finish=False, + ) + + def combine_fn(handle): + return buffer.combine( + x=simulated_gemm_x, + topk_idx=topk_idx, + topk_weights=topk_weights, + handle=handle, + use_logfmt=False, + ) + + # Flush L2 cache + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + + # For combine mode: do ONE dispatch, reuse handle for all combines + combine_handle = None + if mode == "combine": + _, _, combine_handle, _, _ = dispatch_fn() + + # Warmup + for _ in range(warmup_iters): + if mode == "dispatch": + dispatch_fn() + elif mode == "combine": + combine_fn(combine_handle) + else: # e2e + _, _, handle, _, _ = dispatch_fn() + combine_fn(handle) + + cache.zero_() # Flush L2 + sync_all_ranks(rank, world_size, f"{mode}_warmup") + + # Measure with CUDA events + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(measure_iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(measure_iters)] + + for i in range(measure_iters): + start_events[i].record() + + if mode == "dispatch": + dispatch_fn() + elif mode == "combine": + combine_fn(combine_handle) + else: # e2e + _, _, handle, _, _ = dispatch_fn() + combine_fn(handle) + + end_events[i].record() + + torch.cuda.synchronize() + + # Calculate times (skip first iteration) + times = np.array( + [s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)] + )[1:] + + if mode == "combine": + comm_bytes = num_combine_comm_bytes + elif mode == "dispatch": + comm_bytes = num_dispatch_comm_bytes + else: # e2e + comm_bytes = num_dispatch_comm_bytes + num_combine_comm_bytes + + avg_t = np.average(times) + min_t = np.min(times) + max_t = np.max(times) + + # Calculate metrics + bandwidth_gbps = comm_bytes / 1e9 / avg_t + avg_latency_us = avg_t * 1e6 + tokens_per_sec = num_tokens / avg_t + + sync_all_ranks(rank, world_size, f"{mode}_measured") + + # Cleanup + buffer.destroy() + sync_all_ranks(rank, world_size, f"{mode}_cleanup") + + # Validate results + passed = True + error = None + if np.isnan(bandwidth_gbps) or bandwidth_gbps <= 0: + passed = False + error = f"Invalid bandwidth: {bandwidth_gbps}" + elif np.isnan(avg_t) or avg_t <= 0: + passed = False + error = f"Invalid timing: {avg_t}" + + return { + "passed": passed, + "error": error, + "metrics": { + "mode": mode, + "bandwidth_gbps": bandwidth_gbps, + "avg_latency_us": avg_latency_us, + "min_latency_us": min_t * 1e6, + "max_latency_us": max_t * 1e6, + "tokens_per_sec": tokens_per_sec, + "num_tokens": num_tokens, + "hidden": hidden, + "topk": topk, + "total_experts": total_experts, + "measure_iters": measure_iters, + }, + } + + +def log_results(test_name: str, results: List[TestResult]): + """Log formatted results.""" + passed = sum(1 for r in results if r.passed) + total = len(results) + + logger.info("=" * 70) + logger.info("%s: %d/%d ranks passed", test_name, passed, total) + logger.info("=" * 70) + + if passed == 0: + for r in results: + if r.error: + logger.info(" Rank %d: %s", r.rank, r.error[:200]) + return + + # Aggregate metrics + bw_values = [] + lat_values = [] + for r in results: + if r.passed and r.metrics: + bw_values.append(r.metrics.get("bandwidth_gbps", 0)) + lat_values.append(r.metrics.get("avg_latency_us", 0)) + + if bw_values: + logger.info( + "Bandwidth (GB/s): avg=%.2f, min=%.2f, max=%.2f", + sum(bw_values) / len(bw_values), + min(bw_values), + max(bw_values), + ) + if lat_values: + logger.info( + "Latency (μs): avg=%.1f, min=%.1f, max=%.1f", + sum(lat_values) / len(lat_values), + min(lat_values), + max(lat_values), + ) + + +def main(): + parser = argparse.ArgumentParser(description="NIXL EP Data Plane Performance Test") + parser.add_argument( + "--num-processes", type=int, default=8, help="Number of processes per node" + ) + parser.add_argument( + "--mode", + type=str, + default="e2e", + choices=["dispatch", "combine", "e2e"], + help="Test mode: dispatch, combine, or e2e", + ) + parser.add_argument("--tokens", type=int, default=512, help="Number of tokens") + parser.add_argument("--hidden", type=int, default=4096, help="Hidden dimension") + parser.add_argument("--experts-per-rank", type=int, default=8, help="Experts/rank") + parser.add_argument("--topk", type=int, default=2, help="TopK value") + parser.add_argument( + "--nvlink-backend", + type=str, + default="ipc", + choices=["nixl", "ipc", "none"], + help="NVLink backend (none forces RDMA)", + ) + parser.add_argument( + "--warmup", type=int, default=DEFAULT_WARMUP, help="Warmup iters" + ) + parser.add_argument( + "--iters", type=int, default=DEFAULT_ITERS, help="Measure iters" + ) + parser.add_argument("--timeout", type=int, default=300, help="Timeout (seconds)") + parser.add_argument( + "--discover-nics", + action="store_true", + help="Enable GPU-NIC topology discovery (default: disabled, UCX auto-selects)", + ) + parser.add_argument( + "--use-etcd", + action="store_true", + help="Use etcd for metadata exchange instead of TCPStore (default: TCPStore)", + ) + # Multi-node parameters + parser.add_argument( + "--world-size", + type=int, + default=None, + help="Total number of nodes (overrides WORLD_SIZE env var, default: 1)", + ) + parser.add_argument( + "--rank", + type=int, + default=None, + help="Rank of this node 0=master (overrides RANK env var, default: 0)", + ) + parser.add_argument( + "--master-addr", + type=str, + default=None, + help="Master node address (overrides MASTER_ADDR env var)", + ) + args = parser.parse_args() + + # Get multi-node configuration from environment or command line + world_size = ( + args.world_size + if args.world_size is not None + else int(os.environ.get("WORLD_SIZE", "1")) + ) + rank = args.rank if args.rank is not None else int(os.environ.get("RANK", "0")) + master_addr = ( + args.master_addr + if args.master_addr is not None + else os.environ.get("MASTER_ADDR", "127.0.0.1") + ) + + # Configure logger with node prefix for multi-node debugging + for handler in logging.root.handlers: + handler.setFormatter(logging.Formatter(f"[Node {rank}] %(message)s")) + + # Validation + if world_size < 1: + raise ValueError(f"WORLD_SIZE must be >= 1, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK must be in [0, {world_size - 1}], got {rank}") + if world_size > 1 and rank > 0 and master_addr == "127.0.0.1": + raise ValueError( + "MASTER_ADDR must be set (not 127.0.0.1) for worker nodes in multi-node setup. " + "Set MASTER_ADDR environment variable or use --master-addr flag." + ) + + # Calculate total ranks + total_ranks = args.num_processes * world_size + total_experts = args.experts_per_rank * total_ranks + metadata_exchange = "etcd" if args.use_etcd else "TCPStore" + + logger.info("=" * 70) + logger.info("NIXL EP Data Plane Performance Test") + logger.info("=" * 70) + if world_size > 1: + logger.info("Multi-node setup:") + logger.info(" Nodes (WORLD_SIZE): %d", world_size) + logger.info( + " This node (RANK): %d %s", rank, "(master)" if rank == 0 else "(worker)" + ) + logger.info(" Processes per node: %d", args.num_processes) + logger.info(" Total ranks: %d", total_ranks) + logger.info(" Master address: %s", master_addr) + else: + logger.info("Single-node setup: %d processes", args.num_processes) + logger.info("Mode: %s", args.mode) + logger.info("Tokens: %d, Hidden: %d, TopK: %d", args.tokens, args.hidden, args.topk) + logger.info("Experts: %d/rank (%d total)", args.experts_per_rank, total_experts) + logger.info("NVLink backend: %s", args.nvlink_backend) + logger.info("Metadata exchange: %s", metadata_exchange) + logger.info("Warmup: %d, Measure: %d iterations", args.warmup, args.iters) + logger.info("=" * 70) + + results = run_multiprocess_test( + test_fn=_run_data_plane_test, + num_processes=args.num_processes, + timeout=args.timeout, + skip_nic_discovery=not args.discover_nics, + use_tcp_store=not args.use_etcd, + world_size=world_size, + rank=rank, + master_addr=master_addr, + mode=args.mode, + num_experts_per_rank=args.experts_per_rank, + num_tokens=args.tokens, + hidden=args.hidden, + topk=args.topk, + nvlink_backend=args.nvlink_backend, + warmup_iters=args.warmup, + measure_iters=args.iters, + ) + + log_results(f"Data Plane ({args.mode})", results) + + # Exit with error if any rank failed + if not all(r.passed for r in results): + sys.exit(1) + + +if __name__ == "__main__": + main()