From fc4aebdc77cb62181b5a72ba1c53c7940f70cb98 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Thu, 14 Aug 2025 10:00:08 +0800 Subject: [PATCH 01/22] refact attn metadata build Signed-off-by: weiguihua2 --- vllm_ascend/attention/attention_v1.py | 50 +++-- .../attention/attention_v1_torchair.py | 96 ++++----- vllm_ascend/attention/mla_v1.py | 199 ++++++++---------- vllm_ascend/attention/utils.py | 117 ++++++++++ vllm_ascend/torchair/torchair_model_runner.py | 13 +- vllm_ascend/worker/eagle_proposer_v1.py | 19 +- vllm_ascend/worker/model_runner_v1.py | 58 +++-- vllm_ascend/worker/mtp_proposer_v1.py | 32 ++- 8 files changed, 367 insertions(+), 217 deletions(-) create mode 100644 vllm_ascend/attention/utils.py diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 15a7759330..77a8e86f67 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -21,6 +21,8 @@ import torch import torch_npu +import torch.nn as nn +from vllm.config import VllmConfig from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState @@ -32,6 +34,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec) from vllm_ascend.worker.npu_input_batch import InputBatch +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata class AscendAttentionBackend(AttentionBackend): @@ -157,33 +160,36 @@ class AscendMetadata: class AscendAttentionMetadataBuilder: - def __init__(self, runner): - self.runner = runner + def __init__(self, + vllm_config: VllmConfig, + device: torch.device,): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return False def build(self, - num_reqs, - num_actual_tokens, - max_query_len, - enable_dbo_across_dp: bool = False, - is_only_prefill: bool = False): - - block_table = self.runner.input_batch.block_table[0].get_device_tensor( - ) - block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module,): + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] + + block_table = common_attn_metadata.block_table_tensor + block_table[:num_reqs, :common_attn_metadata.max_num_blocks_per_req] = ( block_table[:num_reqs]) - query_lens = self.runner.query_lens - seq_lens = self.runner.seq_lens_cpu[:num_reqs] - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - self.runner.device, non_blocking=True) - attn_mask = self.runner.attn_mask - attn_state = self.runner.attn_state - query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = query_start_loc_cpu.to(self.runner.device, + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + slot_mapping = common_attn_metadata.slot_mapping_cpu[:num_actual_tokens].to( + self.device, non_blocking=True) + attn_mask = common_attn_metadata.attn_mask + attn_state = common_attn_metadata.attn_state + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] + query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) if is_310p(): @@ -202,12 +208,12 @@ def build(self, query_start_loc=query_start_loc, query_lens=query_lens, seq_lens=seq_lens, - max_query_len=max_query_len, + max_query_len=common_attn_metadata.max_query_len, slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, - enable_dbo_across_dp=enable_dbo_across_dp, - is_only_prefill=is_only_prefill) + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + is_only_prefill=common_attn_metadata.is_only_prefill) return attn_metadata diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py index 4d84bac976..8517865bb2 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -21,6 +21,8 @@ import numpy as np import torch import torch_npu +import torch.nn as nn +from vllm.config import VllmConfig from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState @@ -30,6 +32,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d) from vllm_ascend.worker.npu_input_batch import InputBatch +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata class AscendAttentionTorchairBackend(AttentionBackend): @@ -145,8 +148,12 @@ class AscendTorchairMetadata: class AscendAttentionTorchairMetadataBuilder: - def __init__(self, runner): - self.runner = runner + def __init__(self, + vllm_config: VllmConfig, + device: torch.device,): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -154,34 +161,16 @@ def reorder_batch(self, input_batch: "InputBatch", def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: - - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}" - - if isinstance(self.runner.graph_block_tables, np.ndarray): - graph_block_tables = torch.zeros((max_batch_size, max_blocks), - dtype=block_tables.dtype, - device=block_tables.device) - else: - graph_block_tables = self.runner.graph_block_tables.to( - device=block_tables.device, dtype=block_tables.dtype) - num_blocks = block_tables.size(1) - if num_blocks <= max_blocks: - graph_block_tables[:num_seqs, : - num_blocks] = block_tables[:num_seqs, : - num_blocks] + if num_blocks <= self.max_blocks: + return block_tables[:num_seqs, :num_blocks] else: - graph_block_tables[:num_seqs, : - max_blocks] = block_tables[:num_seqs, : - max_blocks] - - return graph_block_tables[:num_seqs, :max_blocks] + return block_tables[:num_seqs, :self.max_blocks] def build_torchair_graph_dummy( - self, num_reqs: int, - num_actual_tokens: int) -> AscendTorchairMetadata: - device = self.runner.device + self, common_attn_metadata: AscendCommonAttentionMetadata) -> AscendTorchairMetadata: + device = self.device + num_reqs = common_attn_metadata.num_reqs _, max_blocks = self.runner.graph_block_tables.shape block_table = torch.zeros((num_reqs, max_blocks), dtype=torch.int32, @@ -208,7 +197,7 @@ def build_torchair_graph_dummy( max_seq_lens=1) attn_metadata = AscendTorchairMetadata( - num_actual_tokens=num_actual_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens, block_tables=block_table, query_lens=0, query_start_loc=query_start_loc, @@ -219,46 +208,43 @@ def build_torchair_graph_dummy( return attn_metadata def build(self, - num_reqs, - num_actual_tokens, - max_query_len, - graph_pad_size: int = -1, - enable_dbo_across_dp: bool = False, - *args, - **kwargs): - - device = self.runner.device - - block_table = self.runner.input_batch.block_table[0].get_device_tensor( - ) - block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module,): + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + + block_table = common_attn_metadata.block_table_tensor + block_table[:num_reqs, :common_attn_metadata.max_num_blocks_per_req] = ( block_table[:num_reqs]) - query_lens = self.runner.query_lens - seq_lens = self.runner.seq_lens_cpu[:num_reqs] - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - self.runner.device, non_blocking=True) - attn_mask = self.runner.attn_mask + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + slot_mapping = common_attn_metadata.slot_mapping_cpu[:num_actual_tokens].to( + self.device, non_blocking=True) + attn_mask = common_attn_metadata.attn_mask - attn_state = self.runner.attn_state + attn_state = common_attn_metadata.attn_state if is_310p() and attn_state == AscendAttentionState.PrefillNoCache: mask_nz = nd_to_nz_2d(attn_mask) attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29) - query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = query_start_loc_cpu.to(self.runner.device, + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] + query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) - input_positions = self.runner.positions_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + # input_positions = common_attn_metadata.positions_cpu[:num_actual_tokens].to( + # device, non_blocking=True).long() + + input_positions = common_attn_metadata.positions[:num_actual_tokens].long() decode_metadata = None + graph_pad_size = common_attn_metadata.graph_pad_size use_torchair_graph = graph_pad_size > -1 - if self.runner.attn_state in [ + if common_attn_metadata.attn_state in [ AscendAttentionState.DecodeOnly, ]: max_seq_lens = seq_lens.max().item() num_seqs = len(seq_lens) - if use_torchair_graph and self.runner.attn_state in [ + if use_torchair_graph and common_attn_metadata.attn_state in [ AscendAttentionState.DecodeOnly, ]: num_reqs_pad_size = 0 @@ -267,7 +253,7 @@ def build(self, pad_value = 0 num_token_pad_size = graph_pad_size - num_actual_tokens num_reqs_pad_size = ( - graph_pad_size // self.runner.decode_token_per_req - + graph_pad_size // common_attn_metadata.decode_token_per_req - num_reqs) pad_value = 1 padded_seq_lens = seq_lens.tolist() + [pad_value @@ -308,11 +294,11 @@ def build(self, query_start_loc=query_start_loc, query_lens=query_lens, seq_lens=seq_lens, - max_query_len=max_query_len, + max_query_len=common_attn_metadata.max_query_len, slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, - enable_dbo_across_dp=enable_dbo_across_dp) + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp) return attn_metadata diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a52d117b43..80345e187e 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -4,11 +4,12 @@ import numpy as np import torch import torch_npu +import torch.nn as nn from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.config import get_current_vllm_config +from vllm.config import get_current_vllm_config, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) @@ -24,6 +25,8 @@ from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import npu_prefetch from vllm_ascend.worker.npu_input_batch import InputBatch +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,split_decodes_and_prefills) + if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -172,20 +175,23 @@ class AscendMLAMetadataBuilder: # _attn_mask_builder = None def __init__(self, - runner, + vllm_config: VllmConfig, + device: torch.device, metadata_cls: Optional[AscendMLAMetadata] = None): self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \ if metadata_cls is not None else AscendMLAMetadata # type: ignore - self.runner = runner - scheduler_config = runner.scheduler_config - model_config = runner.model_config - self.block_size = runner.block_size - self.chunked_prefill_enabled = runner.chunked_prefill_enabled + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + scheduler_config = vllm_config.scheduler_config + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( # Max sure there is enough for 8 full length request or at least # 4 pages of cache per request - max(8 * model_config.max_model_len, + max(8 * self.model_config.max_model_len, 4 * scheduler_config.max_num_seqs * self.block_size), # For long-context models try not to over-allocate limiting # kv-cache space, limiting it to 64k tokens, @@ -200,13 +206,13 @@ def __init__(self, scheduler_config.max_num_seqs * self.block_size self.chunked_prefill_workspace = torch.empty( (self.chunked_prefill_workspace_size, - model_config.get_head_size()), - dtype=model_config.dtype, - device=runner.device, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, ) ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.cos_cache = None self.sin_cache = None @@ -220,8 +226,6 @@ def reorder_batch(self, input_batch: "InputBatch", # better naming here) decodes = [] prefills = [] - num_decode_tokens = 0 - num_prefill_tokens = 0 for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] @@ -231,18 +235,14 @@ def reorder_batch(self, input_batch: "InputBatch", if self.torchair_graph_enabled: if num_tokens - num_spec_tokens == 1: decodes.append(i) - num_decode_tokens += num_tokens else: prefills.append(i) - num_prefill_tokens += num_tokens # For eager mode we treat spec decoding as chunked prefill. else: if num_tokens == 1: decodes.append(i) - num_decode_tokens += num_tokens else: prefills.append(i) - num_prefill_tokens += num_tokens # We hope that this is fairly minimal since decodes # should be around for a number of iterations so hopefully they are @@ -273,49 +273,27 @@ def reorder_batch(self, input_batch: "InputBatch", # Save for next `build` call # TODO(lucas): this is a bit of a hack, we should probably have a # better way of doing this - self._num_decodes = num_decodes - self._num_prefills = num_prefills - self._num_decode_tokens = num_decode_tokens - self._num_prefill_tokens = num_prefill_tokens - return modified_batch def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: - - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}" - - if isinstance(self.runner.graph_block_tables, np.ndarray): - graph_block_tables = torch.zeros((max_batch_size, max_blocks), - dtype=block_tables.dtype, - device=block_tables.device) - else: - graph_block_tables = self.runner.graph_block_tables.to( - device=block_tables.device, dtype=block_tables.dtype) - num_blocks = block_tables.size(1) - if num_blocks <= max_blocks: - graph_block_tables[:num_seqs, : - num_blocks] = block_tables[:num_seqs, : - num_blocks] + if num_blocks <= self.max_blocks: + return block_tables[:num_seqs, :num_blocks] else: - graph_block_tables[:num_seqs, : - max_blocks] = block_tables[:num_seqs, : - max_blocks] - - return graph_block_tables[:num_seqs, :max_blocks] + return block_tables[:num_seqs, :self.max_blocks] def build_torchair_graph_dummy( - self, num_reqs: int, num_actual_tokens: int) -> AscendMLAMetadata: - device = self.runner.device + self, common_attn_metadata: AscendCommonAttentionMetadata,) -> AscendMLAMetadata: + device = self.device + num_reqs = common_attn_metadata.num_reqs _, max_blocks = self.runner.graph_block_tables.shape block_table = torch.zeros((num_reqs, max_blocks), dtype=torch.int32, device=device) block_table = self._get_graph_runner_block_tables( num_reqs, block_table) - num_tokens = num_reqs * self.runner.decode_token_per_req + num_tokens = num_reqs * common_attn_metadata.decode_token_per_req seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) seq_lens_list = [0] * num_reqs input_positions = torch.zeros(num_tokens, @@ -333,16 +311,16 @@ def build_torchair_graph_dummy( 1, 1, self.rope_dim, - dtype=self.runner.dtype, + dtype=self.model_config.dtype, device=device) cos = torch.ones(num_tokens, 1, 1, self.rope_dim, - dtype=self.runner.dtype, + dtype=self.model_config.dtype, device=device) - if self.runner.speculative_config is not None and\ - self.runner.speculative_config.method == 'deepseek_mtp': + if self.vllm_config.speculative_config is not None and\ + self.vllm_config.speculative_config.method == 'deepseek_mtp': attn_state = AscendAttentionState.SpecDecoding num_decode_tokens = 2 else: @@ -354,20 +332,20 @@ def build_torchair_graph_dummy( seq_lens=seq_lens, seq_lens_list=seq_lens_list, max_seq_lens=1, - attn_mask=self.runner.spec_attn_mask, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q[:num_reqs], + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=common_attn_metadata.actual_seq_lengths_q[:num_reqs], sin=sin, cos=cos, ) return self.metadata_cls( # type: ignore - num_input_tokens=num_actual_tokens, - num_actual_tokens=num_actual_tokens, + num_input_tokens=common_attn_metadata.num_actual_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens, slot_mapping=slot_mapping, - head_dim=self.runner.model_config.get_head_size(), + head_dim=self.model_config.get_head_size(), num_decodes=1, num_decode_tokens=num_decode_tokens, num_prefills=0, - attn_mask=self.runner.attn_mask, + attn_mask=common_attn_metadata.attn_mask, attn_state=attn_state, prefill=None, decode=decode_metadata, @@ -378,58 +356,56 @@ def build_torchair_graph_dummy( def build( self, - num_reqs: int, - num_actual_tokens: int, - max_query_len: int, - graph_pad_size: int = -1, - query_start_loc: torch.Tensor = None, - enable_dbo_across_dp: bool = False, - *args, - **kwargs, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, ) -> AscendMLAMetadata: - assert self._num_decodes + self._num_prefills == num_reqs + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata) + assert num_decodes + num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. - device = self.runner.device + device = self.device - block_table = (self.runner.input_batch.block_table[0]. - get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + slot_mapping = common_attn_metadata.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True) - input_positions = self.runner.positions_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() - - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[: - num_reqs] - seq_lens = seq_lens_cpu - max_query_len = query_lens.max().item() - max_seq_lens = seq_lens.max().item() + # input_positions = common_attn_metadata.positions_cpu[:num_actual_tokens].to( + # device, non_blocking=True).long() + + input_positions = common_attn_metadata.positions[:num_actual_tokens].long() + if self.cos_cache is None: - self.cos_cache = self.runner.get_model( - ).model.layers[0].self_attn.rotary_emb.cos_cached - self.sin_cache = self.runner.get_model( - ).model.layers[0].self_attn.rotary_emb.sin_cached - if self.cos_cache.dtype != self.runner.dtype: # type: ignore + self.cos_cache = model.layers[0].self_attn.rotary_emb.cos_cached + self.sin_cache = model.layers[0].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.model_config.dtype: # type: ignore self.cos_cache = self.cos_cache.to( # type: ignore - self.runner.dtype) # type: ignore + self.model_config.dtype) # type: ignore self.sin_cache = self.sin_cache.to( # type: ignore - self.runner.dtype) # type: ignore - + self.model_config.dtype) # type: ignore + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + query_lens = query_seq_lens_cpu[:num_reqs] + num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu - + query_seq_lens_cpu) + + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] prefill_metadata = None chunked_context_metadata = None - if self._num_prefills > 0: - reqs_start = self._num_decodes # prefill_start - tokens_start = self._num_decode_tokens + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + tokens_start = num_decode_tokens max_query_len = query_lens[tokens_start:].max().item() max_seq_lens = seq_lens[tokens_start:].max().item() prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] - context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[ - reqs_start:num_reqs] + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() if self.chunked_prefill_enabled and max_context_len_cpu > 0: @@ -470,7 +446,7 @@ def build( prefill_input_positions].unsqueeze( # type: ignore 1).unsqueeze(2) prefill_metadata = AscendMLAPrefillMetadata( - attn_mask=self.runner.attn_mask, + attn_mask=common_attn_metadata.attn_mask, query_lens=query_lens[tokens_start:], seq_lens=seq_lens, context_lens=seq_lens[tokens_start:], @@ -485,14 +461,15 @@ def build( ) decode_metadata = None + graph_pad_size = common_attn_metadata.graph_pad_size use_torchair_graph = graph_pad_size != -1 - if self._num_decodes > 0: + if num_decodes > 0: actual_seq_lengths_q = query_start_loc[1:].tolist() - max_seq_lens = seq_lens[:self._num_decodes].max().item() - seq_lens = seq_lens[:self._num_decode_tokens] - input_positions = input_positions[:self._num_decode_tokens] - block_table = block_table[:self._num_decode_tokens, ...] - if use_torchair_graph and self.runner.attn_state in [ + max_seq_lens = seq_lens[:num_decodes].max().item() + seq_lens = seq_lens[:num_decode_tokens] + input_positions = input_positions[:num_decode_tokens] + block_table = block_table[:num_decode_tokens, ...] + if use_torchair_graph and common_attn_metadata.attn_state in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ]: @@ -500,9 +477,9 @@ def build( num_token_pad_size = 0 if graph_pad_size != 0: pad_value = 0 - num_token_pad_size = graph_pad_size - self._num_decode_tokens + num_token_pad_size = graph_pad_size - num_decode_tokens num_reqs_pad_size = ( - graph_pad_size // self.runner.decode_token_per_req - + graph_pad_size // common_attn_metadata.decode_token_per_req - num_reqs) padded_seq_lens = seq_lens.tolist( ) + [pad_value] * num_reqs_pad_size @@ -531,14 +508,14 @@ def build( input_positions = torch.cat( [input_positions, position_padding]) actual_seq_lengths_q = query_start_loc[1:].tolist( - ) + self.runner.actual_seq_lengths_q[num_reqs:num_reqs + + ) + common_attn_metadata.actual_seq_lengths_q[num_reqs:num_reqs + num_reqs_pad_size] else: seq_lens_list = seq_lens.tolist() # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) batch_size = slot_mapping.size(0) if actual_seq_lengths_q[-1] != batch_size \ - and self.runner.attn_state == AscendAttentionState.SpecDecoding: + and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding: actual_seq_lengths_q[-1] = batch_size cos = self.cos_cache[input_positions].unsqueeze( # type: ignore @@ -552,7 +529,7 @@ def build( seq_lens=seq_lens, seq_lens_list=seq_lens_list, max_seq_lens=max_seq_lens, - attn_mask=self.runner.spec_attn_mask, + attn_mask=common_attn_metadata.spec_attn_mask, actual_seq_lengths_q=actual_seq_lengths_q, sin=sin, cos=cos) @@ -561,18 +538,18 @@ def build( num_actual_tokens=num_actual_tokens, query_lens=query_lens.tolist(), slot_mapping=slot_mapping, - head_dim=self.runner.model_config.get_head_size(), - num_decodes=self._num_decodes, - num_decode_tokens=self._num_decode_tokens, - num_prefills=self._num_prefills, - attn_mask=self.runner.attn_mask, - attn_state=self.runner.attn_state, + head_dim=self.model_config.get_head_size(), + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + attn_mask=common_attn_metadata.attn_mask, + attn_state=common_attn_metadata.attn_state, prefill=prefill_metadata, decode=decode_metadata, query_start_loc=query_start_loc, block_tables=block_table, seq_lens=seq_lens, - enable_dbo_across_dp=enable_dbo_across_dp, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, ) diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py new file mode 100644 index 0000000000..769280a396 --- /dev/null +++ b/vllm_ascend/attention/utils.py @@ -0,0 +1,117 @@ +from dataclasses import dataclass + +from vllm_ascend.attention.attention_v1 import AscendAttentionState + +import torch + + +@dataclass +class AscendCommonAttentionMetadata: + """ + Per-batch attention metadata, shared across layers and backends. + AttentionMetadataBuilder instances use it to construct per-layer metadata. + + For many of the tensors we keep both GPU and CPU versions. + """ + + query_start_loc: torch.Tensor = None + query_start_loc_cpu: torch.Tensor = None + """(batch_size + 1,), the start location of each request in query Tensor""" + + seq_lens: torch.Tensor = None + seq_lens_cpu: torch.Tensor = None + """(batch_size,), the length of each request including both computed tokens + and newly scheduled tokens""" + + num_reqs: int + """Number of requests""" + num_actual_tokens: int + """Total number of tokens in batch""" + max_query_len: int + """Longest query in batch""" + + actual_seq_lengths_q: list[int] = None + + block_table_tensor: torch.Tensor = None + slot_mapping_cpu: torch.Tensor = None + + positions: torch.Tensor = None + + attn_mask: torch.Tensor = None + spec_attn_mask: torch.Tensor = None + attn_state: AscendAttentionState = None + + decode_token_per_req: int + + max_num_blocks_per_req: int + + enable_dbo_across_dp: bool = False + + is_only_prefill: bool + + graph_pad_size: int = -1 + +@dataclass +class TorchairCommonAttentionMetadata: + """ + Per-batch attention metadata, shared across layers and backends. + AttentionMetadataBuilder instances use it to construct per-layer metadata. + + For many of the tensors we keep both GPU and CPU versions. + """ + + num_reqs: int + """Number of requests""" + num_actual_tokens: int + """Total number of tokens in batch""" + + actual_seq_lengths_q: list[int] = None + + attn_mask: torch.Tensor = None + spec_attn_mask: torch.Tensor = None + + decode_token_per_req: int + + graph_pad_size: int = -1 + + +def split_decodes_and_prefills( + common_attn_metadata: AscendCommonAttentionMetadata, + decode_threshold: int = 1, +) -> tuple[int, int, int, int]: + """ + Assuming a reordered batch, finds the boundary between prefill and decode + requests. + + Args: + common_attn_metadata: AscendCommonAttentionMetadata object containing the + batch metadata. + decode_threshold: The maximum query length to be considered a decode. + + Returns: + num_decodes: The number of decode requests. + num_prefills: The number of prefill requests. + num_decode_tokens: The number of tokens in the decode requests. + num_prefill_tokens: The number of tokens in the prefill requests. + """ + max_query_len = common_attn_metadata.max_query_len + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc_cpu + + if max_query_len <= decode_threshold: + return num_reqs, 0, num_tokens, 0 + + query_lens = query_start_loc[1:] - query_start_loc[:-1] + is_prefill = query_lens > decode_threshold + if not torch.any(is_prefill): + return num_reqs, 0, num_tokens, 0 + + first_prefill = is_prefill.int().argmax(dim=-1).item() + assert torch.all(query_lens[first_prefill:] > decode_threshold) + assert torch.all(query_lens[:first_prefill] <= decode_threshold) + num_decodes = first_prefill + num_prefills = num_reqs - num_decodes + num_decode_tokens = query_start_loc[first_prefill].item() + num_prefill_tokens = num_tokens - num_decode_tokens + return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index f42f83d158..c57b935c9d 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -31,6 +31,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, maybe_converting_weight_acl_format) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +from vllm_ascend.attention.utils import TorchairCommonAttentionMetadata class NPUTorchairModelRunner(NPUModelRunner): @@ -69,8 +70,16 @@ def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn): # NOTE: If torchair graph mode and not with_prefill, # we can't skip_attn, it will cause graph recompile. if not with_prefill: - attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy( - num_reqs=num_reqs, num_actual_tokens=1) + common_attn_metadata = TorchairCommonAttentionMetadata( + num_reqs=num_reqs, + num_actual_tokens=1, + actual_seq_lengths_q=self.actual_seq_lengths_q, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + decode_token_per_req=self.decode_token_per_req, + ) + attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(common_attn_metadata) else: attn_metadata = super()._build_attention_metadata( with_prefill, num_reqs, skip_attn) diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py index 18fb9fda8d..ea910d846b 100644 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ b/vllm_ascend/worker/eagle_proposer_v1.py @@ -16,6 +16,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata PADDING_SLOT_ID = -1 @@ -125,12 +126,26 @@ def propose( query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() - # FIXME(woosuk): The below two ops cause synchronization. Optimize. - attn_metadata = self.runner.attn_metadata_builder.build( + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.runner.query_start_loc[:batch_size + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:batch_size + 1], + seq_lens=self.runner.seq_lens, + seq_lens_cpu=self.runner.seq_lens_cpu, num_reqs=batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0].get_device_tensor(), + slot_mapping_cpu=self.runner.slot_mapping_cpu, + positions=self.positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, + max_num_blocks_per_req=self.runner.max_num_blocks_per_req, ) + # FIXME(woosuk): The below two ops cause synchronization. Optimize. + attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata) if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ebf76ebff9..efd8ec3afd 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -92,6 +92,7 @@ from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] @@ -798,11 +799,25 @@ def get_eagle_atten_dict( # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - attn_metadata_i = self.attn_metadata_builder.build( + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens=self.seq_lens, + seq_lens_cpu=self.seq_lens_cpu, num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + block_table_tensor=self.input_batch.block_table[0].get_device_tensor(), + slot_mapping_cpu=self.slot_mapping_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + decode_token_per_req=self.decode_token_per_req, + max_num_blocks_per_req=self.max_num_blocks_per_req, ) + attn_metadata_i = self.attn_metadata_builder.build(common_attn_metadata) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1192,9 +1207,6 @@ def _process_reqs( attn_state, total_num_scheduled_tokens) - enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), - attn_state, - total_num_scheduled_tokens) (padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill, enable_dbo) = self._get_forward_metadata_across_dp_and_pad( total_num_scheduled_tokens, with_prefill, enable_dbo) @@ -1207,24 +1219,30 @@ def _process_reqs( 'graph_pad_size'] = self.graph_pad_size # type: ignore else: self.graph_pad_size = -1 - + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens=self.seq_lens, + seq_lens_cpu=self.seq_lens_cpu, + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + block_table_tensor=self.input_batch.block_table[0].get_device_tensor(), + slot_mapping_cpu=self.slot_mapping_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + decode_token_per_req=self.decode_token_per_req, + max_num_blocks_per_req=self.max_num_blocks_per_req, + enable_dbo_across_dp=enable_dbo, + is_only_prefill=is_only_prefill, + graph_pad_size=self.graph_pad_size + ) + attn_metadata = self.attn_metadata_builder.build(common_attn_metadata) if self.vllm_config.model_config.use_mla: - extra_builder_kwargs[ - "query_start_loc"] = self.query_start_loc[:num_reqs + 1] - attn_metadata = self.attn_metadata_builder.build( # type: ignore - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - **extra_builder_kwargs, - ) attn_metadata.num_input_tokens = num_input_tokens - else: - attn_metadata = self.attn_metadata_builder.build( # type: ignore - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - **extra_builder_kwargs, - ) # Prepare input_ids token_indices = (positions_np + diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index f4597de23f..00eab740c4 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -18,6 +18,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP from vllm_ascend.utils import ProfileExecuteDuration +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, TorchairCommonAttentionMetadata class MtpProposer: @@ -166,12 +167,26 @@ def propose( else: num_input_tokens = num_tokens - attn_metadata = self.runner.attn_metadata_builder.build( + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.runner.query_start_loc[:batch_size + 1], + query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + 1], + seq_lens=target_positions[last_token_indices] + 1, + seq_lens_cpu=target_positions.cpu()[last_token_indices] + 1, num_reqs=batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, - query_start_loc=cu_num_tokens, - **extra_builder_kwargs) + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0].get_device_tensor(), + slot_mapping_cpu=target_slot_mapping, + positions=target_positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, + max_num_blocks_per_req=self.runner.max_num_blocks_per_req, + graph_pad_size=extra_builder_kwargs['graph_pad_size'] + ) + attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata) self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states @@ -281,8 +296,15 @@ def dummy_run(self, if skip_attn: attn_metadata = None else: - attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( - num_reqs=num_reqs, num_actual_tokens=1) + common_attn_metadata = TorchairCommonAttentionMetadata( + num_reqs=num_reqs, + num_actual_tokens=1, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + decode_token_per_req=self.runner.decode_token_per_req, + ) + attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(common_attn_metadata) input_ids = self.input_ids[:num_tokens] positions = self.positions[:num_tokens] From 22fa7d94c5914f775bd4cf36b577a23414ccd89b Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Thu, 14 Aug 2025 19:37:55 +0800 Subject: [PATCH 02/22] refact attn metadata build Signed-off-by: weiguihua2 --- vllm_ascend/attention/attention_v1.py | 6 +++-- .../attention/attention_v1_torchair.py | 10 ++++--- vllm_ascend/attention/mla_v1.py | 5 ++-- vllm_ascend/attention/utils.py | 27 +++++++++++-------- vllm_ascend/worker/model_runner_v1.py | 10 +++---- 5 files changed, 33 insertions(+), 25 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 77a8e86f67..458be09e46 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -27,7 +27,7 @@ AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState from vllm.forward_context import ForwardContext, get_forward_context -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, cdiv from vllm.v1.core.sched.output import SchedulerOutput from vllm_ascend.ops.attention import vanilla_chunked_prefill @@ -166,6 +166,8 @@ def __init__(self, self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.device = device + self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, + vllm_config.cache_config.block_size) def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -179,7 +181,7 @@ def build(self, query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] block_table = common_attn_metadata.block_table_tensor - block_table[:num_reqs, :common_attn_metadata.max_num_blocks_per_req] = ( + block_table[:num_reqs, :self.max_num_blocks_per_req] = ( block_table[:num_reqs]) query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py index 8517865bb2..9f6d6b2cbf 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -23,6 +23,7 @@ import torch_npu import torch.nn as nn from vllm.config import VllmConfig +from vllm.utils import cdiv from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState @@ -32,7 +33,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d) from vllm_ascend.worker.npu_input_batch import InputBatch -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, get_decode_token_per_req class AscendAttentionTorchairBackend(AttentionBackend): @@ -154,6 +155,9 @@ def __init__(self, self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.device = device + self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, + vllm_config.cache_config.block_size) + self.decode_token_per_req = get_decode_token_per_req(vllm_config.speculative_config) def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -214,7 +218,7 @@ def build(self, num_actual_tokens = common_attn_metadata.num_actual_tokens block_table = common_attn_metadata.block_table_tensor - block_table[:num_reqs, :common_attn_metadata.max_num_blocks_per_req] = ( + block_table[:num_reqs, :self.max_num_blocks_per_req] = ( block_table[:num_reqs]) seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] @@ -253,7 +257,7 @@ def build(self, pad_value = 0 num_token_pad_size = graph_pad_size - num_actual_tokens num_reqs_pad_size = ( - graph_pad_size // common_attn_metadata.decode_token_per_req - + graph_pad_size // self.decode_token_per_req - num_reqs) pad_value = 1 padded_seq_lens = seq_lens.tolist() + [pad_value diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 80345e187e..2a0fe96315 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -25,7 +25,7 @@ from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import npu_prefetch from vllm_ascend.worker.npu_input_batch import InputBatch -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,split_decodes_and_prefills) +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,split_decodes_and_prefills, get_decode_token_per_req) if TYPE_CHECKING: @@ -186,6 +186,7 @@ def __init__(self, scheduler_config = vllm_config.scheduler_config self.block_size = vllm_config.cache_config.block_size self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size + self.decode_token_per_req = get_decode_token_per_req(vllm_config.speculative_config) self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -293,7 +294,7 @@ def build_torchair_graph_dummy( device=device) block_table = self._get_graph_runner_block_tables( num_reqs, block_table) - num_tokens = num_reqs * common_attn_metadata.decode_token_per_req + num_tokens = num_reqs * self.decode_token_per_req seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) seq_lens_list = [0] * num_reqs input_positions = torch.zeros(num_tokens, diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 769280a396..47b472709c 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,5 +1,6 @@ from dataclasses import dataclass +from vllm.config import SpeculativeConfig from vllm_ascend.attention.attention_v1 import AscendAttentionState import torch @@ -14,12 +15,11 @@ class AscendCommonAttentionMetadata: For many of the tensors we keep both GPU and CPU versions. """ - query_start_loc: torch.Tensor = None - query_start_loc_cpu: torch.Tensor = None + query_start_loc: torch.Tensor + query_start_loc_cpu: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" - seq_lens: torch.Tensor = None - seq_lens_cpu: torch.Tensor = None + seq_lens_cpu: torch.Tensor """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" @@ -27,13 +27,11 @@ class AscendCommonAttentionMetadata: """Number of requests""" num_actual_tokens: int """Total number of tokens in batch""" - max_query_len: int - """Longest query in batch""" actual_seq_lengths_q: list[int] = None - block_table_tensor: torch.Tensor = None - slot_mapping_cpu: torch.Tensor = None + block_table_tensor: torch.Tensor + slot_mapping_cpu: torch.Tensor positions: torch.Tensor = None @@ -47,7 +45,7 @@ class AscendCommonAttentionMetadata: enable_dbo_across_dp: bool = False - is_only_prefill: bool + is_only_prefill: bool = False graph_pad_size: int = -1 @@ -70,8 +68,6 @@ class TorchairCommonAttentionMetadata: attn_mask: torch.Tensor = None spec_attn_mask: torch.Tensor = None - decode_token_per_req: int - graph_pad_size: int = -1 @@ -115,3 +111,12 @@ def split_decodes_and_prefills( num_decode_tokens = query_start_loc[first_prefill].item() num_prefill_tokens = num_tokens - num_decode_tokens return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) + + +def get_decode_token_per_req(speculative_config: SpeculativeConfig): + decode_token_per_req = 1 + if not speculative_config: + return decode_token_per_req + spec_token_num = speculative_config.num_speculative_tokens + assert spec_token_num > 0 + return decode_token_per_req + spec_token_num diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index efd8ec3afd..51529a5524 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -92,7 +92,7 @@ from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, get_decode_token_per_req if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] @@ -216,7 +216,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): use_mla=self.model_config.use_mla, ) self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - weakref.proxy(self)) + vllm_config, device) self.attn_mask_builder = AttentionMaskBuilder( min(self.model_config.max_model_len, int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype) @@ -229,13 +229,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer]] = None self.actual_seq_lengths_q = [] - self.spec_token_num = 0 - self.decode_token_per_req = 1 + self.decode_token_per_req = get_decode_token_per_req(self.speculative_config) if self.speculative_config: self.use_spec_decode = True - self.spec_token_num = self.speculative_config.num_speculative_tokens - assert self.spec_token_num > 0 - self.decode_token_per_req = 1 + self.spec_token_num self.actual_seq_lengths_q = [ len for len in range(self.decode_token_per_req, self.max_num_tokens + From 8047b7e0245d6566a1cf03c55036ae427b9e4f68 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Thu, 14 Aug 2025 20:13:16 +0800 Subject: [PATCH 03/22] refact attn metadata build Signed-off-by: weiguihua2 --- vllm_ascend/attention/utils.py | 6 ++---- vllm_ascend/torchair/torchair_model_runner.py | 1 - vllm_ascend/worker/eagle_proposer_v1.py | 9 +++------ vllm_ascend/worker/model_runner_v1.py | 8 +------- vllm_ascend/worker/mtp_proposer_v1.py | 4 ---- 5 files changed, 6 insertions(+), 22 deletions(-) diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 47b472709c..a0c84e4739 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -39,10 +39,8 @@ class AscendCommonAttentionMetadata: spec_attn_mask: torch.Tensor = None attn_state: AscendAttentionState = None - decode_token_per_req: int - - max_num_blocks_per_req: int - + max_query_len: int + enable_dbo_across_dp: bool = False is_only_prefill: bool = False diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index c57b935c9d..d92f9ea1e9 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -77,7 +77,6 @@ def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn): attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, - decode_token_per_req=self.decode_token_per_req, ) attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(common_attn_metadata) else: diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py index ea910d846b..b1448c887a 100644 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ b/vllm_ascend/worker/eagle_proposer_v1.py @@ -129,20 +129,17 @@ def propose( common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.runner.query_start_loc[:batch_size + 1], query_start_loc_cpu=self.query_start_loc_cpu[:batch_size + 1], - seq_lens=self.runner.seq_lens, seq_lens_cpu=self.runner.seq_lens_cpu, + max_query_len=max_query_len, num_reqs=batch_size, num_actual_tokens=num_tokens, - max_query_len=max_query_len, actual_seq_lengths_q=self.runner.actual_seq_lengths_q, block_table_tensor=self.runner.input_batch.block_table[0].get_device_tensor(), - slot_mapping_cpu=self.runner.slot_mapping_cpu, - positions=self.positions, + slot_mapping_cpu=target_slot_mapping, + positions=target_positions, attn_mask=self.runner.attn_mask, spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, - decode_token_per_req=self.runner.decode_token_per_req, - max_num_blocks_per_req=self.runner.max_num_blocks_per_req, ) # FIXME(woosuk): The below two ops cause synchronization. Optimize. attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 51529a5524..0127f5a0c0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -798,11 +798,10 @@ def get_eagle_atten_dict( common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens=self.seq_lens, seq_lens_cpu=self.seq_lens_cpu, num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, + num_actual_tokens=total_num_scheduled_tokens, actual_seq_lengths_q=self.actual_seq_lengths_q, block_table_tensor=self.input_batch.block_table[0].get_device_tensor(), slot_mapping_cpu=self.slot_mapping_cpu, @@ -810,7 +809,6 @@ def get_eagle_atten_dict( attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, - decode_token_per_req=self.decode_token_per_req, max_num_blocks_per_req=self.max_num_blocks_per_req, ) attn_metadata_i = self.attn_metadata_builder.build(common_attn_metadata) @@ -1218,11 +1216,9 @@ def _process_reqs( common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens=self.seq_lens, seq_lens_cpu=self.seq_lens_cpu, num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, actual_seq_lengths_q=self.actual_seq_lengths_q, block_table_tensor=self.input_batch.block_table[0].get_device_tensor(), slot_mapping_cpu=self.slot_mapping_cpu, @@ -1230,8 +1226,6 @@ def _process_reqs( attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, - decode_token_per_req=self.decode_token_per_req, - max_num_blocks_per_req=self.max_num_blocks_per_req, enable_dbo_across_dp=enable_dbo, is_only_prefill=is_only_prefill, graph_pad_size=self.graph_pad_size diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 00eab740c4..840d7de022 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -170,7 +170,6 @@ def propose( common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.runner.query_start_loc[:batch_size + 1], query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + 1], - seq_lens=target_positions[last_token_indices] + 1, seq_lens_cpu=target_positions.cpu()[last_token_indices] + 1, num_reqs=batch_size, num_actual_tokens=num_tokens, @@ -182,8 +181,6 @@ def propose( attn_mask=self.runner.attn_mask, spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, - decode_token_per_req=self.runner.decode_token_per_req, - max_num_blocks_per_req=self.runner.max_num_blocks_per_req, graph_pad_size=extra_builder_kwargs['graph_pad_size'] ) attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata) @@ -302,7 +299,6 @@ def dummy_run(self, actual_seq_lengths_q=self.runner.actual_seq_lengths_q, attn_mask=self.runner.attn_mask, spec_attn_mask=self.runner.spec_attn_mask, - decode_token_per_req=self.runner.decode_token_per_req, ) attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(common_attn_metadata) From 677ac78a9b6f4eff1da00024a6d6f2b5d9ab47b8 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Fri, 15 Aug 2025 18:40:11 +0800 Subject: [PATCH 04/22] refact attn metadata build Signed-off-by: weiguihua2 --- .../attention/attention_v1_torchair.py | 8 +++---- vllm_ascend/attention/mla_v1.py | 20 ++++++++--------- vllm_ascend/attention/utils.py | 22 ++++++++----------- vllm_ascend/worker/eagle_proposer_v1.py | 3 ++- vllm_ascend/worker/model_runner_v1.py | 16 +++++++++----- vllm_ascend/worker/mtp_proposer_v1.py | 18 ++++++++------- 6 files changed, 45 insertions(+), 42 deletions(-) diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py index 9f6d6b2cbf..1df834fb99 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -33,7 +33,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d) from vllm_ascend.worker.npu_input_batch import InputBatch -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, get_decode_token_per_req +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata class AscendAttentionTorchairBackend(AttentionBackend): @@ -157,7 +157,7 @@ def __init__(self, self.device = device self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, vllm_config.cache_config.block_size) - self.decode_token_per_req = get_decode_token_per_req(vllm_config.speculative_config) + self.max_blocks = (self.model_config.max_model_len + vllm_config.cache_config.block_size - 1) // vllm_config.cache_config.block_size def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -175,7 +175,7 @@ def build_torchair_graph_dummy( self, common_attn_metadata: AscendCommonAttentionMetadata) -> AscendTorchairMetadata: device = self.device num_reqs = common_attn_metadata.num_reqs - _, max_blocks = self.runner.graph_block_tables.shape + _, max_blocks = self.max_blocks block_table = torch.zeros((num_reqs, max_blocks), dtype=torch.int32, device=device) @@ -257,7 +257,7 @@ def build(self, pad_value = 0 num_token_pad_size = graph_pad_size - num_actual_tokens num_reqs_pad_size = ( - graph_pad_size // self.decode_token_per_req - + graph_pad_size // common_attn_metadata.decode_token_per_req - num_reqs) pad_value = 1 padded_seq_lens = seq_lens.tolist() + [pad_value diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 2a0fe96315..338c07e190 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -25,7 +25,7 @@ from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import npu_prefetch from vllm_ascend.worker.npu_input_batch import InputBatch -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,split_decodes_and_prefills, get_decode_token_per_req) +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,split_decodes_and_prefills) if TYPE_CHECKING: @@ -186,7 +186,6 @@ def __init__(self, scheduler_config = vllm_config.scheduler_config self.block_size = vllm_config.cache_config.block_size self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size - self.decode_token_per_req = get_decode_token_per_req(vllm_config.speculative_config) self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -288,13 +287,13 @@ def build_torchair_graph_dummy( self, common_attn_metadata: AscendCommonAttentionMetadata,) -> AscendMLAMetadata: device = self.device num_reqs = common_attn_metadata.num_reqs - _, max_blocks = self.runner.graph_block_tables.shape + _, max_blocks = self.max_blocks block_table = torch.zeros((num_reqs, max_blocks), dtype=torch.int32, device=device) block_table = self._get_graph_runner_block_tables( num_reqs, block_table) - num_tokens = num_reqs * self.decode_token_per_req + num_tokens = num_reqs * common_attn_metadata.decode_token_per_req seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) seq_lens_list = [0] * num_reqs input_positions = torch.zeros(num_tokens, @@ -382,8 +381,8 @@ def build( input_positions = common_attn_metadata.positions[:num_actual_tokens].long() if self.cos_cache is None: - self.cos_cache = model.layers[0].self_attn.rotary_emb.cos_cached - self.sin_cache = model.layers[0].self_attn.rotary_emb.sin_cached + self.cos_cache = model.model.layers[0].self_attn.rotary_emb.cos_cached + self.sin_cache = model.model.layers[0].self_attn.rotary_emb.sin_cached if self.cos_cache.dtype != self.model_config.dtype: # type: ignore self.cos_cache = self.cos_cache.to( # type: ignore self.model_config.dtype) # type: ignore @@ -392,10 +391,9 @@ def build( query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] query_lens = query_seq_lens_cpu[:num_reqs] - num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu - - query_seq_lens_cpu) - seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + num_computed_tokens_cpu = (seq_lens - query_lens) + prefill_metadata = None chunked_context_metadata = None if num_prefills > 0: @@ -418,12 +416,12 @@ def build( assert max_context_chunk > 0 num_chunks = cdiv(max_context_len_cpu, max_context_chunk) chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk + .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) cu_seq_lens_cpu = torch.zeros(num_chunks, - self._num_prefills + 1, + num_prefills + 1, dtype=torch.int32, pin_memory=True) torch.cumsum(chunk_seq_lens, diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index a0c84e4739..35dd5f21fa 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,7 +1,7 @@ from dataclasses import dataclass +from enum import Enum from vllm.config import SpeculativeConfig -from vllm_ascend.attention.attention_v1 import AscendAttentionState import torch @@ -28,18 +28,20 @@ class AscendCommonAttentionMetadata: num_actual_tokens: int """Total number of tokens in batch""" - actual_seq_lengths_q: list[int] = None + max_query_len: int + + decode_token_per_req: int block_table_tensor: torch.Tensor slot_mapping_cpu: torch.Tensor + actual_seq_lengths_q: list[int] = None + positions: torch.Tensor = None attn_mask: torch.Tensor = None spec_attn_mask: torch.Tensor = None - attn_state: AscendAttentionState = None - - max_query_len: int + attn_state: Enum = None enable_dbo_across_dp: bool = False @@ -61,6 +63,8 @@ class TorchairCommonAttentionMetadata: num_actual_tokens: int """Total number of tokens in batch""" + decode_token_per_req: int + actual_seq_lengths_q: list[int] = None attn_mask: torch.Tensor = None @@ -110,11 +114,3 @@ def split_decodes_and_prefills( num_prefill_tokens = num_tokens - num_decode_tokens return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) - -def get_decode_token_per_req(speculative_config: SpeculativeConfig): - decode_token_per_req = 1 - if not speculative_config: - return decode_token_per_req - spec_token_num = speculative_config.num_speculative_tokens - assert spec_token_num > 0 - return decode_token_per_req + spec_token_num diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py index b1448c887a..6cf6427f03 100644 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ b/vllm_ascend/worker/eagle_proposer_v1.py @@ -140,9 +140,10 @@ def propose( attn_mask=self.runner.attn_mask, spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, ) # FIXME(woosuk): The below two ops cause synchronization. Optimize. - attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata) + attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata, self.runner.model) if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0127f5a0c0..a6c74a1c95 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -92,7 +92,7 @@ from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, get_decode_token_per_req +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] @@ -229,9 +229,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer]] = None self.actual_seq_lengths_q = [] - self.decode_token_per_req = get_decode_token_per_req(self.speculative_config) + self.decode_token_per_req = 1 if self.speculative_config: self.use_spec_decode = True + spec_token_num = self.speculative_config.num_speculative_tokens + assert spec_token_num > 0 + self.decode_token_per_req = 1 + spec_token_num self.actual_seq_lengths_q = [ len for len in range(self.decode_token_per_req, self.max_num_tokens + @@ -810,8 +813,9 @@ def get_eagle_atten_dict( spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, max_num_blocks_per_req=self.max_num_blocks_per_req, + decode_token_per_req=self.decode_token_per_req, ) - attn_metadata_i = self.attn_metadata_builder.build(common_attn_metadata) + attn_metadata_i = self.attn_metadata_builder.build(common_attn_metadata, self.model) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1228,9 +1232,11 @@ def _process_reqs( attn_state=self.attn_state, enable_dbo_across_dp=enable_dbo, is_only_prefill=is_only_prefill, - graph_pad_size=self.graph_pad_size + max_query_len=max_num_scheduled_tokens, + graph_pad_size=self.graph_pad_size, + decode_token_per_req=self.decode_token_per_req, ) - attn_metadata = self.attn_metadata_builder.build(common_attn_metadata) + attn_metadata = self.attn_metadata_builder.build(common_attn_metadata, self.model) if self.vllm_config.model_config.use_mla: attn_metadata.num_input_tokens = num_input_tokens diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 840d7de022..80eed56f6b 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -181,9 +181,10 @@ def propose( attn_mask=self.runner.attn_mask, spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, - graph_pad_size=extra_builder_kwargs['graph_pad_size'] + graph_pad_size=extra_builder_kwargs['graph_pad_size'], + decode_token_per_req=self.runner.decode_token_per_req, ) - attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata) + attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata, self.runner.model) self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states @@ -294,12 +295,13 @@ def dummy_run(self, attn_metadata = None else: common_attn_metadata = TorchairCommonAttentionMetadata( - num_reqs=num_reqs, - num_actual_tokens=1, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - ) + num_reqs=num_reqs, + num_actual_tokens=1, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + decode_token_per_req=self.runner.decode_token_per_req, + ) attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(common_attn_metadata) input_ids = self.input_ids[:num_tokens] From 52f99e6857ba0dc7b8b5c9b66f5a66895f34d323 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Fri, 15 Aug 2025 19:04:26 +0800 Subject: [PATCH 05/22] refact attn metadata build Signed-off-by: weiguihua2 --- vllm_ascend/attention/attention_v1_torchair.py | 3 +-- vllm_ascend/attention/mla_v1.py | 3 +-- vllm_ascend/torchair/torchair_model_runner.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py index 1df834fb99..c1619c8908 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -175,8 +175,7 @@ def build_torchair_graph_dummy( self, common_attn_metadata: AscendCommonAttentionMetadata) -> AscendTorchairMetadata: device = self.device num_reqs = common_attn_metadata.num_reqs - _, max_blocks = self.max_blocks - block_table = torch.zeros((num_reqs, max_blocks), + block_table = torch.zeros((num_reqs, self.max_blocks), dtype=torch.int32, device=device) block_table = self._get_graph_runner_block_tables( diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 338c07e190..05a0ac7091 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -287,8 +287,7 @@ def build_torchair_graph_dummy( self, common_attn_metadata: AscendCommonAttentionMetadata,) -> AscendMLAMetadata: device = self.device num_reqs = common_attn_metadata.num_reqs - _, max_blocks = self.max_blocks - block_table = torch.zeros((num_reqs, max_blocks), + block_table = torch.zeros((num_reqs, self.max_blocks), dtype=torch.int32, device=device) block_table = self._get_graph_runner_block_tables( diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index d92f9ea1e9..e0b99c7e82 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -76,7 +76,7 @@ def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn): actual_seq_lengths_q=self.actual_seq_lengths_q, attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, - attn_state=self.attn_state, + decode_token_per_req=self.decode_token_per_req, ) attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(common_attn_metadata) else: From ef3b644d86a3f80736a229cc81f73e3bb38094f0 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Sat, 16 Aug 2025 18:06:30 +0800 Subject: [PATCH 06/22] refact attn metadata build Signed-off-by: weiguihua2 --- vllm_ascend/attention/utils.py | 2 +- vllm_ascend/worker/eagle_proposer_v1.py | 2 +- vllm_ascend/worker/model_runner_v1.py | 2 +- vllm_ascend/worker/mtp_proposer_v1.py | 8 ++++---- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 35dd5f21fa..9db1696b23 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -106,7 +106,7 @@ def split_decodes_and_prefills( return num_reqs, 0, num_tokens, 0 first_prefill = is_prefill.int().argmax(dim=-1).item() - assert torch.all(query_lens[first_prefill:] > decode_threshold) + assert torch.all(query_lens[first_prefill:] >= decode_threshold) assert torch.all(query_lens[:first_prefill] <= decode_threshold) num_decodes = first_prefill num_prefills = num_reqs - num_decodes diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py index 6cf6427f03..dcf224a0cd 100644 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ b/vllm_ascend/worker/eagle_proposer_v1.py @@ -128,7 +128,7 @@ def propose( common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.runner.query_start_loc[:batch_size + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:batch_size + 1], + query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + 1], seq_lens_cpu=self.runner.seq_lens_cpu, max_query_len=max_query_len, num_reqs=batch_size, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a6c74a1c95..221c8b4227 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -815,7 +815,7 @@ def get_eagle_atten_dict( max_num_blocks_per_req=self.max_num_blocks_per_req, decode_token_per_req=self.decode_token_per_req, ) - attn_metadata_i = self.attn_metadata_builder.build(common_attn_metadata, self.model) + attn_metadata_i = self.attn_metadata_builder.build(common_attn_metadata, self.get_model()) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 80eed56f6b..6f28c78bb1 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -168,9 +168,9 @@ def propose( num_input_tokens = num_tokens common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.runner.query_start_loc[:batch_size + 1], - query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + 1], - seq_lens_cpu=target_positions.cpu()[last_token_indices] + 1, + query_start_loc=cu_num_tokens[:batch_size + 1], + query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(), + seq_lens_cpu=seq_lens.cpu(), num_reqs=batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, @@ -184,7 +184,7 @@ def propose( graph_pad_size=extra_builder_kwargs['graph_pad_size'], decode_token_per_req=self.runner.decode_token_per_req, ) - attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata, self.runner.model) + attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata, self.runner.get_model()) self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states From e28207fe36e4d8a5b2580e12321f98c2bb73e896 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Sun, 17 Aug 2025 22:36:53 +0800 Subject: [PATCH 07/22] refact attn metadata build Signed-off-by: weiguihua2 --- vllm_ascend/attention/attention_v1.py | 35 ++++++---- .../attention/attention_v1_torchair.py | 54 ++++++++++------ vllm_ascend/attention/mla_v1.py | 64 ++++++++++++------- vllm_ascend/attention/utils.py | 17 +++-- vllm_ascend/torchair/torchair_model_runner.py | 5 +- vllm_ascend/worker/eagle_proposer_v1.py | 9 ++- vllm_ascend/worker/model_runner_v1.py | 21 +++--- vllm_ascend/worker/mtp_proposer_v1.py | 24 +++---- 8 files changed, 138 insertions(+), 91 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 458be09e46..87d698509e 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -20,21 +20,21 @@ from typing import List, Optional, Tuple, Type import torch -import torch_npu import torch.nn as nn -from vllm.config import VllmConfig +import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.config import VllmConfig from vllm.forward_context import ForwardContext, get_forward_context -from vllm.utils import direct_register_custom_op, cdiv +from vllm.utils import cdiv, direct_register_custom_op from vllm.v1.core.sched.output import SchedulerOutput +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec) from vllm_ascend.worker.npu_input_batch import InputBatch -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata class AscendAttentionBackend(AttentionBackend): @@ -160,9 +160,11 @@ class AscendMetadata: class AscendAttentionMetadataBuilder: - def __init__(self, + def __init__( + self, vllm_config: VllmConfig, - device: torch.device,): + device: torch.device, + ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.device = device @@ -173,12 +175,16 @@ def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return False - def build(self, + def build( + self, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module,): + model: nn.Module, + ): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_reqs + + 1] block_table = common_attn_metadata.block_table_tensor block_table[:num_reqs, :self.max_num_blocks_per_req] = ( @@ -186,11 +192,16 @@ def build(self, query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - slot_mapping = common_attn_metadata.slot_mapping_cpu[:num_actual_tokens].to( - self.device, non_blocking=True) + slot_mapping = common_attn_metadata.slot_mapping_cpu[: + num_actual_tokens].to( + self.device, + non_blocking= + True) attn_mask = common_attn_metadata.attn_mask attn_state = common_attn_metadata.attn_state - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_reqs + + 1] query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py index c1619c8908..ec9523f2dd 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -20,20 +20,20 @@ import numpy as np import torch -import torch_npu import torch.nn as nn -from vllm.config import VllmConfig -from vllm.utils import cdiv +import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState +from vllm.config import VllmConfig +from vllm.utils import cdiv from vllm.v1.core.sched.output import SchedulerOutput from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d) from vllm_ascend.worker.npu_input_batch import InputBatch -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata class AscendAttentionTorchairBackend(AttentionBackend): @@ -149,15 +149,19 @@ class AscendTorchairMetadata: class AscendAttentionTorchairMetadataBuilder: - def __init__(self, + def __init__( + self, vllm_config: VllmConfig, - device: torch.device,): + device: torch.device, + ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.device = device self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, vllm_config.cache_config.block_size) - self.max_blocks = (self.model_config.max_model_len + vllm_config.cache_config.block_size - 1) // vllm_config.cache_config.block_size + self.max_blocks = (self.model_config.max_model_len + + vllm_config.cache_config.block_size - + 1) // vllm_config.cache_config.block_size def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -166,13 +170,12 @@ def reorder_batch(self, input_batch: "InputBatch", def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: num_blocks = block_tables.size(1) - if num_blocks <= self.max_blocks: - return block_tables[:num_seqs, :num_blocks] - else: - return block_tables[:num_seqs, :self.max_blocks] + num_blocks = min(num_blocks, self.max_blocks) + return block_tables[:num_seqs, :num_blocks] def build_torchair_graph_dummy( - self, common_attn_metadata: AscendCommonAttentionMetadata) -> AscendTorchairMetadata: + self, common_attn_metadata: AscendCommonAttentionMetadata + ) -> AscendTorchairMetadata: device = self.device num_reqs = common_attn_metadata.num_reqs block_table = torch.zeros((num_reqs, self.max_blocks), @@ -210,9 +213,11 @@ def build_torchair_graph_dummy( decode=decode_metadata) return attn_metadata - def build(self, + def build( + self, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module,): + model: nn.Module, + ): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens @@ -221,8 +226,11 @@ def build(self, block_table[:num_reqs]) seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - slot_mapping = common_attn_metadata.slot_mapping_cpu[:num_actual_tokens].to( - self.device, non_blocking=True) + slot_mapping = common_attn_metadata.slot_mapping_cpu[: + num_actual_tokens].to( + self.device, + non_blocking= + True) attn_mask = common_attn_metadata.attn_mask attn_state = common_attn_metadata.attn_state @@ -230,14 +238,18 @@ def build(self, mask_nz = nd_to_nz_2d(attn_mask) attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29) - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_reqs + + 1] query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] # input_positions = common_attn_metadata.positions_cpu[:num_actual_tokens].to( # device, non_blocking=True).long() - - input_positions = common_attn_metadata.positions[:num_actual_tokens].long() + + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) decode_metadata = None graph_pad_size = common_attn_metadata.graph_pad_size @@ -256,8 +268,8 @@ def build(self, pad_value = 0 num_token_pad_size = graph_pad_size - num_actual_tokens num_reqs_pad_size = ( - graph_pad_size // common_attn_metadata.decode_token_per_req - - num_reqs) + graph_pad_size // + common_attn_metadata.decode_token_per_req - num_reqs) pad_value = 1 padded_seq_lens = seq_lens.tolist() + [pad_value ] * num_reqs_pad_size diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 05a0ac7091..29d5cd812d 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -3,13 +3,13 @@ import numpy as np import torch -import torch_npu import torch.nn as nn +import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.config import get_current_vllm_config, VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) @@ -18,6 +18,8 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + split_decodes_and_prefills) from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn @@ -25,8 +27,6 @@ from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import npu_prefetch from vllm_ascend.worker.npu_input_batch import InputBatch -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,split_decodes_and_prefills) - if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -185,7 +185,8 @@ def __init__(self, self.device = device scheduler_config = vllm_config.scheduler_config self.block_size = vllm_config.cache_config.block_size - self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -278,13 +279,13 @@ def reorder_batch(self, input_batch: "InputBatch", def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: num_blocks = block_tables.size(1) - if num_blocks <= self.max_blocks: - return block_tables[:num_seqs, :num_blocks] - else: - return block_tables[:num_seqs, :self.max_blocks] + num_blocks = min(num_blocks, self.max_blocks) + return block_tables[:num_seqs, :num_blocks] def build_torchair_graph_dummy( - self, common_attn_metadata: AscendCommonAttentionMetadata,) -> AscendMLAMetadata: + self, + common_attn_metadata: AscendCommonAttentionMetadata, + ) -> AscendMLAMetadata: device = self.device num_reqs = common_attn_metadata.num_reqs block_table = torch.zeros((num_reqs, self.max_blocks), @@ -332,7 +333,8 @@ def build_torchair_graph_dummy( seq_lens_list=seq_lens_list, max_seq_lens=1, attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=common_attn_metadata.actual_seq_lengths_q[:num_reqs], + actual_seq_lengths_q=common_attn_metadata. + actual_seq_lengths_q[:num_reqs], sin=sin, cos=cos, ) @@ -362,9 +364,18 @@ def build( num_actual_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + if self.torchair_graph_enabled and common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + decode_threshold = common_attn_metadata.decode_token_per_req + else: + # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding + decode_threshold = 1 num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata) + split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_actual_tokens # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because @@ -372,16 +383,23 @@ def build( device = self.device block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - slot_mapping = common_attn_metadata.slot_mapping_cpu[:num_actual_tokens].to( - device, non_blocking=True) + slot_mapping = common_attn_metadata.slot_mapping_cpu[: + num_actual_tokens].to( + device, + non_blocking= + True) # input_positions = common_attn_metadata.positions_cpu[:num_actual_tokens].to( # device, non_blocking=True).long() - - input_positions = common_attn_metadata.positions[:num_actual_tokens].long() + + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) if self.cos_cache is None: - self.cos_cache = model.model.layers[0].self_attn.rotary_emb.cos_cached - self.sin_cache = model.model.layers[0].self_attn.rotary_emb.sin_cached + self.cos_cache = model.model.layers[ + 0].self_attn.rotary_emb.cos_cached + self.sin_cache = model.model.layers[ + 0].self_attn.rotary_emb.sin_cached if self.cos_cache.dtype != self.model_config.dtype: # type: ignore self.cos_cache = self.cos_cache.to( # type: ignore self.model_config.dtype) # type: ignore @@ -392,7 +410,7 @@ def build( query_lens = query_seq_lens_cpu[:num_reqs] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] num_computed_tokens_cpu = (seq_lens - query_lens) - + prefill_metadata = None chunked_context_metadata = None if num_prefills > 0: @@ -477,8 +495,8 @@ def build( pad_value = 0 num_token_pad_size = graph_pad_size - num_decode_tokens num_reqs_pad_size = ( - graph_pad_size // common_attn_metadata.decode_token_per_req - - num_reqs) + graph_pad_size // + common_attn_metadata.decode_token_per_req - num_reqs) padded_seq_lens = seq_lens.tolist( ) + [pad_value] * num_reqs_pad_size else: @@ -506,8 +524,8 @@ def build( input_positions = torch.cat( [input_positions, position_padding]) actual_seq_lengths_q = query_start_loc[1:].tolist( - ) + common_attn_metadata.actual_seq_lengths_q[num_reqs:num_reqs + - num_reqs_pad_size] + ) + common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] else: seq_lens_list = seq_lens.tolist() # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 9db1696b23..e56c98a28a 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,7 +1,5 @@ from dataclasses import dataclass -from enum import Enum - -from vllm.config import SpeculativeConfig +from typing import Any import torch @@ -29,10 +27,13 @@ class AscendCommonAttentionMetadata: """Total number of tokens in batch""" max_query_len: int + """Max token number of request in batch""" decode_token_per_req: int + """decode token number per request""" block_table_tensor: torch.Tensor + slot_mapping_cpu: torch.Tensor actual_seq_lengths_q: list[int] = None @@ -40,15 +41,18 @@ class AscendCommonAttentionMetadata: positions: torch.Tensor = None attn_mask: torch.Tensor = None + spec_attn_mask: torch.Tensor = None - attn_state: Enum = None - + + attn_state: Any = None + enable_dbo_across_dp: bool = False is_only_prefill: bool = False graph_pad_size: int = -1 + @dataclass class TorchairCommonAttentionMetadata: """ @@ -60,6 +64,7 @@ class TorchairCommonAttentionMetadata: num_reqs: int """Number of requests""" + num_actual_tokens: int """Total number of tokens in batch""" @@ -68,6 +73,7 @@ class TorchairCommonAttentionMetadata: actual_seq_lengths_q: list[int] = None attn_mask: torch.Tensor = None + spec_attn_mask: torch.Tensor = None graph_pad_size: int = -1 @@ -113,4 +119,3 @@ def split_decodes_and_prefills( num_decode_tokens = query_start_loc[first_prefill].item() num_prefill_tokens = num_tokens - num_decode_tokens return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) - diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index e0b99c7e82..9e6d0fbe90 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -25,13 +25,13 @@ from vllm.forward_context import get_forward_context from vllm.logger import logger +from vllm_ascend.attention.utils import TorchairCommonAttentionMetadata from vllm_ascend.platform import NPUPlatform from vllm_ascend.torchair.utils import (check_torchair_cache_exist, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, maybe_converting_weight_acl_format) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner -from vllm_ascend.attention.utils import TorchairCommonAttentionMetadata class NPUTorchairModelRunner(NPUModelRunner): @@ -78,7 +78,8 @@ def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn): spec_attn_mask=self.spec_attn_mask, decode_token_per_req=self.decode_token_per_req, ) - attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(common_attn_metadata) + attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy( + common_attn_metadata) else: attn_metadata = super()._build_attention_metadata( with_prefill, num_reqs, skip_attn) diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py index dcf224a0cd..895649327c 100644 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ b/vllm_ascend/worker/eagle_proposer_v1.py @@ -128,13 +128,15 @@ def propose( common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.runner.query_start_loc[:batch_size + 1], - query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + 1], + query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + + 1], seq_lens_cpu=self.runner.seq_lens_cpu, max_query_len=max_query_len, num_reqs=batch_size, num_actual_tokens=num_tokens, actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0].get_device_tensor(), + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), slot_mapping_cpu=target_slot_mapping, positions=target_positions, attn_mask=self.runner.attn_mask, @@ -143,7 +145,8 @@ def propose( decode_token_per_req=self.runner.decode_token_per_req, ) # FIXME(woosuk): The below two ops cause synchronization. Optimize. - attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata, self.runner.model) + attn_metadata = self.runner.attn_metadata_builder.build( + common_attn_metadata, self.runner.model) if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 221c8b4227..d45a09e12b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -23,7 +23,6 @@ import os import time import types -import weakref from contextlib import contextmanager, nullcontext from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast @@ -80,6 +79,7 @@ AscendMetadata) from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata from vllm_ascend.attention.mla_v1 import AscendMLAMetadata +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, DummyCommImpl, MoECommMethod) @@ -92,7 +92,6 @@ from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] @@ -806,7 +805,8 @@ def get_eagle_atten_dict( max_query_len=max_num_scheduled_tokens, num_actual_tokens=total_num_scheduled_tokens, actual_seq_lengths_q=self.actual_seq_lengths_q, - block_table_tensor=self.input_batch.block_table[0].get_device_tensor(), + block_table_tensor=self.input_batch.block_table[0]. + get_device_tensor(), slot_mapping_cpu=self.slot_mapping_cpu, positions=self.positions, attn_mask=self.attn_mask, @@ -815,7 +815,8 @@ def get_eagle_atten_dict( max_num_blocks_per_req=self.max_num_blocks_per_req, decode_token_per_req=self.decode_token_per_req, ) - attn_metadata_i = self.attn_metadata_builder.build(common_attn_metadata, self.get_model()) + attn_metadata_i = self.attn_metadata_builder.build( + common_attn_metadata, self.get_model()) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1181,8 +1182,6 @@ def _process_reqs( attn_state=attn_state) self.attn_state = attn_state # type: ignore - extra_builder_kwargs = {} - self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens self.query_start_loc[:num_reqs + 1].copy_( @@ -1199,7 +1198,6 @@ def _process_reqs( ] is_only_prefill = bool(np.all(num_valid_tokens != 1)) - extra_builder_kwargs['is_only_prefill'] = is_only_prefill enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), attn_state, @@ -1208,13 +1206,10 @@ def _process_reqs( (padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill, enable_dbo) = self._get_forward_metadata_across_dp_and_pad( total_num_scheduled_tokens, with_prefill, enable_dbo) - extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo self.with_prefill = with_prefill self.num_tokens_across_dp = num_tokens_across_dp if self.torchair_graph_enabled and not with_prefill: self.graph_pad_size = padded_num_tokens_across_dp - extra_builder_kwargs[ - 'graph_pad_size'] = self.graph_pad_size # type: ignore else: self.graph_pad_size = -1 common_attn_metadata = AscendCommonAttentionMetadata( @@ -1224,7 +1219,8 @@ def _process_reqs( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, actual_seq_lengths_q=self.actual_seq_lengths_q, - block_table_tensor=self.input_batch.block_table[0].get_device_tensor(), + block_table_tensor=self.input_batch.block_table[0]. + get_device_tensor(), slot_mapping_cpu=self.slot_mapping_cpu, positions=self.positions, attn_mask=self.attn_mask, @@ -1236,7 +1232,8 @@ def _process_reqs( graph_pad_size=self.graph_pad_size, decode_token_per_req=self.decode_token_per_req, ) - attn_metadata = self.attn_metadata_builder.build(common_attn_metadata, self.model) + attn_metadata = self.attn_metadata_builder.build( + common_attn_metadata, self.model) if self.vllm_config.model_config.use_mla: attn_metadata.num_input_tokens = num_input_tokens diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 6f28c78bb1..07599fc8ce 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -16,9 +16,10 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + TorchairCommonAttentionMetadata) from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP from vllm_ascend.utils import ProfileExecuteDuration -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, TorchairCommonAttentionMetadata class MtpProposer: @@ -89,7 +90,7 @@ def prepare_inputs( # FIXME(woosuk): Avoid synchronization. num_tokens = cu_num_tokens[-1].item() - token_indices = torch.empty( + token_indices = torch.zeros( num_tokens, dtype=torch.int32, device=cu_num_tokens.device, @@ -137,9 +138,6 @@ def propose( # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] if token_indices is not None and self.runner.torchair_graph_enabled: last_token_indices = token_indices - else: - seq_lens = target_positions[last_token_indices] + 1 - seq_lens = seq_lens.cpu() self.input_ids[last_token_indices] = next_token_ids @@ -156,17 +154,16 @@ def propose( # input_batch=self.runner.input_batch, # scheduler_output=self.runner.scheduler_output, # ) - extra_builder_kwargs = {} - is_running_torchair = self.runner.torchair_graph_enabled and \ not self.runner.with_prefill if is_running_torchair: - extra_builder_kwargs['graph_pad_size'] = self.runner.graph_pad_size num_input_tokens = self.runner.graph_pad_size else: num_input_tokens = num_tokens + seq_lens = target_positions[last_token_indices] + 1 + seq_lens = seq_lens.int() common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=cu_num_tokens[:batch_size + 1], query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(), @@ -175,16 +172,18 @@ def propose( num_actual_tokens=num_tokens, max_query_len=max_query_len, actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0].get_device_tensor(), + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), slot_mapping_cpu=target_slot_mapping, positions=target_positions, attn_mask=self.runner.attn_mask, spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, - graph_pad_size=extra_builder_kwargs['graph_pad_size'], + graph_pad_size=self.runner.graph_pad_size, decode_token_per_req=self.runner.decode_token_per_req, ) - attn_metadata = self.runner.attn_metadata_builder.build(common_attn_metadata, self.runner.get_model()) + attn_metadata = self.runner.attn_metadata_builder.build( + common_attn_metadata, self.runner.get_model()) self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states @@ -302,7 +301,8 @@ def dummy_run(self, spec_attn_mask=self.runner.spec_attn_mask, decode_token_per_req=self.runner.decode_token_per_req, ) - attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(common_attn_metadata) + attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( + common_attn_metadata) input_ids = self.input_ids[:num_tokens] positions = self.positions[:num_tokens] From 2acd1b93fe8e1121612566223444f7524faa5dab Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Sun, 17 Aug 2025 22:47:39 +0800 Subject: [PATCH 08/22] refact attn metadata build Signed-off-by: weiguihua2 --- vllm_ascend/attention/utils.py | 6 +++--- vllm_ascend/worker/model_runner_v1.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index e56c98a28a..2c2390d088 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any +from typing import Any, Optional import torch @@ -36,7 +36,7 @@ class AscendCommonAttentionMetadata: slot_mapping_cpu: torch.Tensor - actual_seq_lengths_q: list[int] = None + actual_seq_lengths_q: Optional[list[int]] = None positions: torch.Tensor = None @@ -70,7 +70,7 @@ class TorchairCommonAttentionMetadata: decode_token_per_req: int - actual_seq_lengths_q: list[int] = None + actual_seq_lengths_q: Optional[list[int]] = None attn_mask: torch.Tensor = None diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d45a09e12b..e0d9ddf755 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -812,7 +812,6 @@ def get_eagle_atten_dict( attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, - max_num_blocks_per_req=self.max_num_blocks_per_req, decode_token_per_req=self.decode_token_per_req, ) attn_metadata_i = self.attn_metadata_builder.build( From 68071ea3e31470c7030b4db4d6a91b5d4f871c40 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Sun, 17 Aug 2025 22:59:51 +0800 Subject: [PATCH 09/22] refact attn metadata build Signed-off-by: weiguihua2 --- vllm_ascend/attention/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 2c2390d088..fc93c26427 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -36,7 +36,7 @@ class AscendCommonAttentionMetadata: slot_mapping_cpu: torch.Tensor - actual_seq_lengths_q: Optional[list[int]] = None + actual_seq_lengths_q: Optional[list[int]] positions: torch.Tensor = None From 2a926bbce3041ce656337a53f4af89127b8889ad Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Sun, 17 Aug 2025 23:07:37 +0800 Subject: [PATCH 10/22] refact attn metadata build Signed-off-by: weiguihua2 --- vllm_ascend/attention/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index fc93c26427..281310bc29 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import torch @@ -36,7 +36,7 @@ class AscendCommonAttentionMetadata: slot_mapping_cpu: torch.Tensor - actual_seq_lengths_q: Optional[list[int]] + actual_seq_lengths_q: list[int] positions: torch.Tensor = None @@ -70,7 +70,7 @@ class TorchairCommonAttentionMetadata: decode_token_per_req: int - actual_seq_lengths_q: Optional[list[int]] = None + actual_seq_lengths_q: list[int] attn_mask: torch.Tensor = None From 35cb714faa824794129a68492d0e1b3487e236b8 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 18 Aug 2025 12:27:33 +0800 Subject: [PATCH 11/22] refact attn metadata build Signed-off-by: weiguihua2 --- tests/ut/attention/test_attention_v1.py | 115 +++++++------- tests/ut/attention/test_mla_v1.py | 146 ++++++++++-------- .../attention/attention_v1_torchair.py | 23 ++- vllm_ascend/attention/mla_v1.py | 21 ++- 4 files changed, 182 insertions(+), 123 deletions(-) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index e8fe7ab6bf..36499fc584 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -9,6 +9,7 @@ AscendAttentionState, AscendMetadata, CommonAttentionState) +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata class TestAscendAttentionBackend(TestBase): @@ -67,8 +68,11 @@ def test_copy_blocks(self): class TestAscendAttentionMetadataBuilder(TestBase): def setUp(self): - self.mock_runner = MagicMock() - self.builder = AscendAttentionMetadataBuilder(self.mock_runner) + self.mock_vllm_config = MagicMock() + self.mock_vllm_config.model_config.max_model_len = 640 + self.mock_vllm_config.cache_config.block_size = 64 + self.mock_device = 'cpu:0' + self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config, self.mock_device) def test_reorder_batch(self): mock_input_batch = MagicMock() @@ -86,30 +90,31 @@ def test_reorder_batch(self): def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d, mock_npu_format_cast, mock_ascend_metadata): - num_reqs = 2 - num_actual_tokens = 10 - max_query_len = 5 - - self.mock_runner.input_batch.block_table = [MagicMock()] - self.mock_runner.input_batch.block_table[ - 0].get_device_tensor.return_value = torch.zeros((10, 10)) - self.mock_runner.max_num_blocks_per_req = 10 - self.mock_runner.query_lens = torch.tensor([3, 4]) - self.mock_runner.seq_lens_cpu = torch.tensor([5, 6]) - self.mock_runner.slot_mapping_cpu = torch.tensor(range(20)) - self.mock_runner.device = 'cpu:0' - self.mock_runner.attn_mask = torch.ones((10, 10)) - self.mock_runner.attn_state = AscendAttentionState.PrefillNoCache - self.mock_runner.query_start_loc_cpu = torch.tensor([0, 3, 7]) + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=torch.tensor([0, 3, 7]), + query_start_loc_cpu=torch.tensor([0, 3, 7]), + seq_lens_cpu=torch.tensor([5, 6]), + num_reqs=2, + num_actual_tokens=10, + max_query_len=5, + decode_token_per_req=torch.tensor([1, 1]), + block_table_tensor=torch.zeros((10, 10)), + slot_mapping_cpu=torch.tensor(range(20)), + actual_seq_lengths_q=torch.tensor([0, 1]), + positions=torch.tensor([10, 10]), + attn_mask=torch.ones((10, 10)), + spec_attn_mask=None, + attn_state=AscendAttentionState.PrefillNoCache + ) mock_nz_tensor = MagicMock() + mock_model = MagicMock() mock_nd_to_nz_2d.return_value = mock_nz_tensor mock_npu_format_cast.return_value = mock_nz_tensor self.builder.build( - num_reqs, - num_actual_tokens, - max_query_len, + common_attn_metadata, + mock_model ) @patch('vllm_ascend.attention.attention_v1.AscendMetadata') @@ -120,51 +125,55 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d, def test_build_chunked_prefill(self, mock_ascend_attention_state, mock_is_310p, mock_nd_to_nz_spec, mock_npu_format_cast, mock_ascend_metadata): - num_reqs = 3 - num_actual_tokens = 15 - max_query_len = 6 - - self.mock_runner.input_batch.block_table = [MagicMock()] - self.mock_runner.input_batch.block_table[ - 0].get_device_tensor.return_value = torch.zeros((10, 10)) - self.mock_runner.max_num_blocks_per_req = 10 - self.mock_runner.query_lens = torch.tensor([2, 3, 4]) - self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6]) - self.mock_runner.slot_mapping_cpu = torch.tensor(range(20)) - self.mock_runner.device = 'cpu:0' - self.mock_runner.attn_mask = torch.ones((15, 15)) - self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill - self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9]) + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=torch.tensor([0, 2, 5, 9]), + query_start_loc_cpu=torch.tensor([0, 2, 5, 9]), + seq_lens_cpu=torch.tensor([4, 5, 6]), + num_reqs=3, + num_actual_tokens=15, + max_query_len=6, + decode_token_per_req=torch.tensor([1, 1, 1]), + block_table_tensor=torch.zeros((10, 10)), + slot_mapping_cpu=torch.tensor(range(20)), + actual_seq_lengths_q=torch.tensor([0, 1, 2]), + positions=torch.tensor([10, 10]), + attn_mask=torch.ones((15, 15)), + spec_attn_mask=None, + attn_state=AscendAttentionState.ChunkedPrefill + ) mock_ascend_attention_state = MagicMock() mock_ascend_attention_state.PrefillNoCache = 0 mock_nz_tensor = MagicMock() + mock_model = MagicMock() mock_nd_to_nz_spec.return_value = mock_nz_tensor mock_npu_format_cast.return_value = mock_nz_tensor - self.builder.build(num_reqs, num_actual_tokens, max_query_len) + self.builder.build(common_attn_metadata, mock_model) @patch('vllm_ascend.attention.attention_v1.AscendMetadata') @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata): - num_reqs = 3 - num_actual_tokens = 15 - max_query_len = 6 - - self.mock_runner.input_batch.block_table = [MagicMock()] - self.mock_runner.input_batch.block_table[ - 0].get_device_tensor.return_value = torch.zeros((10, 10)) - self.mock_runner.max_num_blocks_per_req = 10 - self.mock_runner.query_lens = torch.tensor([2, 3, 4]) - self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6]) - self.mock_runner.slot_mapping_cpu = torch.tensor(range(20)) - self.mock_runner.device = 'cpu:0' - self.mock_runner.attn_mask = torch.ones((15, 15)) - self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill - self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9]) - - self.builder.build(num_reqs, num_actual_tokens, max_query_len) + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=torch.tensor([0, 2, 5, 9]), + query_start_loc_cpu=torch.tensor([0, 2, 5, 9]), + seq_lens_cpu=torch.tensor([4, 5, 6]), + num_reqs=3, + num_actual_tokens=15, + max_query_len=6, + decode_token_per_req=torch.tensor([1, 1, 1]), + block_table_tensor=torch.zeros((10, 10)), + slot_mapping_cpu=torch.tensor(range(20)), + actual_seq_lengths_q=torch.tensor([0, 1, 2]), + positions=torch.tensor([10, 10]), + attn_mask=torch.ones((15, 15)), + spec_attn_mask=None, + attn_state=AscendAttentionState.ChunkedPrefill + ) + mock_model = MagicMock() + + self.builder.build(common_attn_metadata, mock_model) class TestAscendAttentionBackendImpl(TestBase): diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 497b7b53ab..4cb896ed91 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -12,6 +12,7 @@ AscendMLAImpl, AscendMLAMetadata, AscendMLAMetadataBuilder, AscendMLAPrefillMetadata) +from vllm_ascend.attention.utils import TorchairCommonAttentionMetadata class TestAscendMLABackend(TestBase): @@ -178,40 +179,39 @@ def test_ascend_mla_metadata_default(self): class TestAscendMLAMetadataBuilder(TestBase): def test_ascend_mla_metadata_builder_default(self): - runner = MagicMock() - runner.scheduler_config = MagicMock() - runner.model_config = MagicMock() - runner.scheduler_config.max_num_seqs = 4 - runner.model_config.max_model_len = 1024 - runner.model_config.get_head_size.return_value = 64 - runner.model_config.dtype = torch.float16 - runner.chunked_prefill_enabled = False - runner.device = "cpu" - runner.block_size = 16 - runner.decode_token_per_req = 1 + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.model_config.get_head_size.return_value = 64 + mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' ascend_config = MagicMock() ascend_config.torchair_graph_config = MagicMock() ascend_config.torchair_graph_config.enabled = True with patch("vllm_ascend.attention.mla_v1.get_ascend_config", return_value=ascend_config): - builder = AscendMLAMetadataBuilder(runner) + builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) - self.assertEqual(builder.runner, runner) - self.assertEqual(builder.block_size, runner.block_size) + self.assertEqual(builder.block_size, mock_vllm_config.cache_config.block_size) self.assertEqual(builder.chunked_prefill_enabled, - runner.chunked_prefill_enabled) + mock_vllm_config.scheduler_config.chunked_prefill_enabled) self.assertEqual(builder.torchair_graph_enabled, True) @patch("vllm_ascend.attention.mla_v1.get_ascend_config") def test_reorder_batch_with_torchair_graph(self, ascend_config): - runner = MagicMock() - runner.chunked_prefill_enabled = False - runner.decode_token_per_req = 1 + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' ascend_config.torchair_graph_config = MagicMock() ascend_config.torchair_graph_config.enabled = True - builder = AscendMLAMetadataBuilder(runner) + builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) input_batch = MagicMock() input_batch.req_ids = [0, 1, 2, 3] @@ -230,22 +230,23 @@ def test_reorder_batch_with_torchair_graph(self, ascend_config): modified = builder.reorder_batch(input_batch, scheduler_output) self.assertFalse(modified) - self.assertEqual(builder._num_decodes, 4) - self.assertEqual(builder._num_prefills, 0) - self.assertEqual(builder._num_decode_tokens, 7) - self.assertEqual(builder._num_prefill_tokens, 0) input_batch.swap_states.assert_not_called() def test_reorder_batch_without_torchair_graph(self): ascend_config = MagicMock() - runner = MagicMock() - runner.chunked_prefill_enabled = False - runner.decode_token_per_req = 1 ascend_config.torchair_graph_config = MagicMock() ascend_config.torchair_graph_config.enabled = False + + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + with patch("vllm_ascend.attention.mla_v1.get_ascend_config", return_value=ascend_config): - builder = AscendMLAMetadataBuilder(runner) + builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) input_batch = MagicMock() input_batch.req_ids = [0, 1, 2, 3] @@ -264,10 +265,6 @@ def test_reorder_batch_without_torchair_graph(self): modified = builder.reorder_batch(input_batch, scheduler_output) self.assertTrue(modified) - self.assertEqual(builder._num_decodes, 2) - self.assertEqual(builder._num_prefills, 2) - self.assertEqual(builder._num_decode_tokens, 2) - self.assertEqual(builder._num_prefill_tokens, 5) input_batch.swap_states.assert_called_once_with(1, 2) @patch("vllm_ascend.attention.mla_v1.get_ascend_config") @@ -275,11 +272,13 @@ def test_get_graph_runner_block_tables_normal(self, mock_ascend_config): ascend_config = MagicMock() mock_ascend_config.return_value = ascend_config ascend_config.torchair_graph_config.enabled = False - runner = MagicMock() - runner.graph_block_tables = torch.zeros((8, 64), dtype=torch.int32) - runner.chunked_prefill_enabled = False - runner.decode_token_per_req = 1 - builder = AscendMLAMetadataBuilder(runner=runner) + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) result = builder._get_graph_runner_block_tables(3, block_tables) @@ -292,11 +291,13 @@ def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config): ascend_config = MagicMock() mock_ascend_config.return_value = ascend_config ascend_config.torchair_graph_config.enabled = False - runner = MagicMock() - runner.graph_block_tables = torch.zeros((8, 4), dtype=torch.int32) - runner.chunked_prefill_enabled = False - runner.decode_token_per_req = 1 - builder = AscendMLAMetadataBuilder(runner=runner) + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 64 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) result = builder._get_graph_runner_block_tables(3, block_tables) @@ -310,11 +311,13 @@ def test_get_graph_runner_block_tables_from_numpy(self, ascend_config = MagicMock() mock_ascend_config.return_value = ascend_config ascend_config.torchair_graph_config.enabled = False - runner = MagicMock() - runner.graph_block_tables = np.zeros((8, 64), dtype=np.int32) - runner.chunked_prefill_enabled = False - runner.decode_token_per_req = 1 - builder = AscendMLAMetadataBuilder(runner=runner) + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) @@ -329,38 +332,55 @@ def test_build_dummy(self, mock_ascend_config): ascend_config = MagicMock() mock_ascend_config.return_value = ascend_config ascend_config.torchair_graph_config.enabled = False - runner = MagicMock() - runner.model_config = MagicMock() - runner.device = "cpu" - runner.graph_block_tables = torch.zeros((8, 64), dtype=torch.int32) - runner.model_config.get_head_size.return_value = 64 - runner.chunked_prefill_enabled = False - runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool) - runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool) - runner.dtype = torch.float16 - runner.decode_token_per_req = 1 - - builder = AscendMLAMetadataBuilder(runner=runner, + # runner = MagicMock() + # runner.model_config = MagicMock() + # runner.device = "cpu" + # runner.graph_block_tables = torch.zeros((8, 64), dtype=torch.int32) + # runner.model_config.get_head_size.return_value = 64 + # runner.chunked_prefill_enabled = False + # runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool) + # runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool) + # runner.dtype = torch.float16 + # runner.decode_token_per_req = 1 + + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_vllm_config.get_head_size.return_value = 64 + mock_vllm_config.model_config.dtype = torch.float16 + mock_device = 'cpu' + + builder = AscendMLAMetadataBuilder(mock_vllm_config, + mock_device, metadata_cls=AscendMLAMetadata) builder.rope_dim = 64 with patch.object(builder, "_get_graph_runner_block_tables", side_effect=lambda x, y: y): - metadata = builder.build_torchair_graph_dummy(3, 3) + common_attn_metadata = TorchairCommonAttentionMetadata( + num_reqs=3, + num_actual_tokens=3, + decode_token_per_req=1, + actual_seq_lengths_q=[0,1,2], + attn_mask=torch.zeros((1, 1), dtype=torch.bool), + spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool), + ) + metadata = builder.build_torchair_graph_dummy(common_attn_metadata) sin_golden = torch.ones(3, 1, 1, 64, - dtype=runner.dtype, - device=runner.device) + dtype=torch.float16, + device=mock_device) cos_golden = torch.ones(3, 1, 1, 64, - dtype=runner.dtype, - device=runner.device) + dtype=torch.float16, + device=mock_device) self.assertIsInstance(metadata, AscendMLAMetadata) self.assertEqual(metadata.num_input_tokens, 3) diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py index ec9523f2dd..4688739f2e 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -30,7 +30,8 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + TorchairCommonAttentionMetadata) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d) from vllm_ascend.worker.npu_input_batch import InputBatch @@ -169,12 +170,26 @@ def reorder_batch(self, input_batch: "InputBatch", def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: + max_blocks = self.max_blocks + + graph_block_tables = torch.zeros((num_seqs, max_blocks), + dtype=block_tables.dtype, + device=block_tables.device) + num_blocks = block_tables.size(1) - num_blocks = min(num_blocks, self.max_blocks) - return block_tables[:num_seqs, :num_blocks] + if num_blocks <= max_blocks: + graph_block_tables[:num_seqs, : + num_blocks] = block_tables[:num_seqs, : + num_blocks] + else: + graph_block_tables[:num_seqs, : + max_blocks] = block_tables[:num_seqs, : + max_blocks] + + return graph_block_tables[:, :max_blocks] def build_torchair_graph_dummy( - self, common_attn_metadata: AscendCommonAttentionMetadata + self, common_attn_metadata: TorchairCommonAttentionMetadata ) -> AscendTorchairMetadata: device = self.device num_reqs = common_attn_metadata.num_reqs diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 29d5cd812d..6f242dfcee 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -19,6 +19,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + TorchairCommonAttentionMetadata, split_decodes_and_prefills) from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context @@ -278,13 +279,27 @@ def reorder_batch(self, input_batch: "InputBatch", def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: + max_blocks = self.max_blocks + + graph_block_tables = torch.zeros((num_seqs, max_blocks), + dtype=block_tables.dtype, + device=block_tables.device) + num_blocks = block_tables.size(1) - num_blocks = min(num_blocks, self.max_blocks) - return block_tables[:num_seqs, :num_blocks] + if num_blocks <= max_blocks: + graph_block_tables[:num_seqs, : + num_blocks] = block_tables[:num_seqs, : + num_blocks] + else: + graph_block_tables[:num_seqs, : + max_blocks] = block_tables[:num_seqs, : + max_blocks] + + return graph_block_tables[:, :max_blocks] def build_torchair_graph_dummy( self, - common_attn_metadata: AscendCommonAttentionMetadata, + common_attn_metadata: TorchairCommonAttentionMetadata, ) -> AscendMLAMetadata: device = self.device num_reqs = common_attn_metadata.num_reqs From 06aa682416ae306c61f6795c55d6e1b08168c2a5 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 18 Aug 2025 13:03:24 +0800 Subject: [PATCH 12/22] refact attn metadata build Signed-off-by: weiguihua2 --- tests/ut/attention/test_attention_v1.py | 17 ++++++----------- tests/ut/attention/test_mla_v1.py | 11 ++++++----- vllm_ascend/attention/attention_v1_torchair.py | 4 ++-- vllm_ascend/attention/mla_v1.py | 4 ++-- 4 files changed, 16 insertions(+), 20 deletions(-) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 36499fc584..ab593414ef 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -72,7 +72,8 @@ def setUp(self): self.mock_vllm_config.model_config.max_model_len = 640 self.mock_vllm_config.cache_config.block_size = 64 self.mock_device = 'cpu:0' - self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config, self.mock_device) + self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config, + self.mock_device) def test_reorder_batch(self): mock_input_batch = MagicMock() @@ -104,18 +105,14 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d, positions=torch.tensor([10, 10]), attn_mask=torch.ones((10, 10)), spec_attn_mask=None, - attn_state=AscendAttentionState.PrefillNoCache - ) + attn_state=AscendAttentionState.PrefillNoCache) mock_nz_tensor = MagicMock() mock_model = MagicMock() mock_nd_to_nz_2d.return_value = mock_nz_tensor mock_npu_format_cast.return_value = mock_nz_tensor - self.builder.build( - common_attn_metadata, - mock_model - ) + self.builder.build(common_attn_metadata, mock_model) @patch('vllm_ascend.attention.attention_v1.AscendMetadata') @patch('torch_npu.npu_format_cast') @@ -139,8 +136,7 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state, positions=torch.tensor([10, 10]), attn_mask=torch.ones((15, 15)), spec_attn_mask=None, - attn_state=AscendAttentionState.ChunkedPrefill - ) + attn_state=AscendAttentionState.ChunkedPrefill) mock_ascend_attention_state = MagicMock() mock_ascend_attention_state.PrefillNoCache = 0 @@ -169,8 +165,7 @@ def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata): positions=torch.tensor([10, 10]), attn_mask=torch.ones((15, 15)), spec_attn_mask=None, - attn_state=AscendAttentionState.ChunkedPrefill - ) + attn_state=AscendAttentionState.ChunkedPrefill) mock_model = MagicMock() self.builder.build(common_attn_metadata, mock_model) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 4cb896ed91..ba0f4b6d0c 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -1,6 +1,5 @@ from unittest.mock import MagicMock, patch -import numpy as np import torch from vllm.distributed.parallel_state import GroupCoordinator from vllm.model_executor.layers.linear import LinearBase @@ -195,9 +194,11 @@ def test_ascend_mla_metadata_builder_default(self): return_value=ascend_config): builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) - self.assertEqual(builder.block_size, mock_vllm_config.cache_config.block_size) - self.assertEqual(builder.chunked_prefill_enabled, - mock_vllm_config.scheduler_config.chunked_prefill_enabled) + self.assertEqual(builder.block_size, + mock_vllm_config.cache_config.block_size) + self.assertEqual( + builder.chunked_prefill_enabled, + mock_vllm_config.scheduler_config.chunked_prefill_enabled) self.assertEqual(builder.torchair_graph_enabled, True) @patch("vllm_ascend.attention.mla_v1.get_ascend_config") @@ -363,7 +364,7 @@ def test_build_dummy(self, mock_ascend_config): num_reqs=3, num_actual_tokens=3, decode_token_per_req=1, - actual_seq_lengths_q=[0,1,2], + actual_seq_lengths_q=[0, 1, 2], attn_mask=torch.zeros((1, 1), dtype=torch.bool), spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool), ) diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py index 4688739f2e..ec71eaac79 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -173,8 +173,8 @@ def _get_graph_runner_block_tables( max_blocks = self.max_blocks graph_block_tables = torch.zeros((num_seqs, max_blocks), - dtype=block_tables.dtype, - device=block_tables.device) + dtype=block_tables.dtype, + device=block_tables.device) num_blocks = block_tables.size(1) if num_blocks <= max_blocks: diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 6f242dfcee..ef9670f568 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -282,8 +282,8 @@ def _get_graph_runner_block_tables( max_blocks = self.max_blocks graph_block_tables = torch.zeros((num_seqs, max_blocks), - dtype=block_tables.dtype, - device=block_tables.device) + dtype=block_tables.dtype, + device=block_tables.device) num_blocks = block_tables.size(1) if num_blocks <= max_blocks: From 0e3dc3a45a83bc69f46668d588e37a0d323b51cf Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 18 Aug 2025 15:40:18 +0800 Subject: [PATCH 13/22] refact model runner Signed-off-by: weiguihua2 --- vllm_ascend/torchair/torchair_model_runner.py | 135 +++++++++++++++- vllm_ascend/worker/model_runner_v1.py | 144 ++++-------------- 2 files changed, 161 insertions(+), 118 deletions(-) diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 9e6d0fbe90..ad2f8951e4 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -18,21 +18,26 @@ # from typing import Optional +import types import torch import torch_npu +import torch.nn as nn from vllm.config import VllmConfig from vllm.forward_context import get_forward_context from vllm.logger import logger +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.utils import TorchairCommonAttentionMetadata from vllm_ascend.platform import NPUPlatform from vllm_ascend.torchair.utils import (check_torchair_cache_exist, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - maybe_converting_weight_acl_format) + maybe_converting_weight_acl_format, is_310p) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +import vllm.envs as envs_vllm + class NPUTorchairModelRunner(NPUModelRunner): @@ -45,7 +50,7 @@ def _get_forward_metadata_across_dp_and_pad( """Override from NPUModelRunner to pad num_tokens""" if self.dp_size == 1: if not with_prefill: - maybe_padded_num_tokens = self.select_torchair_padded_batch_size( + maybe_padded_num_tokens = self._select_torchair_padded_batch_size( num_tokens) return maybe_padded_num_tokens, None, with_prefill, enable_dbo return num_tokens, None, with_prefill, enable_dbo @@ -55,7 +60,7 @@ def _get_forward_metadata_across_dp_and_pad( if not with_prefill: max_num_token = num_tokens_across_dp.max().item() - maybe_padded_num_tokens = self.select_torchair_padded_batch_size( + maybe_padded_num_tokens = self._select_torchair_padded_batch_size( max_num_token) num_tokens_across_dp = torch.full((self.dp_size, ), maybe_padded_num_tokens, @@ -178,3 +183,127 @@ def _capture_model(self): if self.new_kv_cache_bytes > 0: write_kv_cache_bytes_to_file(torch.distributed.get_rank(), self.new_kv_cache_bytes) + + def _update_graph_pad_size(self, with_prefill, graph_pad_size): + if not with_prefill: + self.graph_pad_size = graph_pad_size + else: + super()._update_graph_pad_size(with_prefill, graph_pad_size) + + def _update_input_ids_and_positions(self, input_ids, positions, + num_input_tokens, with_prefill, + padded_num_tokens_across_dp): + """Override from NPUModelRunner to update input_ids and positions""" + input_ids, positions = super()._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, + padded_num_tokens_across_dp) + + if not with_prefill: + input_ids = self.input_ids[:padded_num_tokens_across_dp] + positions = self.positions[:padded_num_tokens_across_dp] + return input_ids, positions + + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds): + model_kwargs = {"kv_caches": self.kv_caches, "attn_metadata": attn_metadata} + if not with_prefill: + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_NZ) + + compiled_model = self._get_torchair_lazy_compiled_model( + padded_num_tokens_across_dp) + hidden_states = compiled_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + else: + assert self.model is not None + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_ND) + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + return hidden_states + + def _get_torchair_lazy_compiled_model(self, batch_size: int): + if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]: + raise ValueError( + f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}" + ) + + compiled_model = self.torchair_compiled_models.get( + batch_size + ) if self.use_cached_npu_graph else self.torchair_compiled_model + + if compiled_model: + return compiled_model + + import torchair # type: ignore + from torchair import patch_for_hcom # type: ignore + + patch_for_hcom() + + if is_310p(): + # on 300I Duo platform, we need to patch broadcast. however, this patch will be + # overwritten by patch_for_hcom in torchair. so we need to re-patch it here. + from vllm_ascend.patch.platform.patch_common.patch_distributed import \ + communication_adaptation_310p + communication_adaptation_310p() + + config = torchair.CompilerConfig() + config.experimental_config.frozen_parameter = True + # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to + # disable it on 300I Duo platform now. + config.experimental_config.tiling_schedule_optimize = not is_310p() + config.experimental_config.enable_view_optimize = \ + get_ascend_config().torchair_graph_config.enable_view_optimize + torch.npu.set_compile_mode(jit_compile=False) + if not self.use_cached_npu_graph: + npu_backend = torchair.get_npu_backend(compiler_config=config) + self.torchair_compiled_model = torch.compile( + self.model, + dynamic=True, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=npu_backend) + return self.torchair_compiled_model + else: + # Generate a new forward proxy code object to prevent the invalidation of + # compilation cache caused by dynamo retracing + forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" + forward_fn = self.model.forward + code = forward_fn.__code__ + # Mark code object with a new proxy name + modified_code = code.replace(co_name=forward_proxy_name, ) + + modified_func = types.FunctionType(modified_code, + forward_fn.__globals__, + name=forward_proxy_name, + argdefs=forward_fn.__defaults__) + + self.model.__dict__[forward_proxy_name] = modified_func.__get__( + self.model, nn.Module) + self.torchair_compiled_models[ + batch_size] = torchair.inference.cache_compile( + self.model.__dict__[forward_proxy_name], + dynamic=True, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + config=config, + ge_cache=False) + return self.torchair_compiled_models[batch_size] + + def _select_torchair_padded_batch_size(self, batch_size: int): + for padded_batch_size in self.torchair_graph_batch_sizes: + if batch_size <= padded_batch_size: + # we treat batch_size as num of requests + return padded_batch_size + raise ValueError( + f"cur batch_size is invalid, torchair_graph_batch_sizes is " + f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." + ) \ No newline at end of file diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e0d9ddf755..b5f445f6da 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -100,7 +100,6 @@ xgr = LazyLoader("xgr", globals(), "xgrammar") import torch_npu -import vllm.envs as envs_vllm import vllm_ascend.envs as envs_ascend @@ -1207,10 +1206,9 @@ def _process_reqs( total_num_scheduled_tokens, with_prefill, enable_dbo) self.with_prefill = with_prefill self.num_tokens_across_dp = num_tokens_across_dp - if self.torchair_graph_enabled and not with_prefill: - self.graph_pad_size = padded_num_tokens_across_dp - else: - self.graph_pad_size = -1 + + self._update_graph_pad_size(with_prefill, padded_num_tokens_across_dp) + common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], @@ -1278,12 +1276,8 @@ def _process_reqs( # then the embedding layer is not included in the ACL graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None - if self.uses_mrope: - positions = self.mrope_positions[:, :num_input_tokens] - if self.torchair_graph_enabled and not with_prefill: - input_ids = self.input_ids[:padded_num_tokens_across_dp] - positions = self.positions[:padded_num_tokens_across_dp] + positions, input_ids = self._update_input_ids_and_positions(input_ids, positions, num_input_tokens, with_prefill, padded_num_tokens_across_dp) if get_pp_group().is_first_rank: intermediate_tensors = None @@ -1321,35 +1315,8 @@ def _process_reqs( num_actual_tokens=total_num_scheduled_tokens): with ProfileExecuteDuration().capture_async("forward"): self.maybe_setup_kv_connector(scheduler_output) - model_kwargs = {} - if self.torchair_graph_enabled: - model_kwargs["kv_caches"] = self.kv_caches - model_kwargs["attn_metadata"] = attn_metadata - if self.torchair_graph_enabled and not with_prefill: - maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_NZ) - - compiled_model = self._get_torchair_lazy_compiled_model( - padded_num_tokens_across_dp) - hidden_states = compiled_model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - else: - assert self.model is not None - maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_ND) - - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) + + self._generate_process_reqs_hidden_states(attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds) self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( @@ -1384,6 +1351,29 @@ def _process_reqs( return (attn_metadata, hidden_states, spec_decode_metadata, positions, total_num_scheduled_tokens, logits_indices, aux_hidden_states, num_scheduled_tokens, finished_sending, finished_recving) + + def _update_graph_pad_size(self, with_prefill, graph_pad_size): + self.graph_pad_size = -1 + + def _update_input_ids_and_positions(self, input_ids, positions, + num_input_tokens, with_prefill, + padded_num_tokens_across_dp): + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + return input_ids, positions + + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds): + assert self.model is not None + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_ND) + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states def _get_cumsum_and_arange( self, @@ -2118,72 +2108,6 @@ def load_model(self) -> None: logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) - def _get_torchair_lazy_compiled_model(self, batch_size: int): - if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]: - raise ValueError( - f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}" - ) - - compiled_model = self.torchair_compiled_models.get( - batch_size - ) if self.use_cached_npu_graph else self.torchair_compiled_model - - if compiled_model: - return compiled_model - - import torchair # type: ignore - from torchair import patch_for_hcom # type: ignore - - patch_for_hcom() - - if is_310p(): - # on 300I Duo platform, we need to patch broadcast. however, this patch will be - # overwritten by patch_for_hcom in torchair. so we need to re-patch it here. - from vllm_ascend.patch.platform.patch_common.patch_distributed import \ - communication_adaptation_310p - communication_adaptation_310p() - - config = torchair.CompilerConfig() - config.experimental_config.frozen_parameter = True - # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to - # disable it on 300I Duo platform now. - config.experimental_config.tiling_schedule_optimize = not is_310p() - config.experimental_config.enable_view_optimize = \ - get_ascend_config().torchair_graph_config.enable_view_optimize - torch.npu.set_compile_mode(jit_compile=False) - if not self.use_cached_npu_graph: - npu_backend = torchair.get_npu_backend(compiler_config=config) - self.torchair_compiled_model = torch.compile( - self.model, - dynamic=True, - fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=npu_backend) - return self.torchair_compiled_model - else: - # Generate a new forward proxy code object to prevent the invalidation of - # compilation cache caused by dynamo retracing - forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" - forward_fn = self.model.forward - code = forward_fn.__code__ - # Mark code object with a new proxy name - modified_code = code.replace(co_name=forward_proxy_name, ) - - modified_func = types.FunctionType(modified_code, - forward_fn.__globals__, - name=forward_proxy_name, - argdefs=forward_fn.__defaults__) - - self.model.__dict__[forward_proxy_name] = modified_func.__get__( - self.model, nn.Module) - self.torchair_compiled_models[ - batch_size] = torchair.inference.cache_compile( - self.model.__dict__[forward_proxy_name], - dynamic=True, - fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - config=config, - ge_cache=False) - return self.torchair_compiled_models[batch_size] - def _convert_torch_format(self, tensor): tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) return tensor @@ -2700,16 +2624,6 @@ def init_torchair_graph_batch_sizes(self): self.torchair_graph_batch_sizes.append(start_graph_batch_size) start_graph_batch_size *= 2 - def select_torchair_padded_batch_size(self, batch_size: int): - for padded_batch_size in self.torchair_graph_batch_sizes: - if batch_size <= padded_batch_size: - # we treat batch_size as num of requests - return padded_batch_size - raise ValueError( - f"cur batch_size is invalid, torchair_graph_batch_sizes is " - f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." - ) - def check_torchair_graph_batch_sizes(self): # return graph_batch_sizes according to the max number of tokens # first pad according to the number of requests From 3a14afab6e49cf3bab0f1db86b6a0e70f4376556 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 18 Aug 2025 15:42:38 +0800 Subject: [PATCH 14/22] refact model runner Signed-off-by: weiguihua2 --- vllm_ascend/worker/model_runner_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b5f445f6da..6de72b6810 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1316,7 +1316,7 @@ def _process_reqs( with ProfileExecuteDuration().capture_async("forward"): self.maybe_setup_kv_connector(scheduler_output) - self._generate_process_reqs_hidden_states(attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds) + hidden_states = self._generate_process_reqs_hidden_states(attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds) self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( From c380320cac4024e3a45342e1efd6de138917d31a Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 18 Aug 2025 16:05:17 +0800 Subject: [PATCH 15/22] refact model runner v1 Signed-off-by: weiguihua2 --- vllm_ascend/torchair/torchair_model_runner.py | 32 +++++++++++-------- vllm_ascend/worker/model_runner_v1.py | 26 +++++++++------ 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index ad2f8951e4..ce0e6f0634 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -17,12 +17,13 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # -from typing import Optional import types +from typing import Optional import torch -import torch_npu import torch.nn as nn +import torch_npu +import vllm.envs as envs_vllm from vllm.config import VllmConfig from vllm.forward_context import get_forward_context from vllm.logger import logger @@ -33,11 +34,9 @@ from vllm_ascend.torchair.utils import (check_torchair_cache_exist, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - maybe_converting_weight_acl_format, is_310p) + is_310p, maybe_converting_weight_acl_format) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner -import vllm.envs as envs_vllm - class NPUTorchairModelRunner(NPUModelRunner): @@ -202,12 +201,19 @@ def _update_input_ids_and_positions(self, input_ids, positions, input_ids = self.input_ids[:padded_num_tokens_across_dp] positions = self.positions[:padded_num_tokens_across_dp] return input_ids, positions - - def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds): - model_kwargs = {"kv_caches": self.kv_caches, "attn_metadata": attn_metadata} + + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, + padded_num_tokens_across_dp, + input_ids, positions, + intermediate_tensors, + inputs_embeds): + model_kwargs = { + "kv_caches": self.kv_caches, + "attn_metadata": attn_metadata + } if not with_prefill: maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_NZ) + ACL_FORMAT_FRACTAL_NZ) compiled_model = self._get_torchair_lazy_compiled_model( padded_num_tokens_across_dp) @@ -221,7 +227,7 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, padd else: assert self.model is not None maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_ND) + ACL_FORMAT_FRACTAL_ND) hidden_states = self.model( input_ids=input_ids, @@ -231,7 +237,7 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, padd **model_kwargs, ) return hidden_states - + def _get_torchair_lazy_compiled_model(self, batch_size: int): if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]: raise ValueError( @@ -297,7 +303,7 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int): config=config, ge_cache=False) return self.torchair_compiled_models[batch_size] - + def _select_torchair_padded_batch_size(self, batch_size: int): for padded_batch_size in self.torchair_graph_batch_sizes: if batch_size <= padded_batch_size: @@ -306,4 +312,4 @@ def _select_torchair_padded_batch_size(self, batch_size: int): raise ValueError( f"cur batch_size is invalid, torchair_graph_batch_sizes is " f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." - ) \ No newline at end of file + ) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6de72b6810..02341fd0a8 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -22,7 +22,6 @@ import math import os import time -import types from contextlib import contextmanager, nullcontext from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast @@ -1277,7 +1276,9 @@ def _process_reqs( input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None - positions, input_ids = self._update_input_ids_and_positions(input_ids, positions, num_input_tokens, with_prefill, padded_num_tokens_across_dp) + positions, input_ids = self._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, + padded_num_tokens_across_dp) if get_pp_group().is_first_rank: intermediate_tensors = None @@ -1315,8 +1316,10 @@ def _process_reqs( num_actual_tokens=total_num_scheduled_tokens): with ProfileExecuteDuration().capture_async("forward"): self.maybe_setup_kv_connector(scheduler_output) - - hidden_states = self._generate_process_reqs_hidden_states(attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds) + + hidden_states = self._generate_process_reqs_hidden_states( + attn_metadata, with_prefill, padded_num_tokens_across_dp, + input_ids, positions, intermediate_tensors, inputs_embeds) self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( @@ -1351,21 +1354,24 @@ def _process_reqs( return (attn_metadata, hidden_states, spec_decode_metadata, positions, total_num_scheduled_tokens, logits_indices, aux_hidden_states, num_scheduled_tokens, finished_sending, finished_recving) - + def _update_graph_pad_size(self, with_prefill, graph_pad_size): self.graph_pad_size = -1 - + def _update_input_ids_and_positions(self, input_ids, positions, num_input_tokens, with_prefill, padded_num_tokens_across_dp): if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] return input_ids, positions - - def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds): + + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, + padded_num_tokens_across_dp, + input_ids, positions, + intermediate_tensors, + inputs_embeds): assert self.model is not None - maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_ND) + maybe_converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) hidden_states = self.model( input_ids=input_ids, From 00c4541e5ccb7b2f40b36b80178f155d3e7a1ad9 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 18 Aug 2025 16:38:58 +0800 Subject: [PATCH 16/22] refact model runner v1 Signed-off-by: weiguihua2 --- vllm_ascend/worker/model_runner_v1.py | 2 +- vllm_ascend/worker/mtp_proposer_v1.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 02341fd0a8..25fb25e691 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1276,7 +1276,7 @@ def _process_reqs( input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None - positions, input_ids = self._update_input_ids_and_positions( + input_ids, positions = self._update_input_ids_and_positions( input_ids, positions, num_input_tokens, with_prefill, padded_num_tokens_across_dp) diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 07599fc8ce..5064410218 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -188,11 +188,6 @@ def propose( self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states - if attn_metadata.prefill is not None: - attn_metadata.prefill.query_lens = query_lens.cpu() - attn_metadata.prefill.input_positions = target_positions - attn_metadata.prefill.seq_lens = seq_lens - if not self.runner.torchair_graph_enabled: # torch mode need to update num_tokens_across_dp # TODO: adapt enable_dbo later From 2fbf6d9ebad08546aeff6424f2a4746d45031132 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 18 Aug 2025 17:22:07 +0800 Subject: [PATCH 17/22] refact attn metadata build Signed-off-by: weiguihua2 --- tests/ut/attention/test_mla_v1.py | 10 ---------- vllm_ascend/attention/attention_v1_torchair.py | 3 --- 2 files changed, 13 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index ba0f4b6d0c..ee17ad8fb3 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -333,16 +333,6 @@ def test_build_dummy(self, mock_ascend_config): ascend_config = MagicMock() mock_ascend_config.return_value = ascend_config ascend_config.torchair_graph_config.enabled = False - # runner = MagicMock() - # runner.model_config = MagicMock() - # runner.device = "cpu" - # runner.graph_block_tables = torch.zeros((8, 64), dtype=torch.int32) - # runner.model_config.get_head_size.return_value = 64 - # runner.chunked_prefill_enabled = False - # runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool) - # runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool) - # runner.dtype = torch.float16 - # runner.decode_token_per_req = 1 mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py index ec71eaac79..cec2c73f32 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -259,9 +259,6 @@ def build( query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - # input_positions = common_attn_metadata.positions_cpu[:num_actual_tokens].to( - # device, non_blocking=True).long() - input_positions = common_attn_metadata.positions[: num_actual_tokens].long( ) From 827c285fc0b905b776b1510b04bea4b14528b57e Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 18 Aug 2025 17:26:42 +0800 Subject: [PATCH 18/22] refact attn metadata build Signed-off-by: weiguihua2 --- vllm_ascend/attention/mla_v1.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index ef9670f568..7eb7e28bb6 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -403,9 +403,6 @@ def build( device, non_blocking= True) - # input_positions = common_attn_metadata.positions_cpu[:num_actual_tokens].to( - # device, non_blocking=True).long() - input_positions = common_attn_metadata.positions[: num_actual_tokens].long( ) From e74f0cea57b94a57f4aecb93fd95e57a8673f4da Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 18 Aug 2025 15:40:18 +0800 Subject: [PATCH 19/22] refact model runner Signed-off-by: weiguihua2 --- vllm_ascend/torchair/torchair_model_runner.py | 135 +++++++++++++++- vllm_ascend/worker/model_runner_v1.py | 144 ++++-------------- 2 files changed, 161 insertions(+), 118 deletions(-) diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 9e6d0fbe90..ad2f8951e4 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -18,21 +18,26 @@ # from typing import Optional +import types import torch import torch_npu +import torch.nn as nn from vllm.config import VllmConfig from vllm.forward_context import get_forward_context from vllm.logger import logger +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.utils import TorchairCommonAttentionMetadata from vllm_ascend.platform import NPUPlatform from vllm_ascend.torchair.utils import (check_torchair_cache_exist, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - maybe_converting_weight_acl_format) + maybe_converting_weight_acl_format, is_310p) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +import vllm.envs as envs_vllm + class NPUTorchairModelRunner(NPUModelRunner): @@ -45,7 +50,7 @@ def _get_forward_metadata_across_dp_and_pad( """Override from NPUModelRunner to pad num_tokens""" if self.dp_size == 1: if not with_prefill: - maybe_padded_num_tokens = self.select_torchair_padded_batch_size( + maybe_padded_num_tokens = self._select_torchair_padded_batch_size( num_tokens) return maybe_padded_num_tokens, None, with_prefill, enable_dbo return num_tokens, None, with_prefill, enable_dbo @@ -55,7 +60,7 @@ def _get_forward_metadata_across_dp_and_pad( if not with_prefill: max_num_token = num_tokens_across_dp.max().item() - maybe_padded_num_tokens = self.select_torchair_padded_batch_size( + maybe_padded_num_tokens = self._select_torchair_padded_batch_size( max_num_token) num_tokens_across_dp = torch.full((self.dp_size, ), maybe_padded_num_tokens, @@ -178,3 +183,127 @@ def _capture_model(self): if self.new_kv_cache_bytes > 0: write_kv_cache_bytes_to_file(torch.distributed.get_rank(), self.new_kv_cache_bytes) + + def _update_graph_pad_size(self, with_prefill, graph_pad_size): + if not with_prefill: + self.graph_pad_size = graph_pad_size + else: + super()._update_graph_pad_size(with_prefill, graph_pad_size) + + def _update_input_ids_and_positions(self, input_ids, positions, + num_input_tokens, with_prefill, + padded_num_tokens_across_dp): + """Override from NPUModelRunner to update input_ids and positions""" + input_ids, positions = super()._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, + padded_num_tokens_across_dp) + + if not with_prefill: + input_ids = self.input_ids[:padded_num_tokens_across_dp] + positions = self.positions[:padded_num_tokens_across_dp] + return input_ids, positions + + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds): + model_kwargs = {"kv_caches": self.kv_caches, "attn_metadata": attn_metadata} + if not with_prefill: + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_NZ) + + compiled_model = self._get_torchair_lazy_compiled_model( + padded_num_tokens_across_dp) + hidden_states = compiled_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + else: + assert self.model is not None + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_ND) + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + return hidden_states + + def _get_torchair_lazy_compiled_model(self, batch_size: int): + if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]: + raise ValueError( + f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}" + ) + + compiled_model = self.torchair_compiled_models.get( + batch_size + ) if self.use_cached_npu_graph else self.torchair_compiled_model + + if compiled_model: + return compiled_model + + import torchair # type: ignore + from torchair import patch_for_hcom # type: ignore + + patch_for_hcom() + + if is_310p(): + # on 300I Duo platform, we need to patch broadcast. however, this patch will be + # overwritten by patch_for_hcom in torchair. so we need to re-patch it here. + from vllm_ascend.patch.platform.patch_common.patch_distributed import \ + communication_adaptation_310p + communication_adaptation_310p() + + config = torchair.CompilerConfig() + config.experimental_config.frozen_parameter = True + # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to + # disable it on 300I Duo platform now. + config.experimental_config.tiling_schedule_optimize = not is_310p() + config.experimental_config.enable_view_optimize = \ + get_ascend_config().torchair_graph_config.enable_view_optimize + torch.npu.set_compile_mode(jit_compile=False) + if not self.use_cached_npu_graph: + npu_backend = torchair.get_npu_backend(compiler_config=config) + self.torchair_compiled_model = torch.compile( + self.model, + dynamic=True, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=npu_backend) + return self.torchair_compiled_model + else: + # Generate a new forward proxy code object to prevent the invalidation of + # compilation cache caused by dynamo retracing + forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" + forward_fn = self.model.forward + code = forward_fn.__code__ + # Mark code object with a new proxy name + modified_code = code.replace(co_name=forward_proxy_name, ) + + modified_func = types.FunctionType(modified_code, + forward_fn.__globals__, + name=forward_proxy_name, + argdefs=forward_fn.__defaults__) + + self.model.__dict__[forward_proxy_name] = modified_func.__get__( + self.model, nn.Module) + self.torchair_compiled_models[ + batch_size] = torchair.inference.cache_compile( + self.model.__dict__[forward_proxy_name], + dynamic=True, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + config=config, + ge_cache=False) + return self.torchair_compiled_models[batch_size] + + def _select_torchair_padded_batch_size(self, batch_size: int): + for padded_batch_size in self.torchair_graph_batch_sizes: + if batch_size <= padded_batch_size: + # we treat batch_size as num of requests + return padded_batch_size + raise ValueError( + f"cur batch_size is invalid, torchair_graph_batch_sizes is " + f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." + ) \ No newline at end of file diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e0d9ddf755..b5f445f6da 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -100,7 +100,6 @@ xgr = LazyLoader("xgr", globals(), "xgrammar") import torch_npu -import vllm.envs as envs_vllm import vllm_ascend.envs as envs_ascend @@ -1207,10 +1206,9 @@ def _process_reqs( total_num_scheduled_tokens, with_prefill, enable_dbo) self.with_prefill = with_prefill self.num_tokens_across_dp = num_tokens_across_dp - if self.torchair_graph_enabled and not with_prefill: - self.graph_pad_size = padded_num_tokens_across_dp - else: - self.graph_pad_size = -1 + + self._update_graph_pad_size(with_prefill, padded_num_tokens_across_dp) + common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], @@ -1278,12 +1276,8 @@ def _process_reqs( # then the embedding layer is not included in the ACL graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None - if self.uses_mrope: - positions = self.mrope_positions[:, :num_input_tokens] - if self.torchair_graph_enabled and not with_prefill: - input_ids = self.input_ids[:padded_num_tokens_across_dp] - positions = self.positions[:padded_num_tokens_across_dp] + positions, input_ids = self._update_input_ids_and_positions(input_ids, positions, num_input_tokens, with_prefill, padded_num_tokens_across_dp) if get_pp_group().is_first_rank: intermediate_tensors = None @@ -1321,35 +1315,8 @@ def _process_reqs( num_actual_tokens=total_num_scheduled_tokens): with ProfileExecuteDuration().capture_async("forward"): self.maybe_setup_kv_connector(scheduler_output) - model_kwargs = {} - if self.torchair_graph_enabled: - model_kwargs["kv_caches"] = self.kv_caches - model_kwargs["attn_metadata"] = attn_metadata - if self.torchair_graph_enabled and not with_prefill: - maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_NZ) - - compiled_model = self._get_torchair_lazy_compiled_model( - padded_num_tokens_across_dp) - hidden_states = compiled_model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - else: - assert self.model is not None - maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_ND) - - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) + + self._generate_process_reqs_hidden_states(attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds) self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( @@ -1384,6 +1351,29 @@ def _process_reqs( return (attn_metadata, hidden_states, spec_decode_metadata, positions, total_num_scheduled_tokens, logits_indices, aux_hidden_states, num_scheduled_tokens, finished_sending, finished_recving) + + def _update_graph_pad_size(self, with_prefill, graph_pad_size): + self.graph_pad_size = -1 + + def _update_input_ids_and_positions(self, input_ids, positions, + num_input_tokens, with_prefill, + padded_num_tokens_across_dp): + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + return input_ids, positions + + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds): + assert self.model is not None + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_ND) + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states def _get_cumsum_and_arange( self, @@ -2118,72 +2108,6 @@ def load_model(self) -> None: logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) - def _get_torchair_lazy_compiled_model(self, batch_size: int): - if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]: - raise ValueError( - f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}" - ) - - compiled_model = self.torchair_compiled_models.get( - batch_size - ) if self.use_cached_npu_graph else self.torchair_compiled_model - - if compiled_model: - return compiled_model - - import torchair # type: ignore - from torchair import patch_for_hcom # type: ignore - - patch_for_hcom() - - if is_310p(): - # on 300I Duo platform, we need to patch broadcast. however, this patch will be - # overwritten by patch_for_hcom in torchair. so we need to re-patch it here. - from vllm_ascend.patch.platform.patch_common.patch_distributed import \ - communication_adaptation_310p - communication_adaptation_310p() - - config = torchair.CompilerConfig() - config.experimental_config.frozen_parameter = True - # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to - # disable it on 300I Duo platform now. - config.experimental_config.tiling_schedule_optimize = not is_310p() - config.experimental_config.enable_view_optimize = \ - get_ascend_config().torchair_graph_config.enable_view_optimize - torch.npu.set_compile_mode(jit_compile=False) - if not self.use_cached_npu_graph: - npu_backend = torchair.get_npu_backend(compiler_config=config) - self.torchair_compiled_model = torch.compile( - self.model, - dynamic=True, - fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=npu_backend) - return self.torchair_compiled_model - else: - # Generate a new forward proxy code object to prevent the invalidation of - # compilation cache caused by dynamo retracing - forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" - forward_fn = self.model.forward - code = forward_fn.__code__ - # Mark code object with a new proxy name - modified_code = code.replace(co_name=forward_proxy_name, ) - - modified_func = types.FunctionType(modified_code, - forward_fn.__globals__, - name=forward_proxy_name, - argdefs=forward_fn.__defaults__) - - self.model.__dict__[forward_proxy_name] = modified_func.__get__( - self.model, nn.Module) - self.torchair_compiled_models[ - batch_size] = torchair.inference.cache_compile( - self.model.__dict__[forward_proxy_name], - dynamic=True, - fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - config=config, - ge_cache=False) - return self.torchair_compiled_models[batch_size] - def _convert_torch_format(self, tensor): tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) return tensor @@ -2700,16 +2624,6 @@ def init_torchair_graph_batch_sizes(self): self.torchair_graph_batch_sizes.append(start_graph_batch_size) start_graph_batch_size *= 2 - def select_torchair_padded_batch_size(self, batch_size: int): - for padded_batch_size in self.torchair_graph_batch_sizes: - if batch_size <= padded_batch_size: - # we treat batch_size as num of requests - return padded_batch_size - raise ValueError( - f"cur batch_size is invalid, torchair_graph_batch_sizes is " - f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." - ) - def check_torchair_graph_batch_sizes(self): # return graph_batch_sizes according to the max number of tokens # first pad according to the number of requests From 709c99118b2ba7d20aa44d0aef985a5bc4d6f687 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 18 Aug 2025 15:42:38 +0800 Subject: [PATCH 20/22] refact model runner Signed-off-by: weiguihua2 --- vllm_ascend/worker/model_runner_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b5f445f6da..6de72b6810 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1316,7 +1316,7 @@ def _process_reqs( with ProfileExecuteDuration().capture_async("forward"): self.maybe_setup_kv_connector(scheduler_output) - self._generate_process_reqs_hidden_states(attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds) + hidden_states = self._generate_process_reqs_hidden_states(attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds) self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( From 2591659f69367929a8df5698ae8611df7f5cc137 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 18 Aug 2025 16:05:17 +0800 Subject: [PATCH 21/22] refact model runner v1 Signed-off-by: weiguihua2 --- vllm_ascend/torchair/torchair_model_runner.py | 32 +++++++++++-------- vllm_ascend/worker/model_runner_v1.py | 26 +++++++++------ 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index ad2f8951e4..ce0e6f0634 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -17,12 +17,13 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # -from typing import Optional import types +from typing import Optional import torch -import torch_npu import torch.nn as nn +import torch_npu +import vllm.envs as envs_vllm from vllm.config import VllmConfig from vllm.forward_context import get_forward_context from vllm.logger import logger @@ -33,11 +34,9 @@ from vllm_ascend.torchair.utils import (check_torchair_cache_exist, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - maybe_converting_weight_acl_format, is_310p) + is_310p, maybe_converting_weight_acl_format) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner -import vllm.envs as envs_vllm - class NPUTorchairModelRunner(NPUModelRunner): @@ -202,12 +201,19 @@ def _update_input_ids_and_positions(self, input_ids, positions, input_ids = self.input_ids[:padded_num_tokens_across_dp] positions = self.positions[:padded_num_tokens_across_dp] return input_ids, positions - - def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds): - model_kwargs = {"kv_caches": self.kv_caches, "attn_metadata": attn_metadata} + + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, + padded_num_tokens_across_dp, + input_ids, positions, + intermediate_tensors, + inputs_embeds): + model_kwargs = { + "kv_caches": self.kv_caches, + "attn_metadata": attn_metadata + } if not with_prefill: maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_NZ) + ACL_FORMAT_FRACTAL_NZ) compiled_model = self._get_torchair_lazy_compiled_model( padded_num_tokens_across_dp) @@ -221,7 +227,7 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, padd else: assert self.model is not None maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_ND) + ACL_FORMAT_FRACTAL_ND) hidden_states = self.model( input_ids=input_ids, @@ -231,7 +237,7 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, padd **model_kwargs, ) return hidden_states - + def _get_torchair_lazy_compiled_model(self, batch_size: int): if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]: raise ValueError( @@ -297,7 +303,7 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int): config=config, ge_cache=False) return self.torchair_compiled_models[batch_size] - + def _select_torchair_padded_batch_size(self, batch_size: int): for padded_batch_size in self.torchair_graph_batch_sizes: if batch_size <= padded_batch_size: @@ -306,4 +312,4 @@ def _select_torchair_padded_batch_size(self, batch_size: int): raise ValueError( f"cur batch_size is invalid, torchair_graph_batch_sizes is " f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." - ) \ No newline at end of file + ) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6de72b6810..02341fd0a8 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -22,7 +22,6 @@ import math import os import time -import types from contextlib import contextmanager, nullcontext from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast @@ -1277,7 +1276,9 @@ def _process_reqs( input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None - positions, input_ids = self._update_input_ids_and_positions(input_ids, positions, num_input_tokens, with_prefill, padded_num_tokens_across_dp) + positions, input_ids = self._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, + padded_num_tokens_across_dp) if get_pp_group().is_first_rank: intermediate_tensors = None @@ -1315,8 +1316,10 @@ def _process_reqs( num_actual_tokens=total_num_scheduled_tokens): with ProfileExecuteDuration().capture_async("forward"): self.maybe_setup_kv_connector(scheduler_output) - - hidden_states = self._generate_process_reqs_hidden_states(attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds) + + hidden_states = self._generate_process_reqs_hidden_states( + attn_metadata, with_prefill, padded_num_tokens_across_dp, + input_ids, positions, intermediate_tensors, inputs_embeds) self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( @@ -1351,21 +1354,24 @@ def _process_reqs( return (attn_metadata, hidden_states, spec_decode_metadata, positions, total_num_scheduled_tokens, logits_indices, aux_hidden_states, num_scheduled_tokens, finished_sending, finished_recving) - + def _update_graph_pad_size(self, with_prefill, graph_pad_size): self.graph_pad_size = -1 - + def _update_input_ids_and_positions(self, input_ids, positions, num_input_tokens, with_prefill, padded_num_tokens_across_dp): if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] return input_ids, positions - - def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, padded_num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds): + + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, + padded_num_tokens_across_dp, + input_ids, positions, + intermediate_tensors, + inputs_embeds): assert self.model is not None - maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_ND) + maybe_converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) hidden_states = self.model( input_ids=input_ids, From 489e9ce0b1e4903d77eae5e9e090ac9c5f83e49f Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 18 Aug 2025 16:38:58 +0800 Subject: [PATCH 22/22] refact model runner v1 Signed-off-by: weiguihua2 --- vllm_ascend/worker/model_runner_v1.py | 2 +- vllm_ascend/worker/mtp_proposer_v1.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 02341fd0a8..25fb25e691 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1276,7 +1276,7 @@ def _process_reqs( input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None - positions, input_ids = self._update_input_ids_and_positions( + input_ids, positions = self._update_input_ids_and_positions( input_ids, positions, num_input_tokens, with_prefill, padded_num_tokens_across_dp) diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 07599fc8ce..5064410218 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -188,11 +188,6 @@ def propose( self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states - if attn_metadata.prefill is not None: - attn_metadata.prefill.query_lens = query_lens.cpu() - attn_metadata.prefill.input_positions = target_positions - attn_metadata.prefill.seq_lens = seq_lens - if not self.runner.torchair_graph_enabled: # torch mode need to update num_tokens_across_dp # TODO: adapt enable_dbo later