From 8559237d4e0051d666d901c9e96511fddf318c69 Mon Sep 17 00:00:00 2001 From: "ashishiva3@gmail.com" Date: Thu, 13 Feb 2020 17:32:35 +0530 Subject: [PATCH] ONNX: upsample subgraph fusion added --- modules/dnn/src/graph_simplifier.cpp | 12 +- modules/dnn/src/graph_simplifier.hpp | 4 +- .../dnn/src/onnx/onnx_graph_simplifier.cpp | 197 +++++++++++++++++- .../dnn/src/onnx/onnx_graph_simplifier.hpp | 13 ++ modules/dnn/src/onnx/onnx_importer.cpp | 77 ------- .../src/tensorflow/tf_graph_simplifier.cpp | 10 +- modules/dnn/test/test_onnx_importer.cpp | 7 + 7 files changed, 232 insertions(+), 88 deletions(-) diff --git a/modules/dnn/src/graph_simplifier.cpp b/modules/dnn/src/graph_simplifier.cpp index 62651053fb..c5073d8a01 100644 --- a/modules/dnn/src/graph_simplifier.cpp +++ b/modules/dnn/src/graph_simplifier.cpp @@ -69,8 +69,12 @@ int Subgraph::getInputNodeId(const Ptr& net, const int numNodes = net->getNumNodes(); for (int i = 0; i < numNodes; ++i) { - if (net->getNodeName(i) == name) - return i; + const int numOutputs = net->getNumOutputs(i); + for (int j = 0; j < numOutputs; j++) + { + if (net->getOutputName(i, j) == name) + return i; + } } CV_Error(Error::StsParseError, "Input node with name " + name + " not found"); } @@ -111,12 +115,12 @@ bool Subgraph::match(const Ptr& net, int nodeId, continue; nodeId = getInputNodeId(net, node, j); const Ptr inpNode = net->getNode(nodeId); - if (inpNode->getType() != "Const") + if (inpNode->getType() != "Const" && inpNode->getType() != "Constant") { nodesToMatch.push(nodeId); targetNodes.push(inputNodes[j]); } - else if (nodes[inputNodes[j]] != "Const") + else if (nodes[inputNodes[j]] != "Const" && nodes[inputNodes[j]] != "Constant") return false; } matchedNodesIds.push_back(nodeToMatch); diff --git a/modules/dnn/src/graph_simplifier.hpp b/modules/dnn/src/graph_simplifier.hpp index 8f3958ba52..39d6262c1b 100644 --- a/modules/dnn/src/graph_simplifier.hpp +++ b/modules/dnn/src/graph_simplifier.hpp @@ -39,7 +39,9 @@ public: virtual int getNumNodes() const = 0; - virtual std::string getNodeName(int idx) const = 0; + virtual int getNumOutputs(int nodeId) const = 0; + + virtual std::string getOutputName(int nodeId, int outId) const = 0; virtual void removeNode(int idx) = 0; }; diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index 6ce119d765..41a768d23c 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -76,12 +76,21 @@ public: return numInputs + net.node_size(); } - virtual std::string getNodeName(int idx) const CV_OVERRIDE + virtual int getNumOutputs(int nodeId) const CV_OVERRIDE { - if (idx < numInputs) - return net.input(idx).name(); + if (nodeId < numInputs) + return 1; else - return net.node(idx - numInputs).output(0); + return net.node(nodeId - numInputs).output_size(); + } + + virtual std::string getOutputName(int nodeId, int outId) const CV_OVERRIDE + { + CV_Assert(outId < getNumOutputs(nodeId)); + if (nodeId < numInputs) + return net.input(nodeId).name(); + else + return net.node(nodeId - numInputs).output(outId); } virtual void removeNode(int idx) CV_OVERRIDE @@ -145,13 +154,193 @@ private: int axis; }; +class ExtractScalesSubgraph : public Subgraph +{ +public: + ExtractScalesSubgraph() + { + input = addNodeToMatch(""); + + int indexH = addNodeToMatch("Constant"); + int shape1 = addNodeToMatch("Shape", input); + int gather1 = addNodeToMatch("Gather", shape1, indexH); + int castG1 = addNodeToMatch("Cast", gather1); + scaleHNode = addNodeToMatch("Constant"); + int mul1 = addNodeToMatch("Mul", castG1, scaleHNode); + int castM1 = addNodeToMatch("Cast", mul1); + int floor1 = addNodeToMatch("Floor", castM1); + + int indexW = addNodeToMatch("Constant"); + int shape2 = addNodeToMatch("Shape", input); + int gather2 = addNodeToMatch("Gather", shape2, indexW); + int castG2 = addNodeToMatch("Cast", gather2); + scaleWNode = addNodeToMatch("Constant"); + int mul2 = addNodeToMatch("Mul", castG2, scaleWNode); + int castM2 = addNodeToMatch("Cast", mul2); + int floor2 = addNodeToMatch("Floor", castM2); + + int unsqueeze1 = addNodeToMatch("Unsqueeze", floor1); + int unsqueeze2 = addNodeToMatch("Unsqueeze", floor2); + concatId = addNodeToMatch("Concat", unsqueeze1, unsqueeze2); + } + + void finalize(const Ptr& net, + const Ptr& fusedNode, + std::vector >& inputs) CV_OVERRIDE + { + opencv_onnx::NodeProto* constant_node = inputs[1].dynamicCast()->node; + opencv_onnx::TensorProto tensor_proto = constant_node->attribute(0).t(); + float scaleW = getMatFromTensor(tensor_proto).at(0); + + constant_node = inputs[2].dynamicCast()->node; + tensor_proto = constant_node->attribute(0).t(); + float scaleH = getMatFromTensor(tensor_proto).at(0); + + opencv_onnx::NodeProto* node = fusedNode.dynamicCast()->node; + opencv_onnx::AttributeProto* attrH = node->add_attribute(); + attrH->set_name("height_scale"); + attrH->set_i(scaleH); + opencv_onnx::AttributeProto* attrW = node->add_attribute(); + attrW->set_name("width_scale"); + attrW->set_i(scaleW); + + node->mutable_input()->DeleteSubrange(1, 2); // Remove two last inputs + } + +protected: + int input, concatId; + int scaleHNode, scaleWNode; +}; + +class UpsampleSubgraph : public ExtractScalesSubgraph +{ +public: + UpsampleSubgraph() : ExtractScalesSubgraph() + { + int shape = addNodeToMatch("Shape", input); + int slice = addNodeToMatch("Slice", shape); + + int castConcat = addNodeToMatch("Cast", concatId); + int castSlice = addNodeToMatch("Cast", slice); + int divide = addNodeToMatch("Div", castConcat, castSlice); + + int constant = addNodeToMatch("Constant"); + int concat = addNodeToMatch("Concat", constant, divide); + + addNodeToMatch("Upsample", input, concat); + setFusedNode("Upsample", input, scaleWNode, scaleHNode); + } +}; + +class ResizeSubgraph1 : public ExtractScalesSubgraph +{ +public: + ResizeSubgraph1() : ExtractScalesSubgraph() + { + int shape = addNodeToMatch("Shape", input); + int slice = addNodeToMatch("Slice", shape, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant")); + + int castConcat = addNodeToMatch("Cast", concatId); + int concat = addNodeToMatch("Concat", slice, castConcat); + int constant = addNodeToMatch("Constant"); + + addNodeToMatch("Resize", input, constant, constant, concat); + setFusedNode("Upsample", input, scaleWNode, scaleHNode); + } +}; + +class ResizeSubgraph2 : public ExtractScalesSubgraph +{ +public: + ResizeSubgraph2() : ExtractScalesSubgraph() + { + int constantConcat = addNodeToMatch("Constant"); + int castConcat = addNodeToMatch("Cast", concatId); + int concat = addNodeToMatch("Concat", constantConcat, castConcat); + int constant = addNodeToMatch("Constant"); + + addNodeToMatch("Resize", input, constant, constant, concat); + setFusedNode("Upsample", input, scaleWNode, scaleHNode); + } +}; + void simplifySubgraphs(opencv_onnx::GraphProto& net) { std::vector > subgraphs; + subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); simplifySubgraphs(Ptr(new ONNXGraphWrapper(net)), subgraphs); } +Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto) +{ + if (tensor_proto.raw_data().empty() && tensor_proto.float_data().empty() && + tensor_proto.double_data().empty() && tensor_proto.int64_data().empty()) + return Mat(); + + opencv_onnx::TensorProto_DataType datatype = tensor_proto.data_type(); + Mat blob; + std::vector sizes; + for (int i = 0; i < tensor_proto.dims_size(); i++) { + sizes.push_back(tensor_proto.dims(i)); + } + if (sizes.empty()) + sizes.assign(1, 1); + if (datatype == opencv_onnx::TensorProto_DataType_FLOAT) { + + if (!tensor_proto.float_data().empty()) { + const ::google::protobuf::RepeatedField field = tensor_proto.float_data(); + Mat(sizes, CV_32FC1, (void*)field.data()).copyTo(blob); + } + else { + char* val = const_cast(tensor_proto.raw_data().c_str()); + Mat(sizes, CV_32FC1, val).copyTo(blob); + } + } + else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE) + { + const ::google::protobuf::RepeatedField field = tensor_proto.double_data(); + CV_Assert(!field.empty()); + Mat(sizes, CV_64FC1, (void*)field.data()).convertTo(blob, CV_32FC1); + } + else if (datatype == opencv_onnx::TensorProto_DataType_INT64) + { + blob.create(sizes, CV_32SC1); + int32_t* dst = reinterpret_cast(blob.data); + + if (!tensor_proto.int64_data().empty()) { + ::google::protobuf::RepeatedField< ::google::protobuf::int64> src = tensor_proto.int64_data(); + convertInt64ToInt32(src, dst, blob.total()); + } + else + { + const char* val = tensor_proto.raw_data().c_str(); +#if CV_STRONG_ALIGNMENT + // Aligned pointer is required: https://github.com/opencv/opencv/issues/16373 + // this doesn't work: typedef int64_t CV_DECL_ALIGNED(1) unaligned_int64_t; + AutoBuffer aligned_val; + if (!isAligned(val)) + { + size_t sz = tensor_proto.raw_data().size(); + aligned_val.allocate(divUp(sz, sizeof(int64_t))); + memcpy(aligned_val.data(), val, sz); + val = (const char*)aligned_val.data(); + } +#endif + const int64_t* src = reinterpret_cast(val); + convertInt64ToInt32(src, dst, blob.total()); + } + } + else + CV_Error(Error::StsUnsupportedFormat, "Unsupported data type: " + + opencv_onnx::TensorProto_DataType_Name(datatype)); + if (tensor_proto.dims_size() == 0) + blob.dims = 1; // To force 1-dimensional cv::Mat for scalars. + return blob; +} + CV__DNN_EXPERIMENTAL_NS_END }} // namespace cv::dnn diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.hpp b/modules/dnn/src/onnx/onnx_graph_simplifier.hpp index 52b4e5ecc0..34924164e5 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.hpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.hpp @@ -24,6 +24,19 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN void simplifySubgraphs(opencv_onnx::GraphProto& net); +template +void convertInt64ToInt32(const T1& src, T2& dst, int size) +{ + for (int i = 0; i < size; i++) { + if (src[i] < std::numeric_limits::min() || src[i] > std::numeric_limits::max()) { + CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range"); + } + dst[i] = saturate_cast(src[i]); + } +} + +Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto); + CV__DNN_EXPERIMENTAL_NS_END }} // namespace dnn, namespace cv diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index cd3693b6e8..6f3ac0409d 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -95,83 +95,6 @@ void releaseONNXTensor(opencv_onnx::TensorProto& tensor_proto) } } -template -void convertInt64ToInt32(const T1& src, T2& dst, int size) -{ - for (int i = 0; i < size; i++) { - if (src[i] < std::numeric_limits::min() || src[i] > std::numeric_limits::max()) { - CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range"); - } - dst[i] = saturate_cast(src[i]); - } -} - -Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto) -{ - CV_Assert(!tensor_proto.raw_data().empty() || !tensor_proto.float_data().empty() - || !tensor_proto.double_data().empty() || !tensor_proto.int64_data().empty()); - - opencv_onnx::TensorProto_DataType datatype = tensor_proto.data_type(); - Mat blob; - std::vector sizes; - for (int i = 0; i < tensor_proto.dims_size(); i++) { - sizes.push_back(tensor_proto.dims(i)); - } - if (sizes.empty()) - sizes.assign(1, 1); - if (datatype == opencv_onnx::TensorProto_DataType_FLOAT) { - - if (!tensor_proto.float_data().empty()) { - const ::google::protobuf::RepeatedField field = tensor_proto.float_data(); - Mat(sizes, CV_32FC1, (void*)field.data()).copyTo(blob); - } - else { - char* val = const_cast(tensor_proto.raw_data().c_str()); - Mat(sizes, CV_32FC1, val).copyTo(blob); - } - } - else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE) - { - const ::google::protobuf::RepeatedField field = tensor_proto.double_data(); - CV_Assert(!field.empty()); - Mat(sizes, CV_64FC1, (void*)field.data()).convertTo(blob, CV_32FC1); - } - else if (datatype == opencv_onnx::TensorProto_DataType_INT64) - { - blob.create(sizes, CV_32SC1); - int32_t* dst = reinterpret_cast(blob.data); - - if (!tensor_proto.int64_data().empty()) { - ::google::protobuf::RepeatedField< ::google::protobuf::int64> src = tensor_proto.int64_data(); - convertInt64ToInt32(src, dst, blob.total()); - } - else - { - const char* val = tensor_proto.raw_data().c_str(); -#if CV_STRONG_ALIGNMENT - // Aligned pointer is required: https://github.com/opencv/opencv/issues/16373 - // this doesn't work: typedef int64_t CV_DECL_ALIGNED(1) unaligned_int64_t; - AutoBuffer aligned_val; - if (!isAligned(val)) - { - size_t sz = tensor_proto.raw_data().size(); - aligned_val.allocate(divUp(sz, sizeof(int64_t))); - memcpy(aligned_val.data(), val, sz); - val = (const char*)aligned_val.data(); - } -#endif - const int64_t* src = reinterpret_cast(val); - convertInt64ToInt32(src, dst, blob.total()); - } - } - else - CV_Error(Error::StsUnsupportedFormat, "Unsupported data type: " + - opencv_onnx::TensorProto_DataType_Name(datatype)); - if (tensor_proto.dims_size() == 0) - blob.dims = 1; // To force 1-dimensional cv::Mat for scalars. - return blob; -} - void runLayer(LayerParams& params, const std::vector& inputs, std::vector& outputs) { diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index 565a64fc9d..2e7bb574e9 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -69,9 +69,15 @@ public: return net.node_size(); } - virtual std::string getNodeName(int idx) const CV_OVERRIDE + virtual int getNumOutputs(int nodeId) const CV_OVERRIDE { - return net.node(idx).name(); + return 1; + } + + virtual std::string getOutputName(int nodeId, int outId) const CV_OVERRIDE + { + CV_Assert(outId == 0); + return net.node(nodeId).name(); } virtual void removeNode(int idx) CV_OVERRIDE diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 35f00ef503..3dc6c7685a 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -316,6 +316,13 @@ TEST_P(Test_ONNX_layers, Resize) testONNXModels("resize_bilinear"); } +TEST_P(Test_ONNX_layers, ResizeUnfused) +{ + testONNXModels("upsample_unfused_opset9_torch1.4"); + testONNXModels("resize_nearest_unfused_opset11_torch1.4"); + testONNXModels("resize_nearest_unfused_opset11_torch1.3"); +} + TEST_P(Test_ONNX_layers, MultyInputs) { const String model = _tf("models/multy_inputs.onnx");