Merge pull request #24463 from dkurt:dnn_shared_nodes_fusion

DNN graph fusion with shared nodes #24463

### Pull Request Readiness Checklist

For now, nodes from matched pattern are removed during the matching process so if nodes are used in similar subgraph, they cannot be found.

required for https://github.com/opencv/opencv/pull/24397

**Merge with extra**: https://github.com/opencv/opencv_extra/pull/1115

A part from [model_name ](https://github.com/onnx/models/blob/main/vision/object_detection_segmentation/fcn/model/fcn-resnet101-11.onnx) with two Resize subgraphs with shared nodes:
![image](https://github.com/opencv/opencv/assets/25801568/611d89d9-12fb-4add-9218-13b10d2c086a)

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
pull/24490/head
Dmitry Kurtaev 1 year ago committed by GitHub
parent fe4d518d85
commit fa56623458
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 50
      modules/dnn/src/graph_simplifier.cpp
  2. 28
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp
  3. 10
      modules/dnn/test/test_onnx_importer.cpp

@ -165,10 +165,7 @@ void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int
inputsNames[i] = inpName; inputsNames[i] = inpName;
} }
// Remove matched nodes except the last one. Indices in ascending order are expected.
Ptr<ImportNodeWrapper> node = net->getNode(matchedNodesIds.back()); Ptr<ImportNodeWrapper> node = net->getNode(matchedNodesIds.back());
for (int i = matchedNodesIds.size() - 2; i >= 0; --i)
net->removeNode(matchedNodesIds[i]);
// Modify the last node to be a fused one. // Modify the last node to be a fused one.
node->setType(fusedNodeOp); node->setType(fusedNodeOp);
@ -191,6 +188,7 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
{ {
int numNodes = net->getNumNodes(); int numNodes = net->getNumNodes();
std::vector<int> matchedNodesIds, targetNodesIds; std::vector<int> matchedNodesIds, targetNodesIds;
std::vector<int> nodesToRemove;
for (int j = 0; j < patterns.size(); ++j) for (int j = 0; j < patterns.size(); ++j)
{ {
for (int i = 0; i < numNodes; ++i) for (int i = 0; i < numNodes; ++i)
@ -198,10 +196,54 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds)) if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds))
{ {
patterns[j]->replace(net, matchedNodesIds, targetNodesIds); patterns[j]->replace(net, matchedNodesIds, targetNodesIds);
numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added. // Remove matched nodes except the last one.
nodesToRemove.insert(nodesToRemove.end(), matchedNodesIds.begin(), matchedNodesIds.end() - 1);
} }
} }
} }
if (nodesToRemove.empty())
return;
// Collect reference counts for every node
std::vector<int> refcounts(net->getNumNodes(), 0);
std::map<std::string, int> nodeIds;
// Register node outputs.
// Every usage of one of the node's outputs should be counted.
for (int nodeId = 0; nodeId < refcounts.size(); ++nodeId) {
for (int i = 0; i < net->getNumOutputs(nodeId); ++i) {
std::string name = net->getOutputName(nodeId, i);
nodeIds[name] = nodeId;
}
}
for (int nodeId = 0; nodeId < refcounts.size(); ++nodeId) {
// Increase counters for node's inputs
auto node = net->getNode(nodeId);
for (int i = 0; i < node->getNumInputs(); ++i) {
std::string inpName = node->getInputName(i);
if (inpName.empty())
continue;
CV_Assert(nodeIds.find(inpName) != nodeIds.end());
refcounts[nodeIds[inpName]] += 1;
}
}
// Remove all fused nodes. Indices expected to be in descending order.
std::sort(nodesToRemove.begin(), nodesToRemove.end(), [](int a, int b) { return a > b; });
for (int nodeId : nodesToRemove) {
if (refcounts[nodeId] == 0) {
// Decrease references to node's inputs and remove node itself
auto node = net->getNode(nodeId);
for (int i = 0; i < node->getNumInputs(); ++i) {
std::string inpName = node->getInputName(i);
refcounts[nodeIds[inpName]] -= 1;
}
net->removeNode(nodeId);
refcounts[nodeId] = -1; // Same node cannot be removed twice
}
}
} }
}} // namespace cv::dnn }} // namespace cv::dnn

@ -1136,6 +1136,33 @@ public:
} }
}; };
class ResizeSubgraph3 : public Subgraph
{
public:
ResizeSubgraph3() : Subgraph()
{
int shapeSrc = addNodeToMatch("");
int input = addNodeToMatch("");
int shape_h = addNodeToMatch("Shape", shapeSrc);
int shape_w = addNodeToMatch("Shape", shapeSrc);
int gather_h = addNodeToMatch("Gather", shape_h, addNodeToMatch("Constant"));
int gather_w = addNodeToMatch("Gather", shape_w, addNodeToMatch("Constant"));
int unsqueeze_h = addNodeToMatch("Unsqueeze", gather_h);
int unsqueeze_w = addNodeToMatch("Unsqueeze", gather_w);
int concat1 = addNodeToMatch("Concat", unsqueeze_h, unsqueeze_w);
int cast = addNodeToMatch("Cast", concat1);
int shape2 = addNodeToMatch("Shape", input);
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant"));
int concat2 = addNodeToMatch("Concat", slice, cast);
addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2);
setFusedNode("Upsample", input, shapeSrc);
}
};
class BatchNormalizationSubgraphBase : public Subgraph class BatchNormalizationSubgraphBase : public Subgraph
{ {
public: public:
@ -1207,6 +1234,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<UpsampleSubgraph>()); subgraphs.push_back(makePtr<UpsampleSubgraph>());
subgraphs.push_back(makePtr<ResizeSubgraph1>()); subgraphs.push_back(makePtr<ResizeSubgraph1>());
subgraphs.push_back(makePtr<ResizeSubgraph2>()); subgraphs.push_back(makePtr<ResizeSubgraph2>());
subgraphs.push_back(makePtr<ResizeSubgraph3>());
subgraphs.push_back(makePtr<SoftMaxSubgraph>()); subgraphs.push_back(makePtr<SoftMaxSubgraph>());
subgraphs.push_back(makePtr<SoftMaxSubgraph2>()); subgraphs.push_back(makePtr<SoftMaxSubgraph2>());
subgraphs.push_back(makePtr<LogSoftMaxSubgraph>()); subgraphs.push_back(makePtr<LogSoftMaxSubgraph>());

@ -54,7 +54,8 @@ public:
void testONNXModels(const String& basename, const Extension ext = npy, void testONNXModels(const String& basename, const Extension ext = npy,
double l1 = 0, double lInf = 0, const bool useSoftmax = false, double l1 = 0, double lInf = 0, const bool useSoftmax = false,
bool checkNoFallbacks = true, int numInps = 1) bool checkNoFallbacks = true, int numInps = 1,
bool testShapes = true)
{ {
String onnxmodel = _tf("models/" + basename + ".onnx", required); String onnxmodel = _tf("models/" + basename + ".onnx", required);
std::vector<Mat> inps(numInps); std::vector<Mat> inps(numInps);
@ -76,7 +77,8 @@ public:
Net net = readNetFromONNX(onnxmodel); Net net = readNetFromONNX(onnxmodel);
ASSERT_FALSE(net.empty()); ASSERT_FALSE(net.empty());
testInputShapes(net, inps); if (testShapes)
testInputShapes(net, inps);
net.setPreferableBackend(backend); net.setPreferableBackend(backend);
net.setPreferableTarget(target); net.setPreferableTarget(target);
@ -248,6 +250,10 @@ TEST_P(Test_ONNX_layers, Gather_shared_indices) {
testONNXModels("gather_shared_indices", npy, 0, 0, false, false, 1); testONNXModels("gather_shared_indices", npy, 0, 0, false, false, 1);
} }
TEST_P(Test_ONNX_layers, Two_resizes_with_shared_subgraphs) {
testONNXModels("two_resizes_with_shared_subgraphs", npy, 0, 0, false, false, 3, /*testShapes*/ false);
}
TEST_P(Test_ONNX_layers, Convolution3D) TEST_P(Test_ONNX_layers, Convolution3D)
{ {
if (backend == DNN_BACKEND_CUDA && target == DNN_TARGET_CUDA_FP16) if (backend == DNN_BACKEND_CUDA && target == DNN_TARGET_CUDA_FP16)

Loading…
Cancel
Save