Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sglang/srt/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Step3VisionEncoderConfig,
Step3VLConfig,
)
from sglang.srt.configs.step3p5 import Step3p5Config

__all__ = [
"AfmoeConfig",
Expand All @@ -50,4 +51,5 @@
"NemotronH_Nano_VL_V2_Config",
"JetNemotronConfig",
"JetVLMConfig",
"Step3p5Config",
]
24 changes: 24 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ def _config_draft_model(self):
and self.hf_config.architectures[0] == "MiMoV2FlashForCausalLM"
):
self.hf_config.architectures[0] = "MiMoV2MTP"
if is_draft_model and self.hf_config.architectures[0] == "Step3p5ForCausalLM":
self.hf_config.architectures[0] = "Step3p5MTP"
if is_draft_model and self.hf_config.architectures[0] in [
"BailingMoeV2ForCausalLM",
"BailingMoeForCausalLM",
Expand Down Expand Up @@ -606,6 +608,11 @@ def get_swa_num_kv_heads(self, tensor_parallel_size) -> int:
if hasattr(self.hf_text_config, "swa_num_key_value_heads"):
total_num_kv_heads = self.hf_text_config.swa_num_key_value_heads
return max(1, total_num_kv_heads // tensor_parallel_size)
elif hasattr(self.hf_text_config, "attention_other_setting"): # For step3p5
total_num_kv_heads = self.hf_text_config.attention_other_setting.get(
"num_attention_groups"
)
return max(1, total_num_kv_heads // tensor_parallel_size)
else:
return self.get_num_kv_heads(tensor_parallel_size)

Expand Down Expand Up @@ -1268,6 +1275,8 @@ def is_hybrid_swa_model(model_architectures: List[str]):
"GptOssForCausalLM",
"MiMoV2FlashForCausalLM",
"MiMoV2MTP",
"Step3p5ForCausalLM",
"Step3p5MTP",
}
return any(arch in hybrid_swa_archs for arch in model_architectures)

Expand Down Expand Up @@ -1303,6 +1312,21 @@ def get_hybrid_layer_ids(
elif "MiMoV2MTP" in model_architectures:
swa_attention_layer_ids = [0]
full_attention_layer_ids = []
elif "Step3p5ForCausalLM" in model_architectures:
layer_types = hf_text_config.layer_types
swa_attention_layer_ids = [
i
for i, x in enumerate(layer_types)
if x == "sliding_attention" and i < num_hidden_layers
]
full_attention_layer_ids = [
i
for i, x in enumerate(layer_types)
if x == "full_attention" and i < num_hidden_layers
]
elif "Step3p5MTP" in model_architectures:
swa_attention_layer_ids = [0]
full_attention_layer_ids = []
else:
swa_attention_layer_ids = None
full_attention_layer_ids = None
Expand Down
97 changes: 97 additions & 0 deletions python/sglang/srt/configs/step3p5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Any, Optional

from transformers.configuration_utils import PretrainedConfig


class Step3p5Config(PretrainedConfig):
model_type = "step3p5"
architectures = ["Step3p5ForCausalLM"]

def __init__(
self,
hidden_size: int = 4096,
intermediate_size: int = 11264,
num_attention_heads: int = 64,
num_attention_groups: int = 8,
num_hidden_layers: int = 45,
max_seq_len: int = 128000,
vocab_size: int = 128815,
rms_norm_eps: float = 1e-5,
moe_intermediate_size: int = 1280,
moe_num_experts: int = 288,
moe_top_k: int = 8,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 128000,
share_expert_dims: int = 1280,
head_dim: int = 128,
norm_expert_weight: bool = True,
layer_types: list[str] = None,
sliding_window: Optional[int] = None,
moe_layers_enum: tuple[int] = (
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
),
**kwargs,
) -> None:
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.max_position_embeddings = max_position_embeddings
self.share_expert_dim = share_expert_dims
self.head_dim = head_dim
self.norm_expert_weight = norm_expert_weight
self.moe_layers_enum = moe_layers_enum
self.layer_types = layer_types
self.sliding_window = sliding_window
super().__init__(**kwargs)
1 change: 1 addition & 0 deletions python/sglang/srt/function_call/function_call_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class FunctionCallParser:
"qwen25": Qwen25Detector,
"qwen3_coder": Qwen3CoderDetector,
"step3": Step3Detector,
"step3p5": Qwen3CoderDetector,
"minimax-m2": MinimaxM2Detector,
"trinity": TrinityDetector,
"interns1": InternlmDetector,
Expand Down
25 changes: 20 additions & 5 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,18 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):


@torch.compile
def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
def _swiglu_silu_clamp_mul(x, gemm1_limit):
gate, up = x.chunk(2, dim=-1)
gate = F.silu(gate)
gate = gate.clamp(min=None, max=gemm1_limit)
up = up.clamp(min=-gemm1_limit, max=gemm1_limit)
return gate * up


@torch.compile
def _swiglu_gpt_oss_sigmoid_alpha(x, gemm1_alpha, gemm1_limit):
# NOTE: This variant uses gemm1_alpha, unlike _swiglu_silu_clamp_mul.
# At present, only GPT-OSS uses this variant.
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=gemm1_limit)
up = up.clamp(min=-gemm1_limit, max=gemm1_limit)
Expand Down Expand Up @@ -471,12 +482,16 @@ def fused_experts_impl(

# Activation function with multiplication
if activation == "silu" and is_gated:
# - gemm1_alpha != None: GPT-OSS-style swiglu(alpha, limit)
# - gemm1_alpha == None and gemm1_limit != None: silu+clamp+mul(limit-only)
if gemm1_alpha is not None:
assert gemm1_limit is not None
intermediate_cache2 = swiglu_with_alpha_and_limit(
intermediate_cache1.view(-1, N),
gemm1_alpha,
gemm1_limit,
intermediate_cache2 = _swiglu_gpt_oss_sigmoid_alpha(
intermediate_cache1.view(-1, N), gemm1_alpha, gemm1_limit
)
elif gemm1_limit is not None:
intermediate_cache2 = _swiglu_silu_clamp_mul(
intermediate_cache1.view(-1, N), gemm1_limit
)
elif _is_cuda or _is_hip:
if not filter_expert:
Expand Down
13 changes: 8 additions & 5 deletions python/sglang/srt/layers/moe/moe_runner/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,11 @@ def run(

# TODO: move these functions to the triton runner
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
_swiglu_gpt_oss_sigmoid_alpha,
_swiglu_silu_clamp_mul,
invoke_fused_moe_kernel,
moe_sum_reduce_torch_compile,
moe_sum_reduce_triton,
swiglu_with_alpha_and_limit,
)

hidden_states = runner_input.hidden_states
Expand Down Expand Up @@ -203,10 +204,12 @@ def run(
if activation == "silu":
if gemm1_alpha is not None:
assert gemm1_limit is not None
intermediate_cache2 = swiglu_with_alpha_and_limit(
intermediate_cache1.view(-1, N),
gemm1_alpha,
gemm1_limit,
intermediate_cache2 = _swiglu_gpt_oss_sigmoid_alpha(
intermediate_cache1.view(-1, N), gemm1_alpha, gemm1_limit
)
elif gemm1_limit is not None:
intermediate_cache2 = _swiglu_silu_clamp_mul(
intermediate_cache1.view(-1, N), gemm1_limit
)
elif _is_cuda or _is_hip:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,8 @@ def initialize(self, min_per_gpu_memory: float):
)
if self.model_config.hf_config.architectures[0] == "MiMoV2MTP":
model_num_layers = 1
elif self.model_config.hf_config.architectures[0] == "Step3p5MTP":
model_num_layers = 1
self.start_layer = getattr(self.model, "start_layer", 0)
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
self.num_effective_layers = self.end_layer - self.start_layer
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ def _initialize_model(
kwargs["sparse_head"] = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.get()
kwargs["model_path"] = model_config.model_path

if load_config.draft_model_idx is not None:
kwargs["draft_model_idx"] = load_config.draft_model_idx

return model_class(**kwargs)


Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/models/mimo_v2_flash_nextn.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
draft_model_idx: Optional[int] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
Expand Down
Loading
Loading