Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU EP] SoftMax Implementation #23538

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 238 additions & 0 deletions onnxruntime/core/providers/webgpu/math/softmax.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <string>

#include "core/common/inlined_containers.h"
#include "core/providers/common.h"
#include "core/providers/webgpu/math/softmax.h"
#include "core/providers/webgpu/tensor/transpose.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_variable.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
namespace onnxruntime {
namespace webgpu {

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Softmax,
kOnnxDomain,
1, 10,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Softmax);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Softmax,
kOnnxDomain,
11, 12,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Softmax);

ONNX_OPERATOR_KERNEL_EX(
Softmax,
kOnnxDomain,
13,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Softmax);

static std::string MaxVector(const std::string& name, int components) {
switch (components) {
case 1:
return name;
case 2:
return "max(" + name + ".x, " + name + ".y)";
case 3:
return "max(max(" + name + ".x, " + name + ".y), " + name + ".z)";
case 4:
return "max(max(" + name + ".x, " + name + ".y), max(" + name + ".z, " + name + ".w))";
default:
ORT_THROW("Unsupported number of components: ", components);
}
}

static std::string SumVector(const std::string& x, int components) {
switch (components) {
case 1:
return x;
case 2:
return "(" + x + ".x + " + x + ".y" + ")";
case 4:
return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")";
default:
ORT_THROW("Unsupported number of components: ", components);
}
}

static int GetMaxComponents(int64_t size) {
if (size % 4 == 0) {
return 4;
} else if (size % 2 == 0) {
return 2;
}
return 1;
}

Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Add input and output variables
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
int components = input.NumComponents();

const std::string thread_max_decl = is_fp32_
? "var thread_max = x_value_t(-3.402823e+38f);\n"
: "var thread_max = x_value_t(-65504.0h);\n";

// Define shared memory for row max and row sum
shader.AdditionalImplementation()
<< "var<workgroup> row_max_shared : x_value_t;\n"
<< "var<workgroup> row_sum_shared : x_value_t;\n"
<< "var<workgroup> thread_shared : array<x_value_t, " << wg_ << ">;\n";

// Define helper functions to get and set values
shader.AdditionalImplementation()
<< "fn getValue(row: i32, col: i32, row_stride: i32) -> x_value_t {\n"
<< " let index = row * row_stride + col;\n"
<< " return x[index];\n"
<< "}\n"
<< "fn setValue(row: i32, col: i32, row_stride: i32, value: x_value_t) {\n"
<< " let index = row * row_stride + col;\n"
<< " result[index] = value;\n"
<< "}\n";

// Main function body
shader.MainFunctionBody()
<< " let gindex = i32(global_idx);\n"
<< " let lindex = i32(local_idx);\n"
<< " const wg = " << wg_ << ";\n"
<< " let row = gindex / wg;\n"
<< " let cols = uniforms.packedCols;\n"
<< " let row_stride : i32 = uniforms.packedCols;\n"

// Find the row's max value
<< thread_max_decl
<< " for (var col = lindex; col < cols; col += wg) {\n"
<< " let value = getValue(row, col, row_stride);\n"
<< " thread_max = max(thread_max, value);\n"
<< " }\n"
<< " if (lindex < cols) {\n"
<< " thread_shared[lindex] = thread_max;\n"
<< " }\n"
<< " workgroupBarrier();\n"

// Reduce to find the max value
<< " var reduce_size = min(cols, wg);\n"
<< " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n"
<< " reduce_size = curr_size + (reduce_size & 1);\n"
<< " if (lindex < curr_size) {\n"
<< " thread_shared[lindex] = max(thread_shared[lindex], thread_shared[lindex + reduce_size]);\n"
<< " }\n"
<< " workgroupBarrier();\n"
<< " }\n"
<< " if (lindex == 0) {\n"
<< " row_max_shared = x_value_t(" << MaxVector("thread_shared[0]", components) << ");\n"
<< " }\n"
<< " workgroupBarrier();\n"

// Find the row's sum of exponentials
<< " var thread_sum = x_value_t(0.0);\n"
<< " for (var col = lindex; col < cols; col += wg) {\n"
<< " let sub_exp = exp(getValue(row, col, row_stride) - row_max_shared);\n"
<< " thread_sum += sub_exp;\n"
<< " }\n"
<< " thread_shared[lindex] = thread_sum;\n"
<< " workgroupBarrier();\n"

// Reduce to find the sum of exponentials
<< " for (var curr_size = wg >> 1; curr_size > 0; curr_size = curr_size >> 1) {\n"
<< " if (lindex < curr_size) {\n"
<< " thread_shared[lindex] = thread_shared[lindex] + thread_shared[lindex + curr_size];\n"
<< " }\n"
<< " workgroupBarrier();\n"
<< " }\n"
<< " if (lindex == 0) {\n"
<< " row_sum_shared = x_value_t(" << SumVector("thread_shared[0]", components) << ");\n"
<< " }\n"
<< " workgroupBarrier();\n"

// Calculate the final value for each element in the row
<< " for (var col = lindex; col < cols; col += wg) {\n"
<< " let value = exp(getValue(row, col, row_stride) - row_max_shared) / row_sum_shared;\n"
<< " setValue(row, col, row_stride, value);\n"
<< " }\n";

return Status::OK();
}

Status Softmax::ComputeInternal(ComputeContext& context) const {
const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
size_t input_rank = input_shape.NumDimensions();
auto* output_tensor = context.Output(0, input_shape);

// normalize axis
size_t axis = HandleNegativeAxis(axis_, input_rank);
bool is_transpose_required = axis < input_rank - 1;

TensorShape transposed_input_shape;
Tensor transposed_input_tensor;
Tensor intermediate_output;
InlinedVector<size_t> perm(input_rank);

if (is_transpose_required) {
std::iota(std::begin(perm), std::end(perm), 0);
perm[axis] = input_rank - 1;
perm[input_rank - 1] = axis;

TensorShapeVector transposed_input_dims;
for (auto e : perm) {
transposed_input_dims.push_back(input_shape[e]);
}

transposed_input_shape = TensorShape(transposed_input_dims);
transposed_input_tensor = context.CreateGPUTensor(input_tensor->DataType(), transposed_input_shape);
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, *input_tensor, transposed_input_tensor));
intermediate_output = context.CreateGPUTensor(output_tensor->DataType(), transposed_input_shape);
}

const int64_t cols = is_transpose_required ? transposed_input_shape[input_rank - 1] : input_shape[input_rank - 1];
const int64_t rows = input_shape.Size() / cols;
const int64_t components = GetMaxComponents(cols);
const auto packed_cols = cols / components;
uint32_t workgroup_size = rows == 1 ? 256 : 64;
// check input tensor element type is float
const bool is_fp32 = input_tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;

SoftmaxProgram program{workgroup_size, is_fp32};
if (is_transpose_required) {
program
.AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
.AddOutputs({{&intermediate_output, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
} else {
program
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
}

program
.CacheHint(std::to_string(components), std::to_string(workgroup_size))
.SetWorkgroupSize(workgroup_size)
.SetDispatchGroupSize(static_cast<uint32_t>(rows))
.AddUniformVariables({{static_cast<int32_t>(packed_cols)}});

ORT_RETURN_IF_ERROR(context.RunProgram(program));

// If transpose was required, transpose the result back
if (is_transpose_required) {
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, intermediate_output, *output_tensor));
}

return Status::OK();
}
} // namespace webgpu
} // namespace onnxruntime
54 changes: 54 additions & 0 deletions onnxruntime/core/providers/webgpu/math/softmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/program.h"
#include "core/framework/op_kernel.h"

namespace onnxruntime {
namespace webgpu {

class Softmax final : public WebGpuKernel {
public:
Softmax(const OpKernelInfo& info) : WebGpuKernel{info} {
int opset_ = info.node().SinceVersion();
int64_t axis;
Status status = info.GetAttr<int64_t>("axis", &axis);

if (status.IsOK()) {
axis_ = axis;
} else {
if (opset_ < 13) {
axis_ = 1; // opset-12 and below, the default axis value is 1
} else {
axis_ = -1; // opset-13, the default axis value is -1
}
}
}

Status ComputeInternal(ComputeContext& context) const override;

private:
int64_t axis_;
};

class SoftmaxProgram final : public Program<SoftmaxProgram> {
public:
SoftmaxProgram(uint32_t wg, bool is_fp32)
: Program{"Softmax"}, wg_{wg}, is_fp32_{is_fp32} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"packedCols", ProgramUniformVariableDataType::Int32});

private:
uint32_t wg_;
bool is_fp32_;
};

} // namespace webgpu
} // namespace onnxruntime
60 changes: 40 additions & 20 deletions onnxruntime/core/providers/webgpu/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Transpose);

auto SqueezeShape(const gsl::span<const int64_t>& shape, const gsl::span<const size_t>& adjusted_perm, InlinedVector<int64_t>& new_shape, InlinedVector<int64_t>& new_perm) {
auto SqueezeShape(const gsl::span<const int64_t>& shape,
const gsl::span<const size_t>& adjusted_perm,
TensorShapeVector& new_shape,
TensorShapeVector& new_perm) {
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != 1) {
new_shape.push_back(shape[i]);
Expand Down Expand Up @@ -97,26 +100,28 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
return Status::OK();
}

Status Transpose::ComputeInternal(ComputeContext& context) const {
const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context,
gsl::span<const size_t> permutations,
const Tensor& input, Tensor& output) {
const auto& input_shape = input.Shape();
const auto& input_dims = input_shape.GetDims();
int32_t rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions());

TensorShapeVector output_dims(rank);
InlinedVector<size_t> default_perm(rank);
const InlinedVector<size_t>* p_perm = nullptr;
ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm));
TensorShape output_shape(output_dims);
auto* output_tensor = context.Output(0, output_shape);

InlinedVector<int64_t> new_shape{};
InlinedVector<int64_t> new_perm{};
SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm);
const bool channels_last = new_perm == InlinedVector<int64_t>({2, 3, 1});
const bool channels_first = new_perm == InlinedVector<int64_t>({3, 1, 2});
for (int32_t i = 0; i < rank; i++) {
output_dims[i] = input_dims[permutations[i]];
}

TensorShapeVector new_shape{};
TensorShapeVector new_perm{};
SqueezeShape(input_shape.GetDims(), permutations, new_shape, new_perm);
const bool channels_last = new_perm == TensorShapeVector({2, 3, 1});
const bool channels_first = new_perm == TensorShapeVector({3, 1, 2});
const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first;
auto new_input_shape = input_shape;
TensorShape new_output_shape(output_dims);

if (use_shared) {
new_input_shape = channels_last
? TensorShape({new_shape[0], new_shape[1] * new_shape[2]})
Expand All @@ -126,16 +131,16 @@ Status Transpose::ComputeInternal(ComputeContext& context) const {
new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]});
}

uint32_t output_size = gsl::narrow_cast<int32_t>(input_tensor->Shape().Size());
TransposeProgram program{*p_perm, use_shared};
uint32_t output_size = gsl::narrow_cast<int32_t>(input_shape.Size());
TransposeProgram program{permutations, use_shared};

if (use_shared) {
program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1);
}

program
.CacheHint(absl::StrJoin(*p_perm, "-"))
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, new_output_shape, 1}})
.CacheHint(absl::StrJoin(permutations, "-"))
.AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}})
.AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}})
.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)))
.AddUniformVariables({
Expand All @@ -148,5 +153,20 @@ Status Transpose::ComputeInternal(ComputeContext& context) const {
return context.RunProgram(program);
}

Status Transpose::ComputeInternal(ComputeContext& context) const {
const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
int32_t rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions());

TensorShapeVector output_dims(rank);
InlinedVector<size_t> default_perm(rank);
const InlinedVector<size_t>* p_perm = nullptr;
ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm));
TensorShape output_shape(output_dims);
auto* output_tensor = context.Output(0, output_shape);

return DoTranspose(context, *p_perm, *input_tensor, *output_tensor);
}

} // namespace webgpu
} // namespace onnxruntime
Loading
Loading