From ac4b26a56180ccea77b1f4be9020d586ae508941 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Fri, 8 Dec 2023 23:29:52 +0300 Subject: [PATCH] Replace Slice optional inputs removal to adjustment --- .../dnn/src/onnx/onnx_graph_simplifier.cpp | 58 ++++++++----------- modules/dnn/src/onnx/onnx_importer.cpp | 4 +- 2 files changed, 26 insertions(+), 36 deletions(-) diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index 7b9ec4082a..e1fa80c165 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -182,12 +182,12 @@ static Mat extractConstant(const Ptr& net, int node_id, int } /* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have - Slice with optional inputs of default values, some of them don't. This Subgraph removes - all optional inputs of Slice if values are default. + Slice with optional inputs of default values, some of them don't. This Subgraph adjusts + all optional inputs of Slice up to 5. */ -class RemoveSliceAllOptionalInputsSubgraph : public Subgraph { +class AdjustSliceAllOptionalInputsSubgraph : public Subgraph { public: - RemoveSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) { + AdjustSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) { num_inputs_ = num_inputs; int input = addNodeToMatch(""); @@ -200,35 +200,17 @@ class RemoveSliceAllOptionalInputsSubgraph : public Subgraph { slice_id = addNodeToMatch("Slice", inputs); - setFusedNode("Slice", std::vector{input, starts, ends}); + setFusedNode("Slice", inputs); } - virtual bool match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds) CV_OVERRIDE { - if (Subgraph::match(net, nodeId, matchedNodesIds)) { - if (num_inputs_ >= 4) { // with axes - // Check if axes are -1 or last axis - auto onnx_net = net.dynamicCast(); - int shape_size = onnx_net->getTensorShapeSize(matchedNodesIds[slice_id], 0); - - auto axes = extractConstant(net, matchedNodesIds[slice_id], 3); - for (size_t i = 0; i < axes.total(); i++) { - const int axis = *(axes.ptr() + i); - if (axis != -1 && axis != shape_size - 1) { - return false; - } - } - } - if (num_inputs_ == 5) { - // Check if steps are 1 - auto steps = extractConstant(net, matchedNodesIds[slice_id], 4); - if (countNonZero(steps != 1)) { - return false; - } - } - return true; + virtual void finalize(const Ptr&, + const Ptr& fusedNode, + std::vector >&) CV_OVERRIDE + { + opencv_onnx::NodeProto* node = fusedNode.dynamicCast()->node; + for (int i = num_inputs_; i < 5; ++i) { + node->add_input(""); } - return false; } private: @@ -1119,7 +1101,11 @@ public: ResizeSubgraph1() : ExtractScalesSubgraph() { int shape = addNodeToMatch("Shape", input); - int slice = addNodeToMatch("Slice", shape, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant")); + int slice = addNodeToMatch("Slice", {shape, + addNodeToMatch(""), + addNodeToMatch(""), + addNodeToMatch(""), + addNodeToMatch("")}); int castConcat = addNodeToMatch("Cast", concatId); int concat = addNodeToMatch("Concat", slice, castConcat); @@ -1163,7 +1149,11 @@ public: int cast = addNodeToMatch("Cast", concat1); int shape2 = addNodeToMatch("Shape", input); - int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant")); + int slice = addNodeToMatch("Slice", {shape2, + addNodeToMatch(""), + addNodeToMatch(""), + addNodeToMatch(""), + addNodeToMatch("")}); int concat2 = addNodeToMatch("Concat", slice, cast); addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2); @@ -1235,8 +1225,8 @@ public: void simplifySubgraphs(opencv_onnx::GraphProto& net) { std::vector > subgraphs; - subgraphs.push_back(makePtr(4)); - subgraphs.push_back(makePtr(5)); + subgraphs.push_back(makePtr(3)); + subgraphs.push_back(makePtr(4)); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 3d1b4d8b82..f52a161f08 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -1235,7 +1235,7 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP starts_ = DictValue::arrayInt(start_blob.begin(), start_blob.total()); ends_ = DictValue::arrayInt(end_blob.begin(), end_blob.total()); - if (inp_size > 3) + if (inp_size > 3 && !getBlob(node_proto, 3).empty()) { Mat axes_blob = getBlob(node_proto, 3); CV_Assert(axes_blob.total() == start_blob.total()); @@ -1244,7 +1244,7 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP has_axes = true; } - if (inp_size == 5) + if (inp_size == 5 && !getBlob(node_proto, 4).empty()) { Mat step_blob = getBlob(node_proto, 4); CV_Assert(step_blob.total() == start_blob.total());