Skip to content

Conversation

@mohbasit
Copy link

The PR introduces ROCm support for the prime framework, enabling it to run on AMD GPUs and NVIDIA seamlessly. The change doesn't break any existing functionality and is tested to work on AMD MI250 and MI300. The changes range from configuration files for the GPUs along with triton annotation changes, ensuring backward compatibility in the process.

rocm = torch.version.hip is not None
if not rocm:
warp_config.append(triton.Config({}, num_warps=32))

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explanation:
The kernel's warp setting of 32 warps per block is incompatible with AMD GPUs due to AMDs larger warp size compared to NVIDIA GPUs. AMD GPUs have 64 threads per warp.

To resolve the issue, the number of warps configured for AMD GPUs should not exceed 16. Therefore, the line in the code setting num_warps to 32 should be disabled for AMD GPUs.

This will be solved in a new triton release, but having this piece of code gives us backward compatibility as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant