Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ import sglang as sgl

sgl_engine = sgl.Engine(model_path="xxx", tp_size=2, random_seed=42)
awex_config = InferenceConfig.from_sgl_engine(sgl_engine, comm_backend="nccl")
# for sglang support, you must ensure https://github.com/sgl-project/sglang/pull/13595
# for sglang support, you must ensure https://github.com/sgl-project/sglang/pull/13595
# is included in your sglang version
inference_engine = SGLangEngine(awex_config, sgl_engine)
reader = WeightsReader(inference_engine)
Expand Down
3 changes: 0 additions & 3 deletions awex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@ class InferenceConfig:
ep_size: Optional[int] = None
enable_dp_attention: Optional[bool] = None
enable_dp_lm_head: Optional[bool] = None
enable_ep_moe: Optional[bool] = None
enable_deepep_moe: Optional[bool] = None
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = None
ep_num_redundant_experts: Optional[int] = None
enable_eplb: Optional[bool] = None
eplb_rebalance_num_iterations: Optional[int] = None
enable_memory_saver: Optional[bool] = None
moe_dense_tp_size: Optional[int] = None
n_share_experts_fusion: Optional[int] = None
Expand Down
2 changes: 0 additions & 2 deletions awex/converter/sglang_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
self.total_kv_heads = model_config.num_key_value_heads
self.infer_engine_config = infer_engine_config
self.rank_info = rank_info
self.enable_ep_moe = infer_engine_config.enable_ep_moe
self.tp_size = infer_engine_config.tp_size
self.tp_rank = self.rank_info.tp_rank
self.ep_size = infer_engine_config.ep_size
Expand Down Expand Up @@ -205,7 +204,6 @@ def _convert_expert_moe_param(
self, name: str, parameter: torch.Tensor, layer_number: str
) -> List[Tuple[str, torch.Tensor]]:
"""Convert expert parameters from SGlang to HuggingFace format."""
assert self.enable_ep_moe, "EP mode must be enabled"
# w13_weight shape: num_experts_per_partition, 2 * intermediate_size, hidden_size
# w2_weight shape: num_experts_per_partition, hidden_size, intermediate_size
if "expert_bias" in name:
Expand Down
20 changes: 7 additions & 13 deletions awex/sharding/sglang_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_sglang_sharding_strategy(
moe_dense_tp_size=infer_engine_config.moe_dense_tp_size,
tp_size=rank_info.tp_size,
ep_size=infer_engine_config.ep_size,
ep_tp_size=1,
ep_tp_size=rank_info.ep_tp_size,
rank_info=rank_info,
**kwargs,
)
Expand All @@ -50,21 +50,15 @@ def get_sglang_rank_info(model_context, engine_rank) -> RankInfo:
tp_size = model_context["tp_size"]
tp_rank = model_context["tp_rank"]
ep_size = infer_engine_config.ep_size
if (
infer_engine_config.enable_ep_moe
or infer_engine_config.enable_deepep_moe
or (
hasattr(infer_engine_config, "enable_pplx_moe")
and infer_engine_config.enable_pplx_moe
)
):
assert ep_size == tp_size, "ep_size must be equal to tp_size"
ep_rank = tp_rank
if ep_size > 1:
ep_tp_size = tp_size // ep_size
ep_tp_rank = tp_rank % ep_tp_size
ep_rank = tp_rank // ep_tp_size
else:
assert ep_size == 1, "ep_size must be 1"
ep_rank = 0
ep_tp_size = 1
ep_tp_rank = 0
ep_tp_size = 1
ep_tp_rank = 0
return RankInfo(
tp_rank=tp_rank,
tp_size=tp_size,
Expand Down
3 changes: 0 additions & 3 deletions awex/tests/test_meta_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ def __init__(self):
server_args.enable_dp_lm_head = False
server_args.moe_dense_tp_size = 1
server_args.ep_size = 1
server_args.enable_ep_moe = False
server_args.enable_deepep_moe = False
server_args.enable_pplx_moe = False

self.engine = MagicMock()
self.engine.server_args = server_args
Expand Down
45 changes: 25 additions & 20 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,44 @@ The core functional modules of weight exchange consist mainly of 5 parts:
- **RDMA weight transmission**: Uses NUMA affinity and RDMA communication for globally load-balanced transfer plan for weight updates;

### (1) Unified Training-Inference Weight Convert

Due to different computational workloads, training and inference engines generally adopt different parallelism strategies. Megatron training engine uses 5D parallelism strategy, DeepSpeed/FSDP uses Zero + DP data parallelism, while SGLang and VLLM inference engines mostly use DP + TP + EP. Additionally, different engines perform **fusion, transposition, and quantization** optimizations on weights after loading to adapt to high-performance operators.

To eliminate differences between different engines for subsequent weight exchange, Awex constructs a **unified weight convert layer** that performs the following converts:

+ **Weight splitting**: Splits merged weights (such as FFN's gate/up) into independent weights, supporting cross-TP Resharding;
+ **Weight name unification**: Converts all internal weights from all engines to the same namespace, establishing weight mapping relationships between training and inference engines;
+ **Attention weight ReGroup**: On the training engine side, regroups and aligns QKV weights along the inference engine's TP/DPAttention parallelism strategy, avoiding shard explosion from fine-grained splitting;
+ **Quantization, precision, and format conversion**: Automatically converts weights on the training side **according to the precision and format of the inference side**, and this low-precision conversion can also reduce the amount of transmitted data.
- **Weight splitting**: Splits merged weights (such as FFN's gate/up) into independent weights, supporting cross-TP Resharding;
- **Weight name unification**: Converts all internal weights from all engines to the same namespace, establishing weight mapping relationships between training and inference engines;
- **Attention weight ReGroup**: On the training engine side, regroups and aligns QKV weights along the inference engine's TP/DPAttention parallelism strategy, avoiding shard explosion from fine-grained splitting;
- **Quantization, precision, and format conversion**: Automatically converts weights on the training side **according to the precision and format of the inference side**, and this low-precision conversion can also reduce the amount of transmitted data.

The entire weight convert adaptation layer is implemented as a **pluggable structure** that can be fully customized at the engine layer, model weight convert, and sharding layer to meet the customization needs of complex scenarios.

### (2) Global Weight Metadata Management

Each training and inference process needs to **be aware of the weight metadata of all training and inference processes globally** for constructing subsequent weight transfer plan. Awex also performs **consistency validation of weight metadata between training and inference** at this step. The main workflow is as follows:

+ Each process in the training engine performs weight convert and obtains metadata for the converted shards
+ Through all_gather_object, each rank obtains global training shard metadata
+ Rank0 on the training side serializes global metadata and reports it to Meta Server
+ Inference instance 0 on the inference side performs similar work; other inference instances have identical metadata and don't need additional computation
+ All training and inference processes obtain global metadata from MetaServer
+ Training and inference engines each perform shard-level metadata consistency and compatibility validation
- Each process in the training engine performs weight convert and obtains metadata for the converted shards
- Through all_gather_object, each rank obtains global training shard metadata
- Rank0 on the training side serializes global metadata and reports it to Meta Server
- Inference instance 0 on the inference side performs similar work; other inference instances have identical metadata and don't need additional computation
- All training and inference processes obtain global metadata from MetaServer
- Training and inference engines each perform shard-level metadata consistency and compatibility validation

### (3) P2P Weight Transmission Execution Plan

After obtaining global weight metadata, Awex constructs a **deterministic point-to-point transmission plan** within each training and inference process.

**Core Strategy** (NCCL mode):

+ For each replica of the same tensor shard, assign training shards to inference shards through Round Robin to ensure uniform pulling;
+ For overlapping shard intervals, if perfectly aligned, directly map; otherwise, use two sends to different shards;
+ Pre-filter shards related to the current process to avoid constructing a global plan (shards can reach tens of millions for trillion-parameter models);
+ Ensure strict order consistency of NCCL send/recv;
- For each replica of the same tensor shard, assign training shards to inference shards through Round Robin to ensure uniform pulling;
- For overlapping shard intervals, if perfectly aligned, directly map; otherwise, use two sends to different shards;
- Pre-filter shards related to the current process to avoid constructing a global plan (shards can reach tens of millions for trillion-parameter models);
- Ensure strict order consistency of NCCL send/recv;

RDMA is more flexible than NCCL and uses a separate transmission plan, which we will expand on in subsequent articles.

### (4) NCCL Weight Transmission

Awex supports two transmission modes: NCCL (NVIDIA Collective Communications Library) and RDMA (Remote Direct Memory Access). NCCL mode is more user-friendly, while RDMA mode is more flexible with higher performance.

NCCL transmission mode primarily uses NCCL's send/recv interface for weight transmission. There are some implementation differences in Awex for separated and co-located modes, which we will detail here.
Expand All @@ -81,12 +85,13 @@ In this case, Awex uses **CUDA IPC to zero-copy map the training process's GPU m

In implementation, we have also made some **performance optimizations**:

+ **Problem**: Each CUDA IPC Handle's Open/Close has significant overhead; MOE and other models may have thousands to tens of thousands of weight tensors per card requiring IPC serialization;
+ **Solution**: Before IPC serialization, merge tensors by shape and dtype, reducing the count to dozens, greatly reducing CUDA IPC overhead;
- **Problem**: Each CUDA IPC Handle's Open/Close has significant overhead; MOE and other models may have thousands to tens of thousands of weight tensors per card requiring IPC serialization;
- **Solution**: Before IPC serialization, merge tensors by shape and dtype, reducing the count to dozens, greatly reducing CUDA IPC overhead;

(**Note**: CUDA IPC does not support CUDA virtual memory. Future plans include allocating additional physical memory space for weight merging and transmission when enabling virtual GPU memory in the training engine)

### (5) RDMA Weight Transmission

Although NCCL transmission mode can already significantly improve weight exchange performance, NCCL mode has two main limitations:

1. **NCCL versions on training and inference sides need to remain compatible**, otherwise NCCL transmission may hang, preventing independent updates and iterations of training and inference engines;
Expand All @@ -100,9 +105,9 @@ Considering these two reasons, we also developed an RDMA-based transmission impl

**RDMA Mode Advantages**:

+ Removes NCCL version binding, supports independent iteration of training and inference engines
+ More flexible transmission plan optimization space
+ Supports dynamic scaling of inference instances
+ Further performance improvement (1T model from 20 seconds to 6 seconds)
- Removes NCCL version binding, supports independent iteration of training and inference engines
- More flexible transmission plan optimization space
- Supports dynamic scaling of inference instances
- Further performance improvement (1T model from 20 seconds to 6 seconds)

RDMA mode implementation will be open-sourced soon. Stay tuned.