-
Notifications
You must be signed in to change notification settings - Fork 145
[AdvancedCompiler]Optimize argmin #954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[AdvancedCompiler]Optimize argmin #954
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @AdvancedCompiler, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the performance of the argmin operator by introducing specialized Triton kernels. It refactors the existing argmin implementation to dynamically select an optimized kernel based on the input tensor's K dimension, alongside new heuristics to fine-tune execution parameters for improved efficiency on NVIDIA GPUs.
Highlights
- Kernel Specialization: The
argminTriton kernel has been refactored into two specialized kernels:argmin_kernel_non_innerandargmin_kernel_inner. This allows for tailored optimizations based on the input tensor's dimensions. - Optimized Heuristics: New heuristic functions have been introduced to provide fine-grained control over tiling, CTA configuration, and warp counts for the newly specialized
argminkernels, aiming for improved performance on NVIDIA GPUs. - Dynamic Kernel Dispatch: The Python
argminfunction now includes logic to dynamically dispatch to eitherargmin_kernel_non_inner(when K > 1) orargmin_kernel_inner(when K == 1), ensuring the most efficient kernel is used for a given input.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request optimizes the argmin operator by splitting the implementation into two specialized Triton kernels for inner and non-inner reduction dimensions, which is a solid performance optimization strategy. The new kernels and their corresponding heuristics are well-structured. However, I've identified a significant issue in the heuristic functions where the input data type is not correctly inferred, potentially leading to suboptimal performance. Additionally, I've provided suggestions to improve code clarity and maintainability in one of the new kernels and a heuristic function.
|
|
||
| K = args["K"] | ||
| M = args["M"] | ||
| dtype = args.get("dtype", "fp16") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dtype is fetched using args.get("dtype", "fp16"), but the kernel call from argmin does not pass a dtype argument. This will cause the heuristic to always use the default value "fp16", leading to suboptimal performance for other data types like float32. You should get the dtype from the input tensor inp which is available in args.
| dtype = args.get("dtype", "fp16") | |
| dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems there is no argument named dtype
|
|
||
| def argmin_heur_num_warps_non_inner(args): | ||
| tile_n = args["TILE_N"] | ||
| dtype = args.get("dtype", "fp16") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dtype is fetched using args.get("dtype", "fp16"), but the kernel call from argmin does not pass a dtype argument. This will cause the heuristic to always use the default value "fp16", leading to suboptimal performance for other data types like float32. You should get the dtype from the input tensor inp which is available in args.
| dtype = args.get("dtype", "fp16") | |
| dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16" |
| if K <= 128: | ||
| tile_k = 1 | ||
| while tile_k * 2 <= K: | ||
| tile_k *= 2 | ||
| return tile_k |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This while loop to find the largest power of two less than or equal to K can be simplified. Using 1 << (K.bit_length() - 1) is more concise and achieves the same result for K > 0.
| if K <= 128: | |
| tile_k = 1 | |
| while tile_k * 2 <= K: | |
| tile_k *= 2 | |
| return tile_k | |
| if K <= 128: | |
| return 1 << (K.bit_length() - 1) if K > 0 else 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the suggested code is better.
|
I believe the points mentioned above also apply to the optimization of argmax. Please update the code in accordance with the suggestions provided in both pull requests. |
PR Category
Operator
Type of Change
Performance Optimization
Description
Issue
Progress
Performance