Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
770 changes: 272 additions & 498 deletions src/layer/vulkan/pooling_vulkan.cpp

Large diffs are not rendered by default.

20 changes: 6 additions & 14 deletions src/layer/vulkan/pooling_vulkan.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019 Tencent
// Copyright 2026 Futz12 <pchar.cn>
// SPDX-License-Identifier: BSD-3-Clause

#ifndef LAYER_POOLING_VULKAN_H
Expand All @@ -16,26 +16,18 @@ class Pooling_vulkan : public Pooling
virtual int create_pipeline(const Option& opt);
virtual int destroy_pipeline(const Option& opt);

virtual int upload_model(VkTransfer& cmd, const Option& opt);

using Pooling::forward;
virtual int forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, const Option& opt) const;

public:
ncnn::Layer* padding;

Pipeline* pipeline_pooling;
Pipeline* pipeline_pooling_pack4;
Pipeline* pipeline_pooling_tile;

Pipeline* pipeline_pooling_global;
Pipeline* pipeline_pooling_global_stage1;
Pipeline* pipeline_pooling_global_stage2;

Pipeline* pipeline_pooling_adaptive;
Pipeline* pipeline_pooling_adaptive_pack4;

Pipeline* pipeline_pooling_global_reduce_first;
Pipeline* pipeline_pooling_global_reduce_first_pack4;
Pipeline* pipeline_pooling_global_reduce;
Pipeline* pipeline_pooling_global_reduce_pack4;
Pipeline* pipeline_pooling_global_reduce_last;
Pipeline* pipeline_pooling_global_reduce_last_pack4;
};

} // namespace ncnn
Expand Down
192 changes: 117 additions & 75 deletions src/layer/vulkan/shader/pooling.comp
Original file line number Diff line number Diff line change
@@ -1,35 +1,34 @@
// Copyright 2018 Tencent
// Copyright 2026 Futz12 <pchar.cn>
// SPDX-License-Identifier: BSD-3-Clause

#version 450

#define FLT_MAX 3.402823466e+38

layout(constant_id = 0) const int pooling_type = 0;
layout(constant_id = 1) const int kernel_w = 1;
layout(constant_id = 2) const int kernel_h = 1;
layout(constant_id = 1) const int kernel_w = 0;
layout(constant_id = 2) const int kernel_h = 0;
layout(constant_id = 3) const int stride_w = 1;
layout(constant_id = 4) const int stride_h = 1;
layout(constant_id = 5) const int pad_left = 0;
layout(constant_id = 6) const int pad_right = 0;
layout(constant_id = 7) const int pad_top = 0;
layout(constant_id = 8) const int pad_bottom = 0;
layout(constant_id = 9) const int global_pooling = 0;
layout(constant_id = 10) const int pad_mode = 0;
layout(constant_id = 11) const int avgpool_count_include_pad = 0;
layout(constant_id = 9) const int pad_mode = 0;
layout(constant_id = 10) const int avgpool_count_include_pad = 0;

#define shape_constant_id_offset 12
#define shape_constant_id_offset 11
layout(constant_id = shape_constant_id_offset + 0) const int dims = 0;
layout(constant_id = shape_constant_id_offset + 1) const int w = 0;
layout(constant_id = shape_constant_id_offset + 2) const int h = 0;
layout(constant_id = shape_constant_id_offset + 3) const int c = 0;
layout(constant_id = shape_constant_id_offset + 4) const int cstep = 0;
layout(constant_id = shape_constant_id_offset + 3) const int d = 0;
layout(constant_id = shape_constant_id_offset + 4) const int c = 0;
layout(constant_id = shape_constant_id_offset + 5) const int cstep = 0;

layout(constant_id = shape_constant_id_offset + 5) const int outdims = 0;
layout(constant_id = shape_constant_id_offset + 6) const int outw = 0;
layout(constant_id = shape_constant_id_offset + 7) const int outh = 0;
layout(constant_id = shape_constant_id_offset + 8) const int outc = 0;
layout(constant_id = shape_constant_id_offset + 9) const int outcstep = 0;
layout(constant_id = shape_constant_id_offset + 6) const int outdims = 0;
layout(constant_id = shape_constant_id_offset + 7) const int outw = 0;
layout(constant_id = shape_constant_id_offset + 8) const int outh = 0;
layout(constant_id = shape_constant_id_offset + 9) const int outd = 0;
layout(constant_id = shape_constant_id_offset + 10) const int outc = 0;
layout(constant_id = shape_constant_id_offset + 11) const int outcstep = 0;

layout(binding = 0) readonly buffer bottom_blob { sfp bottom_blob_data[]; };
layout(binding = 1) writeonly buffer top_blob { sfp top_blob_data[]; };
Expand All @@ -39,107 +38,150 @@ layout(push_constant) uniform parameter
int dims;
int w;
int h;
int d;
int c;
int cstep;

int outdims;
int outw;
int outh;
int outd;
int outc;
int outcstep;

int wtailpad;
int htailpad;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);
int ox = int(gl_GlobalInvocationID.x);
int oy = int(gl_GlobalInvocationID.y);
int oz = int(gl_GlobalInvocationID.z);

if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc))
if (ox >= psc(outw) || oy >= psc(outh) || oz >= psc(outc))
return;

afp res;
int pl;
int pr;
int pt;
int pb;

if (pooling_type == 0)
if (pad_mode == 0 || pad_mode == 1)
{
res = afp(-FLT_MAX);
pl = pad_left;
pr = pad_right;
pt = pad_top;
pb = pad_bottom;

int v_offset = gz * psc(cstep) + gy * stride_h * psc(w) + gx * stride_w;

for (int y = 0; y < kernel_h; y++)
if (pad_mode == 0)
{
for (int x = 0; x < kernel_w; x++)
{
afp v = buffer_ld1(bottom_blob_data, v_offset + x);
res = max(res, v);
}

v_offset += psc(w);
int wtail = (psc(w) + pl + pr - kernel_w) % stride_w;
int htail = (psc(h) + pt + pb - kernel_h) % stride_h;
if (wtail != 0) pr += stride_w - wtail;
if (htail != 0) pb += stride_h - htail;
}
}
if (pooling_type == 1 && avgpool_count_include_pad == 0)
else
{
res = afp(0.f);
int area = 0;
int wpad = kernel_w + (psc(w) - 1) / stride_w * stride_w - psc(w);
int hpad = kernel_h + (psc(h) - 1) / stride_h * stride_h - psc(h);
if (wpad < 0) wpad = 0;
if (hpad < 0) hpad = 0;

int sx = gx * stride_w;
int sy = gy * stride_h;
if (pad_mode == 2)
{
pl = wpad / 2;
pr = wpad - pl;
pt = hpad / 2;
pb = hpad - pt;
}
else
{
pl = wpad - wpad / 2;
pr = wpad / 2;
pt = hpad - hpad / 2;
pb = hpad / 2;
}
}

int inx0 = ox * stride_w - pl;
int iny0 = oy * stride_h - pt;

int v_offset = gz * psc(cstep) + sy * psc(w) + sx;
if (pooling_type == 0)
{
afp mv = afp(-3.402823466e38);

for (int y = 0; y < kernel_h; y++)
for (int ky = 0; ky < kernel_h; ky++)
{
if (sy + y < pad_top)
int iy = iny0 + ky;
for (int kx = 0; kx < kernel_w; kx++)
{
v_offset += psc(w);
continue;
}

if (sy + y >= psc(h) - pad_bottom - p.htailpad)
break;
int ix = inx0 + kx;

for (int x = 0; x < kernel_w; x++)
{
if (sx + x < pad_left)
{
if (ix < 0 || ix >= psc(w) || iy < 0 || iy >= psc(h))
continue;
}

if (sx + x >= psc(w) - pad_right - p.wtailpad)
break;

res += buffer_ld1(bottom_blob_data, v_offset + x);
area += 1;
int si = oz * psc(cstep) + iy * psc(w) + ix;
afp v = buffer_ld1(bottom_blob_data, si);
mv = max(mv, v);
}

v_offset += psc(w);
}

res /= afp(area);
int gi = oz * psc(outcstep) + oy * psc(outw) + ox;
buffer_st1(top_blob_data, gi, mv);
}
if (pooling_type == 1 && avgpool_count_include_pad == 1)
else
{
res = afp(0.f);
afp sum = afp(0.f);

int v_offset = gz * psc(cstep) + gy * stride_h * psc(w) + gx * stride_w;

for (int y = 0; y < kernel_h; y++)
if (avgpool_count_include_pad == 1)
{
for (int x = 0; x < kernel_w; x++)
for (int ky = 0; ky < kernel_h; ky++)
{
res += buffer_ld1(bottom_blob_data, v_offset + x);
int iy = iny0 + ky;
for (int kx = 0; kx < kernel_w; kx++)
{
int ix = inx0 + kx;

if (ix < 0 || ix >= psc(w) || iy < 0 || iy >= psc(h))
continue;

int si = oz * psc(cstep) + iy * psc(w) + ix;
sum += buffer_ld1(bottom_blob_data, si);
}
}

v_offset += psc(w);
sum *= afp(1.f / float(kernel_w * kernel_h));
int gi = oz * psc(outcstep) + oy * psc(outw) + ox;
buffer_st1(top_blob_data, gi, sum);
}
else
{
int vx0 = max(0, -inx0);
int vy0 = max(0, -iny0);
int vx1 = min(kernel_w, psc(w) - inx0);
int vy1 = min(kernel_h, psc(h) - iny0);

res /= afp(kernel_w * kernel_h);
}
int area = (vx1 - vx0) * (vy1 - vy0);
if (area <= 0)
{
int gi = oz * psc(outcstep) + oy * psc(outw) + ox;
buffer_st1(top_blob_data, gi, afp(0.f));
return;
}

const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
for (int ky = vy0; ky < vy1; ky++)
{
int iy = iny0 + ky;
for (int kx = vx0; kx < vx1; kx++)
{
int ix = inx0 + kx;
int si = oz * psc(cstep) + iy * psc(w) + ix;
sum += buffer_ld1(bottom_blob_data, si);
}
}

buffer_st1(top_blob_data, gi, res);
sum *= afp(1.f / float(area));
int gi = oz * psc(outcstep) + oy * psc(outw) + ox;
buffer_st1(top_blob_data, gi, sum);
}
}
}
Loading