From 4743803944d6f6333bf93b89a62d8083d4466710 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Tue, 25 Jun 2024 10:04:23 +0800 Subject: [PATCH] [WebNN EP] Support more Normalization ops for TFLite backend (#21151) Following Normalization ops have been supported in Chromium for TFLite backend: - batchNormalization: https://chromium-review.googlesource.com/c/chromium/src/+/5532745 - layerNormalization: https://chromium-review.googlesource.com/c/chromium/src/+/5573326 - instanceNormalization: https://chromium-review.googlesource.com/c/chromium/src/+/5532750 --- js/web/docs/webnn-operators.md | 6 +++--- onnxruntime/core/providers/webnn/builders/helper.h | 6 +++--- .../webnn/builders/impl/normalization_op_builder.cc | 7 ++----- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index a49759b9a93c5..987e063485846 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -16,7 +16,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | ArgMax | ai.onnx(7-10, 11, 12, 13+) | argMax | ✓ | ✓ | WebNN CPU backend only supports 'select_last_index' value is 0 | | ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin | ✓ | ✓ | WebNN CPU backend only supports 'select_last_index' value is 0 | | AveragePool | ai.onnx(7-9, 10, 11, 12-18, 19+) | averagePool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'count_include_pad' value is 0 | -| BatchNormalization | ai.onnx(7-8, 9-13, 14, 15+) | batchNormalization | ✗ | ✓ | Only supports 'training_mode' value is 0, one output | +| BatchNormalization | ai.onnx(7-8, 9-13, 14, 15+) | batchNormalization | ✓ | ✓ | Only supports 'training_mode' value is 0, one output | | Cast | ai.onnx(7-8, 9-12, 13-18, 19-20, 21+) | cast | ✓ | ✓ | WebNN CPU backend doesn't support casting to uint64 data type | | Ceil | ai.onnx(7-12, 13+) | ceil | ✓ | ✓ | | | Clip | ai.onnx(7-10, 11, 12, 13+) | clamp | ✓ | ✓ | WebNN CPU backend only supports 3 specific ranges: [0.0, infinity], [-1.0, 1.0], [0.0, 6.0] (Chromium issue: https://issues.chromium.org/issues/326156496) | @@ -43,8 +43,8 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | HardSigmoid | ai.onnx(7+) | hardSigmoid | ✓ | ✓ | | | HardSwish | ai.onnx(14+) | hardSwish | ✓ | ✓ | | | Identity | ai.onnx(7-13, 14-15, 16-18, 19-20, 21+) | identity | ✓ | ✓ | | -| InstanceNormalization | ai.onnx(7+) | instanceNormalization | ✗ | ✓ | | -| LayerNormalization | ai.onnx(7-16, 17+) | layerNormalization | ✗ | ✓ | | +| InstanceNormalization | ai.onnx(7+) | instanceNormalization | ✓ | ✓ | | +| LayerNormalization | ai.onnx(7-16, 17+) | layerNormalization | ✓ | ✓ | | | LeakyRelu | ai.onnx(7-15, 16+) | leakyRelu | ✓ | ✓ | | | Less | ai.onnx(7-8, 9-12, 13+) | lesser | ✓ | ✓ | | | LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual | ✓ | ✓ | | diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 401d2eaa09129..395a0b40e5bbb 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -160,7 +160,7 @@ static const InlinedHashMap op_map = { {"ArgMax", {"argMax", true}}, {"ArgMin", {"argMin", true}}, {"AveragePool", {"averagePool2d", true}}, - {"BatchNormalization", {"batchNormalization", false}}, + {"BatchNormalization", {"batchNormalization", true}}, {"Cast", {"cast", true}}, {"Ceil", {"ceil", true}}, {"Clip", {"clamp", true}}, @@ -190,8 +190,8 @@ static const InlinedHashMap op_map = { {"HardSigmoid", {"hardSigmoid", true}}, {"HardSwish", {"hardSwish", true}}, {"Identity", {"identity", true}}, - {"InstanceNormalization", {"instanceNormalization", false}}, - {"LayerNormalization", {"layerNormalization", false}}, + {"InstanceNormalization", {"instanceNormalization", true}}, + {"LayerNormalization", {"layerNormalization", true}}, {"LeakyRelu", {"leakyRelu", true}}, {"Less", {"lesser", true}}, {"LessOrEqual", {"lesserOrEqual", true}}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 90ad9b48d5866..a2aa0df5586e3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -87,11 +87,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder int64_t axis = helper.Get("axis", -1); axis = HandleNegativeAxis(axis, rank); std::vector axes(rank - SafeInt(axis)); - if (model_builder.GetPreferredLayout() == DataLayout::NHWC && axis > 1) { - std::iota(axes.begin(), axes.end(), axis - 1); - } else { - std::iota(axes.begin(), axes.end(), axis); - } + std::iota(axes.begin(), axes.end(), axis); + options.set("axes", emscripten::val::array(axes)); output = model_builder.GetBuilder().call("layerNormalization", input, options); } else if (op_type == "InstanceNormalization") {