Skip to content

Commit b5c35dd

Browse files
committed
nixl_ep: Migrate elastic.py to TCPStore
This commit migrates elastic.py to TCPStore-based metadata exchange instead of ETCD and replaces the custom TCP rank server with a Torch TCPStore group based implementation, reusing the same group for both metadata and rank management. I considered defining an abstract base class to support both implementations (the custom TCP server and TCPStore), but since get_rank is off the data path and already very fast (see performance table below), maintaining both options does not seem necessary- Reviewer feedback is welcome. Performance comparison, 16 concurrent get_rank calls on 2 nodes * 8 gpus: Scope TCPStore Avg (ms) TCP Avg (ms) TCPStore StdDev (ms) TCP StdDev (ms) Local 0.28 1.55 0.04 0.20 Remote 2.95 1.48 2.05 0.17 Signed-off-by: Itay Alroy <ialroy@nvidia.com>
1 parent c3139b4 commit b5c35dd

File tree

6 files changed

+241
-189
lines changed

6 files changed

+241
-189
lines changed

examples/device/ep/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@ NIXL EP provides a flexible buffer initialization pattern that supports dynamic
1515

1616
```python
1717
import nixl_ep
18+
import store_group
19+
20+
# Create TCPStore for coordination
21+
tcp_store = store_group.create_client_store(master_addr, port)
1822

1923
# Initialize buffer with dynamic rank support
20-
buffer = nixl_ep.Buffer(rank, explicitly_destroy=True)
24+
buffer = nixl_ep.Buffer(rank, explicitly_destroy=True, tcp_store_group=tcp_store)
2125
buffer.update_memory_buffers(num_ranks, num_experts_per_rank, rdma_bytes)
2226
buffer.connect_ranks(initial_ranks)
2327

examples/device/ep/tests/elastic/README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
```bash
55
python3 tests/elastic/elastic.py \
66
--plan tests/elastic/single_expansion.json \
7-
--num-processes 8 \
8-
--etcd-server http://127.0.0.1:2379
7+
--num-processes 8
98
```
109

1110
#### Multi-Node Setup:
@@ -14,16 +13,15 @@ python3 tests/elastic/elastic.py \
1413
```bash
1514
python3 tests/elastic/elastic.py \
1615
--plan tests/elastic/single_expansion.json \
17-
--num-processes 4 \
16+
--num-processes 4
1817
```
1918

2019
**Node 2** (will join the second phase with additional 4 ranks):
2120
```bash
2221
python3 tests/elastic/elastic.py \
2322
--plan tests/elastic/single_expansion.json \
2423
--num-processes 4 \
25-
--rank-server $MASTER_IP \
26-
--etcd-server http://$MASTER_IP:2379
24+
--tcp-store $MASTER_IP
2725
```
2826

2927
### Available Test Plans

examples/device/ep/tests/elastic/elastic.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
from typing import cast
3030

3131
import nixl_ep
32-
import rank_server
32+
import rank_manager
33+
import store_group
3334
import torch
3435
from plan import Plan
3536

@@ -50,7 +51,7 @@ def handle_sigterm(
5051
frame,
5152
buffer: nixl_ep.Buffer,
5253
plan: Plan,
53-
rank_client: rank_server.RankClient,
54+
rank_client: rank_manager.RankManager,
5455
):
5556
print(
5657
f"SIGTERM ({signum}) received for process {os.getpid()}! releasing rank and exiting...",
@@ -438,14 +439,22 @@ def test_barrier():
438439

439440

440441
def worker(torch_rank: int, args: argparse.Namespace):
441-
rank_client = rank_server.RankClient(
442-
args.rank_server if args.rank_server else "127.0.0.1"
442+
tcp_store = store_group.create_client_store(
443+
master_addr=args.tcp_store if args.tcp_store else "127.0.0.1",
444+
port=args.tcp_store_port,
443445
)
446+
447+
rank_client = rank_manager.RankManager(tcp_store)
444448
local_rank, global_rank, last_active_phase = rank_client.get_rank()
449+
print(
450+
f"Process {torch_rank} -> global_rank={global_rank}, local_rank={local_rank}",
451+
flush=True,
452+
)
453+
445454
plan = Plan(
446455
args.plan,
447456
global_rank,
448-
start_phase=last_active_phase if last_active_phase is not None else 0,
457+
start_phase=int(last_active_phase) if last_active_phase is not None else 0,
449458
)
450459
if plan.current_phase == -1:
451460
print(
@@ -455,10 +464,6 @@ def worker(torch_rank: int, args: argparse.Namespace):
455464
return
456465

457466
max_num_ranks = plan.get_max_rank() + 1
458-
print(
459-
f"Process {torch_rank} -> global_rank={global_rank}, local_rank={local_rank}",
460-
flush=True,
461-
)
462467

463468
# Initialize torch
464469
os.environ["CUDA_VISIBLE_DEVICES"] = str(local_rank % 8)
@@ -480,9 +485,6 @@ def worker(torch_rank: int, args: argparse.Namespace):
480485
tcp_nics = ",ibp154s0,ibp192s0,ibp206s0,ibp220s0,ibp94s0"
481486
os.environ["UCX_NET_DEVICES"] = f"cuda0-{pxb_nics[local_rank]}:1" + tcp_nics
482487

483-
# Initialize NIXL
484-
os.environ["NIXL_ETCD_ENDPOINTS"] = args.etcd_server
485-
486488
# Initialize nixl_ep buffer
487489
num_rdma_bytes = nixl_ep.Buffer.get_rdma_size_hint(
488490
args.num_tokens,
@@ -498,6 +500,7 @@ def worker(torch_rank: int, args: argparse.Namespace):
498500
nvlink_backend=args.nvlink_backend,
499501
explicitly_destroy=True,
500502
enable_shrink=True,
503+
tcp_store_group=tcp_store,
501504
)
502505
buffer.update_memory_buffers(
503506
num_ranks=max_num_ranks,
@@ -616,15 +619,15 @@ def main():
616619
parser.add_argument("--hidden-dim", type=int, default=7168, help="Hidden dimension")
617620
parser.add_argument("--num-topk", type=int, default=8, help="Number of topk")
618621
parser.add_argument(
619-
"--etcd-server",
622+
"--tcp-store",
620623
type=str,
621-
default="http://127.0.0.1:2379",
622-
help="ETCD server address for NIXL (default: http://127.0.0.1:2379)",
624+
help="External TCPStore address. If not set, a local TCPStore master will be created.",
623625
)
624626
parser.add_argument(
625-
"--rank-server",
626-
type=str,
627-
help="Rank server address. If not set, a rank server will be started locally and will be killed after all the workers launched in this run are finished.",
627+
"--tcp-store-port",
628+
type=int,
629+
default=9999,
630+
help="TCPStore port (default: 9999)",
628631
)
629632
parser.add_argument("--kineto", action="store_true", help="Enable kineto profiling")
630633
parser.add_argument(
@@ -636,14 +639,15 @@ def main():
636639

637640
args = parser.parse_args()
638641

639-
rank_server_process = None
640-
if not args.rank_server:
641-
print("Starting rank server locally", flush=True)
642-
rank_server_process = torch.multiprocessing.Process(
643-
target=rank_server.start_server, daemon=True
642+
# Create TCPStore master if no external TCPStore server is specified
643+
master_store = None
644+
if not args.tcp_store:
645+
master_store = store_group.create_master_store(
646+
port=args.tcp_store_port,
647+
timeout_sec=365 * 24 * 3600, # 1 year timeout
644648
)
645-
rank_server_process.start()
646-
time.sleep(0.5)
649+
rank_manager.init_keys(master_store)
650+
647651
if args.num_processes == 1:
648652
worker(0, args)
649653
return
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import json
17+
import os
18+
import time
19+
from contextlib import contextmanager
20+
from typing import Iterator, Optional, Tuple
21+
22+
import torch.distributed as dist
23+
24+
_KEY_NEXT_GLOBAL_RANK = "rank_manager/next_global_rank"
25+
_KEY_RELEASED_RANKS = "rank_manager/released_ranks"
26+
_KEY_LOCK = "rank_manager/lock"
27+
28+
29+
def _host_local_ranks_key(hostname: str) -> str:
30+
return f"rank_manager/host/{hostname}/local_ranks"
31+
32+
33+
def _rank_context_key(global_rank: int) -> str:
34+
return f"rank_manager/rank/{global_rank}/context"
35+
36+
37+
def _rank_hostname_key(global_rank: int) -> str:
38+
return f"rank_manager/rank/{global_rank}/hostname"
39+
40+
41+
def _rank_local_key(global_rank: int) -> str:
42+
return f"rank_manager/rank/{global_rank}/local_rank"
43+
44+
45+
def init_keys(store: dist.TCPStore) -> None:
46+
store.add(_KEY_NEXT_GLOBAL_RANK, 0)
47+
store.set(_KEY_RELEASED_RANKS, "[]")
48+
store.set(_KEY_LOCK, "0")
49+
50+
51+
class RankManager:
52+
53+
def __init__(self, store: dist.TCPStore):
54+
self._global_rank: Optional[int] = None
55+
self._hostname = os.uname().nodename
56+
self._store = store
57+
58+
@contextmanager
59+
def _lock(self) -> Iterator[None]:
60+
my_id = f"{self._hostname}_{os.getpid()}"
61+
while True:
62+
result = self._store.compare_set(_KEY_LOCK, "0", my_id)
63+
if result.decode() == my_id:
64+
break
65+
if result == b"0":
66+
continue
67+
time.sleep(0.0001)
68+
try:
69+
yield
70+
finally:
71+
self._store.set(_KEY_LOCK, "0")
72+
73+
def _get_json_list(self, key: str) -> list:
74+
return (
75+
json.loads(self._store.get(key).decode())
76+
if self._store.check([key])
77+
else []
78+
)
79+
80+
def _set_json_list(self, key: str, value: list):
81+
self._store.set(key, json.dumps(value))
82+
83+
def get_rank(self) -> Tuple[int, int, Optional[str]]:
84+
"""Returns (local_rank, global_rank, user_context)."""
85+
if self._global_rank is not None:
86+
print(
87+
f"WARNING: rank already assigned - returning existing rank {self._global_rank}",
88+
flush=True,
89+
)
90+
return 0, self._global_rank, None
91+
92+
start = time.perf_counter()
93+
user_context: Optional[str] = None
94+
95+
with self._lock():
96+
released = self._get_json_list(_KEY_RELEASED_RANKS)
97+
98+
if released:
99+
global_rank = min(released)
100+
released.remove(global_rank)
101+
self._set_json_list(_KEY_RELEASED_RANKS, released)
102+
ctx_data = self._store.get(_rank_context_key(global_rank)).decode()
103+
if ctx_data and ctx_data != "None":
104+
user_context = ctx_data
105+
self._store.delete_key(_rank_context_key(global_rank))
106+
else:
107+
global_rank = self._store.add(_KEY_NEXT_GLOBAL_RANK, 1) - 1
108+
109+
local_ranks_key = _host_local_ranks_key(self._hostname)
110+
used_local_ranks = set(self._get_json_list(local_ranks_key))
111+
local_rank = 0
112+
while local_rank in used_local_ranks:
113+
local_rank += 1
114+
used_local_ranks.add(local_rank)
115+
self._store.multi_set(
116+
[
117+
local_ranks_key,
118+
_rank_hostname_key(global_rank),
119+
_rank_local_key(global_rank),
120+
],
121+
[json.dumps(list(used_local_ranks)), self._hostname, str(local_rank)],
122+
)
123+
124+
self._global_rank = global_rank
125+
elapsed_ms = (time.perf_counter() - start) * 1000
126+
print(f"[rank_manager] get_rank took {elapsed_ms:.2f} ms", flush=True)
127+
return local_rank, global_rank, user_context
128+
129+
def release_rank(self, user_context: Optional[str] = None) -> bool:
130+
if self._global_rank is None:
131+
return False
132+
133+
global_rank = self._global_rank
134+
135+
with self._lock():
136+
hostname_key = _rank_hostname_key(global_rank)
137+
local_key = _rank_local_key(global_rank)
138+
values = self._store.multi_get([hostname_key, local_key])
139+
hostname = values[0].decode()
140+
local_rank = int(values[1].decode())
141+
142+
local_ranks_key = _host_local_ranks_key(hostname)
143+
used_local_ranks = self._get_json_list(local_ranks_key)
144+
if local_rank in used_local_ranks:
145+
used_local_ranks.remove(local_rank)
146+
self._set_json_list(local_ranks_key, used_local_ranks)
147+
148+
if user_context is not None:
149+
self._store.set(_rank_context_key(global_rank), str(user_context))
150+
151+
self._store.delete_key(hostname_key)
152+
self._store.delete_key(local_key)
153+
154+
released = self._get_json_list(_KEY_RELEASED_RANKS)
155+
released.append(global_rank)
156+
self._set_json_list(_KEY_RELEASED_RANKS, released)
157+
158+
self._global_rank = None
159+
return True

0 commit comments

Comments
 (0)