Skip to content

Commit e7d8a01

Browse files
committed
add comments, license, and remove some non necessary code
Signed-off-by: ganyi <[email protected]>
1 parent 5701316 commit e7d8a01

File tree

7 files changed

+93
-196
lines changed

7 files changed

+93
-196
lines changed

tests/e2e/singlecard/test_graph_rewriter.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
119
import torch
220
import torch.nn as nn
321
import torch_npu
@@ -24,9 +42,6 @@ def forward(self, x):
2442
hidden_states = self.former_linear(x)
2543
x, residual = torch_npu.npu_add_rms_norm(hidden_states, x, self.weight, self.eps)
2644
quantized_output = quant_per_tensor(x, self.quant_scale, self.quant_offset)
27-
# output = torch_npu.npu_quant_matmul(
28-
# quantized_output, self.post_linear.weight.transpose(1, 0), self.post_linear.bias,
29-
# self.deq_scale, None, torch.int8)
3045
return quantized_output, residual
3146

3247

vllm_ascend/compilation/compiler_interface.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
119
import copy
220
import hashlib
321
import os

vllm_ascend/compilation/graph_rewrite_pass_manager.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
119
from torch import fx as fx
220

321
from vllm.config import VllmConfig

vllm_ascend/compilation/quant_fusion_pass.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
119
import torch
220
from torch.fx.subgraph_rewriter import replace_pattern
321
import torch_npu

vllm_ascend/models/qwen3.py

Lines changed: 0 additions & 156 deletions
This file was deleted.

vllm_ascend/ops/layernorm.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,43 +21,6 @@
2121
from vllm.model_executor.layers.layernorm import RMSNorm
2222

2323

24-
class AddRMSNormW8A8Quant(RMSNorm):
25-
# Fuse AddRmsNorm and W8A8 quantization ops together
26-
27-
def __init__(
28-
self,
29-
hidden_size: int,
30-
layer: torch.nn.Module,
31-
eps: float = 1e-6,
32-
var_hidden_size: Optional[int] = None,
33-
has_weight: bool = True,
34-
dtype: Optional[torch.dtype] = None,
35-
) -> None:
36-
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
37-
self.layer = layer
38-
39-
def forward(
40-
self,
41-
x: torch.Tensor,
42-
residual: Optional[torch.Tensor] = None,
43-
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
44-
import torch_npu
45-
46-
if residual is not None:
47-
x, _, residual = torch_npu.npu_add_rms_norm_quant(
48-
x,
49-
residual,
50-
self.weight,
51-
self.layer.aclnn_input_scale,
52-
self.layer.aclnn_input_offset,
53-
epsilon=self.variance_epsilon)
54-
return x, residual
55-
56-
x, residual = torch_npu.npu_rms_norm(x, self.weight,
57-
self.variance_epsilon)
58-
return x
59-
60-
6124
class AscendRMSNorm(RMSNorm):
6225

6326
def forward_oot(

vllm_ascend/patch/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,24 @@
102102
# Validate more models in all kinds of scenario,
103103
# if performance is always improved, we can enable this patch by default and remove the env
104104
# variable `VLLM_ASCEND_ENABLE_FUSE_MATMUL_ALLREDUCE` in the future.
105+
# ** File: worker/patch_common/patch_compilation.py **
106+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
107+
# 1. `vllm.compilation.backends.make_compiler`
108+
# Why:
109+
# We need to use the `GraphRewriterPassManager` at the compiler interface to actually modify the graph. Therefore we
110+
# need to implement our own compiler interface to do our customized operations.
111+
# How:
112+
# We implement our own compiler interface `AscendAdaptor` to fetch the `GraphRewriterPassManager` and use it to rewrite the
113+
# piecewise graph cutting by vllm's own backend. This function will just return the `AscendAdaptor` to the `VllmBackend``.
114+
# Related PR (if no, explain why):
115+
# - We might add PR to make vllm support custom compiler interface. But its not sure yet.
116+
# Future Plan:
117+
# We might push the customized compiler interface to the vllm main repo, and leave the backend selection to the platform itself.
118+
# 2. `vllm.compilation.backends.VllmBackend.configure_post_pass`
119+
# Why:
120+
# We need register the `GraphRewriterPassManager` to the `VllmBackend` and enable it during
121+
# the compilation. Because we can't directly adopt vllm's inductor pass because torch_npu's limited support on
122+
# triton and inductor. So we need to patch this function into the `VllmBackend` to use the `GraphRewriterPassManager`.
123+
# How:
124+
# This function will inject the `GraphRewriterPassManager` into the inductor config, which is a parameter passed to the compiler interface
125+
# and in our customized compiler interface, and in our `AscendAdaptor` we will use this to rewrite the fx graph.

0 commit comments

Comments
 (0)