Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 227 additions & 0 deletions src/layer/vulkan/shader/softmax.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
// softmax.comp
// Copyright 2026 Futz12 <pchar.cn>
// SPDX-License-Identifier: BSD-3-Clause

#version 450

layout(constant_id = 0) const int axis = 0;

#define shape_constant_id_offset 1
layout(constant_id = shape_constant_id_offset + 0) const int dims = 0;
layout(constant_id = shape_constant_id_offset + 1) const int w = 0;
layout(constant_id = shape_constant_id_offset + 2) const int h = 0;
layout(constant_id = shape_constant_id_offset + 3) const int d = 0;
layout(constant_id = shape_constant_id_offset + 4) const int c = 0;
layout(constant_id = shape_constant_id_offset + 5) const int cstep = 0;
layout(constant_id = shape_constant_id_offset + 6) const int outcstep = 0;

layout(binding = 0) readonly buffer bottom_blob { sfp bottom_blob_data[]; };
layout(binding = 1) writeonly buffer top_blob { sfp top_blob_data[]; };

layout(push_constant) uniform parameter
{
int dims;
int w;
int h;
int d;
int c;
int cstep;
int outcstep;
} p;

shared lfp smaxv[256];
shared lfp ssumv[256];

void reduce_maxv(int lid, int lsize)
{
for (int off = lsize / 2; off > 0; off >>= 1)
{
if (lid < off)
{
afp a = lfp2afp(smaxv[lid]);
afp b = lfp2afp(smaxv[lid + off]);
smaxv[lid] = max(a, b);
}
barrier();
}
}

void reduce_sumv(int lid, int lsize)
{
for (int off = lsize / 2; off > 0; off >>= 1)
{
if (lid < off)
{
afp a = lfp2afp(ssumv[lid]);
afp b = lfp2afp(ssumv[lid + off]);
ssumv[lid] = a + b;
}
barrier();
}
}

void main()
{
int slice = int(gl_WorkGroupID.x);
int lid = int(gl_LocalInvocationID.x);
int lsize = int(gl_WorkGroupSize.x);

int dims_ = psc(dims);
int w_ = psc(w);
int h_ = psc(h);
int d_ = psc(d);
int c_ = psc(c);
int cstep_ = psc(cstep);
int outcstep_ = psc(outcstep);

int pa = axis < 0 ? dims_ + axis : axis;

int base_in = 0;
int base_out = 0;
int size = 0;
int stride_in = 0;
int stride_out = 0;

if (dims_ == 1)
{
base_in = 0;
base_out = 0;
size = w_;
stride_in = 1;
stride_out = 1;
}
else if (dims_ == 2)
{
if (pa == 0)
{
int x = slice;
base_in = x;
base_out = x;
size = h_;
stride_in = w_;
stride_out = w_;
}
else
{
int y = slice;
base_in = y * w_;
base_out = y * w_;
size = w_;
stride_in = 1;
stride_out = 1;
}
}
else if (dims_ == 3)
{
if (pa == 0)
{
int xy = slice;
base_in = xy;
base_out = xy;
size = c_;
stride_in = cstep_;
stride_out = outcstep_;
}
else if (pa == 1)
{
int q = slice / w_;
int x = slice - q * w_;
base_in = q * cstep_ + x;
base_out = q * outcstep_ + x;
size = h_;
stride_in = w_;
stride_out = w_;
}
else
{
int q = slice / h_;
int y = slice - q * h_;
base_in = q * cstep_ + y * w_;
base_out = q * outcstep_ + y * w_;
size = w_;
stride_in = 1;
stride_out = 1;
}
}
else
{
int plane = w_ * h_;

if (pa == 0)
{
int xyd = slice;
base_in = xyd;
base_out = xyd;
size = c_;
stride_in = cstep_;
stride_out = outcstep_;
}
else if (pa == 1)
{
int q = slice / plane;
int xy = slice - q * plane;
base_in = q * cstep_ + xy;
base_out = q * outcstep_ + xy;
size = d_;
stride_in = plane;
stride_out = plane;
}
else if (pa == 2)
{
int t = d_ * w_;
int q = slice / t;
int rem = slice - q * t;
int z = rem / w_;
int x = rem - z * w_;
base_in = q * cstep_ + z * plane + x;
base_out = q * outcstep_ + z * plane + x;
size = h_;
stride_in = w_;
stride_out = w_;
}
else
{
int t = d_ * h_;
int q = slice / t;
int rem = slice - q * t;
int z = rem / h_;
int y = rem - z * h_;
base_in = q * cstep_ + z * plane + y * w_;
base_out = q * outcstep_ + z * plane + y * w_;
size = w_;
stride_in = 1;
stride_out = 1;
}
}

afp lmax = afp(-3.402823466e38);
for (int i = lid; i < size; i += lsize)
{
afp v = buffer_ld1(bottom_blob_data, base_in + i * stride_in);
lmax = max(lmax, v);
}

smaxv[lid] = lmax;
barrier();
reduce_maxv(lid, lsize);
afp maxv = lfp2afp(smaxv[0]);

afp lsum = afp(0.f);
for (int i = lid; i < size; i += lsize)
{
afp v = buffer_ld1(bottom_blob_data, base_in + i * stride_in);
lsum += exp(v - maxv);
}

ssumv[lid] = lsum;
barrier();
reduce_sumv(lid, lsize);
afp invsum = afp(1.f) / lfp2afp(ssumv[0]);

for (int i = lid; i < size; i += lsize)
{
afp v = buffer_ld1(bottom_blob_data, base_in + i * stride_in);
afp e = exp(v - maxv) * invsum;
buffer_st1(top_blob_data, base_out + i * stride_out, e);
}
}
112 changes: 0 additions & 112 deletions src/layer/vulkan/shader/softmax_div_sum.comp

This file was deleted.

Loading