diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 202958d733..66943dd0bb 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -42,6 +42,14 @@ namespace static int toNCHW[] = {0, 2, 3, 1}; +// This values are used to indicate layer output's data layout where it's possible. +enum DataLayout +{ + DATA_LAYOUT_NHWC, + DATA_LAYOUT_NCHW, + DATA_LAYOUT_UNKNOWN +}; + typedef std::vector > StrIntVector; struct Pin @@ -608,6 +616,31 @@ static void addConstNodes(const tensorflow::GraphDef& net, std::map } } +// If all inputs of specific layer have the same data layout we can say that +// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise. +static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std::map& data_layouts) +{ + int layout = DATA_LAYOUT_UNKNOWN; + std::map::const_iterator it; + for (int i = 0, n = layer.input_size(); i < n; ++i) + { + it = data_layouts.find(layer.input(i)); + if (it != data_layouts.end()) + { + if (it->second == DATA_LAYOUT_UNKNOWN) + return DATA_LAYOUT_UNKNOWN; + else if (it->second != layout) + { + if (layout == DATA_LAYOUT_UNKNOWN) + layout = it->second; + else + return DATA_LAYOUT_UNKNOWN; + } + } + } + return layout; +} + void TFImporter::populateNet(Net dstNet) { RemoveIdentityOps(netBin); @@ -619,6 +652,8 @@ void TFImporter::populateNet(Net dstNet) int layersSize = net.node_size(); + std::map data_layouts; + // find all Const layers for params std::map value_id; addConstNodes(netBin, value_id, layers_to_ignore); @@ -636,6 +671,8 @@ void TFImporter::populateNet(Net dstNet) if(layers_to_ignore.find(name) != layers_to_ignore.end()) continue; + data_layouts[name] = predictOutputDataLayout(layer, data_layouts); + if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative") { // The first node of dilated convolution subgraph. @@ -731,6 +768,19 @@ void TFImporter::populateNet(Net dstNet) // one input only connect(layer_id, dstNet, parsePin(input), id, 0); + + if (hasLayerAttr(layer, "data_format")) + { + std::string format = getLayerAttr(layer, "data_format").s(); + if (format == "NHWC") + data_layouts[name] = DATA_LAYOUT_NHWC; + else if (format == "NCHW") + data_layouts[name] = DATA_LAYOUT_NCHW; + else + CV_Error(Error::StsParseError, "Unknown data_format value: " + format); + } + else + data_layouts[name] = DATA_LAYOUT_NHWC; } else if (type == "BiasAdd" || type == "Add") { @@ -806,22 +856,55 @@ void TFImporter::populateNet(Net dstNet) // one input only int input_blob_index = kernel_blob_index == 0 ? 1 : 0; connect(layer_id, dstNet, parsePin(layer.input(input_blob_index)), id, 0); + data_layouts[name] = DATA_LAYOUT_UNKNOWN; } else if (type == "Reshape") { - layerParams.set("dim", parseDims(getConstBlob(layer, value_id, 1))); + Pin inpId = parsePin(layer.input(0)); + DictValue newShape = parseDims(getConstBlob(layer, value_id, 1)); + + if (newShape.size() != 4 && data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC) + { + LayerParams permLP; + int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC. + permLP.set("order", DictValue::arrayInt(order, 4)); + + std::string permName = name + "/nchw"; + CV_Assert(layer_id.find(permName) == layer_id.end()); + int permId = dstNet.addLayer(permName, "Permute", permLP); + layer_id[permName] = permId; + connect(layer_id, dstNet, inpId, permId, 0); + inpId = Pin(permName); + } + layerParams.set("dim", newShape); int id = dstNet.addLayer(name, "Reshape", layerParams); layer_id[name] = id; // one input only - connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); + connect(layer_id, dstNet, inpId, id, 0); + data_layouts[name] = DATA_LAYOUT_UNKNOWN; } else if (type == "Flatten") { + Pin inpId = parsePin(layer.input(0)); + if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC) + { + LayerParams permLP; + int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC. + permLP.set("order", DictValue::arrayInt(order, 4)); + + std::string permName = name + "/nchw"; + CV_Assert(layer_id.find(permName) == layer_id.end()); + int permId = dstNet.addLayer(permName, "Permute", permLP); + layer_id[permName] = permId; + connect(layer_id, dstNet, inpId, permId, 0); + inpId = Pin(permName); + } int id = dstNet.addLayer(name, "Flatten", layerParams); layer_id[name] = id; - connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); + connect(layer_id, dstNet, inpId, id, 0); + data_layouts[name] = DATA_LAYOUT_UNKNOWN; } else if (type == "Transpose") { @@ -830,16 +913,57 @@ void TFImporter::populateNet(Net dstNet) int* permData = (int*)perm.data; if (perm.total() == 4) { - for (int i = 0; i < 4; ++i) - permData[i] = toNCHW[permData[i]]; + // Only NHWC <-> NCHW permutations are allowed. OpenCV is always + // keep NCHW layout this way. + if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC) + { + if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2) + { + // in TensorFlow: NHWC->NCHW + // in OpenCV: NCHW->NCHW + data_layouts[name] = DATA_LAYOUT_NCHW; + } + else if (permData[0] == 0 && permData[1] == 1 && permData[2] == 2 && permData[3] == 3) + { + // in TensorFlow: NHWC->NHWC + // in OpenCV: NCHW->NCHW + data_layouts[name] = DATA_LAYOUT_NHWC; + } + else + CV_Assert(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed."); + } + else if (data_layouts[layer.input(0)] == DATA_LAYOUT_NCHW) + { + if (permData[0] == 0 && permData[1] == 2 && permData[2] == 3 && permData[3] == 1) + { + // in TensorFlow: NCHW->NHWC + // in OpenCV: NCHW->NCHW + data_layouts[name] = DATA_LAYOUT_NHWC; + } + else if (permData[0] == 0 && permData[1] == 1 && permData[2] == 2 && permData[3] == 3) + { + // in TensorFlow: NCHW->NCHW + // in OpenCV: NCHW->NCHW + data_layouts[name] = DATA_LAYOUT_NCHW; + } + else + CV_Assert(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed."); + } + int id = dstNet.addLayer(name, "Identity", layerParams); + layer_id[name] = id; + connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); } - layerParams.set("order", DictValue::arrayInt(permData, perm.total())); + else + { + layerParams.set("order", DictValue::arrayInt(permData, perm.total())); - int id = dstNet.addLayer(name, "Permute", layerParams); - layer_id[name] = id; + int id = dstNet.addLayer(name, "Permute", layerParams); + layer_id[name] = id; - // one input only - connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); + // one input only + connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); + data_layouts[name] = DATA_LAYOUT_UNKNOWN; + } } else if (type == "Const") { @@ -1207,6 +1331,7 @@ void TFImporter::populateNet(Net dstNet) // one input only connect(layer_id, dstNet, parsePin(layer.input(1)), id, 0); + data_layouts[name] = DATA_LAYOUT_UNKNOWN; } else if (type == "ResizeNearestNeighbor") { @@ -1258,6 +1383,7 @@ void TFImporter::populateNet(Net dstNet) layer_id[name] = id; connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); connect(layer_id, dstNet, parsePin(layer.input(1)), id, 1); + data_layouts[name] = DATA_LAYOUT_UNKNOWN; } else if (type == "DetectionOutput") { @@ -1288,6 +1414,7 @@ void TFImporter::populateNet(Net dstNet) layer_id[name] = id; for (int i = 0; i < 3; ++i) connect(layer_id, dstNet, parsePin(layer.input(i)), id, i); + data_layouts[name] = DATA_LAYOUT_UNKNOWN; } else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" || type == "Relu" || type == "Elu" || type == "Softmax" || diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 1badf74ab7..05adbf0bf5 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -159,6 +159,8 @@ TEST(Test_TensorFlow, deconvolution) TEST(Test_TensorFlow, matmul) { runTensorFlowNet("matmul"); + runTensorFlowNet("nhwc_reshape_matmul"); + runTensorFlowNet("nhwc_transpose_reshape_matmul"); } TEST(Test_TensorFlow, defun)