Merge pull request #18296 from sl-sergei:fix_16783

Fix loading issue for Faster RCNN model from #16783

* Add a reproducer with multi-output Gather

* Fix an issue with ONNX graph simplifier

* fix build

* Move checks to correct class

* Minor changes for better code appearence
pull/18145/head
Sergei Slashchinin 4 years ago committed by GitHub
parent 564d1a0f79
commit 2b82f8f12c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 34
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp
  2. 5
      modules/dnn/test/test_onnx_importer.cpp

@ -260,6 +260,40 @@ public:
addNodeToMatch("Cast", gather);
setFusedNode("Gather", input, index);
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE
{
bool retVal = Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds);
size_t matchedNodesNum = matchedNodesIds.size();
// Now we check if merging can be made for these Gather and Cast nodes
if (!retVal || matchedNodesNum < 2)
return retVal;
else {
int nodeToMatch = matchedNodesIds[matchedNodesNum - 1];
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
if (node->getType() == "Cast") {
int inpNodeId = matchedNodesIds[matchedNodesNum - 2];
const Ptr<ImportNodeWrapper> inpNode = net->getNode(inpNodeId);
if (inpNode->getType() == "Gather") {
int numNodes = net->getNumNodes();
std::string inpNodeName = node->getInputName(0);
for (int i = 0; i < numNodes; ++i) {
const Ptr<ImportNodeWrapper> node_to_check = net->getNode(i);
int numInp = node_to_check->getNumInputs();
for (int inp = 0; inp < numInp; ++inp) {
if (i != nodeToMatch && inpNodeName == node_to_check->getInputName(0)) {
// Another node has the same input node, so it cannot be merged.
return false;
}
}
}
}
}
}
return retVal;
}
};
class ExpandSubgraph : public Subgraph

@ -705,6 +705,11 @@ TEST_P(Test_ONNX_layers, Conv1d_variable_weight_bias)
normAssert(ref, out, "", default_l1, default_lInf);
}
TEST_P(Test_ONNX_layers, GatherMultiOutput)
{
testONNXModels("gather_multi_output");
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
class Test_ONNX_nets : public Test_ONNX_layers

Loading…
Cancel
Save