Problem Description
Summary
I am seeing an async-only GPU memory fault with mori.ops.EpDispatchCombineOp
using the IntraNode kernel on 4 local ROCm GPUs. A minimal MoRI-only script
that repeats:
dispatch -> torch.cuda.synchronize() -> dist.barrier()
combine -> torch.cuda.synchronize() -> dist.barrier()
can complete the iterations and then abort with a high-address GPU memory
fault. The same script passes with HIP_LAUNCH_BLOCKING=1.
This looks like either an IntraNode combine completion issue or repeated
dispatch/combine state reuse issue: local stream synchronization appears to
return before all cross-GPU work/state is actually quiescent.
Repro command
HIP_VISIBLE_DEVICES=0,1,2,3
MORI_SHMEM_MODE=ISOLATION
MORI_SHMEM_HEAP_SIZE=8589934592
MORI_SOCKET_IFNAME=lo
torchrun --standalone --nproc_per_node=4 repro_mori_ep_x4_repeat.py
--iters 2 --tokens 128
Repro script
#!/usr/bin/env python3
from future import annotations
import argparse
import os
import time
import mori
import torch
import torch.distributed as dist
def make_indices(num_tokens, *, rank, iteration, world_size, num_experts_per_rank, topk, device):
total_experts = world_size * num_experts_per_rank
token_offsets = torch.arange(num_tokens, device=device, dtype=torch.int32)[:, None]
expert_offsets = torch.arange(topk, device=device, dtype=torch.int32)[None, :]
indices = (
token_offsets * (topk + 3) + expert_offsets * 17 + rank * 29 + iteration * 31
) % total_experts
return indices.contiguous()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--iters", type=int, default=2)
parser.add_argument("--tokens", type=int, default=128)
parser.add_argument("--hidden-size", type=int, default=4096)
parser.add_argument("--max-tokens", type=int, default=16384)
parser.add_argument("--num-local-experts", type=int, default=64)
parser.add_argument("--topk", type=int, default=8)
parser.add_argument("--block-num", type=int, default=80)
parser.add_argument("--warp-num-per-block", type=int, default=16)
parser.add_argument("--skip-combine", action="store_true")
parser.add_argument("--call-reset", action="store_true")
parser.add_argument("--post-sleep", type=float, default=0.0)
return parser.parse_args()
def main():
args = parse_args()
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert world_size == 4
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
dist.init_process_group("cpu:gloo,cuda:nccl", device_id=device)
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
mori.shmem.shmem_torch_process_group_init("default")
config = mori.ops.EpDispatchCombineConfig(
data_type=torch.bfloat16,
rank=rank,
world_size=world_size,
hidden_dim=args.hidden_size,
scale_dim=0,
scale_type_size=0,
max_token_type_size=torch.bfloat16.itemsize,
max_num_inp_token_per_rank=args.max_tokens,
num_experts_per_rank=args.num_local_experts,
num_experts_per_token=args.topk,
warp_num_per_block=args.warp_num_per_block,
block_num=args.block_num,
use_external_inp_buf=True,
kernel_type=mori.ops.EpDispatchCombineKernelType.IntraNode,
gpu_per_node=world_size,
rdma_block_num=0,
num_qp_per_pe=2,
quant_type="none",
)
op = mori.ops.EpDispatchCombineOp(config)
generator = torch.Generator(device=device)
generator.manual_seed(1000 + rank)
torch.cuda.synchronize()
dist.barrier()
if rank == 0:
print(
f"start iters={args.iters} tokens={args.tokens} "
f"hidden={args.hidden_size} max_tokens={args.max_tokens} topk={args.topk}"
)
for iteration in range(args.iters):
hidden_states = torch.randn(
args.tokens,
args.hidden_size,
device=device,
dtype=torch.bfloat16,
generator=generator,
)
topk_weights = torch.rand(
args.tokens,
args.topk,
device=device,
dtype=torch.float32,
generator=generator,
)
topk_ids = make_indices(
args.tokens,
rank=rank,
iteration=iteration,
world_size=world_size,
num_experts_per_rank=args.num_local_experts,
topk=args.topk,
device=device,
)
dispatch_output, dispatch_weights, _, dispatch_indices, dispatch_recv_count = (
op.dispatch(hidden_states, topk_weights, None, topk_ids)
)
torch.cuda.synchronize()
dist.barrier()
received = int(dispatch_recv_count[0].item())
if rank == 0:
print(f"iter={iteration} rank0_received={received}")
if not args.skip_combine:
op.combine(
dispatch_output,
None,
dispatch_indices,
call_reset=args.call_reset,
)
torch.cuda.synchronize()
dist.barrier()
assert dispatch_weights.shape[1] == args.topk
if rank == 0:
print("completed")
torch.cuda.synchronize()
dist.barrier()
if args.post_sleep > 0:
if rank == 0:
print(f"post_sleep={args.post_sleep}")
time.sleep(args.post_sleep)
torch.cuda.synchronize()
dist.barrier()
mori.shmem.shmem_finalize()
dist.destroy_process_group()
if name == "main":
main()
Actual result
With --iters 2 --tokens 128, the iterations complete and print sane counts.
On some runs one or more ranks abort immediately with a GPU memory fault. On a
fresh double-check, the script printed process group destroyed but the child
ranks stayed alive; after terminating the hung launcher, the same GPU memory
fault surfaced. Example output:
start iters=2 tokens=128 hidden=4096 max_tokens=16384 topk=8
iter=0 rank0_received=356
iter=1 rank0_received=356
completed
Memory access fault by GPU node-2 ... on address 0x7f6379005000. Reason: Unknown.
Memory access fault by GPU node-2 ... on address 0x7edc80805000. Reason: Unknown.
...
Root Cause:
rank: 0
exitcode: -6
traceback: Signal 6 (SIGABRT)
The failing address changes between runs. It is usually a high address. The
fault often appears after the final explicit torch.cuda.synchronize() /
dist.barrier() pair, during process teardown, or when terminating a hung
post-teardown run.
Expected result
After each dispatch and combine has completed and all ranks have passed both
torch.cuda.synchronize() and dist.barrier(), the next iteration and process
teardown should not hit a GPU memory fault.
Controls tried
- HIP_LAUNCH_BLOCKING=1 ... --iters 10 --tokens 128: passes cleanly.
- --iters 1 --tokens 128: passes cleanly.
- --iters 3 --tokens 128 --skip-combine: passes cleanly. The receive count
accumulates across dispatches, which is expected because combine is normally
the path that clears/resets the internal receive count.
- --iters 2 --tokens 128 --call-reset: still faults.
- --iters 2 --tokens 128 --post-sleep 5: passes in my runs. This suggests
local stream sync may be returning before some cross-GPU work/state is fully
quiescent.
- Reducing --max-tokens from 16384 to 128: still faults, so this is not
only caused by the large preallocated receive buffer.
Notes
This was originally found through an SGLang MoE EP integration, but the script
above does not import SGLang, does not use AITER, and does not load any model.
It only exercises MoRI EpDispatchCombineOp IntraNode dispatch/combine.
The most suspicious area from the outside is IntraNode combine completion and
state reset/reuse, especially around totalRecvTokenNum clearing and whether
remote P2P/SHMEM writes are guaranteed complete when the local stream sync
returns.
Operating System
see above
CPU
see above
GPU
see above
ROCm Version
see above
ROCm Component
No response
Steps to Reproduce
see above
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
see above
Additional Information
see above
Problem Description
Summary
I am seeing an async-only GPU memory fault with
mori.ops.EpDispatchCombineOpusing the
IntraNodekernel on 4 local ROCm GPUs. A minimal MoRI-only scriptthat repeats:
can complete the iterations and then abort with a high-address GPU memory
fault. The same script passes with HIP_LAUNCH_BLOCKING=1.
This looks like either an IntraNode combine completion issue or repeated
dispatch/combine state reuse issue: local stream synchronization appears to
return before all cross-GPU work/state is actually quiescent.
Repro command
HIP_VISIBLE_DEVICES=0,1,2,3
MORI_SHMEM_MODE=ISOLATION
MORI_SHMEM_HEAP_SIZE=8589934592
MORI_SOCKET_IFNAME=lo
torchrun --standalone --nproc_per_node=4 repro_mori_ep_x4_repeat.py
--iters 2 --tokens 128
Repro script
#!/usr/bin/env python3
from future import annotations
import argparse
import os
import time
import mori
import torch
import torch.distributed as dist
def make_indices(num_tokens, *, rank, iteration, world_size, num_experts_per_rank, topk, device):
total_experts = world_size * num_experts_per_rank
token_offsets = torch.arange(num_tokens, device=device, dtype=torch.int32)[:, None]
expert_offsets = torch.arange(topk, device=device, dtype=torch.int32)[None, :]
indices = (
token_offsets * (topk + 3) + expert_offsets * 17 + rank * 29 + iteration * 31
) % total_experts
return indices.contiguous()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--iters", type=int, default=2)
parser.add_argument("--tokens", type=int, default=128)
parser.add_argument("--hidden-size", type=int, default=4096)
parser.add_argument("--max-tokens", type=int, default=16384)
parser.add_argument("--num-local-experts", type=int, default=64)
parser.add_argument("--topk", type=int, default=8)
parser.add_argument("--block-num", type=int, default=80)
parser.add_argument("--warp-num-per-block", type=int, default=16)
parser.add_argument("--skip-combine", action="store_true")
parser.add_argument("--call-reset", action="store_true")
parser.add_argument("--post-sleep", type=float, default=0.0)
return parser.parse_args()
def main():
args = parse_args()
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert world_size == 4
if name == "main":
main()
Actual result
With --iters 2 --tokens 128, the iterations complete and print sane counts.
On some runs one or more ranks abort immediately with a GPU memory fault. On a
fresh double-check, the script printed process group destroyed but the child
ranks stayed alive; after terminating the hung launcher, the same GPU memory
fault surfaced. Example output:
start iters=2 tokens=128 hidden=4096 max_tokens=16384 topk=8
iter=0 rank0_received=356
iter=1 rank0_received=356
completed
Memory access fault by GPU node-2 ... on address 0x7f6379005000. Reason: Unknown.
Memory access fault by GPU node-2 ... on address 0x7edc80805000. Reason: Unknown.
...
Root Cause:
rank: 0
exitcode: -6
traceback: Signal 6 (SIGABRT)
The failing address changes between runs. It is usually a high address. The
fault often appears after the final explicit torch.cuda.synchronize() /
dist.barrier() pair, during process teardown, or when terminating a hung
post-teardown run.
Expected result
After each dispatch and combine has completed and all ranks have passed both
torch.cuda.synchronize() and dist.barrier(), the next iteration and process
teardown should not hit a GPU memory fault.
Controls tried
accumulates across dispatches, which is expected because combine is normally
the path that clears/resets the internal receive count.
local stream sync may be returning before some cross-GPU work/state is fully
quiescent.
only caused by the large preallocated receive buffer.
Notes
This was originally found through an SGLang MoE EP integration, but the script
above does not import SGLang, does not use AITER, and does not load any model.
It only exercises MoRI EpDispatchCombineOp IntraNode dispatch/combine.
The most suspicious area from the outside is IntraNode combine completion and
state reset/reuse, especially around totalRecvTokenNum clearing and whether
remote P2P/SHMEM writes are guaranteed complete when the local stream sync
returns.
Operating System
see above
CPU
see above
GPU
see above
ROCm Version
see above
ROCm Component
No response
Steps to Reproduce
see above
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
see above
Additional Information
see above