Merge pull request #24808 from fengyuentau:fix_layernorm

dnn: no layer norm fusion if axes.back() is not the axis of last dimension #24808

Merge with https://github.com/opencv/opencv_extra/pull/1137

Resolves https://github.com/opencv/opencv/issues/24797

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [ ] The feature is well documented and sample code can be built with the project CMake
pull/24843/head
Yuantao Feng 10 months ago committed by GitHub
parent 75dc334d39
commit 7fb336322d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 56
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp
  2. 4
      modules/dnn/test/test_graph_simplifier.cpp
  3. 4
      modules/dnn/test/test_onnx_importer.cpp

@ -86,6 +86,7 @@ public:
int getTensorShapeSize(int node_id, int node_input_id) { int getTensorShapeSize(int node_id, int node_input_id) {
const auto node = getNode(node_id); const auto node = getNode(node_id);
const auto &input_name = node->getInputName(node_input_id); const auto &input_name = node->getInputName(node_input_id);
// try to get from value_info
for (int i = 0; i < net.value_info_size(); i++) { for (int i = 0; i < net.value_info_size(); i++) {
const auto value_info = net.value_info(i); const auto value_info = net.value_info(i);
if (value_info.name() == input_name) { if (value_info.name() == input_name) {
@ -97,6 +98,18 @@ public:
} }
} }
} }
// try to get from input
for (int i = 0; i < net.input_size(); i++) {
const auto input = net.input(i);
if (input.name() == input_name) {
if (input.has_type() && input.type().has_tensor_type() &&
input.type().tensor_type().has_shape()) {
return input.type().tensor_type().shape().dim_size();
} else {
return -1;
}
}
}
return -1; return -1;
} }
@ -660,6 +673,10 @@ private:
[Input] -> LayerNorm -> [Output] [Input] -> LayerNorm -> [Output]
\ \
[weight], [bias] [weight], [bias]
Note: axes of ReduceMean must be:
- last element is the axis of last dimension (-1 or (input_ndims - 1))
- a list of adjacent axes, e.g. [1, 2, 3, ..., input_ndims - 1]
*/ */
class LayerNormSubGraph : public Subgraph class LayerNormSubGraph : public Subgraph
{ {
@ -683,19 +700,22 @@ public:
setFusedNode("LayerNormalization", input); setFusedNode("LayerNormalization", input);
} }
static float extractAxis(const Ptr<ImportGraphWrapper>& net, int node_id) static std::vector<int64_t> extractAxis(const Ptr<ImportGraphWrapper>& net, int node_id)
{ {
// TODO: consider ReduceMean-18 which has axes as one of the inputs instead of attributes
Ptr<ImportNodeWrapper> mean_ptr = net->getNode(node_id); Ptr<ImportNodeWrapper> mean_ptr = net->getNode(node_id);
opencv_onnx::NodeProto* mean_node = mean_ptr.dynamicCast<ONNXNodeWrapper>()->node; opencv_onnx::NodeProto* mean_node = mean_ptr.dynamicCast<ONNXNodeWrapper>()->node;
int axis_ = -1; std::vector<int64_t> axes;
for (int i = 0; i < mean_node->attribute_size(); i++) for (int i = 0; i < mean_node->attribute_size(); i++)
{ {
opencv_onnx::AttributeProto attr = mean_node->attribute(i); opencv_onnx::AttributeProto attr = mean_node->attribute(i);
if (attr.name() != "axes") if (attr.name() != "axes")
continue; continue;
axis_ = static_cast<int>(attr.ints(0)); for (int j = 0; j < attr.ints_size(); j++) {
axes.push_back(attr.ints(j));
}
} }
return axis_; return axes;
} }
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
@ -707,11 +727,31 @@ public:
if (pow_exp - 2 > 1e-5) // not pow(2) if (pow_exp - 2 > 1e-5) // not pow(2)
return false; return false;
int axis_mean1 = extractAxis(net, matchedNodesIds[mean]); std::vector<int64_t> axes = extractAxis(net, matchedNodesIds[mean]);
int axis_mean2 = extractAxis(net, matchedNodesIds[mean1]); // check whether it is -1 or last_axis or [axis, ..., last_axis]
if (axis_mean1 != axis_mean2) int64_t input_ndims = static_cast<int64_t>(net.dynamicCast<ONNXGraphWrapper>()->getTensorShapeSize(matchedNodesIds[mean], 0));
if (input_ndims == -1) {
return false; // input shape unknown
}
// assume that axes are sorted in ascending order, e.g. [0, 1, 2, 3] or [-3, -2, -1]
if (axes.back() != -1 && axes.back() != (input_ndims - 1)) {
return false; return false;
axis = axis_mean1; }
for (size_t i = 0; i < axes.size() - 1; i++) {
if (axes[i] - axes[i + 1] != -1) {
return false;
}
}
std::vector<int64_t> axes1 = extractAxis(net, matchedNodesIds[mean1]);
if (axes.size() != axes1.size())
return false;
for (size_t i = 0; i < axes.size(); i++) {
if (((axes[i] + input_ndims) % input_ndims) != ((axes1[i] + input_ndims) % input_ndims)) {
return false;
}
}
axis = axes[0];
epsilon = extractConstant(net, matchedNodesIds[add], 1).at<float>(0); epsilon = extractConstant(net, matchedNodesIds[add], 1).at<float>(0);

@ -47,6 +47,10 @@ TEST_F(Test_Graph_Simplifier, LayerNormSubGraph) {
test("layer_norm_expanded_with_initializers", "LayerNormalization"); test("layer_norm_expanded_with_initializers", "LayerNormalization");
} }
TEST_F(Test_Graph_Simplifier, LayerNormNoFusionSubGraph) {
test("layer_norm_no_fusion", std::vector<std::string>{"NaryEltwise", "Reduce", "Sqrt"});
}
TEST_F(Test_Graph_Simplifier, ResizeSubgraph) { TEST_F(Test_Graph_Simplifier, ResizeSubgraph) {
/* Test for 6 subgraphs: /* Test for 6 subgraphs:
- GatherCastSubgraph - GatherCastSubgraph

@ -3024,6 +3024,10 @@ TEST_P(Test_ONNX_nets, VitTrack) {
normAssert(ref_output3, outputs[2], "VitTrack output3"); normAssert(ref_output3, outputs[2], "VitTrack output3");
} }
TEST_P(Test_ONNX_layers, LayerNormNoFusion) {
testONNXModels("layer_norm_no_fusion");
}
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets()); INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
}} // namespace }} // namespace

Loading…
Cancel
Save