Dmitry Kurtaev 5 years ago
parent 46615ffc4a
commit d3f9ad1145
  1. 52
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp
  2. 19
      modules/dnn/src/onnx/onnx_importer.cpp
  3. 1
      modules/dnn/test/test_onnx_importer.cpp

@ -154,16 +154,10 @@ private:
int axis; int axis;
}; };
class NormalizeSubgraph1 : public Subgraph class NormalizeSubgraphBase : public Subgraph
{ {
public: public:
NormalizeSubgraph1() : axis(1) NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {}
{
input = addNodeToMatch("");
norm = addNodeToMatch("ReduceL2", input);
addNodeToMatch("Div", input, norm);
setFusedNode("Normalize", input);
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds,
@ -171,7 +165,7 @@ public:
{ {
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
{ {
Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[0]); Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]);
opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node; opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node;
for (int i = 0; i < node->attribute_size(); i++) for (int i = 0; i < node->attribute_size(); i++)
@ -204,20 +198,51 @@ public:
} }
protected: protected:
int input, norm; int axis, normNodeOrder;
int axis;
}; };
class NormalizeSubgraph1 : public NormalizeSubgraphBase
{
public:
NormalizeSubgraph1()
{
int input = addNodeToMatch("");
int norm = addNodeToMatch("ReduceL2", input);
addNodeToMatch("Div", input, norm);
setFusedNode("Normalize", input);
}
};
class NormalizeSubgraph2 : public NormalizeSubgraph1 class NormalizeSubgraph2 : public NormalizeSubgraphBase
{ {
public: public:
NormalizeSubgraph2() : NormalizeSubgraph1() NormalizeSubgraph2()
{ {
int input = addNodeToMatch("");
int norm = addNodeToMatch("ReduceL2", input);
int clip = addNodeToMatch("Clip", norm); int clip = addNodeToMatch("Clip", norm);
int shape = addNodeToMatch("Shape", input); int shape = addNodeToMatch("Shape", input);
int expand = addNodeToMatch("Expand", clip, shape); int expand = addNodeToMatch("Expand", clip, shape);
addNodeToMatch("Div", input, expand); addNodeToMatch("Div", input, expand);
setFusedNode("Normalize", input);
}
};
class NormalizeSubgraph3 : public NormalizeSubgraphBase
{
public:
NormalizeSubgraph3() : NormalizeSubgraphBase(1)
{
int input = addNodeToMatch("");
int power = addNodeToMatch("Constant");
int squared = addNodeToMatch("Pow", input, power);
int sum = addNodeToMatch("ReduceSum", squared);
int sqrtNode = addNodeToMatch("Sqrt", sum);
int eps = addNodeToMatch("Constant");
int add = addNodeToMatch("Add", sqrtNode, eps);
addNodeToMatch("Div", input, add);
setFusedNode("Normalize", input);
} }
}; };
@ -368,6 +393,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<SoftMaxSubgraph>()); subgraphs.push_back(makePtr<SoftMaxSubgraph>());
subgraphs.push_back(makePtr<NormalizeSubgraph1>()); subgraphs.push_back(makePtr<NormalizeSubgraph1>());
subgraphs.push_back(makePtr<NormalizeSubgraph2>()); subgraphs.push_back(makePtr<NormalizeSubgraph2>());
subgraphs.push_back(makePtr<NormalizeSubgraph3>());
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs); simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
} }

@ -1457,6 +1457,25 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.type = "Softmax"; layerParams.type = "Softmax";
layerParams.set("log_softmax", layer_type == "LogSoftmax"); layerParams.set("log_softmax", layer_type == "LogSoftmax");
} }
else if (layer_type == "DetectionOutput")
{
CV_CheckEQ(node_proto.input_size(), 3, "");
if (constBlobs.find(node_proto.input(2)) != constBlobs.end())
{
Mat priors = getBlob(node_proto, constBlobs, 2);
LayerParams constParams;
constParams.name = layerParams.name + "/priors";
constParams.type = "Const";
constParams.blobs.push_back(priors);
opencv_onnx::NodeProto priorsProto;
priorsProto.add_output(constParams.name);
addLayer(dstNet, constParams, priorsProto, layer_id, outShapes);
node_proto.set_input(2, constParams.name);
}
}
else else
{ {
for (int j = 0; j < node_proto.input_size(); j++) { for (int j = 0; j < node_proto.input_size(); j++) {

@ -440,6 +440,7 @@ TEST_P(Test_ONNX_layers, ReduceL2)
{ {
testONNXModels("reduceL2"); testONNXModels("reduceL2");
testONNXModels("reduceL2_subgraph"); testONNXModels("reduceL2_subgraph");
testONNXModels("reduceL2_subgraph_2");
} }
TEST_P(Test_ONNX_layers, Split) TEST_P(Test_ONNX_layers, Split)

Loading…
Cancel
Save