Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
[![Python Versions](https://img.shields.io/pypi/pyversions/awex.svg?style=for-the-badge&logo=python)](https://pypi.org/project/awex/)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg?style=for-the-badge)](https://opensource.org/licenses/Apache-2.0)


**Awex** is a high-performance RL training-inference **weight synchronization** framework,
designed to enable **second-level parameter updates** from training to inference in RL workflows.
It minimizes iteration latency, ensuring rollout phases consistently use the latest model.
Expand Down Expand Up @@ -139,6 +138,7 @@ These scripts compare weight formats across Megatron, vLLM, and SGLang by
converting all parameters into HF-style names and then diffing tensors.

**Intended use** (for new model bring‑up):

- These scripts primarily validate Awex converter coverage. They help answer:
“Does the current converter support this new model, or do we need mapping fixes?”
- If your target stack is **Megatron → vLLM**, usually running
Expand All @@ -147,22 +147,25 @@ converting all parameters into HF-style names and then diffing tensors.
parity (or you’re adding SGLang support for a new model).

**GPU/NPU notes**

- All compare/verify scripts accept `--device-backend` (auto/cuda/npu/cpu), but
they are **CUDA-only today** because vLLM/SGLang backends require CUDA.
Use `--device-backend cuda` explicitly if auto-detection picks the wrong device.
- For NPU, use these scripts on CUDA to validate **converter coverage**, then
validate the **runtime weight update** path on NPU with the integration tests.

### Naming normalization (why `self_attn.qkv_proj` becomes `attention.query_key_value_proj`)

Awex normalizes parameter names from different backends into a single canonical
HF-style naming scheme so Megatron, vLLM, and SGLang can be compared directly.
There are three “namespaces” involved:

1) **Megatron (mcore) names** – e.g. `decoder.layers.0.self_attention.linear_qkv.weight`
2) **vLLM/SGLang names** – e.g. `model.layers.0.self_attn.qkv_proj.weight`
3) **Awex canonical HF-style names** – e.g. `model.layers.0.attention.query_key_value_proj.weight`
1. **Megatron (mcore) names** – e.g. `decoder.layers.0.self_attention.linear_qkv.weight`
2. **vLLM/SGLang names** – e.g. `model.layers.0.self_attn.qkv_proj.weight`
3. **Awex canonical HF-style names** – e.g. `model.layers.0.attention.query_key_value_proj.weight`

Example for **QKV** conversion:

- Megatron `self_attention.linear_qkv.weight`
→ (mcore converter) `self_attn.qkv_proj.weight`
→ (normalize) `attention.query_key_value_proj.weight`
Expand Down Expand Up @@ -283,6 +286,7 @@ Ascend and **vllm-ascend** on the inference side.
should remain disabled in normal runs.

**What is NOT NPU-ready yet**

- `compare_megatron_vllm_weights.py`, `verify_weight_conversion.py`, and
`compare_vllm_sglang_weights.py` are **CUDA-only** (they rely on vLLM CUDA
kernels and torch.cuda).
Expand Down
20 changes: 16 additions & 4 deletions awex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

class InferenceConfigValidationError(ValueError):
"""Raised when InferenceConfig contains invalid or inconsistent settings."""

pass


Expand Down Expand Up @@ -153,7 +154,10 @@ def validate(self) -> None:
)

# dump_weights_dir_for_validation is required when dump_weights_list_for_validation is non-empty
if self.dump_weights_list_for_validation and not self.dump_weights_dir_for_validation:
if (
self.dump_weights_list_for_validation
and not self.dump_weights_dir_for_validation
):
errors.append(
"dump_weights_dir_for_validation must be set when "
"dump_weights_list_for_validation is non-empty"
Expand All @@ -168,14 +172,20 @@ def validate(self) -> None:
errors.append("enable_eplb requires ep_size to be set")

# non-file comm_backend requires meta_server_addr for multi-engine setups
if self.num_engines > 1 and self.comm_backend != "file" and not self.meta_server_addr:
if (
self.num_engines > 1
and self.comm_backend != "file"
and not self.meta_server_addr
):
errors.append(
f"meta_server_addr must be set when num_engines > 1 "
f"and comm_backend is {self.comm_backend!r}"
)

if errors:
msg = "InferenceConfig validation failed:\n" + "\n".join(f" - {e}" for e in errors)
msg = "InferenceConfig validation failed:\n" + "\n".join(
f" - {e}" for e in errors
)
raise InferenceConfigValidationError(msg)

def validated(self) -> "InferenceConfig":
Expand All @@ -189,7 +199,9 @@ def validated(self) -> "InferenceConfig":
return self

@staticmethod
def from_dict(config_dict: Dict[str, Any], validate: bool = True) -> "InferenceConfig":
def from_dict(
config_dict: Dict[str, Any], validate: bool = True
) -> "InferenceConfig":
# remove all keys that are not fields of InferenceConfig
config_dict = {
k: v
Expand Down
2 changes: 1 addition & 1 deletion awex/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def import_model_configs():

def get_sharding_strategy(model_name: str):
config = ModelRegistry.get_model_config(model_name)
return config.sharding_strategy
return _get_config_value(config, "sharding_strategy", ShardingStrategy)


def _resolve_converter(
Expand Down
8 changes: 7 additions & 1 deletion awex/tests/test_weights_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ def cleanup_and_exit(signum, frame):
# Build recv operations in round-robin order to match the sender's round-robin pattern
# The sender uses nccl_build_send_ops which interleaves operations across ranks
all_ranks = list(transfer_plan.operations.keys()) # Preserve plan's order
p2p_ops = nccl_build_recv_ops(parameters, transfer_plan, weights_update_group)
p2p_ops, non_contiguous_tensor_pairs, _recv_traj = nccl_build_recv_ops(
parameters, transfer_plan, weights_update_group
)
logger.info(
f"Reader (rank 0): Building recv operations from sender ranks: {all_ranks}"
)
Expand All @@ -329,6 +331,10 @@ def cleanup_and_exit(signum, frame):
# scheduling and stream assignment logic as the production reader.
logger.info(f"Test reader: Executing {len(p2p_ops)} recv ops via batch_send_recv")
batch_send_recv(send_ops=[], recv_ops=p2p_ops, blocking=True, use_group=True)
if non_contiguous_tensor_pairs:
with torch.no_grad():
for original_tensor, recv_tensor in non_contiguous_tensor_pairs:
original_tensor.copy_(recv_tensor)
logger.info("All recv operations completed, synchronizing CUDA")
device_util.synchronize(device_util.current_device())
logger.info("Finished receiving weights")
Expand Down
5 changes: 4 additions & 1 deletion awex/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,10 @@ def check_train_infer_params_meta(
logger.error(error_msg)
infer_tp_size = len(infer_param_meta.replicas[0].shards)
train_tp_size = len(train_param_meta.replicas[0].shards)
max_tp, min_tp = max(infer_tp_size, train_tp_size), min(infer_tp_size, train_tp_size)
max_tp, min_tp = (
max(infer_tp_size, train_tp_size),
min(infer_tp_size, train_tp_size),
)
if max_tp % min_tp != 0:
error_msg = (
f"Inference for parameter {param_name} has wrong tp_size: "
Expand Down
8 changes: 6 additions & 2 deletions awex/util/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.

import os
from importlib.metadata import version

import torch
import torch.distributed as dist
from packaging.version import Version

from awex import logging
from awex.util import device as device_util
Expand Down Expand Up @@ -78,7 +80,9 @@ def init_custom_process_group(
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
# We need to determine the appropriate parameter name based on PyTorch version
pg_options_param_name = (
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
"backend_options"
if Version(version("torch")) >= Version("2.6")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While using packaging.version.Version correctly fixes the string comparison bug, using importlib.metadata.version("torch") is less direct and potentially less robust than using the __version__ attribute already available on the imported torch module. The metadata query can fail in certain environments (e.g., non-standard installations) even if the module is successfully loaded.

Suggested change
if Version(version("torch")) >= Version("2.6")
if Version(torch.__version__) >= Version("2.6")

else "pg_options"
)
pg, _ = _new_process_group_helper(
world_size,
Expand Down Expand Up @@ -157,7 +161,7 @@ def create_pair_subgroups_from_parent(parent_group, world_size):
# Determine the appropriate parameter name based on PyTorch version
pg_options_param_name = (
"backend_options"
if str(torch.__version__) >= "2.6"
if Version(version("torch")) >= Version("2.6")
else "pg_options"
)

Expand Down
58 changes: 55 additions & 3 deletions awex/vllm_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,31 @@
import os
from typing import Any

from fastapi import Request
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from vllm.entrypoints.openai.api_server import router
from vllm.entrypoints.openai.protocol import OpenAIBaseModel

from awex.config import InferenceConfig
from awex.vllm_awex_adapter import AwexVLLMServerAdapter

logger = logging.getLogger(__name__)

# Newer vLLM moved OpenAIBaseModel and removed the shared module-level router.
# Try new paths first, fall back to legacy.
try:
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
except ImportError:
from vllm.entrypoints.openai.protocol import OpenAIBaseModel

try:
from vllm.entrypoints.openai.api_server import router # type: ignore[attr-defined]

_USING_LEGACY_VLLM_ROUTER = True
except ImportError:
router = APIRouter()
_USING_LEGACY_VLLM_ROUTER = False

_awex_build_app_patched = False

_awex_plugin_registered = False
_AWEX_WORKER_METHODS = {
"_get_model_param_info": (
Expand Down Expand Up @@ -395,6 +410,42 @@ def flush_cache(self):
return True


def _ensure_router_attached() -> None:
"""Attach ``router`` to vLLM's FastAPI app on newer vLLM releases.

Legacy vLLM picked up our routes automatically because we registered them
on the shared ``vllm.entrypoints.openai.api_server.router``. Newer vLLM
removed that shared router, so we patch ``build_app`` to include our local
router on every FastAPI app it constructs.
"""
global _awex_build_app_patched
if _USING_LEGACY_VLLM_ROUTER or _awex_build_app_patched:
return
try:
from vllm.entrypoints.openai import api_server as _api_server_module
except ImportError as exc:
logger.warning("Cannot patch vLLM build_app for Awex routes: %s", exc)
return
original_build_app = getattr(_api_server_module, "build_app", None)
if original_build_app is None:
logger.warning(
"vLLM api_server has no build_app; Awex routes will not be attached."
)
return

def _awex_build_app(*args, **kwargs):
app = original_build_app(*args, **kwargs)
try:
app.include_router(router)
logger.info("Attached Awex router to vLLM FastAPI app.")
except Exception as exc:
logger.exception("Failed to attach Awex router to FastAPI app: %s", exc)
return app

_api_server_module.build_app = _awex_build_app
_awex_build_app_patched = True


def register_awex_plugin() -> None:
"""Register Awex endpoints and worker patches for vLLM."""
global _awex_plugin_registered
Expand All @@ -403,6 +454,7 @@ def register_awex_plugin() -> None:
_awex_plugin_registered = True

_patch_awex_worker()
_ensure_router_attached()

@router.post("/areal_awex_init")
async def awex_init(request: AwexInitRequest, raw_request: Request):
Comment on lines 459 to 460
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The new endpoints /areal_awex_init and /areal_awex_update are registered without any explicit authentication or authorization dependencies. Since these endpoints can trigger significant state changes (like re-initializing the NCCL group or updating model weights), they could be exploited if the vLLM server is exposed. Consider ensuring these routes are protected by the same security mechanisms (e.g., API key checks) used for the standard OpenAI-compatible endpoints.

Expand Down
1 change: 0 additions & 1 deletion awex/writer/nccl_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def _initialize(self):
)

def _shake_hands_with_reader(self):

if self.transfer_rank == self.transfer_world_size - 1:
logger.info(
f"Start to test NCCL ready for rank {self.transfer_rank}, world size {self.transfer_world_size}"
Expand Down
Loading