Skip to content

fix(metal): gate bf16 support to Apple Silicon and fix bf16 literals#1207

Open
holg wants to merge 3 commits intotracel-ai:mainfrom
holg:fix/metal-bf16-literals
Open

fix(metal): gate bf16 support to Apple Silicon and fix bf16 literals#1207
holg wants to merge 3 commits intotracel-ai:mainfrom
holg:fix/metal-bf16-literals

Conversation

@holg
Copy link

@holg holg commented Mar 2, 2026

Follow-up to #1203 — addresses review feedback about unconditional bf16 registration on Metal.

Summary

#1203 added bf16 transcendental codegen support for Metal but registered FloatKind::BF16
unconditionally for all Metal devices. MSL's bfloat type is only available on Apple Silicon
(Apple7+ GPU family — M1 and later). Intel and AMD-based Macs do not support it, which would
cause shader compilation failures at runtime.

This PR:

  • Gates bf16 type registration behind #[cfg(apple_silicon)] (target_os = "macos" +
    target_arch = "aarch64"), which maps 1:1 to Apple7+ hardware.
  • Fixes bf16 zero-literal initialization in Magnitude/Normalize — uses typed constructors
    (bfloat(0.0)) instead of bare 0.0 which the Metal compiler rejects for bfloat. This fixes
    2 previously failing bf16 unary tests.
  • Adds MSL overrides for hypot/rhypot (MSL uses overloaded names, not CUDA-style hypotf).

Test plan

  • cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::bf16_ty::unary — 22/22
    pass (up from 20/22 before the literal fix)
  • cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::bf16_ty:: — 85/131 pass
    (up from 83/131; 41 failures are pre-existing plane/warp/vector issues, 5 ignored)
  • cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::f16_ty::unary — 22/22
    pass (no regressions)
  • cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::f32_ty::unary — 22/22
    pass (no regressions)

holg added 2 commits February 28, 2026 14:22
Metal's `bfloat` type has no native transcendental functions (exp, sin,
cos, log, sqrt, etc.). The GPU's Special Function Units only support
`float` (f32) and `half` (f16). Previously, the MSL codegen grouped
bf16 with f16, generating invalid calls like `exp(bfloat_value)` that
fail to compile.

Add `bf16_has_native_math_functions()` to `DialectInstructions` trait
(default `true` for CUDA/HIP, overridden to `false` for Metal). When
false, bf16 operands are cast through f32: `bfloat(exp(float(x)))`.

Also fixes:
- `safe_tanh_scalar` for bf16 on Metal (cast through float)
- `log1p` for bf16 on Metal (cast through float, not half)
- `Remainder::floor()` for bf16 (cast through f32)
- Register `FloatKind::BF16` as supported type for Metal backend
- Add bf16 to MSL test suite (20/22 unary tests pass;
  2 failures are pre-existing vector math literal issues)
- Gate bf16 type registration behind #[cfg(apple_silicon)] since MSL
  bfloat is only available on Apple7+ (M1 and later). Intel/AMD Macs
  would get runtime shader compilation failures.
- Fix bf16 zero-literal initialization in Magnitude/Normalize: use
  typed constructors (bfloat(0.0)) instead of bare 0.0.
- Add MSL overrides for hypot/rhypot (MSL uses overloaded names,
  not CUDA-style hypotf/rhypotf).
Copilot AI review requested due to automatic review settings March 2, 2026 18:32
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Follow-up to #1203 to make Metal bf16 support safe/correct by restricting bf16 usage to Apple Silicon and fixing MSL codegen edge cases for bf16 math and literals.

Changes:

  • Gate Metal BF16 type registration behind #[cfg(apple_silicon)] to avoid runtime MSL compilation failures on Intel/AMD Macs.
  • Adjust shared C++ codegen to treat bf16 “native math support” as dialect-dependent (Metal casts bf16 transcendentals through f32).
  • Fix/extend Metal-specific codegen for bf16 literals and add MSL overrides for hypot/rhypot.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
crates/cubecl-wgpu/src/lib.rs Adds bf16 to the MSL test matrix.
crates/cubecl-wgpu/src/backend/metal.rs Registers BF16 only when building for Apple Silicon.
crates/cubecl-cpp/src/shared/dialect.rs Introduces bf16_has_native_math_functions() dialect hook (default true).
crates/cubecl-cpp/src/shared/unary.rs Uses the new bf16-native-math hook to pick prefixes/cast strategy for unary math formatting.
crates/cubecl-cpp/src/shared/instruction.rs Updates remainder/floor handling for bf16 without native math; fixes bf16 zero initialization in magnitude/normalize.
crates/cubecl-cpp/src/metal/extension.rs Updates safe_tanh formatting to cast bf16 through float and return typed literals.
crates/cubecl-cpp/src/metal/dialect.rs Adds MSL overrides for hypot/rhypot and marks bf16 native math as unavailable.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 60 to 64
pub type TestRuntime = crate::WgpuRuntime;
use half::f16;
use half::{bf16, f16};

cubecl_core::testgen_all!(f32: [f16, f32], i32: [i16, i32], u32: [u16, u32]);
cubecl_core::testgen_all!(f32: [bf16, f16, f32], i32: [i16, i32], u32: [u16, u32]);
cubecl_std::testgen!();
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testgen_all! now includes bf16 for the MSL test suite unconditionally. On x86_64 macOS this will still build, but register_types() no longer registers FloatKind::BF16, and the unary tests don’t appear to skip unsupported float types, so bf16 MSL tests are likely to fail at runtime. Consider gating the bf16 type list behind #[cfg(apple_silicon)] (or otherwise skipping bf16 tests when the device properties don’t advertise BF16 support).

Copilot uses AI. Check for mistakes.
Comment on lines +976 to +978
_elem: Elem<Self>,
) -> std::fmt::Result {
write!(f, "1.0 / hypot({lhs}, {rhs})")
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_instruction_rhypot always emits 1.0 / hypot(...). For f32 call sites this promotes the expression to double (and may rely on double support / implicit narrowing back to float), which is risky in MSL and can cause compilation failures. Use an f32 literal (1.0f) or otherwise respect the _elem parameter to choose the correct literal/cast for F32 vs F64.

Suggested change
_elem: Elem<Self>,
) -> std::fmt::Result {
write!(f, "1.0 / hypot({lhs}, {rhs})")
elem: Elem<Self>,
) -> std::fmt::Result {
let one_literal = match elem {
Elem::F32 => "1.0f",
_ => "1.0",
};
write!(f, "{one_literal} / hypot({lhs}, {rhs})")

Copilot uses AI. Check for mistakes.
- Gate bf16 test matrix behind #[cfg(apple_silicon)] so Intel Mac
  builds don't generate bf16 tests that would fail at runtime.
- Use 1.0f literal in rhypot to avoid implicit promotion to double
  in MSL.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants