Skip to content

Commit

Permalink
[WebNN EP] Support more Normalization ops for TFLite backend (#21151)
Browse files Browse the repository at this point in the history
Following Normalization ops have been supported in Chromium for TFLite
backend:
- batchNormalization:
https://chromium-review.googlesource.com/c/chromium/src/+/5532745
- layerNormalization:
https://chromium-review.googlesource.com/c/chromium/src/+/5573326
- instanceNormalization:
https://chromium-review.googlesource.com/c/chromium/src/+/5532750
  • Loading branch information
Honry authored Jun 25, 2024
1 parent f81c0ec commit 4743803
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
6 changes: 3 additions & 3 deletions js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| ArgMax | ai.onnx(7-10, 11, 12, 13+) | argMax ||| WebNN CPU backend only supports 'select_last_index' value is 0 |
| ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin ||| WebNN CPU backend only supports 'select_last_index' value is 0 |
| AveragePool | ai.onnx(7-9, 10, 11, 12-18, 19+) | averagePool2d ||| Only supports 4-D input, 2-D 'kernel_shape', 'count_include_pad' value is 0 |
| BatchNormalization | ai.onnx(7-8, 9-13, 14, 15+) | batchNormalization | || Only supports 'training_mode' value is 0, one output |
| BatchNormalization | ai.onnx(7-8, 9-13, 14, 15+) | batchNormalization | || Only supports 'training_mode' value is 0, one output |
| Cast | ai.onnx(7-8, 9-12, 13-18, 19-20, 21+) | cast ||| WebNN CPU backend doesn't support casting to uint64 data type |
| Ceil | ai.onnx(7-12, 13+) | ceil ||| |
| Clip | ai.onnx(7-10, 11, 12, 13+) | clamp ||| WebNN CPU backend only supports 3 specific ranges: [0.0, infinity], [-1.0, 1.0], [0.0, 6.0] (Chromium issue: https://issues.chromium.org/issues/326156496) |
Expand All @@ -43,8 +43,8 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| HardSigmoid | ai.onnx(7+) | hardSigmoid ||| |
| HardSwish | ai.onnx(14+) | hardSwish ||| |
| Identity | ai.onnx(7-13, 14-15, 16-18, 19-20, 21+) | identity ||| |
| InstanceNormalization | ai.onnx(7+) | instanceNormalization | || |
| LayerNormalization | ai.onnx(7-16, 17+) | layerNormalization | || |
| InstanceNormalization | ai.onnx(7+) | instanceNormalization | || |
| LayerNormalization | ai.onnx(7-16, 17+) | layerNormalization | || |
| LeakyRelu | ai.onnx(7-15, 16+) | leakyRelu ||| |
| Less | ai.onnx(7-8, 9-12, 13+) | lesser ||| |
| LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual ||| |
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"ArgMax", {"argMax", true}},
{"ArgMin", {"argMin", true}},
{"AveragePool", {"averagePool2d", true}},
{"BatchNormalization", {"batchNormalization", false}},
{"BatchNormalization", {"batchNormalization", true}},
{"Cast", {"cast", true}},
{"Ceil", {"ceil", true}},
{"Clip", {"clamp", true}},
Expand Down Expand Up @@ -190,8 +190,8 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"HardSigmoid", {"hardSigmoid", true}},
{"HardSwish", {"hardSwish", true}},
{"Identity", {"identity", true}},
{"InstanceNormalization", {"instanceNormalization", false}},
{"LayerNormalization", {"layerNormalization", false}},
{"InstanceNormalization", {"instanceNormalization", true}},
{"LayerNormalization", {"layerNormalization", true}},
{"LeakyRelu", {"leakyRelu", true}},
{"Less", {"lesser", true}},
{"LessOrEqual", {"lesserOrEqual", true}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
int64_t axis = helper.Get("axis", -1);
axis = HandleNegativeAxis(axis, rank);
std::vector<uint32_t> axes(rank - SafeInt<uint32_t>(axis));
if (model_builder.GetPreferredLayout() == DataLayout::NHWC && axis > 1) {
std::iota(axes.begin(), axes.end(), axis - 1);
} else {
std::iota(axes.begin(), axes.end(), axis);
}
std::iota(axes.begin(), axes.end(), axis);

options.set("axes", emscripten::val::array(axes));
output = model_builder.GetBuilder().call<emscripten::val>("layerNormalization", input, options);
} else if (op_type == "InstanceNormalization") {
Expand Down

0 comments on commit 4743803

Please sign in to comment.