Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/device/ep/csrc/nixl_ep.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 DeepSeek
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* This file incorporates material from the DeepSeek project, licensed under the MIT License.
* The modifications made by NVIDIA are licensed under the Apache License, Version 2.0.
Expand Down Expand Up @@ -233,7 +233,7 @@ void Buffer::destroy() {
cudaFree(counters_buffer_ptr);
cudaFree(rdma_buffer_ptr);

if (nixl_agent_info and nixl_agent_info->agent != nullptr) {
if (nixl_agent_info and nixl_agent_info->agent != nullptr and getenv("NIXL_ETCD_ENDPOINTS")) {
nixl_agent_info->agent->invalidateLocalMD();
}

Expand Down
20 changes: 6 additions & 14 deletions examples/device/ep/nixl_ep/buffer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 DeepSeek
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This file incorporates material from the DeepSeek project, licensed under the MIT License.
# The modifications made by NVIDIA are licensed under the Apache License, Version 2.0.
Expand Down Expand Up @@ -478,20 +478,12 @@ def _fetch_remote_metadata_from_tcp_store(self, remote_ranks: List[int]):
nixl_metadata_bytes = self.runtime.get_local_metadata()
self.tcp_store_group.set(md_key, nixl_metadata_bytes)

remote_md_keys = [
f"NIXL_EP/{rank}" for rank in remote_ranks if rank != self.rank
]
remote_md_keys = [f"NIXL_EP/{rank}" for rank in remote_ranks]
if remote_md_keys:
self.tcp_store_group.wait(remote_md_keys, timedelta(seconds=30))

remote_mds = []
for rank in remote_ranks:
if rank != self.rank:
remote_md_key = f"NIXL_EP/{rank}"
remote_md_bytes = self.tcp_store_group.get(remote_md_key)
remote_mds.append(remote_md_bytes)
else:
remote_mds.append(b"")
self.tcp_store_group.wait(remote_md_keys, timedelta(seconds=300))
remote_mds = self.tcp_store_group.multi_get(remote_md_keys)
else:
remote_mds = []

try:
yield remote_mds
Expand Down
8 changes: 3 additions & 5 deletions examples/device/ep/tests/elastic/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
```bash
python3 tests/elastic/elastic.py \
--plan tests/elastic/single_expansion.json \
--num-processes 8 \
--etcd-server http://127.0.0.1:2379
--num-processes 8
```

#### Multi-Node Setup:
Expand All @@ -14,16 +13,15 @@ python3 tests/elastic/elastic.py \
```bash
python3 tests/elastic/elastic.py \
--plan tests/elastic/single_expansion.json \
--num-processes 4 \
--num-processes 4
```

**Node 2** (will join the second phase with additional 4 ranks):
```bash
python3 tests/elastic/elastic.py \
--plan tests/elastic/single_expansion.json \
--num-processes 4 \
--rank-server $MASTER_IP \
--etcd-server http://$MASTER_IP:2379
--tcp-server $MASTER_IP
```

### Available Test Plans
Expand Down
43 changes: 23 additions & 20 deletions examples/device/ep/tests/elastic/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import nixl_ep
import rank_server
import store_group
import torch
from plan import Plan

Expand All @@ -44,6 +45,9 @@
per_token_cast_back,
)

TCP_STORE_PORT = 9999
RANK_SERVER_PORT = 10000


def handle_sigterm(
signum,
Expand Down Expand Up @@ -438,9 +442,8 @@ def test_barrier():


def worker(torch_rank: int, args: argparse.Namespace):
rank_client = rank_server.RankClient(
args.rank_server if args.rank_server else "127.0.0.1"
)
server_addr = args.tcp_server if args.tcp_server else "127.0.0.1"
rank_client = rank_server.RankClient(server_addr, RANK_SERVER_PORT)
local_rank, global_rank, last_active_phase = rank_client.get_rank()
plan = Plan(
args.plan,
Expand All @@ -466,8 +469,10 @@ def worker(torch_rank: int, args: argparse.Namespace):
torch.set_default_device("cuda")
torch.cuda.set_device(0)

# Initialize NIXL
os.environ["NIXL_ETCD_ENDPOINTS"] = args.etcd_server
tcp_store = store_group.create_client_store(
master_addr=server_addr,
port=TCP_STORE_PORT,
)

# Initialize nixl_ep buffer
num_rdma_bytes = nixl_ep.Buffer.get_rdma_size_hint(
Expand All @@ -484,6 +489,7 @@ def worker(torch_rank: int, args: argparse.Namespace):
nvlink_backend=args.nvlink_backend,
explicitly_destroy=True,
enable_shrink=True,
tcp_store_group=tcp_store,
)
buffer.update_memory_buffers(
num_ranks=max_num_ranks,
Expand Down Expand Up @@ -584,6 +590,11 @@ def worker(torch_rank: int, args: argparse.Namespace):
print(f"global_rank={global_rank}, local_rank={local_rank} -> done", flush=True)


def run_server():
_store = store_group.create_master_store(port=TCP_STORE_PORT) # noqa: F841
rank_server.start_server(port=RANK_SERVER_PORT)


def main():
parser = argparse.ArgumentParser(description="Elastic EP Test")
parser.add_argument(
Expand All @@ -602,15 +613,9 @@ def main():
parser.add_argument("--hidden-dim", type=int, default=7168, help="Hidden dimension")
parser.add_argument("--num-topk", type=int, default=8, help="Number of topk")
parser.add_argument(
"--etcd-server",
"--tcp-server",
type=str,
default="http://127.0.0.1:2379",
help="ETCD server address for NIXL (default: http://127.0.0.1:2379)",
)
parser.add_argument(
"--rank-server",
type=str,
help="Rank server address. If not set, a rank server will be started locally and will be killed after all the workers launched in this run are finished.",
help="TCP server address (for both TCPStore and rank server). If not set, both will be started locally.",
)
parser.add_argument("--kineto", action="store_true", help="Enable kineto profiling")
parser.add_argument(
Expand All @@ -622,14 +627,12 @@ def main():

args = parser.parse_args()

rank_server_process = None
if not args.rank_server:
print("Starting rank server locally", flush=True)
rank_server_process = torch.multiprocessing.Process(
target=rank_server.start_server, daemon=True
)
rank_server_process.start()
if not args.tcp_server:
print("Starting TCPStore and rank server locally", flush=True)
server_process = torch.multiprocessing.Process(target=run_server, daemon=True)
server_process.start()
time.sleep(0.5)

if args.num_processes == 1:
worker(0, args)
return
Expand Down
45 changes: 45 additions & 0 deletions examples/device/ep/tests/elastic/store_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import timedelta

import torch.distributed as dist


def create_master_store(
port: int = 9999,
timeout_sec: float = 300.0,
) -> dist.TCPStore:
return dist.TCPStore(
host_name="0.0.0.0",
port=port,
is_master=True,
wait_for_workers=False,
timeout=timedelta(seconds=timeout_sec),
)


def create_client_store(
master_addr: str = "127.0.0.1",
port: int = 9999,
timeout_sec: float = 300.0,
) -> dist.TCPStore:
return dist.TCPStore(
host_name=master_addr,
port=port,
is_master=False,
wait_for_workers=False,
timeout=timedelta(seconds=timeout_sec),
)
Loading