Remove Switch and Merge nodes from TensorFlow networks

pull/14251/head
Dmitry Kurtaev 6 years ago
parent df1f62b34c
commit ec41a4897a
  1. 81
      modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
  2. 2
      modules/dnn/src/tensorflow/tf_graph_simplifier.hpp
  3. 3
      modules/dnn/src/tensorflow/tf_importer.cpp
  4. 10
      modules/dnn/test/test_tf_importer.cpp

@ -10,6 +10,7 @@
#ifdef HAVE_PROTOBUF
#include "tf_graph_simplifier.hpp"
#include <queue>
namespace cv { namespace dnn {
CV__DNN_EXPERIMENTAL_NS_BEGIN
@ -883,7 +884,6 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
nodesToAdd.pop_back();
permIds.push_back(nodeToAdd);
// std::cout << net.node(nodeToAdd).name() << '\n';
for (int i = 0; i < edges[nodeToAdd].size(); ++i)
{
@ -902,6 +902,85 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
permute(net.mutable_node(), permIds);
}
// Remove training switches (Switch and Merge nodes and corresponding subgraphs).
void removePhaseSwitches(tensorflow::GraphDef& net)
{
std::vector<int> nodesToRemove;
std::map<std::string, int> nodesMap;
std::map<std::string, int>::iterator nodesMapIt;
std::queue<int> mergeOpSubgraphNodes;
for (int i = 0; i < net.node_size(); ++i)
{
const tensorflow::NodeDef& node = net.node(i);
nodesMap.insert(std::make_pair(node.name(), i));
if (node.op() == "Switch" || node.op() == "Merge")
{
CV_Assert(node.input_size() > 0);
// Replace consumers' inputs.
for (int j = 0; j < net.node_size(); ++j)
{
tensorflow::NodeDef* consumer = net.mutable_node(j);
for (int k = 0; k < consumer->input_size(); ++k)
{
std::string inpName = consumer->input(k);
inpName = inpName.substr(0, inpName.rfind(':'));
if (inpName == node.name())
{
consumer->set_input(k, node.input(0));
}
}
}
nodesToRemove.push_back(i);
if (node.op() == "Merge")
mergeOpSubgraphNodes.push(i);
}
}
std::vector<int> numConsumers(net.node_size(), 0);
for (int i = 0; i < net.node_size(); ++i)
{
const tensorflow::NodeDef& node = net.node(i);
for (int j = 0; j < node.input_size(); ++j)
{
std::string inpName = node.input(j);
inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':'));
nodesMapIt = nodesMap.find(inpName);
CV_Assert(nodesMapIt != nodesMap.end());
numConsumers[nodesMapIt->second] += 1;
}
}
// Remove subgraphs of unused nodes which are terminated by Merge nodes.
while (!mergeOpSubgraphNodes.empty())
{
const tensorflow::NodeDef& node = net.node(mergeOpSubgraphNodes.front());
mergeOpSubgraphNodes.pop();
for (int i = 0; i < node.input_size(); ++i)
{
std::string inpName = node.input(i);
inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':'));
nodesMapIt = nodesMap.find(inpName);
CV_Assert(nodesMapIt != nodesMap.end());
int inpNodeId = nodesMapIt->second;
if (numConsumers[inpNodeId] == 1)
{
mergeOpSubgraphNodes.push(inpNodeId);
nodesToRemove.push_back(inpNodeId);
}
else if (numConsumers[inpNodeId] > 0)
numConsumers[inpNodeId] -= 1;
}
}
std::sort(nodesToRemove.begin(), nodesToRemove.end());
for (int i = nodesToRemove.size() - 1; i >= 0; --i)
{
if (nodesToRemove[i] < net.node_size()) // Ids might be repeated.
net.mutable_node()->DeleteSubrange(nodesToRemove[i], 1);
}
}
CV__DNN_EXPERIMENTAL_NS_END
}} // namespace dnn, namespace cv

@ -27,6 +27,8 @@ void releaseTensor(tensorflow::TensorProto* tensor);
void sortByExecutionOrder(tensorflow::GraphDef& net);
void removePhaseSwitches(tensorflow::GraphDef& net);
CV__DNN_EXPERIMENTAL_NS_END
}} // namespace dnn, namespace cv

@ -657,6 +657,9 @@ static int predictOutputDataLayout(const tensorflow::GraphDef& net,
void TFImporter::populateNet(Net dstNet)
{
if (!netTxt.ByteSize())
removePhaseSwitches(netBin);
RemoveIdentityOps(netBin);
RemoveIdentityOps(netTxt);

@ -185,6 +185,16 @@ TEST_P(Test_TensorFlow_layers, batch_norm)
runTensorFlowNet("mvn_batch_norm_1x1");
}
TEST_P(Test_TensorFlow_layers, slim_batch_norm)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE)
throw SkipTestException("Test is disabled for DLIE");
// Output values range: [-40.0597, 207.827]
double l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.041 : default_l1;
double lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.33 : default_lInf;
runTensorFlowNet("slim_batch_norm", false, l1, lInf);
}
TEST_P(Test_TensorFlow_layers, pooling)
{
runTensorFlowNet("max_pool_even");

Loading…
Cancel
Save