fix(metal): bf16 transcendental codegen — cast through f32#1203
fix(metal): bf16 transcendental codegen — cast through f32#1203holg wants to merge 1 commit intotracel-ai:mainfrom
Conversation
Metal's `bfloat` type has no native transcendental functions (exp, sin, cos, log, sqrt, etc.). The GPU's Special Function Units only support `float` (f32) and `half` (f16). Previously, the MSL codegen grouped bf16 with f16, generating invalid calls like `exp(bfloat_value)` that fail to compile. Add `bf16_has_native_math_functions()` to `DialectInstructions` trait (default `true` for CUDA/HIP, overridden to `false` for Metal). When false, bf16 operands are cast through f32: `bfloat(exp(float(x)))`. Also fixes: - `safe_tanh_scalar` for bf16 on Metal (cast through float) - `log1p` for bf16 on Metal (cast through float, not half) - `Remainder::floor()` for bf16 (cast through f32) - Register `FloatKind::BF16` as supported type for Metal backend - Add bf16 to MSL test suite (20/22 unary tests pass; 2 failures are pre-existing vector math literal issues)
There was a problem hiding this comment.
Pull request overview
This PR fixes Metal (MSL) shader generation for bf16 transcendental math by avoiding invalid calls like exp(bfloat) and instead casting bf16 operands through f32 when the backend lacks native bf16 math support.
Changes:
- Add a dialect capability flag (
bf16_has_native_math_functions) and use it to decide when bf16 transcendentals must cast through f32 in shared codegen. - Fix bf16 behavior for
safe_tanh,log1p, andRemainder::floor()on Metal by routing through f32 where needed. - Register bf16 as a supported Metal type in the wgpu backend and add bf16 to the MSL test generation matrix.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| crates/cubecl-wgpu/src/lib.rs | Adds bf16 to the MSL test type matrix. |
| crates/cubecl-wgpu/src/backend/metal.rs | Registers bf16 as a supported scalar type for the Metal backend. |
| crates/cubecl-cpp/src/shared/unary.rs | Updates unary math formatting to cast bf16 through f32 when the dialect indicates no native bf16 math. |
| crates/cubecl-cpp/src/shared/instruction.rs | Adjusts remainder/floor formatting to route bf16 through f32 when required. |
| crates/cubecl-cpp/src/shared/dialect.rs | Introduces bf16_has_native_math_functions() on DialectInstructions (default true). |
| crates/cubecl-cpp/src/metal/extension.rs | Fixes Metal safe_tanh_scalar generation for bf16 by casting through float. |
| crates/cubecl-cpp/src/metal/dialect.rs | Overrides bf16 native-math capability to false and fixes log1p emission for bf16. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ElemType::Int(IntKind::I16), | ||
| ElemType::Int(IntKind::I32), | ||
| ElemType::Int(IntKind::I64), | ||
| ElemType::Float(FloatKind::BF16), |
There was a problem hiding this comment.
FloatKind::BF16 is registered unconditionally for Metal. Since MSL bfloat support is not guaranteed on all Metal devices/OS versions, this can make DeviceProperties claim bf16 is available and lead to runtime shader compilation failures on unsupported hardware. Consider gating bf16 registration behind an explicit capability check (e.g., a wgpu/Metal feature or adapter/hal query for bfloat support), and only register BF16 when supported.
| ElemType::Float(FloatKind::BF16), |
There was a problem hiding this comment.
I guess it would be important to validate if the device has bf16 support. I think older Macs with Intel or AMD gpus don't support bf16.
|
The review feedback about unconditional bf16 registration is addressed in follow-up PR #1207:
bf16 unary tests go from 20/22 → 22/22 with those fixes applied on top. |
Closes #1200
Summary
bfloattype has no native transcendental functions (exp, sin, cos, log, sqrt, etc.). The GPU's Special Function Units only supportfloat(f32) andhalf(f16). Previously, the MSL codegen grouped bf16 with f16, generating invalid calls likeexp(bfloat_value)that fail tocompile.
bf16_has_native_math_functions()toDialectInstructionstrait (defaulttruefor CUDA/HIP, overridden tofalsefor Metal). Whenfalse, bf16 operands are cast through f32:
bfloat(exp(float(x))).safe_tanh,log1p, andRemainder::floor()for bf16 on Metal.FloatKind::BF16as a supported type for the Metal backend.Test plan
cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::bf16_ty::unary— 20/22 pass (2 failures are pre-existing vector math issues,not transcendentals)
cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::bf16_ty::— 83/131 pass (43 failures in pre-existing plane/warp/vector areas)cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::f16_ty::unary— 22/22 pass (no regressions)cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::f32_ty::unary— 22/22 pass (no regressions)Context
Discovered while building z-image, a diffusion image generation pipeline running on Apple Silicon (M2 Max)
via burn's wgpu/Metal backend. bf16 models couldn't be used directly on Metal because every transcendental operation in the diffusion pipeline
(timestep embeddings, RoPE, attention softmax) generated invalid MSL.
With this fix, individual bf16 transcendental operations work correctly. Full end-to-end bf16 inference still requires additional work in the
fusion kernel codegen path (
elemwise_fusegenerates bf16 MSL patterns that the Metal compiler rejects).