diff --git a/src/layer/vulkan/matmul_vulkan.cpp b/src/layer/vulkan/matmul_vulkan.cpp new file mode 100644 index 000000000000..e0f43b2a4627 --- /dev/null +++ b/src/layer/vulkan/matmul_vulkan.cpp @@ -0,0 +1,411 @@ +// Copyright 2025 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#include "matmul_vulkan.h" + +#include "layer_shader_type.h" + +namespace ncnn { + +MatMul_vulkan::MatMul_vulkan() +{ + support_vulkan = true; + support_vulkan_packing = true; + support_vulkan_any_packing = true; + + pipeline_matmul = 0; +} + +int MatMul_vulkan::create_pipeline(const Option& opt) +{ + std::vector specializations(1); + specializations[0].i = transB; + + Mat local_size_xyz; + + pipeline_matmul = new Pipeline(vkdev); + pipeline_matmul->set_optimal_local_size_xyz(local_size_xyz); + + if (opt.use_shader_local_memory) + { + pipeline_matmul->set_local_size_xyz(8, 8, 1); + } + + pipeline_matmul->create(LayerShaderType::matmul, opt, specializations); + + return 0; +} + +int MatMul_vulkan::destroy_pipeline(const Option& /*opt*/) +{ + delete pipeline_matmul; + pipeline_matmul = 0; + + return 0; +} + +int MatMul_vulkan::forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const +{ + const VkMat& A0 = bottom_blobs[0]; + const VkMat& B0 = bottom_blobs[1]; + + VkMat A; + VkMat B; + vkdev->convert_packing(A0, A, 1, cmd, opt); + vkdev->convert_packing(B0, B, 1, cmd, opt); + + const int Adims = A.dims; + const int Bdims = B.dims; + const int max_ABdims = std::max(Adims, Bdims); + const size_t elemsize = A.elemsize; + + int mode = 0; + + int M = 1; + int N = 1; + int K = 1; + + int a_w = 1, a_h = 1, a_d = 1, a_c = 1; + int b_w = 1, b_h = 1, b_d = 1, b_c = 1; + + int out_dims = 0; + int out_w = 1, out_h = 1, out_d = 1, out_c = 1; + + if (Adims == 1 && Bdims == 1) + { + mode = 0; + K = A.w; + M = 1; + N = 1; + + a_w = K; + a_h = 1; + a_d = 1; + a_c = 1; + + if (transB == 0) + { + b_w = 1; + b_h = K; + } + else + { + b_w = K; + b_h = 1; + } + b_d = 1; + b_c = 1; + + out_dims = 1; + out_w = 1; + out_h = 1; + } + else if (Adims == 2 && Bdims == 2) + { + mode = 0; + M = A.h; + K = A.w; + N = transB == 0 ? B.w : B.h; + + a_w = A.w; + a_h = A.h; + a_d = 1; + a_c = 1; + + b_w = B.w; + b_h = B.h; + b_d = 1; + b_c = 1; + + out_dims = 2; + out_w = N; + out_h = M; + } + else if (Adims == 1 && Bdims == 2) + { + mode = 1; + K = A.w; + M = 1; + N = transB == 0 ? B.w : B.h; + + a_w = K; + a_h = 1; + a_d = 1; + a_c = 1; + + b_w = B.w; + b_h = B.h; + b_d = 1; + b_c = 1; + + out_dims = 1; + out_w = N; + out_h = 1; + } + else if (Adims == 2 && Bdims == 1) + { + mode = 2; + M = A.h; + K = A.w; + N = 1; + + a_w = A.w; + a_h = A.h; + a_d = 1; + a_c = 1; + + if (transB == 0) + { + b_w = 1; + b_h = K; + } + else + { + b_w = K; + b_h = 1; + } + + b_d = 1; + b_c = 1; + + out_dims = 1; + out_w = M; + out_h = 1; + } + else if (Adims == 1 && Bdims > 2) + { + // Vector @ batched-matrix -> reduce one dimension. + mode = 1; + K = A.w; + N = transB == 0 ? B.w : B.h; + + a_w = K; + a_h = 1; + a_d = 1; + a_c = 1; + + b_w = B.w; + b_h = B.h; + + b_d = B.dims == 4 ? B.d : 1; + b_c = B.dims >= 3 ? B.c : 1; + + if (Bdims == 3) + { + out_dims = 2; + out_w = N; + out_h = B.d * B.c; + out_d = 1; + out_c = 1; + } + else + { + out_dims = 3; + out_w = N; + out_h = B.d; + out_c = B.c; + out_d = 1; + } + } + else if (Adims > 2 && Bdims == 1) + { + // Batched-matrix @ vector -> reduce one dimension. + mode = 2; + M = A.h; + K = A.w; + N = 1; + + a_w = A.w; + a_h = A.h; + a_d = A.dims == 4 ? A.d : 1; + a_c = A.dims >= 3 ? A.c : 1; + + if (transB == 0) + { + b_w = 1; + b_h = K; + } + else + { + b_w = K; + b_h = 1; + } + b_d = 1; + b_c = 1; + + if (Adims == 3) + { + out_dims = 2; + out_w = M; + out_h = A.d * A.c; + out_d = 1; + out_c = 1; + } + else + { + out_dims = 3; + out_w = M; + out_h = A.d; + out_c = A.c; + out_d = 1; + } + } + else + { + // Batched matmul follows CPU reshape/broadcast rules. + mode = 0; + M = A.h; + K = A.w; + N = transB == 0 ? B.w : B.h; + + a_w = A.w; + a_h = A.h; + b_w = B.w; + b_h = B.h; + + if (max_ABdims == 3) + { + a_d = 1; + b_d = 1; + a_c = (Adims == 3) ? A.c : 1; + b_c = (Bdims == 3) ? B.c : 1; + + out_dims = 3; + out_w = N; + out_h = M; + out_c = std::max(a_c, b_c); + out_d = 1; + } + else + { + // dims3 -> reshape(w,h,d=orig_c,c=1), dims4 stays (w,h,d,c) + if (Adims == 4) + { + a_d = A.d; + a_c = A.c; + } + else if (Adims == 3) + { + a_d = A.c; + a_c = 1; + } + else + { + a_d = 1; + a_c = 1; + } + + if (Bdims == 4) + { + b_d = B.d; + b_c = B.c; + } + else if (Bdims == 3) + { + b_d = B.c; + b_c = 1; + } + else + { + b_d = 1; + b_c = 1; + } + + out_dims = 4; + out_w = N; + out_h = M; + out_d = std::max(a_d, b_d); + out_c = std::max(a_c, b_c); + } + } + + VkMat& top_blob = top_blobs[0]; + + if (out_dims == 1) + top_blob.create(out_w, elemsize, opt.blob_vkallocator); + else if (out_dims == 2) + top_blob.create(out_w, out_h, elemsize, opt.blob_vkallocator); + else if (out_dims == 3) + top_blob.create(out_w, out_h, out_c, elemsize, opt.blob_vkallocator); + else + top_blob.create(out_w, out_h, out_d, out_c, elemsize, opt.blob_vkallocator); + + if (top_blob.empty()) + return -100; + + std::vector bindings(3); + bindings[0] = top_blob; + bindings[1] = A; + bindings[2] = B; + + const int out_cstep = (top_blob.dims >= 3) ? (int)top_blob.cstep : (top_blob.w * top_blob.h); + const int a_cstep_real = (A.dims >= 3) ? (int)A.cstep : (A.w * A.h); + const int b_cstep_real = (B.dims >= 3) ? (int)B.cstep : (B.w * B.h); + + const int out_dstep = out_w * out_h; + + int a_dstep_real = a_w * a_h; + int b_dstep_real = b_w * b_h; + + if (max_ABdims == 4) + { + if (A.dims == 3) a_dstep_real = (int)A.cstep; + if (B.dims == 3) b_dstep_real = (int)B.cstep; + } + + std::vector constants(23); + constants[0].i = M; + constants[1].i = N; + constants[2].i = K; + constants[3].i = mode; + + constants[4].i = out_dims; + constants[5].i = out_w; + constants[6].i = out_h; + constants[7].i = out_d; + constants[8].i = out_c; + constants[9].i = out_cstep; + constants[10].i = out_dstep; + + constants[11].i = a_w; + constants[12].i = a_h; + constants[13].i = a_d; + constants[14].i = a_c; + constants[15].i = a_cstep_real; + constants[16].i = a_dstep_real; + + constants[17].i = b_w; + constants[18].i = b_h; + constants[19].i = b_d; + constants[20].i = b_c; + constants[21].i = b_cstep_real; + constants[22].i = b_dstep_real; + + const Pipeline* pipeline = pipeline_matmul; + + VkMat dispatcher; + + if (mode == 0 && out_dims >= 2) + { + dispatcher.w = (out_w + 1) / 2; + dispatcher.h = (out_h + 1) / 2; + } + else + { + dispatcher.w = out_w; + dispatcher.h = out_h; + } + + if (out_dims == 4) + dispatcher.c = out_c * out_d; + else if (out_dims == 3) + dispatcher.c = out_c; + else + dispatcher.c = 1; + + cmd.record_pipeline(pipeline, bindings, constants, dispatcher); + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/vulkan/matmul_vulkan.h b/src/layer/vulkan/matmul_vulkan.h new file mode 100644 index 000000000000..962af387a228 --- /dev/null +++ b/src/layer/vulkan/matmul_vulkan.h @@ -0,0 +1,28 @@ +// Copyright 2025 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_MATMUL_VULKAN_H +#define LAYER_MATMUL_VULKAN_H + +#include "matmul.h" + +namespace ncnn { + +class MatMul_vulkan : public MatMul +{ +public: + MatMul_vulkan(); + + virtual int create_pipeline(const Option& opt); + virtual int destroy_pipeline(const Option& opt); + + using MatMul::forward; + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const; + +public: + Pipeline* pipeline_matmul; +}; + +} // namespace ncnn + +#endif // LAYER_MATMUL_VULKAN_H diff --git a/src/layer/vulkan/shader/matmul.comp b/src/layer/vulkan/shader/matmul.comp new file mode 100644 index 000000000000..6858fd45ea97 --- /dev/null +++ b/src/layer/vulkan/shader/matmul.comp @@ -0,0 +1,364 @@ +// Copyright 2025 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#version 450 + +#define LOCAL_MEMORY_UNROLL_INCH 8 + +layout(constant_id = 0) const int transB = 0; + +layout(binding = 0) writeonly buffer top_blob { sfp top_blob_data[]; }; +layout(binding = 1) readonly buffer a_blob { sfp a_blob_data[]; }; +layout(binding = 2) readonly buffer b_blob { sfp b_blob_data[]; }; + +layout(push_constant) uniform parameter +{ + int M; + int N; + int K; + int mode; + + int out_dims; + int out_w; + int out_h; + int out_d; + int out_c; + int out_cstep; + int out_dstep; + + int a_w; + int a_h; + int a_d; + int a_c; + int a_cstep; + int a_dstep; + + int b_w; + int b_h; + int b_d; + int b_c; + int b_cstep; + int b_dstep; +} p; + +#if NCNN_shader_local_memory +shared lfp tmp_a[8][LOCAL_MEMORY_UNROLL_INCH][2]; +shared lfp tmp_b[8][LOCAL_MEMORY_UNROLL_INCH][2]; +#endif + +void main() +{ + const int gx = int(gl_GlobalInvocationID.x); + const int gy = int(gl_GlobalInvocationID.y); + const int gz = int(gl_GlobalInvocationID.z); + + if (p.mode == 0 && p.out_dims >= 2) + { + if (p.out_dims == 3) + { + if (gz >= p.out_c) return; + } + else if (p.out_dims == 4) + { + if (gz >= p.out_c * p.out_d) return; + } + + const int lx = int(gl_LocalInvocationID.x); + const int ly = int(gl_LocalInvocationID.y); + + const int n0 = gx * 2; + const int m0 = gy * 2; + + int bc = 0; + int bd = 0; + + if (p.out_dims == 4) + { + bc = gz / p.out_d; + bd = gz - bc * p.out_d; + } + else if (p.out_dims == 3) + { + bc = gz; + bd = 0; + } + + const int ac = (p.a_c == 1) ? 0 : bc; + const int ad = (p.a_d == 1) ? 0 : bd; + const int bc0 = (p.b_c == 1) ? 0 : bc; + const int bd0 = (p.b_d == 1) ? 0 : bd; + + const int a_base = ac * p.a_cstep + ad * p.a_dstep; + const int b_base = bc0 * p.b_cstep + bd0 * p.b_dstep; + + bool in_m0 = (m0 < p.out_h); + bool in_m1 = (m0 + 1 < p.out_h); + bool in_n0 = (n0 < p.out_w); + bool in_n1 = (n0 + 1 < p.out_w); + + afp sum00 = afp(0.0); + afp sum01 = afp(0.0); + afp sum10 = afp(0.0); + afp sum11 = afp(0.0); + + for (int k0 = 0; k0 < p.K; k0 += LOCAL_MEMORY_UNROLL_INCH) + { +#if NCNN_shader_local_memory + lfp a0 = lfp(0.0); + lfp a1 = lfp(0.0); + const int ak = k0 + lx; + if (ak < p.K) + { + if (in_m0) + a0 = lfp(buffer_ld1(a_blob_data, a_base + m0 * p.a_w + ak)); + if (in_m1) + a1 = lfp(buffer_ld1(a_blob_data, a_base + (m0 + 1) * p.a_w + ak)); + } + tmp_a[ly][lx][0] = a0; + tmp_a[ly][lx][1] = a1; + + lfp b0 = lfp(0.0); + lfp b1 = lfp(0.0); + const int bk = k0 + ly; + if (bk < p.K) + { + if (in_n0) + { + int b_idx0; + if (transB == 0) + b_idx0 = b_base + bk * p.b_w + n0; + else + b_idx0 = b_base + n0 * p.b_w + bk; + b0 = lfp(buffer_ld1(b_blob_data, b_idx0)); + } + + if (in_n1) + { + int b_idx1; + if (transB == 0) + b_idx1 = b_base + bk * p.b_w + (n0 + 1); + else + b_idx1 = b_base + (n0 + 1) * p.b_w + bk; + b1 = lfp(buffer_ld1(b_blob_data, b_idx1)); + } + } + tmp_b[lx][ly][0] = b0; + tmp_b[lx][ly][1] = b1; + + barrier(); + + // Compute 2x2 block. + for (int kk = 0; kk < LOCAL_MEMORY_UNROLL_INCH; kk++) + { + const afp A0 = afp(tmp_a[ly][kk][0]); + const afp A1 = afp(tmp_a[ly][kk][1]); + const afp B0 = afp(tmp_b[lx][kk][0]); + const afp B1 = afp(tmp_b[lx][kk][1]); + + sum00 = sum00 + A0 * B0; + sum01 = sum01 + A0 * B1; + sum10 = sum10 + A1 * B0; + sum11 = sum11 + A1 * B1; + } + + barrier(); +#else + for (int kk = 0; kk < LOCAL_MEMORY_UNROLL_INCH; kk++) + { + const int k = k0 + kk; + if (k >= p.K) + break; + + afp A0 = afp(0.0); + afp A1 = afp(0.0); + if (in_m0) A0 = buffer_ld1(a_blob_data, a_base + m0 * p.a_w + k); + if (in_m1) A1 = buffer_ld1(a_blob_data, a_base + (m0 + 1) * p.a_w + k); + + afp B0 = afp(0.0); + afp B1 = afp(0.0); + + if (in_n0) + { + int b_idx0; + if (transB == 0) + b_idx0 = b_base + k * p.b_w + n0; + else + b_idx0 = b_base + n0 * p.b_w + k; + B0 = buffer_ld1(b_blob_data, b_idx0); + } + if (in_n1) + { + int b_idx1; + if (transB == 0) + b_idx1 = b_base + k * p.b_w + (n0 + 1); + else + b_idx1 = b_base + (n0 + 1) * p.b_w + k; + B1 = buffer_ld1(b_blob_data, b_idx1); + } + + sum00 = sum00 + A0 * B0; + sum01 = sum01 + A0 * B1; + sum10 = sum10 + A1 * B0; + sum11 = sum11 + A1 * B1; + } +#endif + } + + if (in_m0 && in_n0) + { + int out_index; + if (p.out_dims == 2) + out_index = m0 * p.out_w + n0; + else if (p.out_dims == 3) + out_index = gz * p.out_cstep + m0 * p.out_w + n0; + else + out_index = bc * p.out_cstep + bd * p.out_dstep + m0 * p.out_w + n0; + + buffer_st1(top_blob_data, out_index, sum00); + } + if (in_m0 && in_n1) + { + int out_index; + if (p.out_dims == 2) + out_index = m0 * p.out_w + (n0 + 1); + else if (p.out_dims == 3) + out_index = gz * p.out_cstep + m0 * p.out_w + (n0 + 1); + else + out_index = bc * p.out_cstep + bd * p.out_dstep + m0 * p.out_w + (n0 + 1); + + buffer_st1(top_blob_data, out_index, sum01); + } + if (in_m1 && in_n0) + { + int out_index; + if (p.out_dims == 2) + out_index = (m0 + 1) * p.out_w + n0; + else if (p.out_dims == 3) + out_index = gz * p.out_cstep + (m0 + 1) * p.out_w + n0; + else + out_index = bc * p.out_cstep + bd * p.out_dstep + (m0 + 1) * p.out_w + n0; + + buffer_st1(top_blob_data, out_index, sum10); + } + if (in_m1 && in_n1) + { + int out_index; + if (p.out_dims == 2) + out_index = (m0 + 1) * p.out_w + (n0 + 1); + else if (p.out_dims == 3) + out_index = gz * p.out_cstep + (m0 + 1) * p.out_w + (n0 + 1); + else + out_index = bc * p.out_cstep + bd * p.out_dstep + (m0 + 1) * p.out_w + (n0 + 1); + + buffer_st1(top_blob_data, out_index, sum11); + } + + return; + } + + const int x = gx; + const int y = gy; + const int z = gz; + + if (x >= p.out_w || y >= p.out_h) + return; + + int bc = 0; + int bd = 0; + + int m = 0; + int n = 0; + + if (p.mode == 0) + { + n = x; + m = y; + + if (p.out_dims == 4) + { + const int batch_total = p.out_c * p.out_d; + if (z >= batch_total) return; + + bc = z / p.out_d; + bd = z - bc * p.out_d; + } + else if (p.out_dims == 3) + { + if (z >= p.out_c) return; + + bc = z; + bd = 0; + } + } + else if (p.mode == 1) + { + n = x; + m = 0; + + if (p.out_dims == 3) + { + if (z >= p.out_c) return; + bc = z; + bd = y; + } + else if (p.out_dims == 2) + { + bc = y; + bd = 0; + } + } + else + { + n = 0; + m = x; + + if (p.out_dims == 3) + { + if (z >= p.out_c) return; + bc = z; + bd = y; + } + else if (p.out_dims == 2) + { + bc = y; + bd = 0; + } + } + + const int ac = (p.a_c == 1) ? 0 : bc; + const int ad = (p.a_d == 1) ? 0 : bd; + const int bc0 = (p.b_c == 1) ? 0 : bc; + const int bd0 = (p.b_d == 1) ? 0 : bd; + + const int a_base = ac * p.a_cstep + ad * p.a_dstep; + const int b_base = bc0 * p.b_cstep + bd0 * p.b_dstep; + + afp sum = afp(0.0); + + for (int k = 0; k < p.K; k++) + { + const int a_idx = a_base + m * p.a_w + k; + + int b_idx; + if (transB == 0) + b_idx = b_base + k * p.b_w + n; + else + b_idx = b_base + n * p.b_w + k; + + const afp av = buffer_ld1(a_blob_data, a_idx); + const afp bv = buffer_ld1(b_blob_data, b_idx); + sum = sum + av * bv; + } + + int out_index = 0; + if (p.out_dims == 1) + out_index = x; + else if (p.out_dims == 2) + out_index = y * p.out_w + x; + else if (p.out_dims == 3) + out_index = z * p.out_cstep + y * p.out_w + x; + else + out_index = bc * p.out_cstep + bd * p.out_dstep + y * p.out_w + x; + + buffer_st1(top_blob_data, out_index, sum); +}