From 337452b4c036311f5a0c77db79bf713cc84fc621 Mon Sep 17 00:00:00 2001 From: anton Date: Tue, 30 Aug 2022 18:43:37 +0200 Subject: [PATCH] changed names of permutations if Reshpe is in NHWC --- modules/dnn/src/tensorflow/tf_importer.cpp | 16 ++++++++++++---- modules/dnn/test/test_tf_importer.cpp | 6 ++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 80a8b6dfc5..96e0af99ec 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -1097,6 +1097,9 @@ void TFImporter::parseReshape(tensorflow::GraphDef& net, const tensorflow::NodeD std::swap(*newShape.ptr(0, 1), *newShape.ptr(0, 2)); hasSwap = true; } + + bool changedType{false}; + if (inpLayout == DATA_LAYOUT_NHWC) { if (newShapeSize >= 2 || newShape.at(1) == 1) @@ -1110,23 +1113,28 @@ void TFImporter::parseReshape(tensorflow::GraphDef& net, const tensorflow::NodeD else { inpLayout = DATA_LAYOUT_NHWC; + changedType = newShapeSize == 4 && !hasSwap; } } } layerParams.set("dim", DictValue::arrayInt(newShape.ptr(), newShapeSize)); - int id = dstNet.addLayer(name, "Reshape", layerParams); - layer_id[name] = id; + std::string setName = changedType ? name + "/realReshape" : name; + + int id = dstNet.addLayer(setName, "Reshape", layerParams); + layer_id[setName] = id; // one input only connect(layer_id, dstNet, inpId, id, 0); - inpId = Pin(name); + inpId = Pin(setName); if ((inpLayout == DATA_LAYOUT_NHWC || inpLayout == DATA_LAYOUT_UNKNOWN || inpLayout == DATA_LAYOUT_PLANAR) && newShapeSize == 4 && !hasSwap) { int order[] = {0, 3, 1, 2}; // Transform back to OpenCV's NCHW. - addPermuteLayer(order, name + "/nchw", inpId); + + setName = changedType ? name : name + "/nchw"; + addPermuteLayer(order, setName, inpId); inpLayout = DATA_LAYOUT_NCHW; } diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 72a8989f6a..97060df563 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -337,6 +337,12 @@ TEST_P(Test_TensorFlow_layers, eltwise_mul_vec) runTensorFlowNet("eltwise_mul_vec"); } +TEST_P(Test_TensorFlow_layers, tf_reshape_nhwc) +{ + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER); + runTensorFlowNet("tf_reshape_nhwc"); +} TEST_P(Test_TensorFlow_layers, channel_broadcast) {