Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion crates/cubecl-cpp/src/metal/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,11 @@ impl DialectInstructions<Self> for MslDialect {
input: T,
) -> std::fmt::Result {
match input.elem() {
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
Elem::BF16 | Elem::BF16x2 => {
// bfloat has no native log(); cast through float
write!(f, "bfloat(log(float(1.0f) + float({input})))")
}
Elem::F16 | Elem::F16x2 => {
write!(f, "log(half(1.0f) + {input})")
}
_ => write!(f, "log(1.0f + {input})"),
Expand Down Expand Up @@ -954,6 +958,13 @@ impl DialectInstructions<Self> for MslDialect {
""
}

/// Metal's `bfloat` type has no native transcendental functions (exp, sin, cos, etc.).
/// Only `half` (f16) and `float` (f32) do. The GPU's Special Function Units are wired
/// for f32 and f16 only, so bf16 must be cast through f32.
fn bf16_has_native_math_functions() -> bool {
false
}

// Warp
fn compile_warp_shuffle(
f: &mut std::fmt::Formatter<'_>,
Expand Down
16 changes: 14 additions & 2 deletions crates/cubecl-cpp/src/metal/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,27 @@ pub fn format_safe_tanh<D: Dialect>(
item: &Item<D>,
) -> core::fmt::Result {
let elem = item.elem();
// bfloat has no native tanh(); cast through float
let is_bf16 = matches!(elem, Elem::BF16);
let (clamp_ret, tanh_expr) = if is_bf16 {
// bfloat has no native tanh(); cast through float.
// Literal 1.0 is float — must cast to bfloat for return type.
(
format!("return {elem}(1.0);"),
format!("return {elem}(tanh(float(x)));"),
)
} else {
("return 1.0;".to_string(), "return tanh(x);".to_string())
};
write!(
f,
"
/// Metal has a weird numerical behaviour with tanh for inputs over 43.0
inline {elem} safe_tanh_scalar({elem} x) {{
if (x > 43.0) {{
return 1.0;
{clamp_ret}
}} else {{
return tanh(x);
{tanh_expr}
}}
}}
"
Expand Down
9 changes: 9 additions & 0 deletions crates/cubecl-cpp/src/shared/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,15 @@ pub trait DialectInstructions<D: Dialect> {
"h2"
}

/// Whether bf16 has native math functions (transcendentals like exp, sin, cos, etc.)
/// on this backend. When false, bf16 operands are cast through f32 for these operations.
///
/// CUDA/HIP: true (bf16 shares f16 intrinsics via the `h` prefix).
/// Metal: false (MSL `bfloat` has no native transcendental support; only `half`/`float` do).
fn bf16_has_native_math_functions() -> bool {
true
}

// warp
fn compile_warp_shuffle(
f: &mut std::fmt::Formatter<'_>,
Expand Down
11 changes: 8 additions & 3 deletions crates/cubecl-cpp/src/shared/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -904,10 +904,13 @@ impl<D: Dialect> Remainder<D> {
rhs: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
let bf16_native = D::bf16_has_native_math_functions();
let floor = |elem| {
let prefix = match elem {
Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
Elem::F16 => D::compile_instruction_half_function_name_prefix(),
Elem::BF16 if bf16_native => D::compile_instruction_half_function_name_prefix(),
Elem::F16x2 => D::compile_instruction_half2_function_name_prefix(),
Elem::BF16x2 if bf16_native => D::compile_instruction_half2_function_name_prefix(),
_ => "",
};
format!("{prefix}floor")
Expand All @@ -917,9 +920,11 @@ impl<D: Dialect> Remainder<D> {
out.elem(),
Elem::I8 | Elem::I16 | Elem::I32 | Elem::U8 | Elem::U16 | Elem::U32 | Elem::U64
);
// bf16 without native math needs floor() via f32 cast, same as integers
let bf16_needs_cast = matches!(out.elem(), Elem::BF16 | Elem::BF16x2) && !bf16_native;
let out_elem = out.elem();
let rem_expr = |lhs, rhs, floor: &str| {
if is_int {
if is_int || bf16_needs_cast {
format!("{lhs} - {rhs} * ({out_elem}){floor}((float){lhs} / (float){rhs})")
} else {
format!("{lhs} - {rhs} * {floor}({lhs} / {rhs})")
Expand Down
25 changes: 18 additions & 7 deletions crates/cubecl-cpp/src/shared/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ pub trait Unary<D: Dialect> {
pub trait FunctionFmt<D: Dialect> {
fn base_function_name() -> &'static str;
fn function_name(elem: Elem<D>) -> String {
let bf16_native = D::bf16_has_native_math_functions();
if Self::half_support() {
let prefix = match elem {
Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
Elem::F16 => D::compile_instruction_half_function_name_prefix(),
Elem::BF16 if bf16_native => D::compile_instruction_half_function_name_prefix(),
Elem::F16x2 => D::compile_instruction_half2_function_name_prefix(),
Elem::BF16x2 if bf16_native => D::compile_instruction_half2_function_name_prefix(),
_ => "",
};
format!("{prefix}{}", Self::base_function_name())
Expand All @@ -101,15 +104,20 @@ pub trait FunctionFmt<D: Dialect> {
input: Input,
elem: Elem<D>,
) -> std::fmt::Result {
if Self::half_support() {
write!(f, "{}({input})", Self::function_name(elem))
} else {
let bf16_native = D::bf16_has_native_math_functions();
// bf16 without native math support must cast through f32, same as the
// half_support=false path.
let needs_f32_cast =
!Self::half_support() || (matches!(elem, Elem::BF16 | Elem::BF16x2) && !bf16_native);
if needs_f32_cast {
match elem {
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
write!(f, "{}({}(float({input})))", elem, Self::function_name(elem))
}
_ => write!(f, "{}({input})", Self::function_name(elem)),
}
} else {
write!(f, "{}({input})", Self::function_name(elem))
}
}

Expand Down Expand Up @@ -383,10 +391,13 @@ impl<D: Dialect> Unary<D> for Assign {
}

fn elem_function_name<D: Dialect>(base_name: &'static str, elem: Elem<D>) -> String {
let bf16_native = D::bf16_has_native_math_functions();
// Math functions prefix (no leading underscores)
let prefix = match elem {
Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
Elem::F16 => D::compile_instruction_half_function_name_prefix(),
Elem::BF16 if bf16_native => D::compile_instruction_half_function_name_prefix(),
Elem::F16x2 => D::compile_instruction_half2_function_name_prefix(),
Elem::BF16x2 if bf16_native => D::compile_instruction_half2_function_name_prefix(),
_ => "",
};
if prefix.is_empty() {
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-wgpu/src/backend/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ fn register_types(props: &mut DeviceProperties) {
ElemType::Int(IntKind::I16),
ElemType::Int(IntKind::I32),
ElemType::Int(IntKind::I64),
ElemType::Float(FloatKind::BF16),
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FloatKind::BF16 is registered unconditionally for Metal. Since MSL bfloat support is not guaranteed on all Metal devices/OS versions, this can make DeviceProperties claim bf16 is available and lead to runtime shader compilation failures on unsupported hardware. Consider gating bf16 registration behind an explicit capability check (e.g., a wgpu/Metal feature or adapter/hal query for bfloat support), and only register BF16 when supported.

Suggested change
ElemType::Float(FloatKind::BF16),

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it would be important to validate if the device has bf16 support. I think older Macs with Intel or AMD gpus don't support bf16.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am checking right now

ElemType::Float(FloatKind::F16),
ElemType::Float(FloatKind::F32),
ElemType::Bool,
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ mod tests_spirv {
#[allow(unexpected_cfgs)]
mod tests_msl {
pub type TestRuntime = crate::WgpuRuntime;
use half::f16;
use half::{bf16, f16};

cubecl_core::testgen_all!(f32: [f16, f32], i32: [i16, i32], u32: [u16, u32]);
cubecl_core::testgen_all!(f32: [bf16, f16, f32], i32: [i16, i32], u32: [u16, u32]);
cubecl_std::testgen!();
cubecl_std::testgen_tensor_identity!([f16, flex32, f32, u32]);
cubecl_std::testgen_quantized_view!(f16);
Expand Down
Loading