Skip to content

[xpu][mx] Fix NaN scale propagation in RCEIL triton kernel#4271

Open
ugolowic wants to merge 1 commit intopytorch:mainfrom
ugolowic:nan_in_mx_scale_calculation
Open

[xpu][mx] Fix NaN scale propagation in RCEIL triton kernel#4271
ugolowic wants to merge 1 commit intopytorch:mainfrom
ugolowic:nan_in_mx_scale_calculation

Conversation

@ugolowic
Copy link
Copy Markdown
Contributor

In the context of #3576.

Fix NaN scale propagation in RCEIL triton kernel for SPIR-V backends

On XPU's SPIR-V backend, tl.clamp and .to(uint8) silently destroy NaN values:

The existing NaN workaround in _triton_calculate_scale_rceil correctly detected NaN blocks via x != x + tl.max(nan_mask), but then injected NaN into max_abs, which was subsequently killed by tl.clamp further down the chain.

The proposed solution is to move the NaN correction after the log2 -> ceil -> clamp -> +127 -> to(uint8) chain.

tl.clamp and .to(uint8) silently destroy NaN on XPU's SPIR-V backend.
Apply NaN correction via tl.where after the clamp/cast chain instead
of injecting NaN into max_abs before it.

Signed-off-by: Ula Golowicz <urszula.golowicz@intel.com>
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 13, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4271

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 13, 2026
@anmyachev
Copy link
Copy Markdown

Hi @ugolowic,

tl.clamp uses NaN-ignoring fmin/fmax (returning -127.0 instead of NaN) - see intel/intel-xpu-backend-for-triton#5003 and KhronosGroup/SPIRV-LLVM-Translator#3282

This is the default behavior. Did you try tl.clamp(..., propagate_nan=tl.PropagateNan.ALL)?


descale_fp = _calculate_reciprocal_scale(scale_e8m0_biased)

# Overwrite scale and descale for NaN rows. This is done after the
Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do our existing special value tests in test_kernels.py pass with this? the desired behavior here is that if any value in the 1x32 block is nan, the scale becomes nan and the whole quantized block becomes nan.

we did a lot of work to fix a nan loss issue, and make torch vs triton special value handling consistent and correct: #4201

so i am hesitant to make any numerical changes now that we have confirmed in production the fix works and users are using the kernel again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants