-
Notifications
You must be signed in to change notification settings - Fork 352
[core] Adopt graph rewriter on fx.graph to enable automatic kernel fusion #2389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a graph rewriter for fx.graph to enable automatic kernel fusion on Ascend hardware. The changes include a new compiler interface, a graph rewrite pass manager, a specific fusion pass for AddRMSNorm and quantization, and corresponding patches and tests. While the overall approach is sound, I've identified several critical issues related to correctness in the fusion logic, configuration, and testing. Specifically, there are errors in handling operator return values, incorrect configuration in the patch files, and incomplete tests that hide bugs. These issues must be addressed to ensure the feature works correctly.
|
||
|
||
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: | ||
return AscendAdaptor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The make_compiler
function should return an instance of the compiler class, not the class itself. The caller expects an object that implements the CompilerInterface
.
return AscendAdaptor | |
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: | |
return AscendAdaptor() |
PASS_KEY = "graph_rewriter_pass" | ||
inductor_config[PASS_KEY] = self.post_grad_pass_manager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two issues here that will cause the compilation to fail at runtime:
- The key used to store the pass manager in
inductor_config
is"graph_rewriter_pass"
. However,AscendAdaptor
expects the key to be"graph_rewriter_manager"
. - The value assigned is
self.post_grad_pass_manager
, which is the original vLLM Inductor pass manager. It should be the newly createdself.graph_rewriter_pass_manager
.
This will lead to a KeyError
and the use of the wrong pass manager.
PASS_KEY = "graph_rewriter_pass" | |
inductor_config[PASS_KEY] = self.post_grad_pass_manager | |
PASS_KEY = "graph_rewriter_manager" | |
inductor_config[PASS_KEY] = self.graph_rewriter_pass_manager |
offset, | ||
epsilon=1e-6) | ||
quantized_output = output[0] | ||
residual = output[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The replace
function incorrectly uses output[2]
as the residual. Assuming torch.ops.npu.npu_add_rms_norm_quant
has a similar return signature to npu_add_rms_norm
, the residual should be output[1]
. Using output[2]
will result in a functionally incorrect fused operation.
residual = output[2] | |
residual = output[1] |
self.weight = nn.Parameter(torch.Tensor(hidden_size)) | ||
self.bias = nn.Parameter(torch.Tensor(hidden_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The weight
and bias
tensors are created with torch.Tensor()
, which leaves them with uninitialized data. This can lead to non-deterministic behavior and flaky tests. It's important to initialize these parameters to ensure test reproducibility.
self.weight = nn.Parameter(torch.Tensor(hidden_size)) | |
self.bias = nn.Parameter(torch.Tensor(hidden_size)) | |
self.weight = nn.Parameter(torch.ones(hidden_size)) | |
self.bias = nn.Parameter(torch.zeros(hidden_size)) |
from vllm.logger import init_logger | ||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass | ||
from vllm.compilation.inductor_pass import get_pass_context, InductorPass | ||
from quant_fusion_pass import AscendQuantFusionPass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import from quant_fusion_pass import AscendQuantFusionPass
is an implicit relative import. This is not recommended as it is fragile and can fail depending on the execution context. It should be an explicit relative import.
from quant_fusion_pass import AscendQuantFusionPass | |
from .quant_fusion_pass import AscendQuantFusionPass |
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
83bc5be
to
e7d8a01
Compare
👍🏻👍🏻👍🏻This's a super smart idea – using compiler tricks to make model code way simpler. |
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
80afd39
to
3b52d53
Compare
vllm_ascend/ascend_config.py
Outdated
Configuration Object for ascend_compilation_config from additional_config | ||
""" | ||
|
||
def __init__(self, ascend_compilation_config: dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest to explicitly name the options as args (e.g. enable_graph_rewriter = True
...) and if you want things extensible, I guess **kwargs is more pythonic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion, I'll change this
**ascend_compilation_config** | ||
| Name | Type | Default | Description | | ||
| ---- | ---- | ------- | ----------- | | ||
| `enable_graph_rewrite` | bool | `True` | Whether to enable the graph rewriter to rewrite the fx graph generated by torch.compile | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we highlight that this option is a primary flag that could turn off all the compilation and make the other compiler options ignored?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got, I'll emphasis this one
| Name | Type | Default | Description | | ||
| ---- | ---- | ------- | ----------- | | ||
| `enable_graph_rewrite` | bool | `True` | Whether to enable the graph rewriter to rewrite the fx graph generated by torch.compile | | ||
| `enable_quantization_fusion` | bool | `True` | Whether to enable the fusion pass on op + quantize, this should remain open by default to benefit all users for performance boost | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we really want to expose this as an official configuration flag here. I feel this complicate the UI for normal users. Some considerations:
- What is the granularity of such configurations? For example, is it better to make each fusion configurable, e.g., naming it
enable_rmsnorm_quant_fusion
? - Shall we expose this to normal users? Or, is it better to make it a private option first, e.g.,
_enable_quantization_fusion
so that we make it as a private and debugging-only flag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mainly refer to the vllm's official code. Its true that the granularity here is quite obscure in this PR, but this flag is more like a safety trigger here to make sure that if anything goes wrong, we can quickly guide our customer to bypass some issue.
# Related PR (if no, explain why): | ||
# - We might add PR to make vllm support custom compiler interface. But its not sure yet. | ||
# Future Plan: | ||
# We might push the customized compiler interface to the vllm main repo, and leave the backend selection to the platform itself. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I understand correctly that vLLM hard-coded "inductor" as the compiler backend for piece-wise graphs? Is there a way to plugin the custom compiler backend instead of "inductor"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, vllm actually rewrite its own backend called 'VllmBackend', and inside of that vllm will do its own graph break and pattern register routine to make sure the optimization and compatibility to other repos and packages
rms_norm_input, | ||
residual, | ||
rms_norm_weight, | ||
1. / scale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this worth some comments? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
super().__init__(vllm_config) | ||
self.patterns = [] | ||
# Register the AddRMSNormQuant fusion pattern into the graph rewriter pattern list | ||
AddRMSNormQuantPattern(vllm_config).register(self.patterns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it is better to use some decorator to register new patterns, following open/closed principle...
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
What this PR does / why we need it?
The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.
In order to solve that, we propose to leverage
torch.compile
, and automatically fuse the pattern we want to merge. For more details can refer to the RFC #2386Does this PR introduce any user-facing change?
Yes, we add a new
additional_config
How was this patch tested?