Merge pull request #24577 from dkurt:dnn_graph_match_stack

Fix graph fusion with commutative ops #24577

### Pull Request Readiness Checklist

resolves https://github.com/opencv/opencv/issues/24568

**Merge with extra**: https://github.com/opencv/opencv_extra/pull/1125

TODO:
- [x]  replace recursive function to sequential

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
pull/24600/head
Dmitry Kurtaev 1 year ago committed by GitHub
parent 848dd12a1f
commit 332748dd55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 103
      modules/dnn/src/graph_simplifier.cpp
  2. 6
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp
  3. 2
      modules/dnn/src/onnx/onnx_importer.cpp
  4. 1
      modules/dnn/test/test_graph_simplifier.cpp

@ -81,26 +81,45 @@ bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
{
matchedNodesIds.clear();
std::queue<int> nodesToMatch;
std::queue<int> targetNodes;
std::vector<std::pair<int, int> > matchings;
matchings.reserve(nodes.size());
nodesToMatch.push(nodeId);
targetNodes.push(nodes.size() - 1);
while (!nodesToMatch.empty())
// Collection of all matchings states across branching.
// If there is no commutative ops in the subgraph - there would be just a single map.
std::vector<std::shared_ptr<std::map<int, int>>> matchCandidates;
matchCandidates.push_back(makePtr<std::map<int, int>>());
struct State
{
int nodeToMatch;
int targetNodeId;
// Every state refers to current matchings pairs as well as
// matchings from parent branches produced by commutative ops.
std::vector<std::shared_ptr<std::map<int, int>>> matchings;
// When we register a matching pair we should register it in every parent branch.
// This is actual for branching in case of commutative ops only.
void addMatch(std::pair<int, int> match)
{
for (auto& m : matchings)
m->insert(match);
}
};
std::queue<State> states;
states.push({nodeId, (int)nodes.size() - 1, matchCandidates});
while (!states.empty())
{
int nodeToMatch = nodesToMatch.front();
int targetNodeId = targetNodes.front();
nodesToMatch.pop();
targetNodes.pop();
auto state = states.front();
states.pop();
int nodeToMatch = state.nodeToMatch;
int targetNodeId = state.targetNodeId;
auto matchings = state.matchings.back();
if (std::find_if(matchings.begin(), matchings.end(), [&](const std::pair<int, int>& match){ return match.first == targetNodeId; }) !=
matchings.end())
if (matchings->find(targetNodeId) != matchings->end())
continue;
// Empty placeholder matches with any input type
if (nodes[targetNodeId].empty()) {
matchings.push_back({targetNodeId, nodeToMatch});
state.addMatch({targetNodeId, nodeToMatch});
continue;
}
@ -112,43 +131,51 @@ bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
if (inputNodes.size() != node->getNumInputs())
continue;
bool isCommutative = net->isCommutativeOp(node->getType());
state.addMatch({targetNodeId, nodeToMatch});
for (int j = 0; j < inputNodes.size(); ++j)
{
// Sometimes, ONNX may have input but it's empty (see Clip layer from reduceL2_subgraph2_2 testcase)
if (node->getInputName(j).empty())
continue;
nodeId = getInputNodeId(net, node, j);
const Ptr<ImportNodeWrapper> inpNode = net->getNode(nodeId);
bool isCommutative = net->isCommutativeOp(node->getType());
if (isCommutative)
{
for (int i = 0; i < inputNodes.size(); ++i)
{
nodesToMatch.push(nodeId);
targetNodes.push(inputNodes[i]);
}
if (inputNodes.size() != 2)
CV_Error(Error::StsNotImplemented, "Commutative op fusion with more than 2 inputs");
auto newMatchings = makePtr<std::map<int, int>>(*matchings);
matchCandidates.push_back(newMatchings);
state.matchings.push_back(newMatchings);
states.push({getInputNodeId(net, node, 0), inputNodes[0], state.matchings});
states.push({getInputNodeId(net, node, 1), inputNodes[1], state.matchings});
state.matchings.pop_back();
newMatchings = makePtr<std::map<int, int>>(*matchings);
matchCandidates.push_back(newMatchings);
state.matchings.push_back(newMatchings);
states.push({getInputNodeId(net, node, 0), inputNodes[1], state.matchings});
states.push({getInputNodeId(net, node, 1), inputNodes[0], state.matchings});
state.matchings.pop_back();
}
else
{
nodesToMatch.push(nodeId);
targetNodes.push(inputNodes[j]);
for (int j = 0; j < inputNodes.size(); ++j)
{
nodeId = getInputNodeId(net, node, j);
states.push({nodeId, inputNodes[j], state.matchings});
}
}
matchings.push_back({targetNodeId, nodeToMatch});
}
if (matchings.size() != nodes.size())
return false;
// Sort matched by pattern nodes order.
std::sort(matchings.begin(), matchings.end());
matchedNodesIds.resize(matchings.size());
for (int i = 0; i < matchings.size(); ++i)
for (auto& matchings : matchCandidates)
{
if (matchings->size() != nodes.size())
continue;
matchedNodesIds.resize(matchings->size());
for (int i = 0; i < matchings->size(); ++i)
{
matchedNodesIds[i] = matchings[i].second;
CV_Assert(matchings->find(i) != matchings->end());
matchedNodesIds[i] = matchings->at(i);
}
return true;
}
return false;
}
void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds)
{

@ -64,6 +64,12 @@ class ONNXGraphWrapper : public ImportGraphWrapper
public:
ONNXGraphWrapper(opencv_onnx::GraphProto& _net) : net(_net)
{
// Add a fake initializer with empty name.
// Some ONNX models skip their inputs. For example,
// Resize which has 4 inputs but 2 of them have empty names.
// So we add a fake empty node to which such ops may refer as input.
net.add_initializer();
numInputs = net.input_size();
numInitializers = net.initializer_size();
}

@ -3539,7 +3539,7 @@ void ONNXImporter::parseQGemm(LayerParams& layerParams, const opencv_onnx::NodeP
Mat bias;
if (constBlobs.find(node_proto.input(6)) != constBlobs.end())
bias = getBlob(node_proto, 6);
else
if (bias.empty())
bias = Mat::zeros(1, outCn, CV_32S);
Mat biasFused(1, outCn, CV_32S);

@ -35,6 +35,7 @@ class Test_Graph_Simplifier : public ::testing::Test {
TEST_F(Test_Graph_Simplifier, GeluSubGraph) {
test("gelu", "Gelu");
test("bias_gelu", std::vector<std::string>{"Gelu", "NaryEltwise"});
}
TEST_F(Test_Graph_Simplifier, GeluApproximationSubGraph) {

Loading…
Cancel
Save