From 4ae48ff15bfe61b581e59e7ee8725e71e914fd7d Mon Sep 17 00:00:00 2001 From: Lifan Wu Date: Thu, 7 Aug 2025 11:56:20 +0800 Subject: [PATCH 1/3] [Feature] Add GPT-OSS model support --- vllm_ascend/models/__init__.py | 8 +- vllm_ascend/models/gpt_oss.py | 521 +++++++++++++++++++++++++++++++++ 2 files changed, 528 insertions(+), 1 deletion(-) create mode 100644 vllm_ascend/models/gpt_oss.py diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index f47e821b34..ee499ab228 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -8,6 +8,7 @@ def register_model(): from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 from .deepseek_v3 import CustomDeepseekV3ForCausalLM # noqa: F401 + from .gpt_oss import GPTOSSForCausalLM # noqa: F401 from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 @@ -17,6 +18,11 @@ def register_model(): "DeepSeekMTPModel", "vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP") + # Register GPT-OSS model + ModelRegistry.register_model( + "GPTOSSForCausalLM", + "vllm_ascend.models.gpt_oss:GPTOSSForCausalLM") + ModelRegistry.register_model( "Qwen2VLForConditionalGeneration", "vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration") @@ -58,4 +64,4 @@ def register_model(): ModelRegistry.register_model( "PanguProMoEForCausalLM", - "vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM") + "vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM") \ No newline at end of file diff --git a/vllm_ascend/models/gpt_oss.py b/vllm_ascend/models/gpt_oss.py new file mode 100644 index 0000000000..85b1abbd21 --- /dev/null +++ b/vllm_ascend/models/gpt_oss.py @@ -0,0 +1,521 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2024 OpenAI and the vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-OSS model on Ascend NPU.""" + +import math +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.utils import ( + PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.sequence import IntermediateTensors + +from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.utils import dispose_tensor + + +class GPTOSSConfig(PretrainedConfig): + """GPT-OSS model configuration.""" + + model_type = "gpt_oss" + + def __init__( + self, + vocab_size: int = 201088, + hidden_size: int = 2880, + intermediate_size: int = 2880, + num_hidden_layers: int = 36, + num_attention_heads: int = 64, + num_key_value_heads: int = 8, + head_dim: int = 64, + num_experts: int = 128, + experts_per_token: int = 4, + sliding_window: int = 128, + initial_context_length: int = 4096, + rope_theta: float = 150000.0, + rope_scaling_factor: float = 32.0, + rope_ntk_alpha: float = 1.0, + rope_ntk_beta: float = 32.0, + swiglu_limit: float = 7.0, + rms_norm_eps: float = 1e-5, + use_bias: bool = True, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.num_experts = num_experts + self.experts_per_token = experts_per_token + self.sliding_window = sliding_window + self.initial_context_length = initial_context_length + self.rope_theta = rope_theta + self.rope_scaling_factor = rope_scaling_factor + self.rope_ntk_alpha = rope_ntk_alpha + self.rope_ntk_beta = rope_ntk_beta + self.swiglu_limit = swiglu_limit + self.rms_norm_eps = rms_norm_eps + self.use_bias = use_bias + + super().__init__(**kwargs) + + +class GPTOSSAttention(nn.Module): + """GPT-OSS attention layer with sliding window support.""" + + def __init__( + self, + config: GPTOSSConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.num_heads = self.total_num_heads // get_tensor_model_parallel_world_size() + self.total_num_kv_heads = config.num_key_value_heads + self.num_kv_heads = max(1, self.total_num_kv_heads // get_tensor_model_parallel_world_size()) + self.head_dim = config.head_dim + + # Sliding window attention (every other layer) + self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else None + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + # QKV projection + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + # Output projection + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # Attention + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=1.0 / math.sqrt(self.head_dim), + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + sliding_window=self.sliding_window, + ) + + # RoPE + self.rotary_emb = get_rope( + head_dim=self.head_dim, + rotary_dim=self.head_dim, + max_position=config.initial_context_length * config.rope_scaling_factor, + base=int(config.rope_theta), + rope_scaling={ + "type": "yarn", + "factor": config.rope_scaling_factor, + "original_max_position_embeddings": config.initial_context_length, + "alpha": config.rope_ntk_alpha, + "beta": config.rope_ntk_beta, + }, + ) + + # Sink attention weights for streaming attention + self.sinks = nn.Parameter( + torch.zeros(self.num_heads, dtype=torch.float32) + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + + output, _ = self.o_proj(attn_output) + return output + + +class GPTOSSMoELayer(nn.Module): + """GPT-OSS MoE layer with swiglu activation.""" + + def __init__( + self, + config: GPTOSSConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.num_experts = config.num_experts + self.top_k = config.experts_per_token + + # Expert gate + self.gate = RowParallelLinear( + input_size=self.hidden_size, + output_size=self.num_experts, + bias=config.use_bias, + quant_config=None, # Gate typically not quantized + prefix=f"{prefix}.gate", + ) + + # MoE experts using AscendFusedMoE + self.experts = AscendFusedMoE( + num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + # Custom swiglu activation with limit + self.swiglu_limit = config.swiglu_limit + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + # Get routing logits + router_logits, _ = self.gate(hidden_states) + + # Apply MoE with custom swiglu + output = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + ) + + return output + + +class GPTOSSDecoderLayer(nn.Module): + """GPT-OSS decoder layer.""" + + def __init__( + self, + config: GPTOSSConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + # Attention + self.self_attn = GPTOSSAttention( + config=config, + layer_idx=layer_idx, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # MoE MLP + self.mlp = GPTOSSMoELayer( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + # Layer norms + self.input_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # MLP + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states, attn_metadata) + + return hidden_states, residual + + +class GPTOSSModel(nn.Module): + """GPT-OSS model.""" + + def __init__( + self, + config: GPTOSSConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.padding_idx = getattr(config, "pad_token_id", None) + + # Embeddings + self.embed_tokens = VocabParallelEmbedding( + vocab_size=self.vocab_size, + hidden_size=config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + + # Transformer layers + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda layer_idx, prefix: GPTOSSDecoderLayer( + config=config, + layer_idx=layer_idx, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + + # Final layer norm + self.norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states", "residual"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class GPTOSSForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + """GPT-OSS for causal language modeling.""" + + def __init__( + self, + config: GPTOSSConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = GPTOSSModel( + config, + cache_config, + quant_config, + prefix="model", + ) + + # Language model head + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + vocab_size=config.vocab_size, + hidden_size=config.hidden_size, + quant_config=quant_config, + prefix="lm_head", + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds, + ) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata, + ): + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias"): + name = name[:-5] + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias"): + name = name[:-5] + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 31f9af974c8b55e8f9b397470e1e14702ef44505 Mon Sep 17 00:00:00 2001 From: Lifan Wu Date: Mon, 18 Aug 2025 16:22:15 +0800 Subject: [PATCH 2/3] 1 --- vllm_ascend/models/gpt_oss.py | 47 +++++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/models/gpt_oss.py b/vllm_ascend/models/gpt_oss.py index 85b1abbd21..b264faf990 100644 --- a/vllm_ascend/models/gpt_oss.py +++ b/vllm_ascend/models/gpt_oss.py @@ -51,6 +51,7 @@ class GPTOSSConfig(PretrainedConfig): """GPT-OSS model configuration.""" model_type = "gpt_oss" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, @@ -65,6 +66,7 @@ def __init__( experts_per_token: int = 4, sliding_window: int = 128, initial_context_length: int = 4096, + max_position_embeddings: int = 4096, rope_theta: float = 150000.0, rope_scaling_factor: float = 32.0, rope_ntk_alpha: float = 1.0, @@ -72,6 +74,10 @@ def __init__( swiglu_limit: float = 7.0, rms_norm_eps: float = 1e-5, use_bias: bool = True, + bos_token_id: int = 1, + eos_token_id: int = 2, + pad_token_id: Optional[int] = None, + tie_word_embeddings: bool = False, **kwargs, ): self.vocab_size = vocab_size @@ -85,6 +91,7 @@ def __init__( self.experts_per_token = experts_per_token self.sliding_window = sliding_window self.initial_context_length = initial_context_length + self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta self.rope_scaling_factor = rope_scaling_factor self.rope_ntk_alpha = rope_ntk_alpha @@ -93,7 +100,22 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.use_bias = use_bias - super().__init__(**kwargs) + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +# Register the config with transformers +try: + from transformers import AutoConfig + AutoConfig.register("gpt_oss", GPTOSSConfig) +except ImportError: + # If transformers is not available, skip registration + pass class GPTOSSAttention(nn.Module): @@ -414,15 +436,20 @@ def forward( class GPTOSSForCausalLM(nn.Module, SupportsLoRA, SupportsPP): """GPT-OSS for causal language modeling.""" - def __init__( - self, - config: GPTOSSConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional = None, - ) -> None: + # 用于防止模型拆分的模块列表 + _no_split_modules = ["GPTOSSDecoderLayer", "GPTOSSAttention", "GPTOSSMoELayer"] + + # 添加 supports_multimodal 属性 + supports_multimodal = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.lora_config = lora_config @@ -430,7 +457,7 @@ def __init__( config, cache_config, quant_config, - prefix="model", + prefix=maybe_prefix(prefix, "model"), ) # Language model head @@ -439,7 +466,7 @@ def __init__( vocab_size=config.vocab_size, hidden_size=config.hidden_size, quant_config=quant_config, - prefix="lm_head", + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() From a95415fe9c462bb1f71b06b4d25712e8254aa24e Mon Sep 17 00:00:00 2001 From: Lifan Wu Date: Fri, 22 Aug 2025 18:33:59 +0800 Subject: [PATCH 3/3] 1 --- vllm_ascend/models/__init__.py | 13 +- vllm_ascend/models/gpt_oss.py | 699 +++++++++++---------------------- 2 files changed, 225 insertions(+), 487 deletions(-) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index bfde176d4b..48fe22cc0f 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -8,7 +8,7 @@ def register_model(): from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 from .deepseek_v3 import CustomDeepseekV3ForCausalLM # noqa: F401 - from .gpt_oss import GPTOSSForCausalLM # noqa: F401 + from .gpt_oss import CustomGptOssForCausalLM # noqa: F401 from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 @@ -18,11 +18,6 @@ def register_model(): "DeepSeekMTPModel", "vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP") - # Register GPT-OSS model - ModelRegistry.register_model( - "GPTOSSForCausalLM", - "vllm_ascend.models.gpt_oss:GPTOSSForCausalLM") - ModelRegistry.register_model( "Qwen2VLForConditionalGeneration", "vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration") @@ -64,4 +59,8 @@ def register_model(): ModelRegistry.register_model( "PanguProMoEForCausalLM", - "vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM") \ No newline at end of file + "vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM") + + ModelRegistry.register_model( + "GptOssForCausalLM", + "vllm_ascend.models.gpt_oss:CustomGptOssForCausalLM") diff --git a/vllm_ascend/models/gpt_oss.py b/vllm_ascend/models/gpt_oss.py index b264faf990..6838d93a4b 100644 --- a/vllm_ascend/models/gpt_oss.py +++ b/vllm_ascend/models/gpt_oss.py @@ -1,548 +1,287 @@ +# SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2024 OpenAI and the vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only GPT-OSS model on Ascend NPU.""" - -import math -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional import torch -import torch.nn as nn -from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +import torch.distributed as dist +import torch_npu +from torch import nn +from transformers import GptOssConfig + +from vllm.attention import Attention, AttentionType, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.model_executor.layers.activation import SiluAndMul +from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_pp_group) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, +from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP -from vllm.model_executor.models.utils import ( - PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.utils import cdiv -from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.utils import dispose_tensor - - -class GPTOSSConfig(PretrainedConfig): - """GPT-OSS model configuration.""" - - model_type = "gpt_oss" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size: int = 201088, - hidden_size: int = 2880, - intermediate_size: int = 2880, - num_hidden_layers: int = 36, - num_attention_heads: int = 64, - num_key_value_heads: int = 8, - head_dim: int = 64, - num_experts: int = 128, - experts_per_token: int = 4, - sliding_window: int = 128, - initial_context_length: int = 4096, - max_position_embeddings: int = 4096, - rope_theta: float = 150000.0, - rope_scaling_factor: float = 32.0, - rope_ntk_alpha: float = 1.0, - rope_ntk_beta: float = 32.0, - swiglu_limit: float = 7.0, - rms_norm_eps: float = 1e-5, - use_bias: bool = True, - bos_token_id: int = 1, - eos_token_id: int = 2, - pad_token_id: Optional[int] = None, - tie_word_embeddings: bool = False, - **kwargs, - ): - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.head_dim = head_dim - self.num_experts = num_experts - self.experts_per_token = experts_per_token - self.sliding_window = sliding_window - self.initial_context_length = initial_context_length - self.max_position_embeddings = max_position_embeddings - self.rope_theta = rope_theta - self.rope_scaling_factor = rope_scaling_factor - self.rope_ntk_alpha = rope_ntk_alpha - self.rope_ntk_beta = rope_ntk_beta - self.swiglu_limit = swiglu_limit - self.rms_norm_eps = rms_norm_eps - self.use_bias = use_bias - - super().__init__( - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) +# Import the original GPT-OSS classes from vLLM +from vllm.model_executor.models.gpt_oss import ( + GptOssForCausalLM, GptOssModel, OAIAttention, MLPBlock, TransformerBlock +) +from vllm.model_executor.models.utils import ( + extract_layer_index, maybe_prefix, PPMissingLayer, is_pp_missing_parameter +) +from vllm.model_executor.model_loader.weight_utils import ( + AutoWeightsLoader, WeightsMapper +) +from vllm_ascend.ops.fused_moe import AscendFusedMoE -# Register the config with transformers -try: - from transformers import AutoConfig - AutoConfig.register("gpt_oss", GPTOSSConfig) -except ImportError: - # If transformers is not available, skip registration - pass +class CustomOAIAttention(OAIAttention): + """Custom OAI Attention with Ascend optimizations.""" -class GPTOSSAttention(nn.Module): - """GPT-OSS attention layer with sliding window support.""" - def __init__( self, - config: GPTOSSConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, + config: GptOssConfig, quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, prefix: str = "", - ) -> None: - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.total_num_heads = config.num_attention_heads - self.num_heads = self.total_num_heads // get_tensor_model_parallel_world_size() - self.total_num_kv_heads = config.num_key_value_heads - self.num_kv_heads = max(1, self.total_num_kv_heads // get_tensor_model_parallel_world_size()) - self.head_dim = config.head_dim - - # Sliding window attention (every other layer) - self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else None - - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - # QKV projection - self.qkv_proj = QKVParallelLinear( - hidden_size=self.hidden_size, - head_size=self.head_dim, - total_num_heads=self.total_num_heads, - total_num_kv_heads=self.total_num_kv_heads, - bias=config.use_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - # Output projection - self.o_proj = RowParallelLinear( - input_size=self.total_num_heads * self.head_dim, - output_size=self.hidden_size, - bias=config.use_bias, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - - # Attention - self.attn = Attention( - self.num_heads, - self.head_dim, - scale=1.0 / math.sqrt(self.head_dim), - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - sliding_window=self.sliding_window, - ) - - # RoPE - self.rotary_emb = get_rope( - head_dim=self.head_dim, - rotary_dim=self.head_dim, - max_position=config.initial_context_length * config.rope_scaling_factor, - base=int(config.rope_theta), - rope_scaling={ - "type": "yarn", - "factor": config.rope_scaling_factor, - "original_max_position_embeddings": config.initial_context_length, - "alpha": config.rope_ntk_alpha, - "beta": config.rope_ntk_beta, - }, - ) - - # Sink attention weights for streaming attention - self.sinks = nn.Parameter( - torch.zeros(self.num_heads, dtype=torch.float32) - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) - + ): + super().__init__(config, quant_config, cache_config, prefix) + + def forward(self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + # Use original forward but with Ascend-optimized attention + t = self.norm(hidden_states) + + qkv, _ = self.qkv(t) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) + v = v.contiguous() attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - output, _ = self.o_proj(attn_output) - return output + return output + hidden_states + + +class CustomMLPBlock(MLPBlock): + """Custom MLP Block using AscendFusedMoE.""" -class GPTOSSMoELayer(nn.Module): - """GPT-OSS MoE layer with swiglu activation.""" - def __init__( self, - config: GPTOSSConfig, + config: GptOssConfig, + layer_idx: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - ) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.num_experts = config.num_experts - self.top_k = config.experts_per_token - - # Expert gate - self.gate = RowParallelLinear( - input_size=self.hidden_size, - output_size=self.num_experts, - bias=config.use_bias, - quant_config=None, # Gate typically not quantized - prefix=f"{prefix}.gate", - ) - - # MoE experts using AscendFusedMoE + ): + nn.Module.__init__(self) # Skip MLPBlock.__init__ + self.layer_idx = layer_idx + self.num_experts = config.num_local_experts + self.experts_per_token = config.num_experts_per_tok + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + self.router = torch.nn.Linear(config.hidden_size, + config.num_local_experts, + dtype=torch.bfloat16) + assert config.intermediate_size % self.world_size == 0 + + # Use AscendFusedMoE instead of standard FusedMoE self.experts = AscendFusedMoE( - num_experts=self.num_experts, - top_k=self.top_k, - hidden_size=self.hidden_size, - intermediate_size=self.intermediate_size, + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, reduce_results=True, renormalize=True, quant_config=quant_config, prefix=f"{prefix}.experts", + apply_router_weight_on_input=False, + has_bias=True, + activation="swigluoai" ) - - # Custom swiglu activation with limit - self.swiglu_limit = config.swiglu_limit - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None, - ) -> torch.Tensor: - # Get routing logits - router_logits, _ = self.gate(hidden_states) - - # Apply MoE with custom swiglu - output = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - ) - - return output + def forward(self, x: torch.Tensor) -> torch.Tensor: + t = self.norm(x) + g = self.router(t) + t = self.experts(hidden_states=t, router_logits=g) + return x + t + + +class CustomTransformerBlock(TransformerBlock): + """Custom Transformer Block with Ascend-optimized components.""" -class GPTOSSDecoderLayer(nn.Module): - """GPT-OSS decoder layer.""" - def __init__( self, - config: GPTOSSConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, + config: GptOssConfig, quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, prefix: str = "", - ) -> None: - super().__init__() - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - - # Attention - self.self_attn = GPTOSSAttention( - config=config, - layer_idx=layer_idx, - cache_config=cache_config, + ): + nn.Module.__init__(self) # Skip TransformerBlock.__init__ + self.layer_idx = extract_layer_index(prefix) + self.attn = CustomOAIAttention( + config, quant_config=quant_config, - prefix=f"{prefix}.self_attn", + cache_config=cache_config, + prefix=f"{prefix}.attn" ) - - # MoE MLP - self.mlp = GPTOSSMoELayer( - config=config, + self.mlp = CustomMLPBlock( + config, + self.layer_idx, quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - - # Layer norms - self.input_layernorm = RMSNorm( - config.hidden_size, - eps=config.rms_norm_eps, - ) - - self.post_attention_layernorm = RMSNorm( - config.hidden_size, - eps=config.rms_norm_eps, - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, + prefix=f"{prefix}.mlp" ) - - # MLP - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states, attn_metadata) - - return hidden_states, residual + + def forward(self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + attn_output = self.attn(hidden_states, positions, kv_cache, attn_metadata) + output = self.mlp(attn_output) + return output -class GPTOSSModel(nn.Module): - """GPT-OSS model.""" - +class CustomGptOssModel(GptOssModel): + """Custom GPT-OSS Model with Ascend optimizations.""" + def __init__( self, - config: GPTOSSConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + *, + vllm_config: VllmConfig, prefix: str = "", - ) -> None: - super().__init__() - self.config = config - self.vocab_size = config.vocab_size - self.padding_idx = getattr(config, "pad_token_id", None) - - # Embeddings - self.embed_tokens = VocabParallelEmbedding( - vocab_size=self.vocab_size, - hidden_size=config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens", - ) - - # Transformer layers - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda layer_idx, prefix: GPTOSSDecoderLayer( - config=config, - layer_idx=layer_idx, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix, - ), - prefix=f"{prefix}.layers", - ) - - # Final layer norm - self.norm = RMSNorm( - config.hidden_size, - eps=config.rms_norm_eps, + ): + nn.Module.__init__(self) # Skip GptOssModel.__init__ + self.config = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config + self.parallel_config = vllm_config.parallel_config + self.cache_config = vllm_config.cache_config + + self.config.hidden_size = self.config.hidden_size + self.embedding = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, ) - - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states", "residual"], - config.hidden_size)) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - residual, - ) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + self.layers = torch.nn.ModuleList([ + CustomTransformerBlock( + self.config, + quant_config=self.quant_config, + cache_config=self.cache_config, + prefix=maybe_prefix(prefix, f"block.{layer_idx}"), + ) for layer_idx in range(self.config.num_hidden_layers) + ]) + self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + x = self.embedding(input_ids) + for i, layer in enumerate(self.layers): + x = layer(x, positions, + kv_cache=kv_caches[i] if kv_caches else None, + attn_metadata=attn_metadata) + x = self.norm(x) + return x -class GPTOSSForCausalLM(nn.Module, SupportsLoRA, SupportsPP): - """GPT-OSS for causal language modeling.""" - - # 用于防止模型拆分的模块列表 - _no_split_modules = ["GPTOSSDecoderLayer", "GPTOSSAttention", "GPTOSSMoELayer"] - - # 添加 supports_multimodal 属性 - supports_multimodal = False + +class CustomGptOssForCausalLM(GptOssForCausalLM): + """Custom GPT-OSS For Causal Language Modeling with Ascend optimizations.""" - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - self.lora_config = lora_config - - self.model = GPTOSSModel( - config, - cache_config, - quant_config, + packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} + + # Use the same weight mapper as the original + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".self_attn.": ".attn.", + ".post_attention_layernorm.": ".mlp.norm.", + }, + orig_to_new_suffix={ + ".embed_tokens.weight": ".embedding.weight", + ".input_layernorm.weight": ".attn.norm.weight", + ".post_attention_layernorm.weight": ".mlp.norm.weight", + + # MoE MXFP4 weights + ".gate_up_proj_blocks": ".w13_weight", + ".down_proj_blocks": ".w2_weight", + ".gate_up_proj_scales": ".w13_weight_scale", + ".down_proj_scales": ".w2_weight_scale", + + # MoE other weights + ".gate_up_proj": ".w13_weight", + ".down_proj": ".w2_weight", + + # MoE Bias + ".gate_up_proj_bias": ".w13_bias", + ".down_proj_bias": ".w2_bias", + }, + ) + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + nn.Module.__init__(self) # Skip GptOssForCausalLM.__init__ + self.vllm_config = vllm_config + self.config = vllm_config.model_config.hf_config + + # Use CustomGptOssModel instead of GptOssModel + self.model = CustomGptOssModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"), ) - # Language model head if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( - vocab_size=config.vocab_size, - hidden_size=config.hidden_size, - quant_config=quant_config, + self.config.vocab_size, + self.config.hidden_size, + quant_config=vllm_config.quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_tensors, - inputs_embeds, - ) - return model_output - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata, - ) -> Optional[torch.Tensor]: + + self.logits_processor = LogitsProcessor(self.config.vocab_size) + self.sampler = get_sampler() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + assert intermediate_tensors is None + assert inputs_embeds is None + return self.model(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors, inputs_embeds) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata, - ): - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - ] - - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - continue - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if name.endswith(".bias"): - name = name[:-5] - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - if name.endswith(".bias"): - name = name[:-5] - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)