Skip to content

[Issue]: IntraNode EpDispatchCombine async repeated dispatch+combine can fault after local stream sync #342

@zx3xyy

Description

@zx3xyy

Problem Description

Summary

I am seeing an async-only GPU memory fault with mori.ops.EpDispatchCombineOp
using the IntraNode kernel on 4 local ROCm GPUs. A minimal MoRI-only script
that repeats:

dispatch -> torch.cuda.synchronize() -> dist.barrier()
combine  -> torch.cuda.synchronize() -> dist.barrier()

can complete the iterations and then abort with a high-address GPU memory
fault. The same script passes with HIP_LAUNCH_BLOCKING=1.

This looks like either an IntraNode combine completion issue or repeated
dispatch/combine state reuse issue: local stream synchronization appears to
return before all cross-GPU work/state is actually quiescent.

Repro command

HIP_VISIBLE_DEVICES=0,1,2,3
MORI_SHMEM_MODE=ISOLATION
MORI_SHMEM_HEAP_SIZE=8589934592
MORI_SOCKET_IFNAME=lo
torchrun --standalone --nproc_per_node=4 repro_mori_ep_x4_repeat.py
--iters 2 --tokens 128

Repro script

#!/usr/bin/env python3
from future import annotations

import argparse
import os
import time

import mori
import torch
import torch.distributed as dist

def make_indices(num_tokens, *, rank, iteration, world_size, num_experts_per_rank, topk, device):
total_experts = world_size * num_experts_per_rank
token_offsets = torch.arange(num_tokens, device=device, dtype=torch.int32)[:, None]
expert_offsets = torch.arange(topk, device=device, dtype=torch.int32)[None, :]
indices = (
token_offsets * (topk + 3) + expert_offsets * 17 + rank * 29 + iteration * 31
) % total_experts
return indices.contiguous()

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--iters", type=int, default=2)
parser.add_argument("--tokens", type=int, default=128)
parser.add_argument("--hidden-size", type=int, default=4096)
parser.add_argument("--max-tokens", type=int, default=16384)
parser.add_argument("--num-local-experts", type=int, default=64)
parser.add_argument("--topk", type=int, default=8)
parser.add_argument("--block-num", type=int, default=80)
parser.add_argument("--warp-num-per-block", type=int, default=16)
parser.add_argument("--skip-combine", action="store_true")
parser.add_argument("--call-reset", action="store_true")
parser.add_argument("--post-sleep", type=float, default=0.0)
return parser.parse_args()

def main():
args = parse_args()
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert world_size == 4

  torch.cuda.set_device(local_rank)
  device = torch.device("cuda", local_rank)
  dist.init_process_group("cpu:gloo,cuda:nccl", device_id=device)
  torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
  mori.shmem.shmem_torch_process_group_init("default")

  config = mori.ops.EpDispatchCombineConfig(
      data_type=torch.bfloat16,
      rank=rank,
      world_size=world_size,
      hidden_dim=args.hidden_size,
      scale_dim=0,
      scale_type_size=0,
      max_token_type_size=torch.bfloat16.itemsize,
      max_num_inp_token_per_rank=args.max_tokens,
      num_experts_per_rank=args.num_local_experts,
      num_experts_per_token=args.topk,
      warp_num_per_block=args.warp_num_per_block,
      block_num=args.block_num,
      use_external_inp_buf=True,
      kernel_type=mori.ops.EpDispatchCombineKernelType.IntraNode,
      gpu_per_node=world_size,
      rdma_block_num=0,
      num_qp_per_pe=2,
      quant_type="none",
  )
  op = mori.ops.EpDispatchCombineOp(config)

  generator = torch.Generator(device=device)
  generator.manual_seed(1000 + rank)

  torch.cuda.synchronize()
  dist.barrier()
  if rank == 0:
      print(
          f"start iters={args.iters} tokens={args.tokens} "
          f"hidden={args.hidden_size} max_tokens={args.max_tokens} topk={args.topk}"
      )

  for iteration in range(args.iters):
      hidden_states = torch.randn(
          args.tokens,
          args.hidden_size,
          device=device,
          dtype=torch.bfloat16,
          generator=generator,
      )
      topk_weights = torch.rand(
          args.tokens,
          args.topk,
          device=device,
          dtype=torch.float32,
          generator=generator,
      )
      topk_ids = make_indices(
          args.tokens,
          rank=rank,
          iteration=iteration,
          world_size=world_size,
          num_experts_per_rank=args.num_local_experts,
          topk=args.topk,
          device=device,
      )

      dispatch_output, dispatch_weights, _, dispatch_indices, dispatch_recv_count = (
          op.dispatch(hidden_states, topk_weights, None, topk_ids)
      )
      torch.cuda.synchronize()
      dist.barrier()

      received = int(dispatch_recv_count[0].item())
      if rank == 0:
          print(f"iter={iteration} rank0_received={received}")

      if not args.skip_combine:
          op.combine(
              dispatch_output,
              None,
              dispatch_indices,
              call_reset=args.call_reset,
          )
          torch.cuda.synchronize()
          dist.barrier()

      assert dispatch_weights.shape[1] == args.topk

  if rank == 0:
      print("completed")

  torch.cuda.synchronize()
  dist.barrier()
  if args.post_sleep > 0:
      if rank == 0:
          print(f"post_sleep={args.post_sleep}")
      time.sleep(args.post_sleep)
      torch.cuda.synchronize()
      dist.barrier()

  mori.shmem.shmem_finalize()
  dist.destroy_process_group()

if name == "main":
main()

Actual result

With --iters 2 --tokens 128, the iterations complete and print sane counts.
On some runs one or more ranks abort immediately with a GPU memory fault. On a
fresh double-check, the script printed process group destroyed but the child
ranks stayed alive; after terminating the hung launcher, the same GPU memory
fault surfaced. Example output:

start iters=2 tokens=128 hidden=4096 max_tokens=16384 topk=8
iter=0 rank0_received=356
iter=1 rank0_received=356
completed
Memory access fault by GPU node-2 ... on address 0x7f6379005000. Reason: Unknown.
Memory access fault by GPU node-2 ... on address 0x7edc80805000. Reason: Unknown.
...
Root Cause:
rank: 0
exitcode: -6
traceback: Signal 6 (SIGABRT)

The failing address changes between runs. It is usually a high address. The
fault often appears after the final explicit torch.cuda.synchronize() /
dist.barrier() pair, during process teardown, or when terminating a hung
post-teardown run.

Expected result

After each dispatch and combine has completed and all ranks have passed both
torch.cuda.synchronize() and dist.barrier(), the next iteration and process
teardown should not hit a GPU memory fault.

Controls tried

  • HIP_LAUNCH_BLOCKING=1 ... --iters 10 --tokens 128: passes cleanly.
  • --iters 1 --tokens 128: passes cleanly.
  • --iters 3 --tokens 128 --skip-combine: passes cleanly. The receive count
    accumulates across dispatches, which is expected because combine is normally
    the path that clears/resets the internal receive count.
  • --iters 2 --tokens 128 --call-reset: still faults.
  • --iters 2 --tokens 128 --post-sleep 5: passes in my runs. This suggests
    local stream sync may be returning before some cross-GPU work/state is fully
    quiescent.
  • Reducing --max-tokens from 16384 to 128: still faults, so this is not
    only caused by the large preallocated receive buffer.

Notes

This was originally found through an SGLang MoE EP integration, but the script
above does not import SGLang, does not use AITER, and does not load any model.
It only exercises MoRI EpDispatchCombineOp IntraNode dispatch/combine.

The most suspicious area from the outside is IntraNode combine completion and
state reset/reuse, especially around totalRecvTokenNum clearing and whether
remote P2P/SHMEM writes are guaranteed complete when the local stream sync
returns.

Operating System

see above

CPU

see above

GPU

see above

ROCm Version

see above

ROCm Component

No response

Steps to Reproduce

see above

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

see above

Additional Information

see above

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions