diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 9140368522..d1b5a85d6c 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -18,6 +18,7 @@ Implementation of Tensorflow models parser #include #include #include +#include #include "tf_graph_simplifier.hpp" #endif @@ -558,9 +559,7 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map& cons } } -// If all inputs of specific layer have the same data layout we can say that -// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise. -static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std::map& data_layouts) +static int getDataLayout(const tensorflow::NodeDef& layer) { if (hasLayerAttr(layer, "data_format")) { @@ -572,27 +571,48 @@ static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std:: else CV_Error(Error::StsParseError, "Unknown data_format value: " + format); } + return DATA_LAYOUT_UNKNOWN; +} + +static inline std::string getNodeName(const std::string& tensorName) +{ + return tensorName.substr(0, tensorName.rfind(':')); +} + +// If all inputs of specific layer have the same data layout we can say that +// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise. +static int predictOutputDataLayout(const tensorflow::GraphDef& net, + const tensorflow::NodeDef& layer, + const std::map& data_layouts) +{ + int layout = getDataLayout(layer); + if (layout != DATA_LAYOUT_UNKNOWN) + return layout; // Determine layout by layer's inputs - int layout = DATA_LAYOUT_UNKNOWN; std::map::const_iterator it; for (int i = 0, n = layer.input_size(); i < n; ++i) { - it = data_layouts.find(layer.input(i).substr(0, layer.input(i).rfind(':'))); + it = data_layouts.find(getNodeName(layer.input(i))); if (it != data_layouts.end()) { - if (it->second == DATA_LAYOUT_UNKNOWN) - return DATA_LAYOUT_UNKNOWN; - else if (it->second != layout) + if (layout != DATA_LAYOUT_UNKNOWN) { - if (layout == DATA_LAYOUT_UNKNOWN) - layout = it->second; - else + if (it->second != layout && it->second != DATA_LAYOUT_UNKNOWN) return DATA_LAYOUT_UNKNOWN; } + else + layout = it->second; } } - return layout; + + if (layout != DATA_LAYOUT_UNKNOWN) + return layout; + + // Determine layout by layer's consumers recursively. + it = data_layouts.find(layer.name()); + CV_Assert(it != data_layouts.end()); + return it->second; } void TFImporter::populateNet(Net dstNet) @@ -610,6 +630,52 @@ void TFImporter::populateNet(Net dstNet) int layersSize = net.node_size(); std::map data_layouts; + // Pre-fill data layouts where they are set explicitly. + // Assuming that nodes are in topological order + for (int i = net.node_size() - 1; i >= 0; --i) + { + const tensorflow::NodeDef& layer = net.node(i); + std::string name = layer.name(); + + int layout = getDataLayout(layer); + std::map::iterator it = data_layouts.find(name); + if (it != data_layouts.end()) + { + if (layout != DATA_LAYOUT_UNKNOWN) + { + if (it->second == DATA_LAYOUT_UNKNOWN) + it->second = layout; + else if (it->second != layout) + { + it->second = DATA_LAYOUT_UNKNOWN; + layout = DATA_LAYOUT_UNKNOWN; + } + } + else + layout = it->second; + } + else + data_layouts[name] = layout; + + // Specify input layers to have the same data layout. + for (int j = 0; j < layer.input_size(); ++j) + { + name = getNodeName(layer.input(j)); + it = data_layouts.find(name); + if (it != data_layouts.end()) + { + if (layout != DATA_LAYOUT_UNKNOWN) + { + if (it->second == DATA_LAYOUT_UNKNOWN) + it->second = layout; + else if (it->second != layout) + it->second = DATA_LAYOUT_UNKNOWN; + } + } + else + data_layouts[name] = layout; + } + } // find all Const layers for params std::map value_id; @@ -628,7 +694,8 @@ void TFImporter::populateNet(Net dstNet) if(layers_to_ignore.find(name) != layers_to_ignore.end()) continue; - data_layouts[name] = predictOutputDataLayout(layer, data_layouts); + int predictedLayout = predictOutputDataLayout(net, layer, data_layouts); + data_layouts[name] = predictedLayout; if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative") { @@ -885,6 +952,7 @@ void TFImporter::populateNet(Net dstNet) // one input only connect(layer_id, dstNet, inpId, id, 0); + data_layouts[name] = DATA_LAYOUT_UNKNOWN; } else if (type == "Flatten" || type == "Squeeze") { @@ -1013,7 +1081,10 @@ void TFImporter::populateNet(Net dstNet) { int axisId = (type == "Concat" ? 0 : layer.input_size() - 1); int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0); - layerParams.set("axis", 0 <= axis && axis < 4 ? toNCHW(axis) : axis); + + if (data_layouts[name] == DATA_LAYOUT_NHWC) + axis = toNCHW(axis); + layerParams.set("axis", axis); int id = dstNet.addLayer(name, "Concat", layerParams); layer_id[name] = id; diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 5ac8890e50..33238c718e 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -142,9 +142,10 @@ TEST_P(Test_TensorFlow_layers, eltwise_add_mul) runTensorFlowNet("eltwise_add_mul", GetParam()); } -TEST_P(Test_TensorFlow_layers, pad_and_concat) +TEST_P(Test_TensorFlow_layers, concat) { runTensorFlowNet("pad_and_concat", GetParam()); + runTensorFlowNet("concat_axis_1", GetParam()); } TEST_P(Test_TensorFlow_layers, batch_norm)