Skip to content

[core] Adopt graph rewriter on fx.graph to enable automatic kernel fusion #2389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

ganyi1996ppo
Copy link
Collaborator

@ganyi1996ppo ganyi1996ppo commented Aug 15, 2025

What this PR does / why we need it?

The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.

In order to solve that, we propose to leverage torch.compile, and automatically fuse the pattern we want to merge. For more details can refer to the RFC #2386

Does this PR introduce any user-facing change?

Yes, we add a new additional_config

How was this patch tested?

Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a graph rewriter for fx.graph to enable automatic kernel fusion on Ascend hardware. The changes include a new compiler interface, a graph rewrite pass manager, a specific fusion pass for AddRMSNorm and quantization, and corresponding patches and tests. While the overall approach is sound, I've identified several critical issues related to correctness in the fusion logic, configuration, and testing. Specifically, there are errors in handling operator return values, incorrect configuration in the patch files, and incomplete tests that hide bugs. These issues must be addressed to ensure the feature works correctly.



def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
return AscendAdaptor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The make_compiler function should return an instance of the compiler class, not the class itself. The caller expects an object that implements the CompilerInterface.

Suggested change
return AscendAdaptor
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
return AscendAdaptor()

Comment on lines 34 to 35
PASS_KEY = "graph_rewriter_pass"
inductor_config[PASS_KEY] = self.post_grad_pass_manager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are two issues here that will cause the compilation to fail at runtime:

  1. The key used to store the pass manager in inductor_config is "graph_rewriter_pass". However, AscendAdaptor expects the key to be "graph_rewriter_manager".
  2. The value assigned is self.post_grad_pass_manager, which is the original vLLM Inductor pass manager. It should be the newly created self.graph_rewriter_pass_manager.

This will lead to a KeyError and the use of the wrong pass manager.

Suggested change
PASS_KEY = "graph_rewriter_pass"
inductor_config[PASS_KEY] = self.post_grad_pass_manager
PASS_KEY = "graph_rewriter_manager"
inductor_config[PASS_KEY] = self.graph_rewriter_pass_manager

offset,
epsilon=1e-6)
quantized_output = output[0]
residual = output[2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The replace function incorrectly uses output[2] as the residual. Assuming torch.ops.npu.npu_add_rms_norm_quant has a similar return signature to npu_add_rms_norm, the residual should be output[1]. Using output[2] will result in a functionally incorrect fused operation.

Suggested change
residual = output[2]
residual = output[1]

Comment on lines +18 to +45
self.weight = nn.Parameter(torch.Tensor(hidden_size))
self.bias = nn.Parameter(torch.Tensor(hidden_size))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The weight and bias tensors are created with torch.Tensor(), which leaves them with uninitialized data. This can lead to non-deterministic behavior and flaky tests. It's important to initialize these parameters to ensure test reproducibility.

Suggested change
self.weight = nn.Parameter(torch.Tensor(hidden_size))
self.bias = nn.Parameter(torch.Tensor(hidden_size))
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))

from vllm.logger import init_logger
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.compilation.inductor_pass import get_pass_context, InductorPass
from quant_fusion_pass import AscendQuantFusionPass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The import from quant_fusion_pass import AscendQuantFusionPass is an implicit relative import. This is not recommended as it is fragile and can fail depending on the execution context. It should be an explicit relative import.

Suggested change
from quant_fusion_pass import AscendQuantFusionPass
from .quant_fusion_pass import AscendQuantFusionPass

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@ApsarasX
Copy link
Collaborator

ApsarasX commented Aug 15, 2025

👍🏻👍🏻👍🏻This's a super smart idea – using compiler tricks to make model code way simpler.

@github-actions github-actions bot added documentation Improvements or additions to documentation module:core labels Aug 15, 2025
@ganyi1996ppo ganyi1996ppo marked this pull request as ready for review August 18, 2025 06:13
Configuration Object for ascend_compilation_config from additional_config
"""

def __init__(self, ascend_compilation_config: dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to explicitly name the options as args (e.g. enable_graph_rewriter = True ...) and if you want things extensible, I guess **kwargs is more pythonic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, I'll change this

**ascend_compilation_config**
| Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- |
| `enable_graph_rewrite` | bool | `True` | Whether to enable the graph rewriter to rewrite the fx graph generated by torch.compile |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we highlight that this option is a primary flag that could turn off all the compilation and make the other compiler options ignored?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got, I'll emphasis this one

| Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- |
| `enable_graph_rewrite` | bool | `True` | Whether to enable the graph rewriter to rewrite the fx graph generated by torch.compile |
| `enable_quantization_fusion` | bool | `True` | Whether to enable the fusion pass on op + quantize, this should remain open by default to benefit all users for performance boost |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we really want to expose this as an official configuration flag here. I feel this complicate the UI for normal users. Some considerations:

  1. What is the granularity of such configurations? For example, is it better to make each fusion configurable, e.g., naming it enable_rmsnorm_quant_fusion?
  2. Shall we expose this to normal users? Or, is it better to make it a private option first, e.g., _enable_quantization_fusion so that we make it as a private and debugging-only flag?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mainly refer to the vllm's official code. Its true that the granularity here is quite obscure in this PR, but this flag is more like a safety trigger here to make sure that if anything goes wrong, we can quickly guide our customer to bypass some issue.

# Related PR (if no, explain why):
# - We might add PR to make vllm support custom compiler interface. But its not sure yet.
# Future Plan:
# We might push the customized compiler interface to the vllm main repo, and leave the backend selection to the platform itself.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that vLLM hard-coded "inductor" as the compiler backend for piece-wise graphs? Is there a way to plugin the custom compiler backend instead of "inductor"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, vllm actually rewrite its own backend called 'VllmBackend', and inside of that vllm will do its own graph break and pattern register routine to make sure the optimization and compatibility to other repos and packages

rms_norm_input,
residual,
rms_norm_weight,
1. / scale,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this worth some comments? :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

super().__init__(vllm_config)
self.patterns = []
# Register the AddRMSNormQuant fusion pattern into the graph rewriter pattern list
AddRMSNormQuantPattern(vllm_config).register(self.patterns)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it is better to use some decorator to register new patterns, following open/closed principle...

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants