Dmitry Kurtaev 5 years ago
parent 46615ffc4a
commit d3f9ad1145
  1. 52
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp
  2. 19
      modules/dnn/src/onnx/onnx_importer.cpp
  3. 1
      modules/dnn/test/test_onnx_importer.cpp

@ -154,16 +154,10 @@ private:
int axis;
};
class NormalizeSubgraph1 : public Subgraph
class NormalizeSubgraphBase : public Subgraph
{
public:
NormalizeSubgraph1() : axis(1)
{
input = addNodeToMatch("");
norm = addNodeToMatch("ReduceL2", input);
addNodeToMatch("Div", input, norm);
setFusedNode("Normalize", input);
}
NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
@ -171,7 +165,7 @@ public:
{
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
{
Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[0]);
Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]);
opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node;
for (int i = 0; i < node->attribute_size(); i++)
@ -204,20 +198,51 @@ public:
}
protected:
int input, norm;
int axis;
int axis, normNodeOrder;
};
class NormalizeSubgraph1 : public NormalizeSubgraphBase
{
public:
NormalizeSubgraph1()
{
int input = addNodeToMatch("");
int norm = addNodeToMatch("ReduceL2", input);
addNodeToMatch("Div", input, norm);
setFusedNode("Normalize", input);
}
};
class NormalizeSubgraph2 : public NormalizeSubgraph1
class NormalizeSubgraph2 : public NormalizeSubgraphBase
{
public:
NormalizeSubgraph2() : NormalizeSubgraph1()
NormalizeSubgraph2()
{
int input = addNodeToMatch("");
int norm = addNodeToMatch("ReduceL2", input);
int clip = addNodeToMatch("Clip", norm);
int shape = addNodeToMatch("Shape", input);
int expand = addNodeToMatch("Expand", clip, shape);
addNodeToMatch("Div", input, expand);
setFusedNode("Normalize", input);
}
};
class NormalizeSubgraph3 : public NormalizeSubgraphBase
{
public:
NormalizeSubgraph3() : NormalizeSubgraphBase(1)
{
int input = addNodeToMatch("");
int power = addNodeToMatch("Constant");
int squared = addNodeToMatch("Pow", input, power);
int sum = addNodeToMatch("ReduceSum", squared);
int sqrtNode = addNodeToMatch("Sqrt", sum);
int eps = addNodeToMatch("Constant");
int add = addNodeToMatch("Add", sqrtNode, eps);
addNodeToMatch("Div", input, add);
setFusedNode("Normalize", input);
}
};
@ -368,6 +393,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<SoftMaxSubgraph>());
subgraphs.push_back(makePtr<NormalizeSubgraph1>());
subgraphs.push_back(makePtr<NormalizeSubgraph2>());
subgraphs.push_back(makePtr<NormalizeSubgraph3>());
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
}

@ -1457,6 +1457,25 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.type = "Softmax";
layerParams.set("log_softmax", layer_type == "LogSoftmax");
}
else if (layer_type == "DetectionOutput")
{
CV_CheckEQ(node_proto.input_size(), 3, "");
if (constBlobs.find(node_proto.input(2)) != constBlobs.end())
{
Mat priors = getBlob(node_proto, constBlobs, 2);
LayerParams constParams;
constParams.name = layerParams.name + "/priors";
constParams.type = "Const";
constParams.blobs.push_back(priors);
opencv_onnx::NodeProto priorsProto;
priorsProto.add_output(constParams.name);
addLayer(dstNet, constParams, priorsProto, layer_id, outShapes);
node_proto.set_input(2, constParams.name);
}
}
else
{
for (int j = 0; j < node_proto.input_size(); j++) {

@ -440,6 +440,7 @@ TEST_P(Test_ONNX_layers, ReduceL2)
{
testONNXModels("reduceL2");
testONNXModels("reduceL2_subgraph");
testONNXModels("reduceL2_subgraph_2");
}
TEST_P(Test_ONNX_layers, Split)

Loading…
Cancel
Save