diff --git a/crates/cubecl-core/src/codegen/compiler.rs b/crates/cubecl-core/src/codegen/compiler.rs index 01a036338..72d92fccc 100644 --- a/crates/cubecl-core/src/codegen/compiler.rs +++ b/crates/cubecl-core/src/codegen/compiler.rs @@ -4,4 +4,5 @@ pub struct WgpuCompilationOptions { pub supports_fp_fast_math: bool, pub supports_u64: bool, pub supports_explicit_smem: bool, + pub supports_long_vectors: bool, } diff --git a/crates/cubecl-spirv/src/item.rs b/crates/cubecl-spirv/src/item.rs index 4002ae2e2..ff42211c7 100644 --- a/crates/cubecl-spirv/src/item.rs +++ b/crates/cubecl-spirv/src/item.rs @@ -27,7 +27,12 @@ impl Item { Item::Scalar(elem) => elem.id(b), Item::Vector(elem, vec) => { let elem = elem.id(b); - b.type_vector(elem, *vec) + if b.compilation_options.supports_long_vectors { + let len = b.const_u32(*vec); + b.type_vector_id_ext(elem, len) + } else { + b.type_vector(elem, *vec) + } } Item::Array(item, len) => { let item = item.id(b); diff --git a/crates/cubecl-spirv/src/target.rs b/crates/cubecl-spirv/src/target.rs index 625f3ca76..c3ad7f6e9 100644 --- a/crates/cubecl-spirv/src/target.rs +++ b/crates/cubecl-spirv/src/target.rs @@ -72,6 +72,11 @@ impl SpirvTarget for GLCompute { b.extension("SPV_KHR_workgroup_memory_explicit_layout"); } + if b.compilation_options.supports_long_vectors { + b.extension("SPV_EXT_long_vector"); + b.capability(Capability::LongVectorEXT); + } + if b.addr_type.size_bits() == 64 { b.extension("SPV_EXT_shader_64bit_indexing"); b.capability(Capability::Shader64BitIndexingEXT); diff --git a/crates/cubecl-wgpu/src/backend/vulkan.rs b/crates/cubecl-wgpu/src/backend/vulkan.rs index 7e1c42ce2..07b915ee9 100644 --- a/crates/cubecl-wgpu/src/backend/vulkan.rs +++ b/crates/cubecl-wgpu/src/backend/vulkan.rs @@ -4,7 +4,7 @@ use cubecl_core::{ prelude::{CompiledKernel, Visibility}, server::{Bindings, ComputeServer}, }; -use cubecl_ir::{DeviceProperties, features::*}; +use cubecl_ir::{DeviceProperties, LineSize, features::*}; use cubecl_runtime::compiler::CompilationError; use cubecl_spirv::{GLCompute, SpirvCompiler, SpirvKernel}; use features::ExtendedFeatures; @@ -174,6 +174,13 @@ fn register_features( comp_options.supports_explicit_smem = true; } + if let Some(long_vector) = &extended_feat.long_vector + && long_vector.long_vector == TRUE + { + comp_options.supports_long_vectors = true; + props.hardware.max_line_size = LineSize::MAX; + } + if extended_feat.cmma.is_some() { register_cmma(ash, adapter, props); } diff --git a/crates/cubecl-wgpu/src/backend/vulkan/features.rs b/crates/cubecl-wgpu/src/backend/vulkan/features.rs index 5aab3b414..eb55fbe3d 100644 --- a/crates/cubecl-wgpu/src/backend/vulkan/features.rs +++ b/crates/cubecl-wgpu/src/backend/vulkan/features.rs @@ -21,6 +21,7 @@ pub struct ExtendedFeatures<'a> { pub wg_explicit_layout: Option>, pub index_64: Option>, + pub long_vector: Option>, pub extensions: Vec<&'static CStr>, } @@ -82,6 +83,11 @@ impl<'a> ExtendedFeatures<'a> { self.extensions.push(EXT_SHADER_64BIT_INDEXING_NAME); self.index_64 = Some(PhysicalDeviceShader64BitIndexingFeaturesEXT::default()); } + + if phys_caps.supports_extension(EXT_SHADER_LONG_VECTOR_NAME) { + self.extensions.push(EXT_SHADER_LONG_VECTOR_NAME); + self.long_vector = Some(PhysicalDeviceShaderLongVectorFeaturesEXT::default()); + } } pub fn add_to_device_create(&'a mut self, info: DeviceCreateInfo<'a>) -> DeviceCreateInfo<'a> { @@ -110,6 +116,7 @@ impl<'a> ExtendedFeatures<'a> { info = push_opt(info, &mut self.float8); info = push_opt(info, &mut self.wg_explicit_layout); info = push_opt(info, &mut self.index_64); + info = push_opt(info, &mut self.long_vector); info } @@ -140,6 +147,7 @@ impl<'a> ExtendedFeatures<'a> { features = push_opt(features, &mut self.float8); features = push_opt(features, &mut self.wg_explicit_layout); features = push_opt(features, &mut self.index_64); + features = push_opt(features, &mut self.long_vector); unsafe { // convert to ash version, they represent the same type so this is safe @@ -183,5 +191,8 @@ impl<'a> ExtendedFeatures<'a> { if let Some(index_64) = &mut self.index_64 { index_64.p_next = null_mut(); } + if let Some(long_vector) = &mut self.long_vector { + long_vector.p_next = null_mut(); + } } }