diff --git a/modules/dnn/src/graph_simplifier.cpp b/modules/dnn/src/graph_simplifier.cpp new file mode 100644 index 0000000000..62651053fb --- /dev/null +++ b/modules/dnn/src/graph_simplifier.cpp @@ -0,0 +1,207 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// Copyright (C) 2020, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. + +#include "precomp.hpp" + +#include "graph_simplifier.hpp" + +#include + +namespace cv { namespace dnn { + +Subgraph::~Subgraph() {} + +int Subgraph::addNodeToMatch(const std::string& op, int input_0, int input_1, + int input_2, int input_3) +{ + int nodeInputs[] = {input_0, input_1, input_2, input_3}; + int numInputs = 0; + for (int i = 0; i < 4; ++i) + { + numInputs += (int)(nodeInputs[i] != -1); + } + return addNodeToMatch(op, std::vector(&nodeInputs[0], &nodeInputs[0] + numInputs)); +} + +int Subgraph::addNodeToMatch(const std::string& op, const std::vector& inputs_) +{ + for (int i = 0; i < inputs_.size(); ++i) + { + CV_Assert(inputs_[i] < (int)nodes.size()); + } + nodes.push_back(op); + inputs.push_back(inputs_); + return nodes.size() - 1; +} + +void Subgraph::setFusedNode(const std::string& op, int input_0, int input_1, + int input_2, int input_3, int input_4, int input_5) +{ + int nodeInputs[] = {input_0, input_1, input_2, input_3, input_4, input_5}; + int numInputs = 0; + for (int i = 0; i < 6; ++i) + { + CV_Assert(nodeInputs[i] < (int)nodes.size()); + numInputs += (int)(nodeInputs[i] != -1); + } + setFusedNode(op, std::vector(&nodeInputs[0], &nodeInputs[0] + numInputs)); +} + +void Subgraph::setFusedNode(const std::string& op, const std::vector& inputs_) +{ + fusedNodeInputs = inputs_; + fusedNodeOp = op; +} + +int Subgraph::getInputNodeId(const Ptr& net, + const Ptr& node, + int inpId) +{ + CV_Assert(inpId < node->getNumInputs()); + std::string name = node->getInputName(inpId); + // If operation produces several tensors, they are specified by index + // after ':' character. In example, "input:0". + name = name.substr(0, name.rfind(':')); + const int numNodes = net->getNumNodes(); + for (int i = 0; i < numNodes; ++i) + { + if (net->getNodeName(i) == name) + return i; + } + CV_Error(Error::StsParseError, "Input node with name " + name + " not found"); +} + +bool Subgraph::match(const Ptr& net, int nodeId, + std::vector& matchedNodesIds, + std::vector& targetNodesIds) +{ + matchedNodesIds.clear(); + targetNodesIds.clear(); + + std::queue nodesToMatch; + std::queue targetNodes; + nodesToMatch.push(nodeId); + targetNodes.push(nodes.size() - 1); + while (!nodesToMatch.empty()) + { + int nodeToMatch = nodesToMatch.front(); + int targetNodeId = targetNodes.front(); + nodesToMatch.pop(); + targetNodes.pop(); + + if (std::find(matchedNodesIds.begin(), matchedNodesIds.end(), nodeToMatch) != + matchedNodesIds.end()) + continue; + + const Ptr node = net->getNode(nodeToMatch); + if (node->getType() != nodes[targetNodeId]) + return false; + + std::vector& inputNodes = inputs[targetNodeId]; + if (inputNodes.size() != node->getNumInputs()) + return false; + + for (int j = 0; j < inputNodes.size(); ++j) + { + if (nodes[inputNodes[j]].empty()) // Unknown input node type. + continue; + nodeId = getInputNodeId(net, node, j); + const Ptr inpNode = net->getNode(nodeId); + if (inpNode->getType() != "Const") + { + nodesToMatch.push(nodeId); + targetNodes.push(inputNodes[j]); + } + else if (nodes[inputNodes[j]] != "Const") + return false; + } + matchedNodesIds.push_back(nodeToMatch); + targetNodesIds.push_back(targetNodeId); + } + + const int n = matchedNodesIds.size(); + std::vector > elements(n); + for (int i = 0; i < n; ++i) + elements[i] = std::make_pair(matchedNodesIds[i], targetNodesIds[i]); + std::sort(elements.begin(), elements.end()); + for (int i = 0; i < n; ++i) + { + matchedNodesIds[i] = elements[i].first; + targetNodesIds[i] = elements[i].second; + } + return true; +} + +void Subgraph::replace(const Ptr& net, const std::vector& matchedNodesIds, + const std::vector& targetNodesIds) +{ + // Extract names of input nodes. + std::vector inputsNames(fusedNodeInputs.size()); + for (int i = 0; i < fusedNodeInputs.size(); ++i) + { + std::string inpName; + // Find input node name looking at inputs of fused nodes. + for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j) + { + Ptr node = net->getNode(matchedNodesIds[j]); + std::vector& inpIndices = inputs[targetNodesIds[j]]; + + CV_Assert(node->getNumInputs() == inpIndices.size()); + for (int k = 0; k < inpIndices.size(); ++k) + { + if (inpIndices[k] == fusedNodeInputs[i]) + { + inpName = node->getInputName(k); + break; + } + } + } + CV_Assert(!inpName.empty()); + inputsNames[i] = inpName; + } + + // Remove matched nodes except the last one. Indices in ascending order are expected. + Ptr node = net->getNode(matchedNodesIds.back()); + for (int i = matchedNodesIds.size() - 2; i >= 0; --i) + net->removeNode(matchedNodesIds[i]); + + // Modify the last node to be a fused one. + node->setType(fusedNodeOp); + node->setInputNames(inputsNames); + + std::vector > inputNodes(inputsNames.size()); + for (int i = 0; i < inputsNames.size(); ++i) + { + inputNodes[i] = net->getNode(getInputNodeId(net, node, i)); + } + finalize(net, node, inputNodes); +} + +void Subgraph::finalize(const Ptr& net, + const Ptr& fusedNode, + std::vector >& inputs) {} + +void simplifySubgraphs(const Ptr& net, + const std::vector >& patterns) +{ + int numNodes = net->getNumNodes(); + std::vector matchedNodesIds, targetNodesIds; + for (int i = 0; i < numNodes; ++i) + { + for (int j = 0; j < patterns.size(); ++j) + { + if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds)) + { + patterns[j]->replace(net, matchedNodesIds, targetNodesIds); + numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added. + break; + } + } + } +} + +}} // namespace cv::dnn diff --git a/modules/dnn/src/graph_simplifier.hpp b/modules/dnn/src/graph_simplifier.hpp new file mode 100644 index 0000000000..8f3958ba52 --- /dev/null +++ b/modules/dnn/src/graph_simplifier.hpp @@ -0,0 +1,100 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// Copyright (C) 2020, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. + +#ifndef __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__ +#define __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__ + +#include + +#include + +namespace cv { namespace dnn { + +class ImportNodeWrapper +{ +public: + virtual ~ImportNodeWrapper() {}; + + virtual int getNumInputs() const = 0; + + virtual std::string getInputName(int idx) const = 0; + + virtual std::string getType() const = 0; + + virtual void setType(const std::string& type) = 0; + + virtual void setInputNames(const std::vector& inputs) = 0; +}; + +class ImportGraphWrapper +{ +public: + virtual ~ImportGraphWrapper() {}; + + virtual Ptr getNode(int idx) const = 0; + + virtual int getNumNodes() const = 0; + + virtual std::string getNodeName(int idx) const = 0; + + virtual void removeNode(int idx) = 0; +}; + +class Subgraph // Interface to match and replace subgraphs. +{ +public: + virtual ~Subgraph(); + + // Add a node to be matched in the origin graph. Specify ids of nodes that + // are expected to be inputs. Returns id of a newly added node. + // TODO: Replace inputs to std::vector in C++11 + int addNodeToMatch(const std::string& op, int input_0 = -1, int input_1 = -1, + int input_2 = -1, int input_3 = -1); + + int addNodeToMatch(const std::string& op, const std::vector& inputs_); + + // Specify resulting node. All the matched nodes in subgraph excluding + // input nodes will be fused into this single node. + // TODO: Replace inputs to std::vector in C++11 + void setFusedNode(const std::string& op, int input_0 = -1, int input_1 = -1, + int input_2 = -1, int input_3 = -1, int input_4 = -1, + int input_5 = -1); + + void setFusedNode(const std::string& op, const std::vector& inputs_); + + static int getInputNodeId(const Ptr& net, + const Ptr& node, + int inpId); + + // Match TensorFlow subgraph starting from with a set of nodes to be fused. + // Const nodes are skipped during matching. Returns true if nodes are matched and can be fused. + virtual bool match(const Ptr& net, int nodeId, + std::vector& matchedNodesIds, + std::vector& targetNodesIds); + + // Fuse matched subgraph. + void replace(const Ptr& net, const std::vector& matchedNodesIds, + const std::vector& targetNodesIds); + + virtual void finalize(const Ptr& net, + const Ptr& fusedNode, + std::vector >& inputs); + +private: + std::vector nodes; // Nodes to be matched in the origin graph. + std::vector > inputs; // Connections of an every node to it's inputs. + + std::string fusedNodeOp; // Operation name of resulting fused node. + std::vector fusedNodeInputs; // Inputs of fused node. +}; + +void simplifySubgraphs(const Ptr& net, + const std::vector >& patterns); + +}} // namespace dnn, namespace cv + +#endif // __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__ diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp new file mode 100644 index 0000000000..f9f9194a22 --- /dev/null +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -0,0 +1,157 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// Copyright (C) 2020, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. + +#include "../precomp.hpp" + +#include "../graph_simplifier.hpp" +#include "onnx_graph_simplifier.hpp" + +#include + +namespace cv { namespace dnn { +CV__DNN_EXPERIMENTAL_NS_BEGIN + +// This wrapper can behave differently for fake input nodes and real graph nodes. +class ONNXNodeWrapper : public ImportNodeWrapper +{ +public: + ONNXNodeWrapper(opencv_onnx::NodeProto* _node = 0) : node(_node) {} + + virtual int getNumInputs() const CV_OVERRIDE + { + return node ? node->input_size() : 0; + } + + virtual std::string getInputName(int idx) const CV_OVERRIDE + { + CV_Assert_N(node, idx < node->input_size()); + return node->input(idx); + } + + virtual std::string getType() const CV_OVERRIDE + { + return node ? node->op_type() : ""; + } + + virtual void setType(const std::string& type) CV_OVERRIDE + { + CV_Assert(node); + node->set_op_type(type); + } + + virtual void setInputNames(const std::vector& inputs) CV_OVERRIDE + { + CV_Assert(node); + node->clear_input(); + for (int i = 0; i < inputs.size(); ++i) + node->add_input(inputs[i]); + } + + opencv_onnx::NodeProto* node; +}; + +// ONNX graph's inputs are separate from nodes so we index them before the rest of nodes. +class ONNXGraphWrapper : public ImportGraphWrapper +{ +public: + ONNXGraphWrapper(opencv_onnx::GraphProto& _net) : net(_net) + { + numInputs = net.input_size(); + } + + virtual Ptr getNode(int idx) const CV_OVERRIDE + { + opencv_onnx::NodeProto* node = 0; + if (idx >= numInputs) + node = net.mutable_node(idx - numInputs); + return makePtr(node); + } + + virtual int getNumNodes() const CV_OVERRIDE + { + return numInputs + net.node_size(); + } + + virtual std::string getNodeName(int idx) const CV_OVERRIDE + { + if (idx < numInputs) + return net.input(idx).name(); + else + return net.node(idx - numInputs).output(0); + } + + virtual void removeNode(int idx) CV_OVERRIDE + { + CV_Assert(idx >= numInputs); + net.mutable_node()->DeleteSubrange(idx - numInputs, 1); + } + +private: + int numInputs; + opencv_onnx::GraphProto& net; +}; + +class SoftMaxSubgraph : public Subgraph +{ +public: + SoftMaxSubgraph() + { + int input = addNodeToMatch(""); + int inpExp = addNodeToMatch("Exp", input); + int sum = addNodeToMatch("ReduceSum", inpExp); + addNodeToMatch("Div", inpExp, sum); + setFusedNode("Softmax", input); + } + + virtual bool match(const Ptr& net, int nodeId, + std::vector& matchedNodesIds, + std::vector& targetNodesIds) CV_OVERRIDE + { + if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) + { + Ptr sum = net->getNode(matchedNodesIds[1]); + opencv_onnx::NodeProto* node = sum.dynamicCast()->node; + + for (int i = 0; i < node->attribute_size(); i++) + { + opencv_onnx::AttributeProto attr = node->attribute(i); + if (attr.name() != "axes") + continue; + if (attr.ints_size() != 1) + CV_Error(Error::StsNotImplemented, format("Unexpected number of axes: %d", attr.ints_size())); + axis = attr.ints(0); + return true; + } + CV_Error(Error::StsNotImplemented, "Missed axes attribute"); + } + return false; + } + + virtual void finalize(const Ptr&, + const Ptr& fusedNode, + std::vector >&) CV_OVERRIDE + { + opencv_onnx::NodeProto* node = fusedNode.dynamicCast()->node; + opencv_onnx::AttributeProto* attr = node->add_attribute(); + attr->set_name("axis"); + attr->set_i(axis); + } + +private: + int axis; +}; + +void simplifySubgraphs(opencv_onnx::GraphProto& net) +{ + std::vector > subgraphs; + subgraphs.push_back(makePtr()); + + simplifySubgraphs(Ptr(new ONNXGraphWrapper(net)), subgraphs); +} + +CV__DNN_EXPERIMENTAL_NS_END +}} // namespace cv::dnn diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.hpp b/modules/dnn/src/onnx/onnx_graph_simplifier.hpp new file mode 100644 index 0000000000..52b4e5ecc0 --- /dev/null +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.hpp @@ -0,0 +1,30 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// Copyright (C) 2020, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. + +#ifndef __OPENCV_DNN_ONNX_SIMPLIFIER_HPP__ +#define __OPENCV_DNN_ONNX_SIMPLIFIER_HPP__ + +#include "../precomp.hpp" + +#if defined(__GNUC__) && __GNUC__ >= 5 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsuggest-override" +#endif +#include "opencv-onnx.pb.h" +#if defined(__GNUC__) && __GNUC__ >= 5 +#pragma GCC diagnostic pop +#endif + +namespace cv { namespace dnn { +CV__DNN_EXPERIMENTAL_NS_BEGIN + +void simplifySubgraphs(opencv_onnx::GraphProto& net); + +CV__DNN_EXPERIMENTAL_NS_END +}} // namespace dnn, namespace cv + +#endif // __OPENCV_DNN_ONNX_SIMPLIFIER_HPP__ diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index f08b6bd740..36945f5317 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -26,6 +26,8 @@ #pragma GCC diagnostic pop #endif +#include "onnx_graph_simplifier.hpp" + namespace cv { namespace dnn { CV__DNN_EXPERIMENTAL_NS_BEGIN @@ -326,6 +328,9 @@ void ONNXImporter::populateNet(Net dstNet) { CV_Assert(model_proto.has_graph()); opencv_onnx::GraphProto graph_proto = model_proto.graph(); + + simplifySubgraphs(graph_proto); + std::map constBlobs = getGraphTensors(graph_proto); // List of internal blobs shapes. std::map outShapes; diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index 0f9670e8a4..0d53be5a58 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -9,6 +9,7 @@ #ifdef HAVE_PROTOBUF +#include "../graph_simplifier.hpp" #include "tf_graph_simplifier.hpp" #include @@ -18,203 +19,87 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN using ::google::protobuf::RepeatedField; using ::google::protobuf::MapPair; -class Subgraph // Interface to match and replace TensorFlow subgraphs. +class TFNodeWrapper : public ImportNodeWrapper { public: - virtual ~Subgraph() {} + TFNodeWrapper(tensorflow::NodeDef* _node) : node(_node) {} - // Add a node to be matched in the origin graph. Specify ids of nodes that - // are expected to be inputs. Returns id of a newly added node. - // TODO: Replace inputs to std::vector in C++11 - int addNodeToMatch(const std::string& op, int input_0 = -1, int input_1 = -1, - int input_2 = -1, int input_3 = -1) + virtual int getNumInputs() const CV_OVERRIDE { - int nodeInputs[] = {input_0, input_1, input_2, input_3}; - int numInputs = 0; - for (int i = 0; i < 4; ++i) - { - numInputs += (int)(nodeInputs[i] != -1); - } - return addNodeToMatch(op, std::vector(&nodeInputs[0], &nodeInputs[0] + numInputs)); + return node->input_size(); } - int addNodeToMatch(const std::string& op, const std::vector& inputs_) + virtual std::string getInputName(int idx) const CV_OVERRIDE { - for (int i = 0; i < inputs_.size(); ++i) - { - CV_Assert(inputs_[i] < (int)nodes.size()); - } - nodes.push_back(op); - inputs.push_back(inputs_); - return nodes.size() - 1; + return node->input(idx); } - // Specify resulting node. All the matched nodes in subgraph excluding - // input nodes will be fused into this single node. - // TODO: Replace inputs to std::vector in C++11 - void setFusedNode(const std::string& op, int input_0 = -1, int input_1 = -1, - int input_2 = -1, int input_3 = -1, int input_4 = -1, - int input_5 = -1) + virtual std::string getType() const CV_OVERRIDE { - int nodeInputs[] = {input_0, input_1, input_2, input_3, input_4, input_5}; - int numInputs = 0; - for (int i = 0; i < 6; ++i) - { - CV_Assert(nodeInputs[i] < (int)nodes.size()); - numInputs += (int)(nodeInputs[i] != -1); - } - setFusedNode(op, std::vector(&nodeInputs[0], &nodeInputs[0] + numInputs)); + return node->op(); } - void setFusedNode(const std::string& op, const std::vector& inputs_) + virtual void setType(const std::string& type) CV_OVERRIDE { - fusedNodeInputs = inputs_; - fusedNodeOp = op; + node->set_op(type); } - static int getInputNodeId(const tensorflow::GraphDef& net, - const tensorflow::NodeDef& node, - int inpId) + virtual void setInputNames(const std::vector& inputs) CV_OVERRIDE { - CV_Assert(inpId < node.input_size()); - std::string name = node.input(inpId); - // If operation produces several tensors, they are specified by index - // after ':' character. In example, "input:0". - name = name.substr(0, name.rfind(':')); - const int numNodes = net.node_size(); - for (int i = 0; i < numNodes; ++i) - { - if (net.node(i).name() == name) - return i; - } - CV_Error(Error::StsParseError, "Input node with name " + name + " not found"); + node->clear_input(); + for (int i = 0; i < inputs.size(); ++i) + node->add_input(inputs[i]); } - // Match TensorFlow subgraph starting from with a set of nodes to be fused. - // Const nodes are skipped during matching. Returns true if nodes are matched and can be fused. - virtual bool match(const tensorflow::GraphDef& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds) - { - matchedNodesIds.clear(); - targetNodesIds.clear(); - - std::queue nodesToMatch; - std::queue targetNodes; - nodesToMatch.push(nodeId); - targetNodes.push(nodes.size() - 1); - while (!nodesToMatch.empty()) - { - int nodeToMatch = nodesToMatch.front(); - int targetNodeId = targetNodes.front(); - nodesToMatch.pop(); - targetNodes.pop(); - - if (std::find(matchedNodesIds.begin(), matchedNodesIds.end(), nodeToMatch) != - matchedNodesIds.end()) - continue; - - const tensorflow::NodeDef& node = net.node(nodeToMatch); - if (node.op() != nodes[targetNodeId]) - return false; - - std::vector& inputNodes = inputs[targetNodeId]; - if (inputNodes.size() != node.input_size()) - return false; + tensorflow::NodeDef* node; +}; - for (int j = 0; j < inputNodes.size(); ++j) - { - if (nodes[inputNodes[j]].empty()) // Unknown input node type. - continue; - nodeId = getInputNodeId(net, node, j); - const tensorflow::NodeDef& inpNode = net.node(nodeId); - if (inpNode.op() != "Const") - { - nodesToMatch.push(nodeId); - targetNodes.push(inputNodes[j]); - } - else if (nodes[inputNodes[j]] != "Const") - return false; - } - matchedNodesIds.push_back(nodeToMatch); - targetNodesIds.push_back(targetNodeId); - } +class TFGraphWrapper : public ImportGraphWrapper +{ +public: + TFGraphWrapper(tensorflow::GraphDef& _net) : net(_net) {} - const int n = matchedNodesIds.size(); - std::vector > elements(n); - for (int i = 0; i < n; ++i) - elements[i] = std::make_pair(matchedNodesIds[i], targetNodesIds[i]); - std::sort(elements.begin(), elements.end()); - for (int i = 0; i < n; ++i) - { - matchedNodesIds[i] = elements[i].first; - targetNodesIds[i] = elements[i].second; - } - return true; + virtual Ptr getNode(int idx) const CV_OVERRIDE + { + return makePtr(net.mutable_node(idx)); } - // Fuse matched subgraph. - void replace(tensorflow::GraphDef& net, const std::vector& matchedNodesIds, - const std::vector& targetNodesIds) + virtual int getNumNodes() const CV_OVERRIDE { - // Extract names of input nodes. - std::vector inputsNames(fusedNodeInputs.size()); - for (int i = 0; i < fusedNodeInputs.size(); ++i) - { - std::string inpName; - // Find input node name looking at inputs of fused nodes. - for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j) - { - const tensorflow::NodeDef &node = net.node(matchedNodesIds[j]); - std::vector& inpIndices = inputs[targetNodesIds[j]]; - - CV_Assert(node.input_size() == inpIndices.size()); - for (int k = 0; k < inpIndices.size(); ++k) - { - if (inpIndices[k] == fusedNodeInputs[i]) - { - inpName = node.input(k); - break; - } - } - } - CV_Assert(!inpName.empty()); - inputsNames[i] = inpName; - } - - // Remove matched nodes except the last one. Indices in ascending order are expected. - tensorflow::NodeDef* node = net.mutable_node(matchedNodesIds.back()); - for (int i = matchedNodesIds.size() - 2; i >= 0; --i) - net.mutable_node()->DeleteSubrange(matchedNodesIds[i], 1); + return net.node_size(); + } - // Modify the last node to be a fused one. - node->set_op(fusedNodeOp); - node->clear_input(); - for (int i = 0; i < inputsNames.size(); ++i) - { - node->add_input(inputsNames[i]); - } + virtual std::string getNodeName(int idx) const CV_OVERRIDE + { + return net.node(idx).name(); + } - std::vector inputNodes(inputsNames.size()); - for (int i = 0; i < inputsNames.size(); ++i) - { - inputNodes[i] = net.mutable_node(getInputNodeId(net, *node, i)); - } - finalize(net, node, inputNodes); + virtual void removeNode(int idx) CV_OVERRIDE + { + net.mutable_node()->DeleteSubrange(idx, 1); } - virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef*, - std::vector&) {} + tensorflow::GraphDef& net; +}; -private: - std::vector nodes; // Nodes to be matched in the origin graph. - std::vector > inputs; // Connections of an every node to it's inputs. +class TFSubgraph : public Subgraph +{ + virtual void finalize(const Ptr& netWrapper, + const Ptr& fusedNodeWrapper, + std::vector >& inputs) CV_OVERRIDE + { + std::vector inputNodes(inputs.size()); + for (int i = 0; i < inputs.size(); ++i) + inputNodes[i] = inputs[i].dynamicCast()->node; + finalize(netWrapper.dynamicCast()->net, + fusedNodeWrapper.dynamicCast()->node, inputNodes); + } - std::string fusedNodeOp; // Operation name of resulting fused node. - std::vector fusedNodeInputs; // Inputs of fused node. + virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode, + std::vector& inputNodes) {} }; -class BatchNormSubgraph : public Subgraph +class BatchNormSubgraph : public TFSubgraph { public: BatchNormSubgraph() @@ -250,7 +135,7 @@ public: } }; -class BatchNormNoGammaSubgraph : public Subgraph +class BatchNormNoGammaSubgraph : public TFSubgraph { public: BatchNormNoGammaSubgraph() @@ -366,20 +251,21 @@ public: setFusedNode("Relu6", input); } - virtual bool match(const tensorflow::GraphDef& net, int nodeId, + virtual bool match(const Ptr& net, int nodeId, std::vector& matchedNodesIds, std::vector& targetNodesIds) CV_OVERRIDE { if (!Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) return false; - Mat maxValue = getTensorContent(net.node(matchedNodesIds.front() + 1).attr().at("value").tensor()); + tensorflow::NodeDef* node = net->getNode(matchedNodesIds.front() + 1).dynamicCast()->node; + Mat maxValue = getTensorContent(node->attr().at("value").tensor()); return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at(0) == 6; } }; // Keras' reshape stores output shape in separate Const nodes by one value. // Need to merge them into a single Const node. -class ReshapeKerasSubgraph : public Subgraph +class ReshapeKerasSubgraph : public TFSubgraph { public: ReshapeKerasSubgraph(int _numOutDims) : numOutDims(_numOutDims) @@ -402,15 +288,15 @@ public: setFusedNode("Reshape", ids); } - virtual bool match(const tensorflow::GraphDef& net, int nodeId, + virtual bool match(const Ptr& net, int nodeId, std::vector& matchedNodesIds, std::vector& targetNodesIds) CV_OVERRIDE { - const tensorflow::NodeDef& node = net.node(nodeId); - if (node.input_size() == 0) + Ptr node = net->getNode(nodeId); + if (node->getNumInputs() == 0) return false; - inpName = node.input(0); + inpName = node->getInputName(0); return Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds); } @@ -457,7 +343,7 @@ public: } }; -class DeconvolutionValidKerasSubgraph : public Subgraph +class DeconvolutionValidKerasSubgraph : public TFSubgraph { public: DeconvolutionValidKerasSubgraph() @@ -518,7 +404,7 @@ public: } }; -class DeconvolutionSameKerasSubgraph : public Subgraph +class DeconvolutionSameKerasSubgraph : public TFSubgraph { public: DeconvolutionSameKerasSubgraph() @@ -608,7 +494,7 @@ public: }; // In case of resizing by factor. -class UpsamplingKerasSubgraph : public Subgraph +class UpsamplingKerasSubgraph : public TFSubgraph { public: UpsamplingKerasSubgraph(const std::string& type) @@ -703,7 +589,7 @@ public: } }; -class KerasMVNSubgraph : public Subgraph +class KerasMVNSubgraph : public TFSubgraph { public: KerasMVNSubgraph() @@ -758,20 +644,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net) subgraphs.push_back(Ptr(new ReshapeAsShapeSubgraph())); subgraphs.push_back(Ptr(new KerasMVNSubgraph())); - int numNodes = net.node_size(); - std::vector matchedNodesIds, targetNodesIds; - for (int i = 0; i < numNodes; ++i) - { - for (int j = 0; j < subgraphs.size(); ++j) - { - if (subgraphs[j]->match(net, i, matchedNodesIds, targetNodesIds)) - { - subgraphs[j]->replace(net, matchedNodesIds, targetNodesIds); - numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added. - break; - } - } - } + simplifySubgraphs(Ptr(new TFGraphWrapper(net)), subgraphs); } void RemoveIdentityOps(tensorflow::GraphDef& net) diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 2122813195..3f821ddf34 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -396,6 +396,7 @@ TEST_P(Test_ONNX_layers, Softmax) { testONNXModels("softmax"); testONNXModels("log_softmax", npy, 0, 0, false, false); + testONNXModels("softmax_unfused"); } TEST_P(Test_ONNX_layers, Split_EltwiseMax)