fix(metal): gate bf16 support to Apple Silicon and fix bf16 literals#1207
fix(metal): gate bf16 support to Apple Silicon and fix bf16 literals#1207holg wants to merge 3 commits intotracel-ai:mainfrom
Conversation
Metal's `bfloat` type has no native transcendental functions (exp, sin, cos, log, sqrt, etc.). The GPU's Special Function Units only support `float` (f32) and `half` (f16). Previously, the MSL codegen grouped bf16 with f16, generating invalid calls like `exp(bfloat_value)` that fail to compile. Add `bf16_has_native_math_functions()` to `DialectInstructions` trait (default `true` for CUDA/HIP, overridden to `false` for Metal). When false, bf16 operands are cast through f32: `bfloat(exp(float(x)))`. Also fixes: - `safe_tanh_scalar` for bf16 on Metal (cast through float) - `log1p` for bf16 on Metal (cast through float, not half) - `Remainder::floor()` for bf16 (cast through f32) - Register `FloatKind::BF16` as supported type for Metal backend - Add bf16 to MSL test suite (20/22 unary tests pass; 2 failures are pre-existing vector math literal issues)
- Gate bf16 type registration behind #[cfg(apple_silicon)] since MSL bfloat is only available on Apple7+ (M1 and later). Intel/AMD Macs would get runtime shader compilation failures. - Fix bf16 zero-literal initialization in Magnitude/Normalize: use typed constructors (bfloat(0.0)) instead of bare 0.0. - Add MSL overrides for hypot/rhypot (MSL uses overloaded names, not CUDA-style hypotf/rhypotf).
There was a problem hiding this comment.
Pull request overview
Follow-up to #1203 to make Metal bf16 support safe/correct by restricting bf16 usage to Apple Silicon and fixing MSL codegen edge cases for bf16 math and literals.
Changes:
- Gate Metal BF16 type registration behind
#[cfg(apple_silicon)]to avoid runtime MSL compilation failures on Intel/AMD Macs. - Adjust shared C++ codegen to treat bf16 “native math support” as dialect-dependent (Metal casts bf16 transcendentals through f32).
- Fix/extend Metal-specific codegen for bf16 literals and add MSL overrides for
hypot/rhypot.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| crates/cubecl-wgpu/src/lib.rs | Adds bf16 to the MSL test matrix. |
| crates/cubecl-wgpu/src/backend/metal.rs | Registers BF16 only when building for Apple Silicon. |
| crates/cubecl-cpp/src/shared/dialect.rs | Introduces bf16_has_native_math_functions() dialect hook (default true). |
| crates/cubecl-cpp/src/shared/unary.rs | Uses the new bf16-native-math hook to pick prefixes/cast strategy for unary math formatting. |
| crates/cubecl-cpp/src/shared/instruction.rs | Updates remainder/floor handling for bf16 without native math; fixes bf16 zero initialization in magnitude/normalize. |
| crates/cubecl-cpp/src/metal/extension.rs | Updates safe_tanh formatting to cast bf16 through float and return typed literals. |
| crates/cubecl-cpp/src/metal/dialect.rs | Adds MSL overrides for hypot/rhypot and marks bf16 native math as unavailable. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| pub type TestRuntime = crate::WgpuRuntime; | ||
| use half::f16; | ||
| use half::{bf16, f16}; | ||
|
|
||
| cubecl_core::testgen_all!(f32: [f16, f32], i32: [i16, i32], u32: [u16, u32]); | ||
| cubecl_core::testgen_all!(f32: [bf16, f16, f32], i32: [i16, i32], u32: [u16, u32]); | ||
| cubecl_std::testgen!(); |
There was a problem hiding this comment.
testgen_all! now includes bf16 for the MSL test suite unconditionally. On x86_64 macOS this will still build, but register_types() no longer registers FloatKind::BF16, and the unary tests don’t appear to skip unsupported float types, so bf16 MSL tests are likely to fail at runtime. Consider gating the bf16 type list behind #[cfg(apple_silicon)] (or otherwise skipping bf16 tests when the device properties don’t advertise BF16 support).
| _elem: Elem<Self>, | ||
| ) -> std::fmt::Result { | ||
| write!(f, "1.0 / hypot({lhs}, {rhs})") |
There was a problem hiding this comment.
compile_instruction_rhypot always emits 1.0 / hypot(...). For f32 call sites this promotes the expression to double (and may rely on double support / implicit narrowing back to float), which is risky in MSL and can cause compilation failures. Use an f32 literal (1.0f) or otherwise respect the _elem parameter to choose the correct literal/cast for F32 vs F64.
| _elem: Elem<Self>, | |
| ) -> std::fmt::Result { | |
| write!(f, "1.0 / hypot({lhs}, {rhs})") | |
| elem: Elem<Self>, | |
| ) -> std::fmt::Result { | |
| let one_literal = match elem { | |
| Elem::F32 => "1.0f", | |
| _ => "1.0", | |
| }; | |
| write!(f, "{one_literal} / hypot({lhs}, {rhs})") |
- Gate bf16 test matrix behind #[cfg(apple_silicon)] so Intel Mac builds don't generate bf16 tests that would fail at runtime. - Use 1.0f literal in rhypot to avoid implicit promotion to double in MSL.
Follow-up to #1203 — addresses review feedback about unconditional bf16 registration on Metal.
Summary
#1203 added bf16 transcendental codegen support for Metal but registered
FloatKind::BF16unconditionally for all Metal devices. MSL's
bfloattype is only available on Apple Silicon(Apple7+ GPU family — M1 and later). Intel and AMD-based Macs do not support it, which would
cause shader compilation failures at runtime.
This PR:
#[cfg(apple_silicon)](target_os = "macos"+target_arch = "aarch64"), which maps 1:1 to Apple7+ hardware.Magnitude/Normalize— uses typed constructors(
bfloat(0.0)) instead of bare0.0which the Metal compiler rejects for bfloat. This fixes2 previously failing bf16 unary tests.
hypot/rhypot(MSL uses overloaded names, not CUDA-stylehypotf).Test plan
cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::bf16_ty::unary— 22/22pass (up from 20/22 before the literal fix)
cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::bf16_ty::— 85/131 pass(up from 83/131; 41 failures are pre-existing plane/warp/vector issues, 5 ignored)
cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::f16_ty::unary— 22/22pass (no regressions)
cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::f32_ty::unary— 22/22pass (no regressions)