Merge pull request #22448 from Ichini24:reshape-permutations-fix

changed names of permutations if Reshpe is in NHWC
pull/22227/head
Alexander Smorkalov 2 years ago committed by GitHub
commit c2c8da2517
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 16
      modules/dnn/src/tensorflow/tf_importer.cpp
  2. 6
      modules/dnn/test/test_tf_importer.cpp

@ -1097,6 +1097,9 @@ void TFImporter::parseReshape(tensorflow::GraphDef& net, const tensorflow::NodeD
std::swap(*newShape.ptr<int32_t>(0, 1), *newShape.ptr<int32_t>(0, 2)); std::swap(*newShape.ptr<int32_t>(0, 1), *newShape.ptr<int32_t>(0, 2));
hasSwap = true; hasSwap = true;
} }
bool changedType{false};
if (inpLayout == DATA_LAYOUT_NHWC) if (inpLayout == DATA_LAYOUT_NHWC)
{ {
if (newShapeSize >= 2 || newShape.at<int>(1) == 1) if (newShapeSize >= 2 || newShape.at<int>(1) == 1)
@ -1110,23 +1113,28 @@ void TFImporter::parseReshape(tensorflow::GraphDef& net, const tensorflow::NodeD
else else
{ {
inpLayout = DATA_LAYOUT_NHWC; inpLayout = DATA_LAYOUT_NHWC;
changedType = newShapeSize == 4 && !hasSwap;
} }
} }
} }
layerParams.set("dim", DictValue::arrayInt<int*>(newShape.ptr<int>(), newShapeSize)); layerParams.set("dim", DictValue::arrayInt<int*>(newShape.ptr<int>(), newShapeSize));
int id = dstNet.addLayer(name, "Reshape", layerParams); std::string setName = changedType ? name + "/realReshape" : name;
layer_id[name] = id;
int id = dstNet.addLayer(setName, "Reshape", layerParams);
layer_id[setName] = id;
// one input only // one input only
connect(layer_id, dstNet, inpId, id, 0); connect(layer_id, dstNet, inpId, id, 0);
inpId = Pin(name); inpId = Pin(setName);
if ((inpLayout == DATA_LAYOUT_NHWC || inpLayout == DATA_LAYOUT_UNKNOWN || inpLayout == DATA_LAYOUT_PLANAR) && if ((inpLayout == DATA_LAYOUT_NHWC || inpLayout == DATA_LAYOUT_UNKNOWN || inpLayout == DATA_LAYOUT_PLANAR) &&
newShapeSize == 4 && !hasSwap) newShapeSize == 4 && !hasSwap)
{ {
int order[] = {0, 3, 1, 2}; // Transform back to OpenCV's NCHW. int order[] = {0, 3, 1, 2}; // Transform back to OpenCV's NCHW.
addPermuteLayer(order, name + "/nchw", inpId);
setName = changedType ? name : name + "/nchw";
addPermuteLayer(order, setName, inpId);
inpLayout = DATA_LAYOUT_NCHW; inpLayout = DATA_LAYOUT_NCHW;
} }

@ -337,6 +337,12 @@ TEST_P(Test_TensorFlow_layers, eltwise_mul_vec)
runTensorFlowNet("eltwise_mul_vec"); runTensorFlowNet("eltwise_mul_vec");
} }
TEST_P(Test_TensorFlow_layers, tf_reshape_nhwc)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
runTensorFlowNet("tf_reshape_nhwc");
}
TEST_P(Test_TensorFlow_layers, channel_broadcast) TEST_P(Test_TensorFlow_layers, channel_broadcast)
{ {

Loading…
Cancel
Save