Skip to content

Conversation

@AdvancedCompiler
Copy link
Contributor

PR Category

Operator

Type of Change

Performance Optimization

Description

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

018cc5f10ea3527f5781b5588ac348b7 474a412728a5aa00b0d65e476cd6eab3

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @AdvancedCompiler, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of the argmin operator by introducing specialized Triton kernels. It refactors the existing argmin implementation to dynamically select an optimized kernel based on the input tensor's K dimension, alongside new heuristics to fine-tune execution parameters for improved efficiency on NVIDIA GPUs.

Highlights

  • Kernel Specialization: The argmin Triton kernel has been refactored into two specialized kernels: argmin_kernel_non_inner and argmin_kernel_inner. This allows for tailored optimizations based on the input tensor's dimensions.
  • Optimized Heuristics: New heuristic functions have been introduced to provide fine-grained control over tiling, CTA configuration, and warp counts for the newly specialized argmin kernels, aiming for improved performance on NVIDIA GPUs.
  • Dynamic Kernel Dispatch: The Python argmin function now includes logic to dynamically dispatch to either argmin_kernel_non_inner (when K > 1) or argmin_kernel_inner (when K == 1), ensuring the most efficient kernel is used for a given input.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the argmin operator by splitting the implementation into two specialized Triton kernels for inner and non-inner reduction dimensions, which is a solid performance optimization strategy. The new kernels and their corresponding heuristics are well-structured. However, I've identified a significant issue in the heuristic functions where the input data type is not correctly inferred, potentially leading to suboptimal performance. Additionally, I've provided suggestions to improve code clarity and maintainability in one of the new kernels and a heuristic function.


K = args["K"]
M = args["M"]
dtype = args.get("dtype", "fp16")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dtype is fetched using args.get("dtype", "fp16"), but the kernel call from argmin does not pass a dtype argument. This will cause the heuristic to always use the default value "fp16", leading to suboptimal performance for other data types like float32. You should get the dtype from the input tensor inp which is available in args.

Suggested change
dtype = args.get("dtype", "fp16")
dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems there is no argument named dtype


def argmin_heur_num_warps_non_inner(args):
tile_n = args["TILE_N"]
dtype = args.get("dtype", "fp16")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dtype is fetched using args.get("dtype", "fp16"), but the kernel call from argmin does not pass a dtype argument. This will cause the heuristic to always use the default value "fp16", leading to suboptimal performance for other data types like float32. You should get the dtype from the input tensor inp which is available in args.

Suggested change
dtype = args.get("dtype", "fp16")
dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16"

Comment on lines 30 to 34
if K <= 128:
tile_k = 1
while tile_k * 2 <= K:
tile_k *= 2
return tile_k
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This while loop to find the largest power of two less than or equal to K can be simplified. Using 1 << (K.bit_length() - 1) is more concise and achieves the same result for K > 0.

Suggested change
if K <= 128:
tile_k = 1
while tile_k * 2 <= K:
tile_k *= 2
return tile_k
if K <= 128:
return 1 << (K.bit_length() - 1) if K > 0 else 1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the suggested code is better.

@StrongSpoon
Copy link
Collaborator

I believe the points mentioned above also apply to the optimization of argmax. Please update the code in accordance with the suggestions provided in both pull requests.

“ph0375” and others added 2 commits September 11, 2025 18:07
StrongSpoon
StrongSpoon previously approved these changes Sep 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants