Skip to content

Commit

Permalink
[WebNN] Add op support validation for decomposed WebNN ops (#23370)
Browse files Browse the repository at this point in the history
- Some ONNX op are supported by decomposed WebNN ops, defines a
`decomposed_op_map` map to specific decomposed WebNN ops list.
- WebNN ops have various first input names such as 'a', 'input',
'inputs', etc. Defines a `webnn_op_first_input_name_map` map to record
special names other than 'input', and a `GetWebNNOpFirstInputName`
function to retrieve the first input name of a WebNN op.
- Check if the input and output data types are supported by each
decomposed WebNN op.
- Remove the unnecessary `CheckSingleOp` function, because WebNN's
`OpSupportLimits` has already covered op supported check.
  • Loading branch information
Honry authored Feb 11, 2025
1 parent 8c3e34d commit 5fa8bd0
Show file tree
Hide file tree
Showing 25 changed files with 261 additions and 112 deletions.
43 changes: 20 additions & 23 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,7 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
std::unordered_set<const Node*> supported_nodes;

for (const auto& node : graph_viewer.Nodes()) {
bool supported = false;
// Firstly check if platform supports the WebNN op.
if (CheckSingleOp(node.OpType(), wnn_builder, device_type)) {
supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger);
}
const bool supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger);
LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType()
<< "] index: [" << node.Index()
<< "] name: [" << node.Name()
Expand All @@ -125,7 +121,7 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
return supported_nodes;
}

bool AreInputDataTypesSame(const std::string& op_type,
bool AreInputDataTypesSame(const std::string_view op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger) {
for (size_t i = 1; i < input_types.size(); i++) {
Expand All @@ -145,46 +141,47 @@ bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& we
if (it == onnx_to_webnn_data_type_map.end())
return false;

std::string webnn_data_type = it->second;
const std::string_view webnn_data_type = it->second;

// Check if WebNN supports the data type.
emscripten::val is_supported = webnn_supported_data_types.call<emscripten::val>("includes",
emscripten::val(webnn_data_type));
emscripten::val is_supported =
webnn_supported_data_types.call<emscripten::val>("includes", emscripten::val(std::string(webnn_data_type)));
return is_supported.as<bool>();
}

// Check if the input or output data type of ONNX node is supported by the WebNN operator.
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
bool IsDataTypeSupportedByOp(const std::string_view onnx_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const std::string_view webnn_input_output_name,
const std::string_view onnx_input_output_name,
const logging::Logger& logger) {
std::string webnn_op_type;
if (!GetWebNNOpType(onnx_op_type, webnn_op_type))
return false;
const std::string_view webnn_op_type = GetWebNNOpType(onnx_op_type);

return IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits,
return !webnn_op_type.empty() &&
IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits,
webnn_input_output_name, onnx_input_output_name, logger);
}

bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
bool IsDataTypeSupportedByWebNNOp(const std::string_view onnx_op_type,
const std::string_view webnn_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const std::string_view webnn_input_output_name,
const std::string_view onnx_input_output_name,
const logging::Logger& logger) {
if (wnn_limits[webnn_op_type].isUndefined()) {
if (wnn_limits[std::string(webnn_op_type)].isUndefined()) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] is not supported for now";
return false;
}
if (wnn_limits[webnn_op_type][webnn_input_output_name].isUndefined()) {

if (wnn_limits[std::string(webnn_op_type)][std::string(webnn_input_output_name)].isUndefined()) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] doesn't have parameter ["
<< webnn_input_output_name << "]";
return false;
}
if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
if (!IsSupportedDataType(
onnx_data_type, wnn_limits[std::string(webnn_op_type)][std::string(webnn_input_output_name)]["dataTypes"])) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] " << onnx_input_output_name << "'s data type: ["
<< onnx_data_type << "] is not supported by WebNN op [" << webnn_op_type << "] for now";
return false;
Expand Down
84 changes: 52 additions & 32 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,16 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger);
// TODO(@Honry): Some ONNX ops are supported by decomposed WebNN ops,
// we need to check the support of the decomposed ops.
static const InlinedHashMap<std::string, std::string> op_map = {

// Some ONNX ops are supported by decomposed WebNN ops.
const std::map<std::string_view, std::vector<std::string_view>> decomposed_op_map = {
{"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}},
{"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "split"}},
{"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
{"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
};
// ONNX op type to WebNN op type mapping.
const std::map<std::string_view, std::string_view> op_map = {
{"Abs", "abs"},
{"Add", "add"},
{"And", "logicalAnd"},
Expand Down Expand Up @@ -247,7 +254,6 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Log", "log"},
{"LpPool", "l2Pool2d"},
{"LSTM", "lstm"},
{"LRN", "averagePool2d"},
{"MatMul", "matmul"},
{"MatMulInteger", "matmulInteger"},
{"Max", "max"},
Expand Down Expand Up @@ -275,17 +281,14 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Relu", "relu"},
{"Reshape", "reshape"},
{"Resize", "resample2d"},
{"RotaryEmbedding", "gather"},
{"ScatterElements", "scatterElements"},
{"ScatterND", "scatterND"},
{"Shape", "slice"},
{"Sigmoid", "sigmoid"},
{"Sign", "sign"},
{"SimplifiedLayerNormalization", "layerNormalization"},
{"Softplus", "softplus"},
{"Softsign", "softsign"},
{"Sin", "sin"},
{"SkipSimplifiedLayerNormalization", "layerNormalization"},
{"Slice", "slice"},
{"Softmax", "softmax"},
{"Split", "split"},
Expand All @@ -302,29 +305,46 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Xor", "logicalXor"},
};

inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder,
const WebnnDeviceType device_type) {
auto op_map_entry = op_map.find(op_type);
// Returns false if the op_type is not listed in the op_map or
// if the WebNN op has not been implemented in MLGraphBuilder in current browser.
if (op_map_entry == op_map.end() || !wnn_builder[op_map_entry->second].as<bool>()) {
return false;
}
// WebNN op name to its first input name mapping, only record the name that is different from "input".
// This map is used to determine the first input name of a WebNN op and is utilized by OpSupportLimits.
const std::map<std::string_view, std::string_view> webnn_op_first_input_name_map = {
{"add", "a"},
{"concat", "inputs"},
{"div", "a"},
{"equal", "a"},
{"gemm", "a"},
{"greater", "a"},
{"greaterOrEqual", "a"},
{"lesser", "a"},
{"lesserOrEqual", "a"},
{"logicalAnd", "a"},
{"logicalNot", "a"},
{"logicalOr", "a"},
{"logicalXor", "a"},
{"matmul", "a"},
{"max", "a"},
{"min", "a"},
{"mul", "a"},
{"pow", "a"},
{"sub", "a"},
{"where", "condition"},
};

return true;
// Retrieve the first input name of a WebNN op used for validating supported input data types.
// WebNN ops have various first input names such as 'a', 'input', 'inputs', etc.
// Special names other than 'input' are recorded in the webnn_op_first_input_name_map.
inline std::string_view GetWebNNOpFirstInputName(const std::string_view webnn_op_type) {
auto it = webnn_op_first_input_name_map.find(webnn_op_type);
return (it != webnn_op_first_input_name_map.end()) ? it->second : "input";
}

inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) {
inline std::string_view GetWebNNOpType(const std::string_view op_type) {
auto it = op_map.find(op_type);
// Returns false if the op_type is not listed in the op_map.
if (it == op_map.end()) {
return false;
}
webnn_op_type = it->second;
return true;
// Return an empty string if the op_type is not listed in the op_map.
return (it != op_map.end()) ? it->second : "";
}

static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> onnx_to_webnn_data_type_map = {
const std::map<ONNX_NAMESPACE::TensorProto_DataType, std::string_view> onnx_to_webnn_data_type_map = {
{ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"},
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"},
Expand All @@ -338,22 +358,22 @@ static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> o
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"},
};

bool AreInputDataTypesSame(const std::string& op_type,
bool AreInputDataTypesSame(const std::string_view op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger);
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types);
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
bool IsDataTypeSupportedByOp(const std::string_view onnx_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const std::string_view webnn_input_output_name,
const std::string_view onnx_input_output_name,
const logging::Logger& logger);
bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
bool IsDataTypeSupportedByWebNNOp(const std::string_view onnx_op_type,
const std::string_view webnn_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const std::string_view webnn_input_output_name,
const std::string_view onnx_input_output_name,
const logging::Logger& logger);

bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,17 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializ
const logging::Logger& logger) const {
// We only check the type of input 0 by default, specific op builder can override this.
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
const std::string_view webnn_op_type = GetWebNNOpType(op_type);
if (webnn_op_type.empty())
return false;

return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
const std::string_view webnn_input_name = GetWebNNOpFirstInputName(webnn_op_type);
return IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits,
webnn_input_name, "input", logger);
}

bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits,
Expand All @@ -83,7 +88,7 @@ bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
const logging::Logger& logger) const {
// We only check the type of output 0 by default, specific op builder can override this.
const auto& output = *node.OutputDefs()[0];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t output_type;
if (!GetType(output, output_type, logger))
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
bool BinaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type;
int32_t input1_type;

Expand Down
24 changes: 0 additions & 24 deletions onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ class CastOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
private:
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -85,25 +80,6 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return Status::OK();
}

// Operator support related.
bool CastOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input_type;

if (!GetType(*input_defs[0], input_type, logger))
return false;

if (!IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "input", logger))
return false;

NodeAttrHelper helper(node);
// Check cast to type.
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
return IsDataTypeSupportedByOp(op_type, to_type, wnn_limits, "output", "to", logger);
}

void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<CastOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
bool ConcatOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type;

if (!GetType(*input_defs[0], input0_type, logger))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
bool ConvOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type; // input data type
int32_t input1_type; // weight data type
int32_t input2_type; // bias or x_zero_point data type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* init
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();

const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type;
int32_t input1_type;
bool has_input1 = TensorExists(input_defs, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ bool GatherElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet&
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ bool GatherNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* in
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ bool GatherOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* init
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input_type;
int32_t indices_type;
if (!GetType(input, input_type, logger) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer
bool GemmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type; // A data type
int32_t input1_type; // B data type
int32_t input2_type; // C or a_zero_point data type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c
bool GruOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input_X_type = 0; // input data type
int32_t input_W_type = 0; // weight data type
int32_t input_R_type = 0; // recurrent weight data type
Expand Down Expand Up @@ -226,7 +226,7 @@ bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& output_defs = node.OutputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t Y_type = 0;
int32_t Y_h_type = 0;
bool has_Y = TensorExists(output_defs, 0);
Expand Down
Loading

0 comments on commit 5fa8bd0

Please sign in to comment.