[xpu][mx] Fix NaN scale propagation in RCEIL triton kernel#4271
[xpu][mx] Fix NaN scale propagation in RCEIL triton kernel#4271ugolowic wants to merge 1 commit intopytorch:mainfrom
Conversation
tl.clamp and .to(uint8) silently destroy NaN on XPU's SPIR-V backend. Apply NaN correction via tl.where after the clamp/cast chain instead of injecting NaN into max_abs before it. Signed-off-by: Ula Golowicz <urszula.golowicz@intel.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4271
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @ugolowic,
This is the default behavior. Did you try |
|
|
||
| descale_fp = _calculate_reciprocal_scale(scale_e8m0_biased) | ||
|
|
||
| # Overwrite scale and descale for NaN rows. This is done after the |
There was a problem hiding this comment.
do our existing special value tests in test_kernels.py pass with this? the desired behavior here is that if any value in the 1x32 block is nan, the scale becomes nan and the whole quantized block becomes nan.
we did a lot of work to fix a nan loss issue, and make torch vs triton special value handling consistent and correct: #4201
so i am hesitant to make any numerical changes now that we have confirmed in production the fix works and users are using the kernel again.
In the context of #3576.
Fix NaN scale propagation in RCEIL triton kernel for SPIR-V backends
On XPU's SPIR-V backend,
tl.clampand.to(uint8)silently destroy NaN values:tl.clampuses NaN-ignoringfmin/fmax(returning -127.0 instead of NaN) - see intel/intel-xpu-backend-for-triton#5003 and KhronosGroup/SPIRV-LLVM-Translator#3282fptoui(NaN), used by.to(uint8)is undefined behavior per the LLVM spec (returning 0 instead of 255) - see The LLVM LangReffptouisemantics sectionThe existing NaN workaround in
_triton_calculate_scale_rceilcorrectly detected NaN blocks viax != x+tl.max(nan_mask), but then injected NaN intomax_abs, which was subsequently killed bytl.clampfurther down the chain.The proposed solution is to move the NaN correction after the
log2 -> ceil -> clamp -> +127 -> to(uint8)chain.