diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index e73a3b69b5da2..0454dbd7bd03d 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -422,6 +422,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}") @@ -434,15 +435,17 @@ function(ggml_add_cpu_backend_variant_impl tag_name) set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS}) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) if (NOT DOTPROD_ENABLED MATCHES -1) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c) + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c) endif() if (NOT I8MM_ENABLED MATCHES -1) @@ -450,9 +453,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() if (NOT SME_ENABLED MATCHES -1) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c) - set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2") + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c) + set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2") endif() set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}") diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index aacc2bb5ee0f2..910fd0ee4e743 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -4,16 +4,22 @@ // KleidiAI micro-kernels #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" -#include "kai_lhs_quant_pack_qsi8d32p_f32.h" -#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" -#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" -#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" +#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" + +#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" + +#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" +#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" +#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" + #include "kai_common.h" #include "kernels.h" @@ -61,6 +67,53 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, }, /* .required_cpu = */ CPU_FEATURE_SME, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, + }, + { + /* SME GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + }, + /* SME GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, + /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, + /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, + }, + /* .required_cpu = */ CPU_FEATURE_SME, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_F16, + /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__APPLE__) @@ -105,6 +158,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__ARM_FEATURE_MATMUL_INT8) @@ -148,6 +204,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #else @@ -192,6 +251,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__ARM_FEATURE_DOTPROD) @@ -235,12 +297,33 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #endif }; -ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) { +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) { + ggml_kleidiai_kernels * kernel = nullptr; + + if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { + if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && + gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && + gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type && + gemm_gemv_kernels[i].op_type == tensor->type) { + kernel = &gemm_gemv_kernels[i]; + break; + } + } + } + + return kernel; +} + +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { ggml_kleidiai_kernels * kernels = nullptr; for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h index 2ffe97eb42feb..5ac02bda7c018 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -4,6 +4,9 @@ #pragma once +#include +#include "ggml.h" + enum cpu_feature { CPU_FEATURE_NONE = 0, CPU_FEATURE_DOTPROD = 1, @@ -26,26 +29,53 @@ struct kernel_info { size_t (*get_nr)(void); size_t (*get_kr)(void); size_t (*get_sr)(void); - size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl); - size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl); + std::variant< + std::function, + std::function + > get_lhs_offset; + std::variant< + std::function, + std::function + > get_rhs_packed_offset; size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride); size_t (*get_dst_size)(size_t m, size_t n); - void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, - float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); + std::variant< + std::function, + std::function + > run_kernel; }; struct lhs_packing_info { size_t (*get_offset)(size_t m_idx, size_t lhs_stride); - size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); - size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); - void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, - size_t lhs_stride, void* lhs_packed); + std::variant< + std::function, + std::function + > get_packed_offset; + std::variant< + std::function, + std::function + > packed_size; + std::variant< + std::function, + std::function + > pack_func; }; struct rhs_packing_info { - size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl); - void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, - const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params); + std::variant< + std::function, + std::function + > packed_size; + std::variant< + std::function, + std::function + > pack_func; }; struct ggml_kleidiai_kernels { @@ -55,6 +85,10 @@ struct ggml_kleidiai_kernels { rhs_packing_info rhs_info; cpu_feature required_cpu; + ggml_type lhs_type; + ggml_type rhs_type; + ggml_type op_type; }; -ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features); +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor); +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features); diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 4e89ca0faa2a7..86926277b17d3 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -34,8 +34,9 @@ #include "ggml-common.h" struct ggml_kleidiai_context { + cpu_feature features; ggml_kleidiai_kernels * kernels; -} static ctx = { NULL }; +} static ctx = { CPU_FEATURE_NONE, NULL }; static void init_kleidiai_context(void) { @@ -47,18 +48,18 @@ static void init_kleidiai_context(void) { const char *env_var = getenv("GGML_KLEIDIAI_SME"); int sme_enabled = 0; - cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | - (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | - (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); + ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | + (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | + (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); if (env_var) { sme_enabled = atoi(env_var); } if (sme_enabled != 0) { - features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; + ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; } - ctx.kernels = ggml_kleidiai_select_kernels(features); + ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features); } ggml_critical_section_end(); } @@ -68,95 +69,273 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { return tensor->ne[dim]; } +template +static Ret variant_call(const Variant & var, Args&&... args) { + return std::visit([&](auto&& func) -> Ret { + if constexpr (std::is_invocable_r_v) { + return func(std::forward(args)...); + } else { + throw std::runtime_error("Invalid function type in variant_call"); + } + }, var); +} + namespace ggml::cpu::kleidiai { + +static size_t round_down(size_t x, size_t y) { + return y == 0 ? x : x - (x % y); +} + +static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) { + size_t src_stride = rhs_stride / sizeof(uint16_t); + size_t dst_stride = n; + + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + for (size_t n_idx = 0; n_idx < n; ++n_idx) { + uint16_t v = *(src + k_idx + n_idx * src_stride); + *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v); + } + } +} + class tensor_traits : public ggml::cpu::tensor_traits { bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { - GGML_ASSERT(ctx.kernels); - kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; + ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); + GGML_ASSERT(kernels); + kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; size_t k = op->src[0]->ne[0]; + size_t n = op->src[0]->ne[1]; size_t m = op->src[1]->ne[1]; size_t mr = kernel->get_mr(); size_t kr = kernel->get_kr(); size_t sr = kernel->get_sr(); - size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr); + if (kernels->rhs_type == GGML_TYPE_Q4_0) { + size = variant_call(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr); + } else if (kernels->rhs_type == GGML_TYPE_F16) { + size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr) + + variant_call(kernels->rhs_info.packed_size, n, k) + + k * n * sizeof(float) + n * sizeof(float); + } else { + GGML_ASSERT(false); + } return true; } + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { if (dst->op == GGML_OP_MUL_MAT) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; + if (dst->src[0]->type == GGML_TYPE_Q4_0) { + return compute_forward_q4_0(params, dst); + } else if (dst->src[0]->type == GGML_TYPE_F16) { + return compute_forward_kv_cache(params, dst); + } + } + return false; + } - GGML_TENSOR_BINARY_OP_LOCALS + bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(ctx.kernels); - kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; - lhs_packing_info * lhs_info = &ctx.kernels->lhs_info; + GGML_TENSOR_BINARY_OP_LOCALS - GGML_ASSERT(kernel); + ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); + GGML_ASSERT(kernels); - const int ith = params->ith; - const int nth = params->nth; + kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; + GGML_ASSERT(kernel); - const size_t k = ne00; - const size_t m = ne11; - const size_t n = ne01; + const size_t nth = params->nth; + const size_t ith = params->ith; - const size_t n_step = kernel->get_n_step(); - const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); - const size_t n_start = ith * num_n_per_thread; + const size_t lhs_batch_size0 = ne12; + const size_t rhs_batch_size0 = ne02; - size_t n_to_process = num_n_per_thread; - if ((n_start + n_to_process) > n) { - n_to_process = n - n_start; - } + const size_t r = lhs_batch_size0 / rhs_batch_size0; + + const size_t m = ne11 * r; + const size_t n = ne01; + const size_t k = ne00; + const size_t batch_size = rhs_batch_size0; + + const size_t lhs_stride = src1->nb[1]; + const size_t rhs_stride = src0->nb[1]; + const size_t dst_stride = dst->nb[1]; + + const size_t mr = kernel->get_mr(); + const size_t nr = kernel->get_nr(); + const size_t kr = kernel->get_kr(); + const size_t sr = kernel->get_sr(); + + const size_t m_step = kernel->get_m_step(); + const size_t n_step = kernel->get_n_step(); + + const bool parallelize_on_m = m >= m_step * nth; + + const size_t lhs_packed_size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr); + const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, n, k); + const size_t kxn_size = k * n * sizeof(float); + const size_t bias_size = n * sizeof(float); + + const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size; + GGML_ASSERT(wsize_required <= params->wsize); + + uint8_t * lhs_packed = static_cast(params->wdata); + uint8_t * rhs_packed = lhs_packed + lhs_packed_size; + uint8_t * rhs_kxn = rhs_packed + rhs_packed_size; + uint8_t * bias = rhs_kxn + kxn_size; - const uint8_t * lhs = static_cast(src1->data); - uint8_t * lhs_packed = (uint8_t*)params->wdata; - const uint8_t * rhs_packed = static_cast(src0->data); + for (size_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const uint8_t * lhs_batch = static_cast(src1->data) + batch_idx * m * lhs_stride; + const uint8_t * rhs_batch = static_cast(src0->data) + batch_idx * n * rhs_stride; + uint8_t * dst_batch = static_cast(dst->data) + batch_idx * m * dst_stride; - size_t mr = kernel->get_mr(); - size_t kr = kernel->get_kr(); - size_t sr = kernel->get_sr(); + // LHS packing + { + const size_t m_roundup_mr = kai_roundup(m, mr); + const size_t num_threads = KAI_MIN(m_roundup_mr / mr, nth); - // Calculate number of columns to be processed per thread - const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; - const size_t m_start = ith * num_m_per_thread; - size_t m_to_process = num_m_per_thread; - if ((m_start + m_to_process) > m) { - m_to_process = m - m_start; + if (ith < num_threads) { + const size_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr); + const size_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0; + + const size_t m_start = ith * num_m_per_thread0; + const size_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; + + const size_t lhs_offset = variant_call(kernels->gemm.get_lhs_offset, m_start, lhs_stride); + const size_t lhs_packed_offset = variant_call(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr); + + const void * src_ptr = static_cast(lhs_batch) + lhs_offset; + void * dst_ptr = static_cast(lhs_packed) + lhs_packed_offset; + + variant_call(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); + } + + // Investigate if this barrier can be removed. + if (parallelize_on_m == false || num_threads != nth) { + ggml_barrier(params->threadpool); + } } - if(m_start < m) { - // Transform LHS - const size_t src_stride = src1->nb[1]; - const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); - const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr); - void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); + // RHS packing + if (ith == 0) { + memset(bias, 0, n * sizeof(float)); + transpose_f32kxn_f16nxk(n, k, reinterpret_cast(rhs_kxn), + reinterpret_cast(rhs_batch), rhs_stride); - lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); + variant_call(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float), + rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr); } ggml_barrier(params->threadpool); - // Perform the operation - const size_t dst_stride = dst->nb[1]; - const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr); - const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0); - const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); - const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); - const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); - float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); - - kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, - dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); - return true; + // Perform the matmul + { + const size_t m_to_process = m; + const size_t m_start = 0; + + const size_t num_threads = KAI_MIN(n / n_step, nth); + + if (ith < num_threads) { + const size_t num_n_per_thread0 = round_down(n / num_threads, n_step); + const size_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0; + + const size_t n_start = ith * num_n_per_thread0; + const size_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0; + + const size_t lhs_packed_offset = variant_call(kernel->get_lhs_offset, m_start, k); + const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k); + const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride); + + const void * lhs_ptr = lhs_packed + lhs_packed_offset; + const void * rhs_ptr = rhs_packed + rhs_packed_offset; + float * dst_ptr = reinterpret_cast(dst_batch + dst_offset); + + variant_call(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + } + + ggml_barrier(params->threadpool); + } } - return false; + + return true; + } + + bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); + GGML_ASSERT(kernels); + + kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; + lhs_packing_info * lhs_info = &kernels->lhs_info; + + GGML_ASSERT(kernel); + + const int ith = params->ith; + const int nth = params->nth; + + const size_t k = ne00; + const size_t m = ne11; + const size_t n = ne01; + + size_t mr = kernel->get_mr(); + size_t kr = kernel->get_kr(); + size_t sr = kernel->get_sr(); + + const uint8_t * lhs = static_cast(src1->data); + uint8_t * lhs_packed = (uint8_t*)params->wdata; + const uint8_t * rhs_packed = static_cast(src0->data); + + const size_t n_step = kernel->get_n_step(); + const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); + const size_t n_start = ith * num_n_per_thread; + + size_t n_to_process = num_n_per_thread; + if ((n_start + n_to_process) > n) { + n_to_process = n - n_start; + } + + // Calculate number of columns to be processed per thread + const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; + const size_t m_start = ith * num_m_per_thread; + size_t m_to_process = num_m_per_thread; + if ((m_start + m_to_process) > m) { + m_to_process = m - m_start; + } + + if(m_start < m) { + // Transform LHS + const size_t src_stride = src1->nb[1]; + const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); + const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr); + void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); + + variant_call(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); + } + + ggml_barrier(params->threadpool); + + // Perform the operation + const size_t dst_stride = dst->nb[1]; + const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr); + const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k, QK4_0); + const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); + const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); + const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); + float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); + + variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, + sizeof(float), -FLT_MAX, FLT_MAX); + + return true; } public: @@ -169,13 +348,13 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t sr = ctx.kernels->gemm.get_sr(); #ifndef NDEBUG - const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0); + const size_t repacked_size = variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0); GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!"); #endif struct kai_rhs_pack_qs4cxs1s0_param params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; - ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms); + variant_call(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms); return 0; @@ -189,7 +368,7 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc } } // namespace ggml::cpu::kleidiai -GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { +static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor); GGML_UNUSED(buffer); @@ -238,12 +417,11 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b namespace ggml::cpu::kleidiai { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { - if ( op->op == GGML_OP_MUL_MAT && - op->src[0]->type == GGML_TYPE_Q4_0 && - op->src[0]->buffer && - (ggml_n_dims(op->src[0]) == 2) && - op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels - ) { + if (op->op == GGML_OP_MUL_MAT && + op->src[0]->type == GGML_TYPE_Q4_0 && + op->src[0]->buffer && + (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) { if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { return false; } @@ -260,6 +438,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } + else if (ggml_kleidiai_select_kernels(ctx.features, op) && + op->src[0]->op == GGML_OP_VIEW && + (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) && + op->src[1]->ne[1] > 1) { + if ((op->src[0]->nb[0] != 2) || + (op->src[1]->nb[0] != 4) || + (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || + (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { + return nullptr; + } + + return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL); + } } return nullptr; }