Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,9 @@ class SGLangConfig:
# and passed as `model_loader_extra_config` to SGLang.
enable_multithread_load: bool = False

Comment thread
rchardx marked this conversation as resolved.
# Internal field, not exposed to users.
enable_return_routed_experts: bool = False

# Use staticmethod to make OmegaConf happy.
@staticmethod
def build_cmd(
Expand Down Expand Up @@ -1555,6 +1558,12 @@ class InferenceEngineConfig:
"help": "OpenAI proxy configuration (used when workflow is an agent workflow)."
},
)
return_routed_experts: bool = field(
default=False,
metadata={
"help": "Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models."
},
)

def __post_init__(self):
"""Validate scheduling_spec length."""
Expand Down Expand Up @@ -1675,7 +1684,12 @@ class SwanlabConfig:
config: dict | None = None
logdir: str | None = None
mode: str | None = "disabled"
api_key: str | None = os.getenv("SWANLAB_API_KEY", None)
# set None to prevent info-leak in docs
api_key: str | None = None

def __post_init__(self):
if self.api_key is None:
self.api_key = os.getenv("SWANLAB_API_KEY")
Comment thread
ZiyiTsang marked this conversation as resolved.


@dataclass
Expand Down
4 changes: 4 additions & 0 deletions areal/api/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class ModelResponse:
ttft: float = float("inf") # Time to first token
itl: list[float] = field(default_factory=list) # List of inter-token latencies

# MoE routing (only populated when return_routed_experts=True)
routed_experts: np.ndarray | None = None

@property
def input_len(self) -> int:
return len(self.input_tokens)
Expand Down Expand Up @@ -271,6 +274,7 @@ class HttpGenerationResult:
output_tokens: list[int]
output_logprobs: list[float]
stop_reason: str
routed_experts: np.ndarray | None = None


@dataclass
Expand Down
20 changes: 19 additions & 1 deletion areal/engine/sglang_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from concurrent.futures import Future
from typing import Any

import numpy as np
import pybase64
from torchdata.stateful_dataloader import StatefulDataLoader

from areal.api.cli_args import InferenceEngineConfig, PerfTracerConfig, SGLangConfig
Expand Down Expand Up @@ -66,6 +68,9 @@ def build_generation_request(
"stream": False,
}

# Add return_routed_experts to payload if set
if req.metadata.get("return_routed_experts", False):
payload["return_routed_experts"] = True
# Add LoRA if initialized
if with_lora:
lora_name = gconfig.lora_name
Expand All @@ -85,11 +90,24 @@ def parse_generation_response(
finish_reason = meta_info["finish_reason"]
stop_reason = finish_reason["type"]
stop_message = finish_reason.get("message", "")

# Extract routed_experts information if available
routed_experts = meta_info.get("routed_experts", None)
if routed_experts is not None:
num_sgl_token = (
meta_info["prompt_tokens"] + meta_info["completion_tokens"] - 1
)
Comment thread
ZiyiTsang marked this conversation as resolved.
Comment on lines +97 to +99
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better maintainability, could you add a brief comment explaining why 1 is subtracted in the num_sgl_token calculation? This would clarify the logic for future developers who might not be familiar with the specifics of the sglang API's token counting.

# Extract expert_id and reshape to (num_sgl_token, num_layers*expert_top_k)
routed_experts = np.frombuffer(
pybase64.b64decode(routed_experts.encode("utf-8")), dtype=np.int32
).reshape(num_sgl_token, -1)

if stop_reason == "abort" and stop_message.startswith("Abort before prefill"):
return HttpGenerationResult(
output_tokens=[],
output_logprobs=[],
stop_reason=stop_reason,
routed_experts=routed_experts,
)

output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
Expand All @@ -99,6 +117,7 @@ def parse_generation_response(
output_tokens=output_tokens,
output_logprobs=output_logprobs,
stop_reason=stop_reason,
routed_experts=routed_experts,
)

def build_disk_weight_update_requests(
Expand Down Expand Up @@ -211,7 +230,6 @@ def get_onload_request(self, tags: list[str] | None = None) -> HttpRequest:
def launch_server(self, server_args: dict[str, Any]) -> subprocess.Popen:
"""Launch SGLang server subprocess."""
cmd = SGLangConfig.build_cmd_from_args(server_args)

_env = os.environ.copy()
triton_cache_path = _env.get("TRITON_CACHE_PATH", TRITON_CACHE_PATH)
_env["TRITON_CACHE_PATH"] = os.path.join(triton_cache_path, str(uuid.uuid4()))
Expand Down
3 changes: 3 additions & 0 deletions areal/infra/launcher/sglang_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ def launch_sglang_server(argv):
# Get CPU per GPU from rollout scheduling spec
rollout_spec = get_scheduling_spec(config.rollout)

if config.rollout.return_routed_experts:
config.sglang.enable_return_routed_experts = True

sglang_server = SGLangServerWrapper(
config.experiment_name,
config.trial_name,
Expand Down
27 changes: 27 additions & 0 deletions areal/infra/remote_inf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Any, Protocol

import aiohttp
import numpy as np
import ray
import requests
import torch.distributed as dist
Expand Down Expand Up @@ -716,6 +717,10 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse:
# we are going to modify it in-place
req = req.copy()

# Populate return_routed_experts from config to metadata
if self.config.return_routed_experts:
req.metadata["return_routed_experts"] = True

# Validate n_samples
gconfig = req.gconfig
if gconfig.n_samples != 1:
Expand Down Expand Up @@ -743,6 +748,7 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse:
accumulated_output_tokens = []
accumulated_output_logprobs = []
accumulated_versions = []
accumulated_routed_experts: list[np.ndarray] = []

# A single "rid" shares the same server to allow KV cache reuse
if req.rid in self.rid_to_address:
Expand Down Expand Up @@ -799,12 +805,26 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse:
gen_result = self.backend.parse_generation_response(result)
stop_reason = gen_result.stop_reason

if (
req.metadata.get("return_routed_experts", False)
and gen_result.routed_experts is None
):
if stop_reason != "abort": # Only validate for successful generations
raise RuntimeError(
"Requested return_routed_experts=True but received None from SGLang. "
"This usually means the model is not a MoE (Mixture of Experts) model. "
"Please use a MoE model to get routed_experts information."
)

# Update accumulated outputs
accumulated_output_tokens.extend(gen_result.output_tokens)
accumulated_output_logprobs.extend(gen_result.output_logprobs)
Comment thread
ZiyiTsang marked this conversation as resolved.
accumulated_versions.extend(
[self.get_version()] * len(gen_result.output_tokens)
)
# Accumulate routed_experts for MoE models
if gen_result.routed_experts is not None:
accumulated_routed_experts.append(gen_result.routed_experts)

# Update request for next iteration
req.input_ids += gen_result.output_tokens
Expand All @@ -824,6 +844,12 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse:

latency = time.perf_counter() - start_time

accumulated_routed_experts = (
np.concatenate(accumulated_routed_experts)
if accumulated_routed_experts
else None
)

response = ModelResponse(
input_tokens=req.input_ids[
: len(req.input_ids) - len(accumulated_output_tokens)
Expand All @@ -837,6 +863,7 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse:
ttft=latency, # Simplified for non-streaming
tokenizer=req.tokenizer,
processor=req.processor,
routed_experts=accumulated_routed_experts,
)
return response

Expand Down
20 changes: 20 additions & 0 deletions areal/trainer/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def __init__(
# Parse allocation mode.
self.allocation_mode = AllocationMode.from_str(config.allocation_mode)

# Validate config before proceeding with weight initialization
self._validate_cfg()

self._amend_xccl_weight_update_envvar()

# Create models: actor, critic, etc.
Expand Down Expand Up @@ -652,6 +655,8 @@ def _init_rollout(

# Determine engine class and server args based on backend
if self.allocation_mode.gen_backend == "sglang":
if self.config.rollout.return_routed_experts:
self.config.sglang.enable_return_routed_experts = True
if lora_path is not None and self.config.actor.use_lora:
self.config.sglang.lora_paths = [
f"{self.config.gconfig.lora_name}-v0={lora_path}"
Expand All @@ -663,6 +668,10 @@ def _init_rollout(
base_gpu_id=0,
)
elif self.allocation_mode.gen_backend == "vllm":
if self.config.rollout.return_routed_experts:
raise ValueError(
"return_routed_experts is not supported with vLLM backend. Please disable return_routed_experts or switch to SGLang backend."
)
Comment on lines +671 to +674
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This validation check is redundant because the _validate_cfg method, which is called earlier in __init__, already performs the same check. To avoid code duplication, this if block can be removed.

if lora_path is not None and self.config.actor.use_lora:
self.config.vllm.lora_modules = [
f"{self.config.gconfig.lora_name}-v0={lora_path}"
Expand Down Expand Up @@ -842,6 +851,17 @@ def _export_and_commit_stats(self, epoch: int, epoch_step: int, global_step: int
dist.barrier(group=self.actor.cpu_group)
current_platform.synchronize()

def _validate_cfg(self):
"""validate config for incompatible settings before weight initialization, to avoid wasted resources on spawning workers and loading models."""
if (
self.allocation_mode.gen_backend == "vllm"
and self.config.rollout.return_routed_experts
):
raise ValueError(
"return_routed_experts is only supported with SGLang backend. "
"Please disable return_routed_experts or switch to SGLang backend."
)

def _requires_proxy_workflow(self, workflow: WorkflowLike | None) -> bool:
"""Check if workflow requires proxy workers (i.e., not a RolloutWorkflow).

Expand Down
2 changes: 2 additions & 0 deletions docs/cli_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ Configuration for inference servers, including offpolicyness control.
| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the RolloutController. |
| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. |
| `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). |
| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. |

(section-sg-lang)=

Expand Down Expand Up @@ -572,6 +573,7 @@ https://github.com/sgl-project/sglang for detailed documentation.
| `enable_metrics` | boolean | `True` | - |
| `decode_log_interval` | integer | `1` | - |
| `enable_multithread_load` | boolean | `False` | - |
| `enable_return_routed_experts` | boolean | `False` | - |

(section-v-llm)=

Expand Down
Loading