Skip to content

EpDispatchCombineOp crashes with SIGSEGV / OOM on MI355X (sglang-0.5.9-rocm720-mi35x-mori-0227-2 container) #210

@sunway513

Description

@sunway513

Description

Attempting to run Mori EP dispatch/combine microbenchmark on MI355X results in either SIGSEGV or OOM errors depending on heap size configuration.

Environment

  • Container: rocm/sgl-dev:sglang-0.5.9-rocm720-mi35x-mori-0227-2
  • GPU: 8x AMD Instinct MI355X
  • Mori: installed at /sgl-workspace/mori/python/mori/

Reproduction

#!/usr/bin/env python3
import os, torch, torch.distributed as dist, torch.multiprocessing as mp
import mori

os.environ["MORI_SHMEM_HEAP_SIZE"] = "8G"

def worker(rank, world_size):
    os.environ.update({"MASTER_ADDR":"localhost","MASTER_PORT":"29850"})
    torch.cuda.set_device(rank)
    device = torch.device("cuda", rank)
    dist.init_process_group(
        backend="cpu:gloo,cuda:nccl", rank=rank,
        world_size=world_size, device_id=device,
    )
    world_group = torch.distributed.group.WORLD
    torch._C._distributed_c10d._register_process_group("default", world_group)
    mori.shmem.shmem_torch_process_group_init("default")

    config = mori.ops.EpDispatchCombineConfig(
        data_type=torch.bfloat16, rank=rank, world_size=world_size,
        hidden_dim=7168, scale_dim=0,
        scale_type_size=torch.tensor([], dtype=torch.float8_e4m3fnuz).element_size(),
        max_token_type_size=2,
        max_num_inp_token_per_rank=4096,
        num_experts_per_rank=32,  # 256 experts / 8 GPUs
        num_experts_per_token=8,
    )
    op = mori.ops.EpDispatchCombineOp(config)

    x = torch.randn(128, 7168, dtype=torch.bfloat16, device=device)
    topk_ids = torch.randint(0, 256, (128, 8), device=device, dtype=torch.int32)
    topk_weights = torch.randn(128, 8, dtype=torch.float32, device=device).softmax(dim=-1)

    out = op.dispatch(x, topk_weights, torch.ones(128,1,dtype=torch.float32,device=device), topk_ids)
    print(f"rank {rank}: dispatch ok")

    mori.shmem.shmem_finalize()
    dist.destroy_process_group()

if __name__ == "__main__":
    mp.spawn(worker, args=(8,), nprocs=8, join=True)

Error

With MORI_SHMEM_HEAP_SIZE=4G:

[application] [error] Out of heap memory! Requested: 1879048192 bytes (aligned), Current heap size: 4294967296 bytes.
[shmem] [error] Out of static heap memory! Requested: 1879048192 bytes.
[dispatch_combine.cpp:79] hip failed with invalid argument

With MORI_SHMEM_HEAP_SIZE=8G:

process 5 terminated with signal SIGSEGV

Notes

  • The examples/ops/dispatch_combine/test_dispatch_combine.py test script runs correctly for correctness testing (token routing), but does not include any performance benchmarking output.
  • The dispatch/combine API (op.dispatch(input, weights, scales, indices)) differs from the Python-level test which uses a different wrapper class. Documentation on the correct C++ binding API for benchmarking would be helpful.

Expected

Mori EP dispatch/combine should run without SIGSEGV at MORI_SHMEM_HEAP_SIZE=8G for DeepSeek-V3 config (256 experts, 7168 hidden, top-8, 4096 tokens).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions