diff --git a/crates/cubecl-cpp/src/metal/dialect.rs b/crates/cubecl-cpp/src/metal/dialect.rs index e4f48f87d..f8415c995 100644 --- a/crates/cubecl-cpp/src/metal/dialect.rs +++ b/crates/cubecl-cpp/src/metal/dialect.rs @@ -832,7 +832,11 @@ impl DialectInstructions for MslDialect { input: T, ) -> std::fmt::Result { match input.elem() { - Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => { + Elem::BF16 | Elem::BF16x2 => { + // bfloat has no native log(); cast through float + write!(f, "bfloat(log(float(1.0f) + float({input})))") + } + Elem::F16 | Elem::F16x2 => { write!(f, "log(half(1.0f) + {input})") } _ => write!(f, "log(1.0f + {input})"), @@ -954,6 +958,13 @@ impl DialectInstructions for MslDialect { "" } + /// Metal's `bfloat` type has no native transcendental functions (exp, sin, cos, etc.). + /// Only `half` (f16) and `float` (f32) do. The GPU's Special Function Units are wired + /// for f32 and f16 only, so bf16 must be cast through f32. + fn bf16_has_native_math_functions() -> bool { + false + } + // Warp fn compile_warp_shuffle( f: &mut std::fmt::Formatter<'_>, diff --git a/crates/cubecl-cpp/src/metal/extension.rs b/crates/cubecl-cpp/src/metal/extension.rs index 7364164a8..b7cf2926f 100644 --- a/crates/cubecl-cpp/src/metal/extension.rs +++ b/crates/cubecl-cpp/src/metal/extension.rs @@ -150,15 +150,27 @@ pub fn format_safe_tanh( item: &Item, ) -> core::fmt::Result { let elem = item.elem(); + // bfloat has no native tanh(); cast through float + let is_bf16 = matches!(elem, Elem::BF16); + let (clamp_ret, tanh_expr) = if is_bf16 { + // bfloat has no native tanh(); cast through float. + // Literal 1.0 is float — must cast to bfloat for return type. + ( + format!("return {elem}(1.0);"), + format!("return {elem}(tanh(float(x)));"), + ) + } else { + ("return 1.0;".to_string(), "return tanh(x);".to_string()) + }; write!( f, " /// Metal has a weird numerical behaviour with tanh for inputs over 43.0 inline {elem} safe_tanh_scalar({elem} x) {{ if (x > 43.0) {{ - return 1.0; + {clamp_ret} }} else {{ - return tanh(x); + {tanh_expr} }} }} " diff --git a/crates/cubecl-cpp/src/shared/dialect.rs b/crates/cubecl-cpp/src/shared/dialect.rs index 61dfa099d..e621cec8f 100644 --- a/crates/cubecl-cpp/src/shared/dialect.rs +++ b/crates/cubecl-cpp/src/shared/dialect.rs @@ -693,6 +693,15 @@ pub trait DialectInstructions { "h2" } + /// Whether bf16 has native math functions (transcendentals like exp, sin, cos, etc.) + /// on this backend. When false, bf16 operands are cast through f32 for these operations. + /// + /// CUDA/HIP: true (bf16 shares f16 intrinsics via the `h` prefix). + /// Metal: false (MSL `bfloat` has no native transcendental support; only `half`/`float` do). + fn bf16_has_native_math_functions() -> bool { + true + } + // warp fn compile_warp_shuffle( f: &mut std::fmt::Formatter<'_>, diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index 538ce8bb8..e288bfd25 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -904,10 +904,13 @@ impl Remainder { rhs: &Variable, out: &Variable, ) -> core::fmt::Result { + let bf16_native = D::bf16_has_native_math_functions(); let floor = |elem| { let prefix = match elem { - Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(), - Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(), + Elem::F16 => D::compile_instruction_half_function_name_prefix(), + Elem::BF16 if bf16_native => D::compile_instruction_half_function_name_prefix(), + Elem::F16x2 => D::compile_instruction_half2_function_name_prefix(), + Elem::BF16x2 if bf16_native => D::compile_instruction_half2_function_name_prefix(), _ => "", }; format!("{prefix}floor") @@ -917,9 +920,11 @@ impl Remainder { out.elem(), Elem::I8 | Elem::I16 | Elem::I32 | Elem::U8 | Elem::U16 | Elem::U32 | Elem::U64 ); + // bf16 without native math needs floor() via f32 cast, same as integers + let bf16_needs_cast = matches!(out.elem(), Elem::BF16 | Elem::BF16x2) && !bf16_native; let out_elem = out.elem(); let rem_expr = |lhs, rhs, floor: &str| { - if is_int { + if is_int || bf16_needs_cast { format!("{lhs} - {rhs} * ({out_elem}){floor}((float){lhs} / (float){rhs})") } else { format!("{lhs} - {rhs} * {floor}({lhs} / {rhs})") diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index 55b02f416..6257203de 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -85,10 +85,13 @@ pub trait Unary { pub trait FunctionFmt { fn base_function_name() -> &'static str; fn function_name(elem: Elem) -> String { + let bf16_native = D::bf16_has_native_math_functions(); if Self::half_support() { let prefix = match elem { - Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(), - Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(), + Elem::F16 => D::compile_instruction_half_function_name_prefix(), + Elem::BF16 if bf16_native => D::compile_instruction_half_function_name_prefix(), + Elem::F16x2 => D::compile_instruction_half2_function_name_prefix(), + Elem::BF16x2 if bf16_native => D::compile_instruction_half2_function_name_prefix(), _ => "", }; format!("{prefix}{}", Self::base_function_name()) @@ -101,15 +104,20 @@ pub trait FunctionFmt { input: Input, elem: Elem, ) -> std::fmt::Result { - if Self::half_support() { - write!(f, "{}({input})", Self::function_name(elem)) - } else { + let bf16_native = D::bf16_has_native_math_functions(); + // bf16 without native math support must cast through f32, same as the + // half_support=false path. + let needs_f32_cast = + !Self::half_support() || (matches!(elem, Elem::BF16 | Elem::BF16x2) && !bf16_native); + if needs_f32_cast { match elem { Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => { write!(f, "{}({}(float({input})))", elem, Self::function_name(elem)) } _ => write!(f, "{}({input})", Self::function_name(elem)), } + } else { + write!(f, "{}({input})", Self::function_name(elem)) } } @@ -383,10 +391,13 @@ impl Unary for Assign { } fn elem_function_name(base_name: &'static str, elem: Elem) -> String { + let bf16_native = D::bf16_has_native_math_functions(); // Math functions prefix (no leading underscores) let prefix = match elem { - Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(), - Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(), + Elem::F16 => D::compile_instruction_half_function_name_prefix(), + Elem::BF16 if bf16_native => D::compile_instruction_half_function_name_prefix(), + Elem::F16x2 => D::compile_instruction_half2_function_name_prefix(), + Elem::BF16x2 if bf16_native => D::compile_instruction_half2_function_name_prefix(), _ => "", }; if prefix.is_empty() { diff --git a/crates/cubecl-wgpu/src/backend/metal.rs b/crates/cubecl-wgpu/src/backend/metal.rs index cacb6a013..3f8fb3abf 100644 --- a/crates/cubecl-wgpu/src/backend/metal.rs +++ b/crates/cubecl-wgpu/src/backend/metal.rs @@ -105,6 +105,7 @@ fn register_types(props: &mut DeviceProperties) { ElemType::Int(IntKind::I16), ElemType::Int(IntKind::I32), ElemType::Int(IntKind::I64), + ElemType::Float(FloatKind::BF16), ElemType::Float(FloatKind::F16), ElemType::Float(FloatKind::F32), ElemType::Bool, diff --git a/crates/cubecl-wgpu/src/lib.rs b/crates/cubecl-wgpu/src/lib.rs index 1887e2f13..9e556d5f6 100644 --- a/crates/cubecl-wgpu/src/lib.rs +++ b/crates/cubecl-wgpu/src/lib.rs @@ -58,9 +58,9 @@ mod tests_spirv { #[allow(unexpected_cfgs)] mod tests_msl { pub type TestRuntime = crate::WgpuRuntime; - use half::f16; + use half::{bf16, f16}; - cubecl_core::testgen_all!(f32: [f16, f32], i32: [i16, i32], u32: [u16, u32]); + cubecl_core::testgen_all!(f32: [bf16, f16, f32], i32: [i16, i32], u32: [u16, u32]); cubecl_std::testgen!(); cubecl_std::testgen_tensor_identity!([f16, flex32, f32, u32]); cubecl_std::testgen_quantized_view!(f16);