Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ jobs:
# work around for colors
# see: https://github.com/rust-lang/rustfmt/issues/3385
TERM: xterm-256color
run: cargo xtask check format
run: cargo xtask check --ci format
# --------------------------------------------------------------------------------
- name: Lint
run: cargo xtask check lint
run: cargo xtask check --ci lint
# --------------------------------------------------------------------------------
- name: Typos
uses: tracel-ai/github-actions/check-typos@v8
Expand All @@ -97,10 +97,10 @@ jobs:
cache-key: ${{ matrix.rust }}-linux
# --------------------------------------------------------------------------------
- name: Documentation Build
run: cargo xtask doc build
run: cargo xtask doc --ci build
# --------------------------------------------------------------------------------
- name: Documentation Tests
run: cargo xtask doc tests
run: cargo xtask doc --ci tests

linux-std-tests:
runs-on: [
Expand Down
6 changes: 6 additions & 0 deletions crates/cubecl-cpp/src/metal/address_space.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ impl<D: Dialect> From<&Variable<D>> for AddressSpace {
}
}
Variable::SharedArray(..) => AddressSpace::ThreadGroup,
Variable::Tmp { is_ptr: true, .. } => AddressSpace::Device,
Variable::LocalMut { item, .. } | Variable::LocalConst { item, .. }
if matches!(item.elem, crate::shared::Elem::Atomic(_)) =>
{
AddressSpace::Device
}
_ => AddressSpace::Thread,
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-cpp/src/metal/arch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl MetalArchitecture {

impl Architecture for MetalArchitecture {
fn warp_size(&self) -> u32 {
64
32
}

fn is_wmma_capable(&self) -> bool {
Expand Down
50 changes: 45 additions & 5 deletions crates/cubecl-cpp/src/metal/dialect.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{
AddressSpace, Extension,
arch::MetalArchitecture,
extension::{format_ffs, format_mulhi},
extension::{format_ffs, format_hypot, format_mulhi, format_rhypot},
format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
};
use crate::{
Expand Down Expand Up @@ -157,6 +157,8 @@ using namespace metal;
Extension::Ffs(elem) => format_ffs(f, elem)?,
Extension::MulHi(elem) => format_mulhi(f, elem)?,
Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
Extension::Hypot(elem) => format_hypot::<Self>(f, elem)?,
Extension::Rhypot(elem) => format_rhypot::<Self>(f, elem)?,
Extension::NoExtension => {}
}
}
Expand Down Expand Up @@ -205,6 +207,22 @@ using namespace metal;
shared::Instruction::<Self>::Tanh(instruction) => {
register_extension(Extension::SafeTanh(instruction.input.item()));
}
shared::Instruction::<Self>::Hypot(instruction) => {
// For half types, the Binary impl casts to float, so we need float hypot
let elem = match instruction.out.elem() {
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => Elem::F32,
other => other,
};
register_extension(Extension::Hypot(elem));
}
shared::Instruction::<Self>::Rhypot(instruction) => {
// For half types, the Binary impl casts to float, so we need float rhypot
let elem = match instruction.out.elem() {
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => Elem::F32,
other => other,
};
register_extension(Extension::Rhypot(elem));
}
_ => {}
}
}
Expand Down Expand Up @@ -390,7 +408,7 @@ void {kernel_name}("
if flags.static_meta_length > 0 {
let binding = Binding {
id: 0,
item: Item::scalar(Elem::<Self>::U32, true),
item: flags.address_type,
location: Location::Storage,
size: None,
vis: Visibility::Read,
Expand Down Expand Up @@ -686,11 +704,15 @@ impl DialectInstructions<Self> for MslDialect {
val: &Variable<Self>,
out: &Variable<Self>,
) -> std::fmt::Result {
let out = out.fmt_left();
let expected_name = format!("{out}_expected");
let out_item = out.item();
writeln!(f, "{out_item} {expected_name} = {cmp};")?;
writeln!(
f,
"{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
)
"atomic_compare_exchange_weak_explicit({input}, &{expected_name}, {val}, memory_order_relaxed, memory_order_relaxed);"
)?;
let out = out.fmt_left();
writeln!(f, "{out} = {expected_name};")
}

fn compile_atomic_load(
Expand Down Expand Up @@ -938,6 +960,24 @@ impl DialectInstructions<Self> for MslDialect {
write!(f, "pow({lhs}, {elem}({rhs}))")
}

fn compile_instruction_hypot(
f: &mut std::fmt::Formatter<'_>,
lhs: &str,
rhs: &str,
_elem: Elem<Self>,
) -> std::fmt::Result {
write!(f, "hypot({lhs}, {rhs})")
}

fn compile_instruction_rhypot(
f: &mut std::fmt::Formatter<'_>,
lhs: &str,
rhs: &str,
_elem: Elem<Self>,
) -> std::fmt::Result {
write!(f, "rhypot({lhs}, {rhs})")
}

fn compile_instruction_half_function_name_prefix() -> &'static str {
""
}
Expand Down
45 changes: 41 additions & 4 deletions crates/cubecl-cpp/src/metal/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ use crate::{
};

#[allow(clippy::enum_variant_names)]
#[derive(Debug, Clone, Default, PartialEq)]
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
pub enum Extension<D: Dialect> {
Erf(Elem<D>, Elem<D>),
Ffs(Elem<D>),
MulHi(Elem<D>),
SafeTanh(Item<D>),
Hypot(Elem<D>),
Rhypot(Elem<D>),
#[default]
NoExtension,
}
Expand Down Expand Up @@ -150,15 +152,16 @@ pub fn format_safe_tanh<D: Dialect>(
item: &Item<D>,
) -> core::fmt::Result {
let elem = item.elem();
// Note: For bfloat, tanh() returns float, so we need explicit casts
write!(
f,
"
/// Metal has a weird numerical behaviour with tanh for inputs over 43.0
inline {elem} safe_tanh_scalar({elem} x) {{
if (x > 43.0) {{
return 1.0;
if (x > {elem}(43.0)) {{
return {elem}(1.0);
}} else {{
return tanh(x);
return {elem}(tanh(float(x)));
}}
}}
"
Expand All @@ -181,3 +184,37 @@ inline {elem} safe_tanh_scalar({elem} x) {{
}
writeln!(f, "}}")
}

pub fn format_hypot<D: Dialect>(
f: &mut core::fmt::Formatter<'_>,
elem: &Elem<D>,
) -> core::fmt::Result {
// Note: For half/bfloat types, the Binary impl already casts to float,
// so this function is only called with float or double
write!(
f,
"
// MSL doesn't have hypot built-in, implement it as sqrt(x*x + y*y)
inline {elem} hypot({elem} x, {elem} y) {{
return sqrt(x * x + y * y);
}}
"
)
}

pub fn format_rhypot<D: Dialect>(
f: &mut core::fmt::Formatter<'_>,
elem: &Elem<D>,
) -> core::fmt::Result {
// Note: For half/bfloat types, the Binary impl already casts to float,
// so this function is only called with float or double
write!(
f,
"
// MSL doesn't have rhypot built-in, implement it as rsqrt(x*x + y*y)
inline {elem} rhypot({elem} x, {elem} y) {{
return rsqrt(x * x + y * y);
}}
"
)
}
8 changes: 6 additions & 2 deletions crates/cubecl-cpp/src/shared/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1305,10 +1305,14 @@ impl<D: Dialect> CppCompiler<D> {
instructions.push(Instruction::Powi(self.compile_binary(op, out)))
}
gpu::Arithmetic::Hypot(op) => {
instructions.push(Instruction::Hypot(self.compile_binary(op, out)))
let instruction = Instruction::Hypot(self.compile_binary(op, out));
D::register_instruction_extension(&mut self.extensions, &instruction);
instructions.push(instruction)
}
gpu::Arithmetic::Rhypot(op) => {
instructions.push(Instruction::Rhypot(self.compile_binary(op, out)))
let instruction = Instruction::Rhypot(self.compile_binary(op, out));
D::register_instruction_extension(&mut self.extensions, &instruction);
instructions.push(instruction)
}
gpu::Arithmetic::Sqrt(op) => {
let op = self.compile_unary(op, out);
Expand Down
6 changes: 4 additions & 2 deletions crates/cubecl-cpp/src/shared/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,8 @@ impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {

let mag = format!("{out}_mag");

writeln!(f, "{} {mag} = 0.0;", out.item())?;
// Use elem cast for the literal to support bfloat
writeln!(f, "{} {mag} = {}(0.0);", out.item(), out.item())?;

for i in 0..num {
let input_i = input.index(i);
Expand Down Expand Up @@ -1029,7 +1030,8 @@ impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {

let out_item = out.item();
let out = out.fmt_left();
writeln!(f, "{elem} {norm} = 0.0;")?;
// Use elem cast for the literal to support bfloat
writeln!(f, "{elem} {norm} = {elem}(0.0);")?;

for i in 0..num {
let input_i = input.index(i);
Expand Down
9 changes: 8 additions & 1 deletion crates/cubecl-cpp/src/shared/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,14 @@ pub trait FunctionFmt<D: Dialect> {
elem: Elem<D>,
) -> std::fmt::Result {
if Self::half_support() {
write!(f, "{}({input})", Self::function_name(elem))
// Note: Metal's math functions support half but NOT bfloat directly.
// For bfloat, functions like sin/cos/sqrt return float, so we need to cast back.
match elem {
Elem::BF16 | Elem::BF16x2 => {
write!(f, "{}({}(float({input})))", elem, Self::function_name(elem))
}
_ => write!(f, "{}({input})", Self::function_name(elem)),
}
} else {
match elem {
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
Expand Down
24 changes: 20 additions & 4 deletions crates/cubecl-cpp/src/shared/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,12 +567,24 @@ impl<D: Dialect> FmtLeft for Variable<D> {
match self {
Self::LocalConst { item, .. } => match item.elem {
Elem::Atomic(_) => {
format!("{item}* {self}")
// Atomic pointers need address space (for Metal)
let addr_space = D::address_space_for_variable(self);
format!("{addr_space}{item}* {self}")
}
_ => {
format!("const {item} {self}")
}
},
Self::LocalMut { item, .. } => match item.elem {
Elem::Atomic(_) => {
// Atomic pointers need address space (for Metal)
let addr_space = D::address_space_for_variable(self);
format!("{addr_space}{item}* {self}")
}
_ => {
format!("{self}")
}
},
Variable::Tmp {
item,
is_declared,
Expand All @@ -584,10 +596,12 @@ impl<D: Dialect> FmtLeft for Variable<D> {
return format!("{self}");
}
if *is_ptr {
// Pointer types need address space (for Metal)
let addr_space = D::address_space_for_variable(self);
if *is_const {
return format!("const {item} *{self}");
return format!("const {addr_space}{item}* {self}");
}
return format!("{item} *{self}");
return format!("{addr_space}{item}* {self}");
}

format!("{item} {self}")
Expand Down Expand Up @@ -677,7 +691,9 @@ impl<D: Dialect> FmtLeft for IndexedVariable<D> {
Variable::LocalConst { item, .. } => format!("const {item} {name}"),
Variable::Tmp { item, is_ptr, .. } => {
if *is_ptr {
format!("{item} *{name}")
// For pointer types, include the address space (required for Metal)
let addr_space = D::address_space_for_variable(var);
format!("{addr_space}{item}* {name}")
} else {
format!("{item} {name}")
}
Expand Down
50 changes: 50 additions & 0 deletions crates/cubecl-metal/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
[package]
authors = ["dcvz <david@dcvz.io>"]
categories = ["science"]
description = "Metal runtime for CubeCL"
edition.workspace = true
keywords = ["gpu", "metal"]
license.workspace = true
name = "cubecl-metal"
readme.workspace = true
repository = "https://github.com/tracel-ai/cubecl/tree/main/crates/cubecl-metal"
version.workspace = true

[features]
default = ["std"]
std = [
"cubecl-runtime/std",
"cubecl-common/std",
"cubecl-core/std",
]

[dependencies]
cubecl-common = { path = "../cubecl-common", version = "=0.9.0", default-features = false, features = ["cache"] }
cubecl-core = { path = "../cubecl-core", version = "=0.9.0", default-features = false }
cubecl-ir = { path = "../cubecl-ir", version = "=0.9.0", default-features = false }
cubecl-runtime = { path = "../cubecl-runtime", version = "=0.9.0", default-features = false }
cubecl-cpp = { path = "../cubecl-cpp", version = "=0.9.0", features = ["metal"] }

# Metal bindings
objc2 = "0.6"
objc2-metal = { version = "0.3", features = ["block2", "MTLLibrary"] }
objc2-foundation = "0.3"
block2 = "0.6"

# Utilities
log = { workspace = true }
derive-new = { workspace = true }
hashbrown = { workspace = true }
bytemuck = { workspace = true }
half = { workspace = true }
serde = { workspace = true }
async-channel = { workspace = true }

[dev-dependencies]
cubecl-core = { path = "../cubecl-core", version = "=0.9.0", features = ["export_tests"] }
cubecl-std = { path = "../cubecl-std", version = "=0.9.0", features = ["export_tests"] }
test-log = { workspace = true, features = ["trace"] }
paste = { workspace = true }

[lints]
workspace = true
Loading
Loading