diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index 77dc1c52df..7b8dd483c7 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -86,6 +86,7 @@ public: int getTensorShapeSize(int node_id, int node_input_id) { const auto node = getNode(node_id); const auto &input_name = node->getInputName(node_input_id); + // try to get from value_info for (int i = 0; i < net.value_info_size(); i++) { const auto value_info = net.value_info(i); if (value_info.name() == input_name) { @@ -97,6 +98,18 @@ public: } } } + // try to get from input + for (int i = 0; i < net.input_size(); i++) { + const auto input = net.input(i); + if (input.name() == input_name) { + if (input.has_type() && input.type().has_tensor_type() && + input.type().tensor_type().has_shape()) { + return input.type().tensor_type().shape().dim_size(); + } else { + return -1; + } + } + } return -1; } @@ -660,6 +673,10 @@ private: [Input] -> LayerNorm -> [Output] \ [weight], [bias] + + Note: axes of ReduceMean must be: + - last element is the axis of last dimension (-1 or (input_ndims - 1)) + - a list of adjacent axes, e.g. [1, 2, 3, ..., input_ndims - 1] */ class LayerNormSubGraph : public Subgraph { @@ -683,19 +700,22 @@ public: setFusedNode("LayerNormalization", input); } - static float extractAxis(const Ptr& net, int node_id) + static std::vector extractAxis(const Ptr& net, int node_id) { + // TODO: consider ReduceMean-18 which has axes as one of the inputs instead of attributes Ptr mean_ptr = net->getNode(node_id); opencv_onnx::NodeProto* mean_node = mean_ptr.dynamicCast()->node; - int axis_ = -1; + std::vector axes; for (int i = 0; i < mean_node->attribute_size(); i++) { opencv_onnx::AttributeProto attr = mean_node->attribute(i); if (attr.name() != "axes") continue; - axis_ = static_cast(attr.ints(0)); + for (int j = 0; j < attr.ints_size(); j++) { + axes.push_back(attr.ints(j)); + } } - return axis_; + return axes; } virtual bool match(const Ptr& net, int nodeId, @@ -707,11 +727,31 @@ public: if (pow_exp - 2 > 1e-5) // not pow(2) return false; - int axis_mean1 = extractAxis(net, matchedNodesIds[mean]); - int axis_mean2 = extractAxis(net, matchedNodesIds[mean1]); - if (axis_mean1 != axis_mean2) + std::vector axes = extractAxis(net, matchedNodesIds[mean]); + // check whether it is -1 or last_axis or [axis, ..., last_axis] + int64_t input_ndims = static_cast(net.dynamicCast()->getTensorShapeSize(matchedNodesIds[mean], 0)); + if (input_ndims == -1) { + return false; // input shape unknown + } + // assume that axes are sorted in ascending order, e.g. [0, 1, 2, 3] or [-3, -2, -1] + if (axes.back() != -1 && axes.back() != (input_ndims - 1)) { return false; - axis = axis_mean1; + } + for (size_t i = 0; i < axes.size() - 1; i++) { + if (axes[i] - axes[i + 1] != -1) { + return false; + } + } + + std::vector axes1 = extractAxis(net, matchedNodesIds[mean1]); + if (axes.size() != axes1.size()) + return false; + for (size_t i = 0; i < axes.size(); i++) { + if (((axes[i] + input_ndims) % input_ndims) != ((axes1[i] + input_ndims) % input_ndims)) { + return false; + } + } + axis = axes[0]; epsilon = extractConstant(net, matchedNodesIds[add], 1).at(0); diff --git a/modules/dnn/test/test_graph_simplifier.cpp b/modules/dnn/test/test_graph_simplifier.cpp index e09a68c158..91b4e271f5 100644 --- a/modules/dnn/test/test_graph_simplifier.cpp +++ b/modules/dnn/test/test_graph_simplifier.cpp @@ -47,6 +47,10 @@ TEST_F(Test_Graph_Simplifier, LayerNormSubGraph) { test("layer_norm_expanded_with_initializers", "LayerNormalization"); } +TEST_F(Test_Graph_Simplifier, LayerNormNoFusionSubGraph) { + test("layer_norm_no_fusion", std::vector{"NaryEltwise", "Reduce", "Sqrt"}); +} + TEST_F(Test_Graph_Simplifier, ResizeSubgraph) { /* Test for 6 subgraphs: - GatherCastSubgraph diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 457b151ccf..4d56cb0e17 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -3024,6 +3024,10 @@ TEST_P(Test_ONNX_nets, VitTrack) { normAssert(ref_output3, outputs[2], "VitTrack output3"); } +TEST_P(Test_ONNX_layers, LayerNormNoFusion) { + testONNXModels("layer_norm_no_fusion"); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets()); }} // namespace