Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion crates/cubecl-cpp/src/metal/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,11 @@ impl DialectInstructions<Self> for MslDialect {
input: T,
) -> std::fmt::Result {
match input.elem() {
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
Elem::BF16 | Elem::BF16x2 => {
// bfloat has no native log(); cast through float
write!(f, "bfloat(log(float(1.0f) + float({input})))")
}
Elem::F16 | Elem::F16x2 => {
write!(f, "log(half(1.0f) + {input})")
}
_ => write!(f, "log(1.0f + {input})"),
Expand Down Expand Up @@ -954,6 +958,33 @@ impl DialectInstructions<Self> for MslDialect {
""
}

/// MSL uses overloaded `hypot()` (no `f` suffix like CUDA's `hypotf()`).
fn compile_instruction_hypot(
f: &mut std::fmt::Formatter<'_>,
lhs: &str,
rhs: &str,
_elem: Elem<Self>,
) -> std::fmt::Result {
write!(f, "hypot({lhs}, {rhs})")
}

/// MSL has no `rhypot` intrinsic; emit `1.0 / hypot(...)` instead.
fn compile_instruction_rhypot(
f: &mut std::fmt::Formatter<'_>,
lhs: &str,
rhs: &str,
_elem: Elem<Self>,
) -> std::fmt::Result {
write!(f, "1.0 / hypot({lhs}, {rhs})")
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_instruction_rhypot always emits 1.0 / hypot(...). For f32 call sites this promotes the expression to double (and may rely on double support / implicit narrowing back to float), which is risky in MSL and can cause compilation failures. Use an f32 literal (1.0f) or otherwise respect the _elem parameter to choose the correct literal/cast for F32 vs F64.

Suggested change
_elem: Elem<Self>,
) -> std::fmt::Result {
write!(f, "1.0 / hypot({lhs}, {rhs})")
elem: Elem<Self>,
) -> std::fmt::Result {
let one_literal = match elem {
Elem::F32 => "1.0f",
_ => "1.0",
};
write!(f, "{one_literal} / hypot({lhs}, {rhs})")

Copilot uses AI. Check for mistakes.
}

/// Metal's `bfloat` type has no native transcendental functions (exp, sin, cos, etc.).
/// Only `half` (f16) and `float` (f32) do. The GPU's Special Function Units are wired
/// for f32 and f16 only, so bf16 must be cast through f32.
fn bf16_has_native_math_functions() -> bool {
false
}

// Warp
fn compile_warp_shuffle(
f: &mut std::fmt::Formatter<'_>,
Expand Down
16 changes: 14 additions & 2 deletions crates/cubecl-cpp/src/metal/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,27 @@ pub fn format_safe_tanh<D: Dialect>(
item: &Item<D>,
) -> core::fmt::Result {
let elem = item.elem();
// bfloat has no native tanh(); cast through float
let is_bf16 = matches!(elem, Elem::BF16);
let (clamp_ret, tanh_expr) = if is_bf16 {
// bfloat has no native tanh(); cast through float.
// Literal 1.0 is float — must cast to bfloat for return type.
(
format!("return {elem}(1.0);"),
format!("return {elem}(tanh(float(x)));"),
)
} else {
("return 1.0;".to_string(), "return tanh(x);".to_string())
};
write!(
f,
"
/// Metal has a weird numerical behaviour with tanh for inputs over 43.0
inline {elem} safe_tanh_scalar({elem} x) {{
if (x > 43.0) {{
return 1.0;
{clamp_ret}
}} else {{
return tanh(x);
{tanh_expr}
}}
}}
"
Expand Down
9 changes: 9 additions & 0 deletions crates/cubecl-cpp/src/shared/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,15 @@ pub trait DialectInstructions<D: Dialect> {
"h2"
}

/// Whether bf16 has native math functions (transcendentals like exp, sin, cos, etc.)
/// on this backend. When false, bf16 operands are cast through f32 for these operations.
///
/// CUDA/HIP: true (bf16 shares f16 intrinsics via the `h` prefix).
/// Metal: false (MSL `bfloat` has no native transcendental support; only `half`/`float` do).
fn bf16_has_native_math_functions() -> bool {
true
}

// warp
fn compile_warp_shuffle(
f: &mut std::fmt::Formatter<'_>,
Expand Down
16 changes: 11 additions & 5 deletions crates/cubecl-cpp/src/shared/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -904,10 +904,13 @@ impl<D: Dialect> Remainder<D> {
rhs: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
let bf16_native = D::bf16_has_native_math_functions();
let floor = |elem| {
let prefix = match elem {
Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
Elem::F16 => D::compile_instruction_half_function_name_prefix(),
Elem::BF16 if bf16_native => D::compile_instruction_half_function_name_prefix(),
Elem::F16x2 => D::compile_instruction_half2_function_name_prefix(),
Elem::BF16x2 if bf16_native => D::compile_instruction_half2_function_name_prefix(),
_ => "",
};
format!("{prefix}floor")
Expand All @@ -917,9 +920,11 @@ impl<D: Dialect> Remainder<D> {
out.elem(),
Elem::I8 | Elem::I16 | Elem::I32 | Elem::U8 | Elem::U16 | Elem::U32 | Elem::U64
);
// bf16 without native math needs floor() via f32 cast, same as integers
let bf16_needs_cast = matches!(out.elem(), Elem::BF16 | Elem::BF16x2) && !bf16_native;
let out_elem = out.elem();
let rem_expr = |lhs, rhs, floor: &str| {
if is_int {
if is_int || bf16_needs_cast {
format!("{lhs} - {rhs} * ({out_elem}){floor}((float){lhs} / (float){rhs})")
} else {
format!("{lhs} - {rhs} * {floor}({lhs} / {rhs})")
Expand Down Expand Up @@ -1000,7 +1005,8 @@ impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {

let mag = format!("{out}_mag");

writeln!(f, "{} {mag} = 0.0;", out.item())?;
let item = out.item();
writeln!(f, "{item} {mag} = {item}(0.0);")?;

for i in 0..num {
let input_i = input.index(i);
Expand Down Expand Up @@ -1031,7 +1037,7 @@ impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {

let out_item = out.item();
let out = out.fmt_left();
writeln!(f, "{elem} {norm} = 0.0;")?;
writeln!(f, "{elem} {norm} = {elem}(0.0);")?;

for i in 0..num {
let input_i = input.index(i);
Expand Down
25 changes: 18 additions & 7 deletions crates/cubecl-cpp/src/shared/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ pub trait Unary<D: Dialect> {
pub trait FunctionFmt<D: Dialect> {
fn base_function_name() -> &'static str;
fn function_name(elem: Elem<D>) -> String {
let bf16_native = D::bf16_has_native_math_functions();
if Self::half_support() {
let prefix = match elem {
Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
Elem::F16 => D::compile_instruction_half_function_name_prefix(),
Elem::BF16 if bf16_native => D::compile_instruction_half_function_name_prefix(),
Elem::F16x2 => D::compile_instruction_half2_function_name_prefix(),
Elem::BF16x2 if bf16_native => D::compile_instruction_half2_function_name_prefix(),
_ => "",
};
format!("{prefix}{}", Self::base_function_name())
Expand All @@ -101,15 +104,20 @@ pub trait FunctionFmt<D: Dialect> {
input: Input,
elem: Elem<D>,
) -> std::fmt::Result {
if Self::half_support() {
write!(f, "{}({input})", Self::function_name(elem))
} else {
let bf16_native = D::bf16_has_native_math_functions();
// bf16 without native math support must cast through f32, same as the
// half_support=false path.
let needs_f32_cast =
!Self::half_support() || (matches!(elem, Elem::BF16 | Elem::BF16x2) && !bf16_native);
if needs_f32_cast {
match elem {
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
write!(f, "{}({}(float({input})))", elem, Self::function_name(elem))
}
_ => write!(f, "{}({input})", Self::function_name(elem)),
}
} else {
write!(f, "{}({input})", Self::function_name(elem))
}
}

Expand Down Expand Up @@ -383,10 +391,13 @@ impl<D: Dialect> Unary<D> for Assign {
}

fn elem_function_name<D: Dialect>(base_name: &'static str, elem: Elem<D>) -> String {
let bf16_native = D::bf16_has_native_math_functions();
// Math functions prefix (no leading underscores)
let prefix = match elem {
Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
Elem::F16 => D::compile_instruction_half_function_name_prefix(),
Elem::BF16 if bf16_native => D::compile_instruction_half_function_name_prefix(),
Elem::F16x2 => D::compile_instruction_half2_function_name_prefix(),
Elem::BF16x2 if bf16_native => D::compile_instruction_half2_function_name_prefix(),
_ => "",
};
if prefix.is_empty() {
Expand Down
8 changes: 8 additions & 0 deletions crates/cubecl-wgpu/src/backend/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ fn register_types(props: &mut DeviceProperties) {
register(ty.into(), TypeUsage::all_scalar());
}

// bf16 (bfloat) requires Apple Silicon (Apple7+ GPU family, i.e. M1 and later).
// Intel Macs (x86_64) do not support the bfloat type in MSL.
#[cfg(apple_silicon)]
register(
ElemType::Float(FloatKind::BF16).into(),
TypeUsage::all_scalar(),
);

for ty in atomic_types {
register(
StorageType::Atomic(ty),
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ mod tests_spirv {
#[allow(unexpected_cfgs)]
mod tests_msl {
pub type TestRuntime = crate::WgpuRuntime;
use half::f16;
use half::{bf16, f16};

cubecl_core::testgen_all!(f32: [f16, f32], i32: [i16, i32], u32: [u16, u32]);
cubecl_core::testgen_all!(f32: [bf16, f16, f32], i32: [i16, i32], u32: [u16, u32]);
cubecl_std::testgen!();
Comment on lines 60 to 72
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testgen_all! now includes bf16 for the MSL test suite unconditionally. On x86_64 macOS this will still build, but register_types() no longer registers FloatKind::BF16, and the unary tests don’t appear to skip unsupported float types, so bf16 MSL tests are likely to fail at runtime. Consider gating the bf16 type list behind #[cfg(apple_silicon)] (or otherwise skipping bf16 tests when the device properties don’t advertise BF16 support).

Copilot uses AI. Check for mistakes.
cubecl_std::testgen_tensor_identity!([f16, flex32, f32, u32]);
cubecl_std::testgen_quantized_view!(f16);
Expand Down
Loading