|
|
|
@ -10,6 +10,7 @@ |
|
|
|
|
#ifdef HAVE_PROTOBUF |
|
|
|
|
|
|
|
|
|
#include "tf_graph_simplifier.hpp" |
|
|
|
|
#include <queue> |
|
|
|
|
|
|
|
|
|
namespace cv { namespace dnn { |
|
|
|
|
CV__DNN_EXPERIMENTAL_NS_BEGIN |
|
|
|
@ -883,7 +884,6 @@ void sortByExecutionOrder(tensorflow::GraphDef& net) |
|
|
|
|
nodesToAdd.pop_back(); |
|
|
|
|
|
|
|
|
|
permIds.push_back(nodeToAdd); |
|
|
|
|
// std::cout << net.node(nodeToAdd).name() << '\n';
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < edges[nodeToAdd].size(); ++i) |
|
|
|
|
{ |
|
|
|
@ -902,6 +902,85 @@ void sortByExecutionOrder(tensorflow::GraphDef& net) |
|
|
|
|
permute(net.mutable_node(), permIds); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Remove training switches (Switch and Merge nodes and corresponding subgraphs).
|
|
|
|
|
void removePhaseSwitches(tensorflow::GraphDef& net) |
|
|
|
|
{ |
|
|
|
|
std::vector<int> nodesToRemove; |
|
|
|
|
std::map<std::string, int> nodesMap; |
|
|
|
|
std::map<std::string, int>::iterator nodesMapIt; |
|
|
|
|
std::queue<int> mergeOpSubgraphNodes; |
|
|
|
|
for (int i = 0; i < net.node_size(); ++i) |
|
|
|
|
{ |
|
|
|
|
const tensorflow::NodeDef& node = net.node(i); |
|
|
|
|
nodesMap.insert(std::make_pair(node.name(), i)); |
|
|
|
|
if (node.op() == "Switch" || node.op() == "Merge") |
|
|
|
|
{ |
|
|
|
|
CV_Assert(node.input_size() > 0); |
|
|
|
|
// Replace consumers' inputs.
|
|
|
|
|
for (int j = 0; j < net.node_size(); ++j) |
|
|
|
|
{ |
|
|
|
|
tensorflow::NodeDef* consumer = net.mutable_node(j); |
|
|
|
|
for (int k = 0; k < consumer->input_size(); ++k) |
|
|
|
|
{ |
|
|
|
|
std::string inpName = consumer->input(k); |
|
|
|
|
inpName = inpName.substr(0, inpName.rfind(':')); |
|
|
|
|
if (inpName == node.name()) |
|
|
|
|
{ |
|
|
|
|
consumer->set_input(k, node.input(0)); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
nodesToRemove.push_back(i); |
|
|
|
|
if (node.op() == "Merge") |
|
|
|
|
mergeOpSubgraphNodes.push(i); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
std::vector<int> numConsumers(net.node_size(), 0); |
|
|
|
|
for (int i = 0; i < net.node_size(); ++i) |
|
|
|
|
{ |
|
|
|
|
const tensorflow::NodeDef& node = net.node(i); |
|
|
|
|
for (int j = 0; j < node.input_size(); ++j) |
|
|
|
|
{ |
|
|
|
|
std::string inpName = node.input(j); |
|
|
|
|
inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':')); |
|
|
|
|
nodesMapIt = nodesMap.find(inpName); |
|
|
|
|
CV_Assert(nodesMapIt != nodesMap.end()); |
|
|
|
|
numConsumers[nodesMapIt->second] += 1; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Remove subgraphs of unused nodes which are terminated by Merge nodes.
|
|
|
|
|
while (!mergeOpSubgraphNodes.empty()) |
|
|
|
|
{ |
|
|
|
|
const tensorflow::NodeDef& node = net.node(mergeOpSubgraphNodes.front()); |
|
|
|
|
mergeOpSubgraphNodes.pop(); |
|
|
|
|
for (int i = 0; i < node.input_size(); ++i) |
|
|
|
|
{ |
|
|
|
|
std::string inpName = node.input(i); |
|
|
|
|
inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':')); |
|
|
|
|
nodesMapIt = nodesMap.find(inpName); |
|
|
|
|
CV_Assert(nodesMapIt != nodesMap.end()); |
|
|
|
|
|
|
|
|
|
int inpNodeId = nodesMapIt->second; |
|
|
|
|
if (numConsumers[inpNodeId] == 1) |
|
|
|
|
{ |
|
|
|
|
mergeOpSubgraphNodes.push(inpNodeId); |
|
|
|
|
nodesToRemove.push_back(inpNodeId); |
|
|
|
|
} |
|
|
|
|
else if (numConsumers[inpNodeId] > 0) |
|
|
|
|
numConsumers[inpNodeId] -= 1; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
std::sort(nodesToRemove.begin(), nodesToRemove.end()); |
|
|
|
|
for (int i = nodesToRemove.size() - 1; i >= 0; --i) |
|
|
|
|
{ |
|
|
|
|
if (nodesToRemove[i] < net.node_size()) // Ids might be repeated.
|
|
|
|
|
net.mutable_node()->DeleteSubrange(nodesToRemove[i], 1); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CV__DNN_EXPERIMENTAL_NS_END |
|
|
|
|
}} // namespace dnn, namespace cv
|
|
|
|
|
|
|
|
|
|