Skip to content

fix(metal): bf16 transcendental codegen — cast through f32#1203

Open
holg wants to merge 1 commit intotracel-ai:mainfrom
holg:fix/metal-bf16-transcendentals
Open

fix(metal): bf16 transcendental codegen — cast through f32#1203
holg wants to merge 1 commit intotracel-ai:mainfrom
holg:fix/metal-bf16-transcendentals

Conversation

@holg
Copy link

@holg holg commented Feb 28, 2026

Closes #1200

Summary

  • Metal's bfloat type has no native transcendental functions (exp, sin, cos, log, sqrt, etc.). The GPU's Special Function Units only support
    float (f32) and half (f16). Previously, the MSL codegen grouped bf16 with f16, generating invalid calls like exp(bfloat_value) that fail to
    compile.
  • Added bf16_has_native_math_functions() to DialectInstructions trait (default true for CUDA/HIP, overridden to false for Metal). When
    false, bf16 operands are cast through f32: bfloat(exp(float(x))).
  • Fixed safe_tanh, log1p, and Remainder::floor() for bf16 on Metal.
  • Registered FloatKind::BF16 as a supported type for the Metal backend.
  • Added bf16 to the MSL test suite.

Test plan

  • cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::bf16_ty::unary — 20/22 pass (2 failures are pre-existing vector math issues,
    not transcendentals)
  • cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::bf16_ty:: — 83/131 pass (43 failures in pre-existing plane/warp/vector areas)
  • cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::f16_ty::unary — 22/22 pass (no regressions)
  • cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::f32_ty::unary — 22/22 pass (no regressions)
  • End-to-end diffusion inference with bf16 model weights on wgpu/Metal (via f32 dtype conversion) — correct images produced

Context

Discovered while building z-image, a diffusion image generation pipeline running on Apple Silicon (M2 Max)
via burn's wgpu/Metal backend. bf16 models couldn't be used directly on Metal because every transcendental operation in the diffusion pipeline
(timestep embeddings, RoPE, attention softmax) generated invalid MSL.

With this fix, individual bf16 transcendental operations work correctly. Full end-to-end bf16 inference still requires additional work in the
fusion kernel codegen path (elemwise_fuse generates bf16 MSL patterns that the Metal compiler rejects).

Metal's `bfloat` type has no native transcendental functions (exp, sin,
cos, log, sqrt, etc.). The GPU's Special Function Units only support
`float` (f32) and `half` (f16). Previously, the MSL codegen grouped
bf16 with f16, generating invalid calls like `exp(bfloat_value)` that
fail to compile.

Add `bf16_has_native_math_functions()` to `DialectInstructions` trait
(default `true` for CUDA/HIP, overridden to `false` for Metal). When
false, bf16 operands are cast through f32: `bfloat(exp(float(x)))`.

Also fixes:
- `safe_tanh_scalar` for bf16 on Metal (cast through float)
- `log1p` for bf16 on Metal (cast through float, not half)
- `Remainder::floor()` for bf16 (cast through f32)
- Register `FloatKind::BF16` as supported type for Metal backend
- Add bf16 to MSL test suite (20/22 unary tests pass;
  2 failures are pre-existing vector math literal issues)
Copilot AI review requested due to automatic review settings February 28, 2026 14:06
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes Metal (MSL) shader generation for bf16 transcendental math by avoiding invalid calls like exp(bfloat) and instead casting bf16 operands through f32 when the backend lacks native bf16 math support.

Changes:

  • Add a dialect capability flag (bf16_has_native_math_functions) and use it to decide when bf16 transcendentals must cast through f32 in shared codegen.
  • Fix bf16 behavior for safe_tanh, log1p, and Remainder::floor() on Metal by routing through f32 where needed.
  • Register bf16 as a supported Metal type in the wgpu backend and add bf16 to the MSL test generation matrix.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
crates/cubecl-wgpu/src/lib.rs Adds bf16 to the MSL test type matrix.
crates/cubecl-wgpu/src/backend/metal.rs Registers bf16 as a supported scalar type for the Metal backend.
crates/cubecl-cpp/src/shared/unary.rs Updates unary math formatting to cast bf16 through f32 when the dialect indicates no native bf16 math.
crates/cubecl-cpp/src/shared/instruction.rs Adjusts remainder/floor formatting to route bf16 through f32 when required.
crates/cubecl-cpp/src/shared/dialect.rs Introduces bf16_has_native_math_functions() on DialectInstructions (default true).
crates/cubecl-cpp/src/metal/extension.rs Fixes Metal safe_tanh_scalar generation for bf16 by casting through float.
crates/cubecl-cpp/src/metal/dialect.rs Overrides bf16 native-math capability to false and fixes log1p emission for bf16.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

ElemType::Int(IntKind::I16),
ElemType::Int(IntKind::I32),
ElemType::Int(IntKind::I64),
ElemType::Float(FloatKind::BF16),
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FloatKind::BF16 is registered unconditionally for Metal. Since MSL bfloat support is not guaranteed on all Metal devices/OS versions, this can make DeviceProperties claim bf16 is available and lead to runtime shader compilation failures on unsupported hardware. Consider gating bf16 registration behind an explicit capability check (e.g., a wgpu/Metal feature or adapter/hal query for bfloat support), and only register BF16 when supported.

Suggested change
ElemType::Float(FloatKind::BF16),

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it would be important to validate if the device has bf16 support. I think older Macs with Intel or AMD gpus don't support bf16.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am checking right now

@holg
Copy link
Author

holg commented Mar 2, 2026

The review feedback about unconditional bf16 registration is addressed in follow-up PR #1207:

  • bf16 type registration is now gated behind #[cfg(apple_silicon)] (target_os = "macos" + target_arch = "aarch64"), so Intel/AMD Macs won't advertise bf16 support.
  • Also fixes bf16 zero-literal initialization (bfloat(0.0) instead of bare 0.0) and adds MSL hypot/rhypot overrides.

bf16 unary tests go from 20/22 → 22/22 with those fixes applied on top.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Metal MSL codegen: bf16 transcendental functions generate invalid shader code

3 participants