Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sglang/srt/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Step3VisionEncoderConfig,
Step3VLConfig,
)
from sglang.srt.configs.step3p5 import Step3p5Config

__all__ = [
"AfmoeConfig",
Expand All @@ -50,4 +51,5 @@
"NemotronH_Nano_VL_V2_Config",
"JetNemotronConfig",
"JetVLMConfig",
"Step3p5Config",
]
24 changes: 24 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ def _config_draft_model(self):
and self.hf_config.architectures[0] == "MiMoV2FlashForCausalLM"
):
self.hf_config.architectures[0] = "MiMoV2MTP"
if is_draft_model and self.hf_config.architectures[0] == "Step3p5ForCausalLM":
self.hf_config.architectures[0] = "Step3p5MTP"
if is_draft_model and self.hf_config.architectures[0] in [
"BailingMoeV2ForCausalLM",
"BailingMoeForCausalLM",
Expand Down Expand Up @@ -606,6 +608,11 @@ def get_swa_num_kv_heads(self, tensor_parallel_size) -> int:
if hasattr(self.hf_text_config, "swa_num_key_value_heads"):
total_num_kv_heads = self.hf_text_config.swa_num_key_value_heads
return max(1, total_num_kv_heads // tensor_parallel_size)
elif hasattr(self.hf_text_config, "attention_other_setting"): # For step3p5
total_num_kv_heads = self.hf_text_config.attention_other_setting.get(
"num_attention_groups"
)
return max(1, total_num_kv_heads // tensor_parallel_size)
else:
return self.get_num_kv_heads(tensor_parallel_size)

Expand Down Expand Up @@ -1268,6 +1275,8 @@ def is_hybrid_swa_model(model_architectures: List[str]):
"GptOssForCausalLM",
"MiMoV2FlashForCausalLM",
"MiMoV2MTP",
"Step3p5ForCausalLM",
"Step3p5MTP",
}
return any(arch in hybrid_swa_archs for arch in model_architectures)

Expand Down Expand Up @@ -1303,6 +1312,21 @@ def get_hybrid_layer_ids(
elif "MiMoV2MTP" in model_architectures:
swa_attention_layer_ids = [0]
full_attention_layer_ids = []
elif "Step3p5ForCausalLM" in model_architectures:
layer_types = hf_text_config.layer_types
swa_attention_layer_ids = [
i
for i, x in enumerate(layer_types)
if x == "sliding_attention" and i < num_hidden_layers
]
full_attention_layer_ids = [
i
for i, x in enumerate(layer_types)
if x == "full_attention" and i < num_hidden_layers
]
elif "Step3p5MTP" in model_architectures:
swa_attention_layer_ids = [0]
full_attention_layer_ids = []
else:
swa_attention_layer_ids = None
full_attention_layer_ids = None
Expand Down
97 changes: 97 additions & 0 deletions python/sglang/srt/configs/step3p5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Any, Optional

from transformers.configuration_utils import PretrainedConfig


class Step3p5Config(PretrainedConfig):
model_type = "step3p5"
architectures = ["Step3p5ForCausalLM"]

def __init__(
self,
hidden_size: int = 4096,
intermediate_size: int = 11264,
num_attention_heads: int = 64,
num_attention_groups: int = 8,
num_hidden_layers: int = 45,
max_seq_len: int = 128000,
vocab_size: int = 128815,
rms_norm_eps: float = 1e-5,
moe_intermediate_size: int = 1280,
moe_num_experts: int = 288,
moe_top_k: int = 8,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 128000,
share_expert_dims: int = 1280,
head_dim: int = 128,
norm_expert_weight: bool = True,
layer_types: list[str] = None,
sliding_window: Optional[int] = None,
moe_layers_enum: tuple[int] = (
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
),
**kwargs,
) -> None:
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.max_position_embeddings = max_position_embeddings
self.share_expert_dim = share_expert_dims
self.head_dim = head_dim
self.norm_expert_weight = norm_expert_weight
self.moe_layers_enum = moe_layers_enum
self.layer_types = layer_types
self.sliding_window = sliding_window
super().__init__(**kwargs)
1 change: 1 addition & 0 deletions python/sglang/srt/function_call/function_call_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class FunctionCallParser:
"qwen25": Qwen25Detector,
"qwen3_coder": Qwen3CoderDetector,
"step3": Step3Detector,
"step3p5": Qwen3CoderDetector,
"minimax-m2": MinimaxM2Detector,
"trinity": TrinityDetector,
"interns1": InternlmDetector,
Expand Down
25 changes: 20 additions & 5 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,18 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):


@torch.compile
def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
def _swiglu_silu_clamp_mul(x, gemm1_limit):
gate, up = x.chunk(2, dim=-1)
gate = F.silu(gate)
gate = gate.clamp(min=None, max=gemm1_limit)
up = up.clamp(min=-gemm1_limit, max=gemm1_limit)
return gate * up


@torch.compile
def _swiglu_gpt_oss_sigmoid_alpha(x, gemm1_alpha, gemm1_limit):
# NOTE: This variant uses gemm1_alpha, unlike _swiglu_silu_clamp_mul.
# At present, only GPT-OSS uses this variant.
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=gemm1_limit)
up = up.clamp(min=-gemm1_limit, max=gemm1_limit)
Expand Down Expand Up @@ -471,12 +482,16 @@ def fused_experts_impl(

# Activation function with multiplication
if activation == "silu" and is_gated:
# - gemm1_alpha != None: GPT-OSS-style swiglu(alpha, limit)
# - gemm1_alpha == None and gemm1_limit != None: silu+clamp+mul(limit-only)
if gemm1_alpha is not None:
assert gemm1_limit is not None
intermediate_cache2 = swiglu_with_alpha_and_limit(
intermediate_cache1.view(-1, N),
gemm1_alpha,
gemm1_limit,
intermediate_cache2 = _swiglu_gpt_oss_sigmoid_alpha(
intermediate_cache1.view(-1, N), gemm1_alpha, gemm1_limit
)
elif gemm1_limit is not None:
intermediate_cache2 = _swiglu_silu_clamp_mul(
intermediate_cache1.view(-1, N), gemm1_limit
)
elif _is_cuda or _is_hip:
if not filter_expert:
Expand Down
13 changes: 8 additions & 5 deletions python/sglang/srt/layers/moe/moe_runner/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,11 @@ def run(

# TODO: move these functions to the triton runner
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
_swiglu_gpt_oss_sigmoid_alpha,
_swiglu_silu_clamp_mul,
invoke_fused_moe_kernel,
moe_sum_reduce_torch_compile,
moe_sum_reduce_triton,
swiglu_with_alpha_and_limit,
)

hidden_states = runner_input.hidden_states
Expand Down Expand Up @@ -203,10 +204,12 @@ def run(
if activation == "silu":
if gemm1_alpha is not None:
assert gemm1_limit is not None
intermediate_cache2 = swiglu_with_alpha_and_limit(
intermediate_cache1.view(-1, N),
gemm1_alpha,
gemm1_limit,
intermediate_cache2 = _swiglu_gpt_oss_sigmoid_alpha(
intermediate_cache1.view(-1, N), gemm1_alpha, gemm1_limit
)
elif gemm1_limit is not None:
intermediate_cache2 = _swiglu_silu_clamp_mul(
intermediate_cache1.view(-1, N), gemm1_limit
)
elif _is_cuda or _is_hip:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,8 @@ def initialize(self, min_per_gpu_memory: float):
)
if self.model_config.hf_config.architectures[0] == "MiMoV2MTP":
model_num_layers = 1
elif self.model_config.hf_config.architectures[0] == "Step3p5MTP":
model_num_layers = 1
self.start_layer = getattr(self.model, "start_layer", 0)
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
self.num_effective_layers = self.end_layer - self.start_layer
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ def _initialize_model(
kwargs["sparse_head"] = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.get()
kwargs["model_path"] = model_config.model_path

if load_config.draft_model_idx is not None:
kwargs["draft_model_idx"] = load_config.draft_model_idx

return model_class(**kwargs)


Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/models/mimo_v2_flash_nextn.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
draft_model_idx: Optional[int] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
Expand Down
Loading
Loading