Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/layer/vulkan/shader/shrink.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2026 Futz12 <pchar.cn>
// SPDX-License-Identifier: BSD-3-Clause

#version 450

layout (constant_id = 0) const float const_bias = 0.f;
layout (constant_id = 1) const float const_lambd = 0.f;

#define shape_constant_id_offset 2
layout (constant_id = shape_constant_id_offset + 0) const uint n = 0;

layout (binding = 0) buffer bottom_top_blob { sfpvec4 bottom_top_blob_data[]; };

layout (push_constant) uniform parameter
{
uint n;
} p;

void main()
{
const uint gi = gl_GlobalInvocationID.x;

if (gi >= psc(n))
return;

afpvec4 v = buffer_ld4(bottom_top_blob_data, gi);

afpvec4 vb = afpvec4(const_bias);
afpvec4 vl = afpvec4(const_lambd);

afpvec4 zero = afpvec4(0.f);

v = mix(
mix(v + vb, v - vb, greaterThan(v, vl)),
zero,
lessThanEqual(abs(v), vl)
);

buffer_st4(bottom_top_blob_data, gi, v);
}
84 changes: 84 additions & 0 deletions src/layer/vulkan/shrink_vulkan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2026 Futz12 <pchar.cn>
// SPDX-License-Identifier: BSD-3-Clause

#include "shrink_vulkan.h"

#include "layer_shader_type.h"

namespace ncnn {

Shrink_vulkan::Shrink_vulkan()
{
support_vulkan = true;
support_vulkan_packing = true;

pipeline_shrink = 0;
}

int Shrink_vulkan::create_pipeline(const Option& opt)
{
const Mat& shape = top_shapes.empty() ? Mat() : top_shapes[0];

int elempack = 1;
if (shape.dims == 1) elempack = shape.w % 4 == 0 ? 4 : 1;
if (shape.dims == 2) elempack = shape.h % 4 == 0 ? 4 : 1;
if (shape.dims == 3 || shape.dims == 4) elempack = shape.c % 4 == 0 ? 4 : 1;

size_t elemsize;
if (opt.use_fp16_storage || opt.use_fp16_packed)
{
elemsize = elempack * 2u;
}
else
{
elemsize = elempack * 4u;
}

Mat shape_packed;
if (shape.dims == 1) shape_packed = Mat(shape.w / elempack, (void*)0, elemsize, elempack);
if (shape.dims == 2) shape_packed = Mat(shape.w, shape.h / elempack, (void*)0, elemsize, elempack);
if (shape.dims == 3) shape_packed = Mat(shape.w, shape.h, shape.c / elempack, (void*)0, elemsize, elempack);
if (shape.dims == 4) shape_packed = Mat(shape.w, shape.h, shape.d, shape.c / elempack, (void*)0, elemsize, elempack);

std::vector<vk_specialization_type> specializations(2 + 1);
specializations[0].f = bias;
specializations[1].f = lambd;
specializations[2 + 0].u32 = shape_packed.total() * elempack / 4;

const int local_size_x = vkdev->info.subgroup_size();

pipeline_shrink = new Pipeline(vkdev);
pipeline_shrink->set_optimal_local_size_xyz(local_size_x, 1, 1);
pipeline_shrink->create(LayerShaderType::shrink, opt, specializations);

return 0;
}

int Shrink_vulkan::destroy_pipeline(const Option& /*opt*/)
{
delete pipeline_shrink;
pipeline_shrink = 0;

return 0;
}

int Shrink_vulkan::forward_inplace(VkMat& bottom_top_blob, VkCompute& cmd, const Option& /*opt*/) const
{
const size_t n = bottom_top_blob.total() * bottom_top_blob.elempack / 4;

std::vector<VkMat> bindings(1);
bindings[0] = bottom_top_blob;

std::vector<vk_constant_type> constants(1);
constants[0].u32 = n;

VkMat dispatcher;
dispatcher.w = n;
dispatcher.h = 1;
dispatcher.c = 1;
cmd.record_pipeline(pipeline_shrink, bindings, constants, dispatcher);

return 0;
}

} // namespace ncnn
28 changes: 28 additions & 0 deletions src/layer/vulkan/shrink_vulkan.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2026 Futz12 <pchar.cn>
// SPDX-License-Identifier: BSD-3-Clause

#ifndef LAYER_SHRINK_VULKAN_H
#define LAYER_SHRINK_VULKAN_H

#include "shrink.h"

namespace ncnn {

class Shrink_vulkan : public Shrink
{
public:
Shrink_vulkan();

virtual int create_pipeline(const Option& opt);
virtual int destroy_pipeline(const Option& opt);

using Shrink::forward_inplace;
virtual int forward_inplace(VkMat& bottom_top_blob, VkCompute& cmd, const Option& opt) const;

public:
Pipeline* pipeline_shrink;
};

} // namespace ncnn

#endif // LAYER_SHRINK_VULKAN_H
Loading