diff --git a/src/layer/vulkan/lstm_vulkan.cpp b/src/layer/vulkan/lstm_vulkan.cpp new file mode 100644 index 000000000000..b0763835cd31 --- /dev/null +++ b/src/layer/vulkan/lstm_vulkan.cpp @@ -0,0 +1,513 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#include "lstm_vulkan.h" + +#include + +#include "layer_shader_type.h" + +namespace ncnn { + +LSTM_vulkan::LSTM_vulkan() +{ + support_vulkan = true; + support_vulkan_packing = false; + support_vulkan_any_packing = false; + + pipeline_lstm_copy = 0; + pipeline_lstm_step = 0; + pipeline_lstm_step_h = 0; + pipeline_lstm_proj = 0; +} + +int LSTM_vulkan::load_param(const ParamDict& pd) +{ + int ret = LSTM::load_param(pd); + + if (int8_scale_term) + { + support_vulkan = false; + } + + return ret; +} + +int LSTM_vulkan::create_pipeline(const Option& opt) +{ + if (!support_vulkan) + return 0; + + { + pipeline_lstm_copy = new Pipeline(vkdev); + pipeline_lstm_copy->set_local_size_xyz(64, 1, 1); + + std::vector specializations; + pipeline_lstm_copy->create(LayerShaderType::lstm_copy, opt, specializations); + } + + { + pipeline_lstm_step = new Pipeline(vkdev); + pipeline_lstm_step->set_local_size_xyz(64, 1, 1); + + std::vector specializations; + pipeline_lstm_step->create(LayerShaderType::lstm_step, opt, specializations); + } + + if (num_output != hidden_size) + { + pipeline_lstm_step_h = new Pipeline(vkdev); + pipeline_lstm_step_h->set_local_size_xyz(64, 1, 1); + + std::vector specializations_h; + pipeline_lstm_step_h->create(LayerShaderType::lstm_step_h, opt, specializations_h); + + pipeline_lstm_proj = new Pipeline(vkdev); + pipeline_lstm_proj->set_local_size_xyz(64, 1, 1); + + std::vector specializations_p; + pipeline_lstm_proj->create(LayerShaderType::lstm_proj, opt, specializations_p); + } + + return 0; +} + +int LSTM_vulkan::destroy_pipeline(const Option& /*opt*/) +{ + delete pipeline_lstm_copy; + pipeline_lstm_copy = 0; + + delete pipeline_lstm_step; + pipeline_lstm_step = 0; + + delete pipeline_lstm_step_h; + pipeline_lstm_step_h = 0; + + delete pipeline_lstm_proj; + pipeline_lstm_proj = 0; + + return 0; +} + +int LSTM_vulkan::upload_model(VkTransfer& cmd, const Option& opt) +{ + if (!support_vulkan) + return 0; + + cmd.record_upload(weight_xc_data, weight_xc_data_gpu, opt); + cmd.record_upload(bias_c_data, bias_c_data_gpu, opt); + cmd.record_upload(weight_hc_data, weight_hc_data_gpu, opt); + + if (num_output != hidden_size) + { + cmd.record_upload(weight_hr_data, weight_hr_data_gpu, opt); + } + + if (opt.lightmode) + { + weight_xc_data.release(); + bias_c_data.release(); + weight_hc_data.release(); + weight_hr_data.release(); + } + + return 0; +} + +static inline void record_lstm_copy(const Pipeline* pipeline, + VkCompute& cmd, + const VkMat& src, + VkMat& dst, + int len, + int src_offset, + int dst_offset, + int mode) +{ + std::vector bindings(2); + bindings[0] = src; + bindings[1] = dst; + + std::vector constants(4); + constants[0].i = len; + constants[1].i = src_offset; + constants[2].i = dst_offset; + constants[3].i = mode; + + VkMat dispatcher; + dispatcher.w = len; + dispatcher.h = 1; + dispatcher.c = 1; + + cmd.record_pipeline(pipeline, bindings, constants, dispatcher); +} + +static inline void record_lstm_step(const Pipeline* pipeline, + VkCompute& cmd, + const VkMat& bottom_blob, + const VkMat& weight_xc, + const VkMat& bias_c, + const VkMat& weight_hc, + const VkMat& hidden_prev, + const VkMat& cell_prev, + VkMat& hidden_next, + VkMat& cell_next, + VkMat& top_blob, + int size, + int num_output, + int hidden_size, + int ti, + int outw, + int out_offset, + int dir, + int wxc_dir_stride, + int whc_dir_stride, + int bias_dir_stride, + int bottom_step) +{ + std::vector bindings(9); + bindings[0] = bottom_blob; + bindings[1] = weight_xc; + bindings[2] = bias_c; + bindings[3] = weight_hc; + bindings[4] = hidden_prev; + bindings[5] = cell_prev; + bindings[6] = hidden_next; + bindings[7] = cell_next; + bindings[8] = top_blob; + + std::vector constants(11); + constants[0].i = size; + constants[1].i = num_output; + constants[2].i = hidden_size; + constants[3].i = ti; + constants[4].i = outw; + constants[5].i = out_offset; + constants[6].i = dir; + constants[7].i = wxc_dir_stride; + constants[8].i = whc_dir_stride; + constants[9].i = bias_dir_stride; + constants[10].i = bottom_step; + + VkMat dispatcher; + dispatcher.w = hidden_size; + dispatcher.h = 1; + dispatcher.c = 1; + + cmd.record_pipeline(pipeline, bindings, constants, dispatcher); +} + +static inline void record_lstm_step_h(const Pipeline* pipeline, + VkCompute& cmd, + const VkMat& bottom_blob, + const VkMat& weight_xc, + const VkMat& bias_c, + const VkMat& weight_hc, + const VkMat& hidden_prev, + const VkMat& cell_prev, + VkMat& hidden_h_next, + VkMat& cell_next, + int size, + int num_output, + int hidden_size, + int ti, + int dir, + int wxc_dir_stride, + int whc_dir_stride, + int bias_dir_stride, + int bottom_step) +{ + std::vector bindings(8); + bindings[0] = bottom_blob; + bindings[1] = weight_xc; + bindings[2] = bias_c; + bindings[3] = weight_hc; + bindings[4] = hidden_prev; + bindings[5] = cell_prev; + bindings[6] = hidden_h_next; + bindings[7] = cell_next; + + std::vector constants(9); + constants[0].i = size; + constants[1].i = num_output; + constants[2].i = hidden_size; + constants[3].i = ti; + constants[4].i = dir; + constants[5].i = wxc_dir_stride; + constants[6].i = whc_dir_stride; + constants[7].i = bias_dir_stride; + constants[8].i = bottom_step; + + VkMat dispatcher; + dispatcher.w = hidden_size; + dispatcher.h = 1; + dispatcher.c = 1; + + cmd.record_pipeline(pipeline, bindings, constants, dispatcher); +} + +static inline void record_lstm_proj(const Pipeline* pipeline, + VkCompute& cmd, + const VkMat& hidden_h, + const VkMat& weight_hr, + VkMat& hidden_next, + VkMat& top_blob, + int hidden_size, + int num_output, + int ti, + int outw, + int out_offset, + int dir, + int hr_dir_stride) +{ + std::vector bindings(4); + bindings[0] = hidden_h; + bindings[1] = weight_hr; + bindings[2] = hidden_next; + bindings[3] = top_blob; + + std::vector constants(7); + constants[0].i = hidden_size; + constants[1].i = num_output; + constants[2].i = ti; + constants[3].i = outw; + constants[4].i = out_offset; + constants[5].i = dir; + constants[6].i = hr_dir_stride; + + VkMat dispatcher; + dispatcher.w = num_output; + dispatcher.h = 1; + dispatcher.c = 1; + + cmd.record_pipeline(pipeline, bindings, constants, dispatcher); +} + +int LSTM_vulkan::forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const +{ + if (!support_vulkan) + return -1; + + VkMat bottom_blob = bottom_blobs[0]; + + if (bottom_blob.dims != 2) + return -1; + + const int size = bottom_blob.w; + const int T = bottom_blob.h; + + const int num_directions = direction == 2 ? 2 : 1; + + VkMat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, bottom_blob.elemsize, 1, opt.blob_vkallocator); + if (top_blob.empty()) + return -100; + + VkAllocator* state_vkallocator = top_blobs.size() == 3 ? opt.blob_vkallocator : opt.workspace_vkallocator; + + VkMat hidden0; + VkMat hidden0_next; + VkMat cell0; + VkMat cell0_next; + + hidden0.create(num_output, 1, bottom_blob.elemsize, 1, state_vkallocator); + hidden0_next.create(num_output, 1, bottom_blob.elemsize, 1, state_vkallocator); + cell0.create(hidden_size, 1, bottom_blob.elemsize, 1, state_vkallocator); + cell0_next.create(hidden_size, 1, bottom_blob.elemsize, 1, state_vkallocator); + if (hidden0.empty() || hidden0_next.empty() || cell0.empty() || cell0_next.empty()) + return -100; + + VkMat hidden1; + VkMat hidden1_next; + VkMat cell1; + VkMat cell1_next; + + if (num_directions == 2) + { + hidden1.create(num_output, 1, bottom_blob.elemsize, 1, state_vkallocator); + hidden1_next.create(num_output, 1, bottom_blob.elemsize, 1, state_vkallocator); + cell1.create(hidden_size, 1, bottom_blob.elemsize, 1, state_vkallocator); + cell1_next.create(hidden_size, 1, bottom_blob.elemsize, 1, state_vkallocator); + if (hidden1.empty() || hidden1_next.empty() || cell1.empty() || cell1_next.empty()) + return -100; + } + + if (bottom_blobs.size() == 3) + { + VkMat hidden_in; + VkMat cell_in; + vkdev->convert_packing(bottom_blobs[1], hidden_in, 1, cmd, opt); + vkdev->convert_packing(bottom_blobs[2], cell_in, 1, cmd, opt); + + record_lstm_copy(pipeline_lstm_copy, cmd, hidden_in, hidden0, num_output, 0, 0, 1); + record_lstm_copy(pipeline_lstm_copy, cmd, cell_in, cell0, hidden_size, 0, 0, 1); + + if (num_directions == 2) + { + record_lstm_copy(pipeline_lstm_copy, cmd, hidden_in, hidden1, num_output, num_output, 0, 1); + record_lstm_copy(pipeline_lstm_copy, cmd, cell_in, cell1, hidden_size, hidden_size, 0, 1); + } + } + else + { + record_lstm_copy(pipeline_lstm_copy, cmd, bottom_blob, hidden0, num_output, 0, 0, 0); + record_lstm_copy(pipeline_lstm_copy, cmd, bottom_blob, cell0, hidden_size, 0, 0, 0); + + if (num_directions == 2) + { + record_lstm_copy(pipeline_lstm_copy, cmd, bottom_blob, hidden1, num_output, 0, 0, 0); + record_lstm_copy(pipeline_lstm_copy, cmd, bottom_blob, cell1, hidden_size, 0, 0, 0); + } + } + + const int wxc_dir_stride = size * (hidden_size * 4); + const int whc_dir_stride = num_output * (hidden_size * 4); + const int bias_dir_stride = hidden_size * 4; + const int hr_dir_stride = hidden_size * num_output; + const int bottom_step = size; + + const bool has_projection = (num_output != hidden_size); + + VkMat hiddenh0; + VkMat hiddenh1; + if (has_projection) + { + hiddenh0.create(hidden_size, 1, bottom_blob.elemsize, 1, opt.workspace_vkallocator); + if (hiddenh0.empty()) + return -100; + + if (num_directions == 2) + { + hiddenh1.create(hidden_size, 1, bottom_blob.elemsize, 1, opt.workspace_vkallocator); + if (hiddenh1.empty()) + return -100; + } + } + + auto run_sequence = [&](int dir_index, int out_offset, int reverse, + VkMat& hprev, VkMat& hnext, + VkMat& cprev, VkMat& cnext, + VkMat& htmp) { + for (int t = 0; t < T; t++) + { + const int ti = reverse ? (T - 1 - t) : t; + + if (!has_projection) + { + record_lstm_step(pipeline_lstm_step, + cmd, + bottom_blob, + weight_xc_data_gpu, + bias_c_data_gpu, + weight_hc_data_gpu, + hprev, + cprev, + hnext, + cnext, + top_blob, + size, + num_output, + hidden_size, + ti, + top_blob.w, + out_offset, + dir_index, + wxc_dir_stride, + whc_dir_stride, + bias_dir_stride, + bottom_step); + + std::swap(hprev, hnext); + std::swap(cprev, cnext); + } + else + { + record_lstm_step_h(pipeline_lstm_step_h, + cmd, + bottom_blob, + weight_xc_data_gpu, + bias_c_data_gpu, + weight_hc_data_gpu, + hprev, + cprev, + htmp, + cnext, + size, + num_output, + hidden_size, + ti, + dir_index, + wxc_dir_stride, + whc_dir_stride, + bias_dir_stride, + bottom_step); + + record_lstm_proj(pipeline_lstm_proj, + cmd, + htmp, + weight_hr_data_gpu, + hnext, + top_blob, + hidden_size, + num_output, + ti, + top_blob.w, + out_offset, + dir_index, + hr_dir_stride); + + std::swap(hprev, hnext); + std::swap(cprev, cnext); + } + } + }; + + if (direction == 0 || direction == 1) + { + run_sequence(0, 0, direction, hidden0, hidden0_next, cell0, cell0_next, hiddenh0); + } + else + { + run_sequence(0, 0, 0, hidden0, hidden0_next, cell0, cell0_next, hiddenh0); + run_sequence(1, num_output, 1, hidden1, hidden1_next, cell1, cell1_next, hiddenh1); + } + + if (top_blobs.size() == 3) + { + if (num_directions == 1) + { + top_blobs[1] = hidden0; + top_blobs[2] = cell0; + } + else + { + VkMat& hidden_out = top_blobs[1]; + VkMat& cell_out = top_blobs[2]; + + hidden_out.create(num_output, 2, bottom_blob.elemsize, 1, opt.blob_vkallocator); + cell_out.create(hidden_size, 2, bottom_blob.elemsize, 1, opt.blob_vkallocator); + if (hidden_out.empty() || cell_out.empty()) + return -100; + + record_lstm_copy(pipeline_lstm_copy, cmd, hidden0, hidden_out, num_output, 0, 0, 1); + record_lstm_copy(pipeline_lstm_copy, cmd, hidden1, hidden_out, num_output, 0, num_output, 1); + + record_lstm_copy(pipeline_lstm_copy, cmd, cell0, cell_out, hidden_size, 0, 0, 1); + record_lstm_copy(pipeline_lstm_copy, cmd, cell1, cell_out, hidden_size, 0, hidden_size, 1); + } + } + + return 0; +} + +int LSTM_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, const Option& opt) const +{ + std::vector bottom_blobs(1); + std::vector top_blobs(1); + bottom_blobs[0] = bottom_blob; + + int ret = forward(bottom_blobs, top_blobs, cmd, opt); + top_blob = top_blobs[0]; + return ret; +} + +} // namespace ncnn diff --git a/src/layer/vulkan/lstm_vulkan.h b/src/layer/vulkan/lstm_vulkan.h new file mode 100644 index 000000000000..2948440b3a46 --- /dev/null +++ b/src/layer/vulkan/lstm_vulkan.h @@ -0,0 +1,41 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_LSTM_VULKAN_H +#define LAYER_LSTM_VULKAN_H + +#include "lstm.h" + +namespace ncnn { + +class LSTM_vulkan : public LSTM +{ +public: + LSTM_vulkan(); + + virtual int load_param(const ParamDict& pd); + + virtual int create_pipeline(const Option& opt); + virtual int destroy_pipeline(const Option& opt); + + virtual int upload_model(VkTransfer& cmd, const Option& opt); + + using LSTM::forward; + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const; + virtual int forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, const Option& opt) const; + +public: + VkMat weight_xc_data_gpu; + VkMat bias_c_data_gpu; + VkMat weight_hc_data_gpu; + VkMat weight_hr_data_gpu; + + Pipeline* pipeline_lstm_copy; + Pipeline* pipeline_lstm_step; + Pipeline* pipeline_lstm_step_h; + Pipeline* pipeline_lstm_proj; +}; + +} // namespace ncnn + +#endif // LAYER_LSTM_VULKAN_H diff --git a/src/layer/vulkan/shader/lstm_copy.comp b/src/layer/vulkan/shader/lstm_copy.comp new file mode 100644 index 000000000000..50c27d11ac27 --- /dev/null +++ b/src/layer/vulkan/shader/lstm_copy.comp @@ -0,0 +1,36 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#version 450 + +layout(binding = 0) readonly buffer src_blob { sfp src_data[]; }; +layout(binding = 1) writeonly buffer dst_blob { sfp dst_data[]; }; + +layout(push_constant) uniform parameter +{ + int len; + int src_offset; + int dst_offset; + int mode; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x); + + if (gx >= p.len) + return; + + int di = p.dst_offset + gx; + + if (p.mode == 0) + { + buffer_st1(dst_data, di, afp(0.f)); + } + else + { + int si = p.src_offset + gx; + afp v = buffer_ld1(src_data, si); + buffer_st1(dst_data, di, v); + } +} diff --git a/src/layer/vulkan/shader/lstm_proj.comp b/src/layer/vulkan/shader/lstm_proj.comp new file mode 100644 index 000000000000..d4c2bc3104bf --- /dev/null +++ b/src/layer/vulkan/shader/lstm_proj.comp @@ -0,0 +1,44 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#version 450 + +layout(binding = 0) readonly buffer hidden_h_blob { sfp hidden_h_data[]; }; +layout(binding = 1) readonly buffer weight_hr_blob { sfp weight_hr_data[]; }; +layout(binding = 2) writeonly buffer hidden_next_blob { sfp hidden_next_data[]; }; +layout(binding = 3) writeonly buffer top_blob { sfp top_blob_data[]; }; + +layout(push_constant) uniform parameter +{ + int hidden_size; + int num_output; + int ti; + int outw; + int out_offset; + int dir; + int hr_dir_stride; +} p; + +void main() +{ + int q = int(gl_GlobalInvocationID.x); + + if (q >= p.num_output) + return; + + int hr_base = p.dir * p.hr_dir_stride; + int hr_row = hr_base + q * p.hidden_size; + + afp sum = afp(0.f); + for (int i = 0; i < p.hidden_size; i++) + { + afp hi = buffer_ld1(hidden_h_data, i); + afp wi = buffer_ld1(weight_hr_data, hr_row + i); + sum += hi * wi; + } + + buffer_st1(hidden_next_data, q, sum); + + int out_index = p.ti * p.outw + p.out_offset + q; + buffer_st1(top_blob_data, out_index, sum); +} diff --git a/src/layer/vulkan/shader/lstm_step.comp b/src/layer/vulkan/shader/lstm_step.comp new file mode 100644 index 000000000000..770eb6ea7ca8 --- /dev/null +++ b/src/layer/vulkan/shader/lstm_step.comp @@ -0,0 +1,98 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#version 450 + +layout(binding = 0) readonly buffer bottom_blob { sfp bottom_blob_data[]; }; +layout(binding = 1) readonly buffer weight_xc_blob { sfp weight_xc_data[]; }; +layout(binding = 2) readonly buffer bias_c_blob { sfp bias_c_data[]; }; +layout(binding = 3) readonly buffer weight_hc_blob { sfp weight_hc_data[]; }; +layout(binding = 4) readonly buffer hidden_prev_blob { sfp hidden_prev_data[]; }; +layout(binding = 5) readonly buffer cell_prev_blob { sfp cell_prev_data[]; }; +layout(binding = 6) writeonly buffer hidden_next_blob { sfp hidden_next_data[]; }; +layout(binding = 7) writeonly buffer cell_next_blob { sfp cell_next_data[]; }; +layout(binding = 8) writeonly buffer top_blob { sfp top_blob_data[]; }; + +layout(push_constant) uniform parameter +{ + int size; + int num_output; + int hidden_size; + int ti; + int outw; + int out_offset; + int dir; + int wxc_dir_stride; + int whc_dir_stride; + int bias_dir_stride; + int bottom_step; +} p; + +void main() +{ + int q = int(gl_GlobalInvocationID.x); + + if (q >= p.hidden_size) + return; + + int x_offset = p.ti * p.bottom_step; + + int wxc_base = p.dir * p.wxc_dir_stride; + int whc_base = p.dir * p.whc_dir_stride; + int bias_base = p.dir * p.bias_dir_stride; + + int bias_I = bias_base + 0 * p.hidden_size + q; + int bias_F = bias_base + 1 * p.hidden_size + q; + int bias_O = bias_base + 2 * p.hidden_size + q; + int bias_G = bias_base + 3 * p.hidden_size + q; + + int wxc_I = wxc_base + (0 * p.hidden_size + q) * p.size; + int wxc_F = wxc_base + (1 * p.hidden_size + q) * p.size; + int wxc_O = wxc_base + (2 * p.hidden_size + q) * p.size; + int wxc_G = wxc_base + (3 * p.hidden_size + q) * p.size; + + int whc_I = whc_base + (0 * p.hidden_size + q) * p.num_output; + int whc_F = whc_base + (1 * p.hidden_size + q) * p.num_output; + int whc_O = whc_base + (2 * p.hidden_size + q) * p.num_output; + int whc_G = whc_base + (3 * p.hidden_size + q) * p.num_output; + + afp I = buffer_ld1(bias_c_data, bias_I); + afp F = buffer_ld1(bias_c_data, bias_F); + afp O = buffer_ld1(bias_c_data, bias_O); + afp G = buffer_ld1(bias_c_data, bias_G); + + for (int i = 0; i < p.size; i++) + { + afp xi = buffer_ld1(bottom_blob_data, x_offset + i); + + I += buffer_ld1(weight_xc_data, wxc_I + i) * xi; + F += buffer_ld1(weight_xc_data, wxc_F + i) * xi; + O += buffer_ld1(weight_xc_data, wxc_O + i) * xi; + G += buffer_ld1(weight_xc_data, wxc_G + i) * xi; + } + + for (int i = 0; i < p.num_output; i++) + { + afp hi = buffer_ld1(hidden_prev_data, i); + + I += buffer_ld1(weight_hc_data, whc_I + i) * hi; + F += buffer_ld1(weight_hc_data, whc_F + i) * hi; + O += buffer_ld1(weight_hc_data, whc_O + i) * hi; + G += buffer_ld1(weight_hc_data, whc_G + i) * hi; + } + + I = afp(1.f) / (afp(1.f) + exp(-I)); + F = afp(1.f) / (afp(1.f) + exp(-F)); + O = afp(1.f) / (afp(1.f) + exp(-O)); + G = tanh(G); + + afp cprev = buffer_ld1(cell_prev_data, q); + afp cnext = F * cprev + I * G; + afp H = O * tanh(cnext); + + buffer_st1(cell_next_data, q, cnext); + buffer_st1(hidden_next_data, q, H); + + int out_index = p.ti * p.outw + p.out_offset + q; + buffer_st1(top_blob_data, out_index, H); +} diff --git a/src/layer/vulkan/shader/lstm_step_h.comp b/src/layer/vulkan/shader/lstm_step_h.comp new file mode 100644 index 000000000000..42af4268b669 --- /dev/null +++ b/src/layer/vulkan/shader/lstm_step_h.comp @@ -0,0 +1,92 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#version 450 + +layout(binding = 0) readonly buffer bottom_blob { sfp bottom_blob_data[]; }; +layout(binding = 1) readonly buffer weight_xc_blob { sfp weight_xc_data[]; }; +layout(binding = 2) readonly buffer bias_c_blob { sfp bias_c_data[]; }; +layout(binding = 3) readonly buffer weight_hc_blob { sfp weight_hc_data[]; }; +layout(binding = 4) readonly buffer hidden_prev_blob { sfp hidden_prev_data[]; }; +layout(binding = 5) readonly buffer cell_prev_blob { sfp cell_prev_data[]; }; +layout(binding = 6) writeonly buffer hidden_h_next_blob { sfp hidden_h_next_data[]; }; +layout(binding = 7) writeonly buffer cell_next_blob { sfp cell_next_data[]; }; + +layout(push_constant) uniform parameter +{ + int size; + int num_output; + int hidden_size; + int ti; + int dir; + int wxc_dir_stride; + int whc_dir_stride; + int bias_dir_stride; + int bottom_step; +} p; + +void main() +{ + int q = int(gl_GlobalInvocationID.x); + + if (q >= p.hidden_size) + return; + + int x_offset = p.ti * p.bottom_step; + + int wxc_base = p.dir * p.wxc_dir_stride; + int whc_base = p.dir * p.whc_dir_stride; + int bias_base = p.dir * p.bias_dir_stride; + + int bias_I = bias_base + 0 * p.hidden_size + q; + int bias_F = bias_base + 1 * p.hidden_size + q; + int bias_O = bias_base + 2 * p.hidden_size + q; + int bias_G = bias_base + 3 * p.hidden_size + q; + + int wxc_I = wxc_base + (0 * p.hidden_size + q) * p.size; + int wxc_F = wxc_base + (1 * p.hidden_size + q) * p.size; + int wxc_O = wxc_base + (2 * p.hidden_size + q) * p.size; + int wxc_G = wxc_base + (3 * p.hidden_size + q) * p.size; + + int whc_I = whc_base + (0 * p.hidden_size + q) * p.num_output; + int whc_F = whc_base + (1 * p.hidden_size + q) * p.num_output; + int whc_O = whc_base + (2 * p.hidden_size + q) * p.num_output; + int whc_G = whc_base + (3 * p.hidden_size + q) * p.num_output; + + afp I = buffer_ld1(bias_c_data, bias_I); + afp F = buffer_ld1(bias_c_data, bias_F); + afp O = buffer_ld1(bias_c_data, bias_O); + afp G = buffer_ld1(bias_c_data, bias_G); + + for (int i = 0; i < p.size; i++) + { + afp xi = buffer_ld1(bottom_blob_data, x_offset + i); + + I += buffer_ld1(weight_xc_data, wxc_I + i) * xi; + F += buffer_ld1(weight_xc_data, wxc_F + i) * xi; + O += buffer_ld1(weight_xc_data, wxc_O + i) * xi; + G += buffer_ld1(weight_xc_data, wxc_G + i) * xi; + } + + for (int i = 0; i < p.num_output; i++) + { + afp hi = buffer_ld1(hidden_prev_data, i); + + I += buffer_ld1(weight_hc_data, whc_I + i) * hi; + F += buffer_ld1(weight_hc_data, whc_F + i) * hi; + O += buffer_ld1(weight_hc_data, whc_O + i) * hi; + G += buffer_ld1(weight_hc_data, whc_G + i) * hi; + } + + I = afp(1.f) / (afp(1.f) + exp(-I)); + F = afp(1.f) / (afp(1.f) + exp(-F)); + O = afp(1.f) / (afp(1.f) + exp(-O)); + G = tanh(G); + + afp cprev = buffer_ld1(cell_prev_data, q); + afp cnext = F * cprev + I * G; + afp Ht = O * tanh(cnext); + + buffer_st1(cell_next_data, q, cnext); + buffer_st1(hidden_h_next_data, q, Ht); +}