Fix dilated convolution from Keras

pull/11614/head
Dmitry Kurtaev 7 years ago
parent 06c1890639
commit 085be6a445
  1. 9
      modules/dnn/src/tensorflow/tf_importer.cpp
  2. 1
      modules/dnn/test/test_tf_importer.cpp

@ -644,8 +644,9 @@ void TFImporter::populateNet(Net dstNet)
CV_Assert(layer.input_size() == 3);
DictValue dilation = parseDims(getConstBlob(layer, value_id, 1));
CV_Assert(dilation.size() == 2 && dilation.get<int>(0) == dilation.get<int>(1));
layerParams.set("dilation", dilation.get<int>(0));
CV_Assert(dilation.size() == 2);
layerParams.set("dilation_h", dilation.get<int>(0));
layerParams.set("dilation_w", dilation.get<int>(1));
Mat paddings;
parseTensor<int>(getConstBlob(layer, value_id, 2), paddings);
@ -655,6 +656,10 @@ void TFImporter::populateNet(Net dstNet)
layerParams.set("pad_w", paddings.at<float>(2));
StrIntVector next_layers = getNextLayers(net, name, "Conv2D");
if (next_layers.empty())
{
next_layers = getNextLayers(net, name, "DepthwiseConv2dNative");
}
CV_Assert(next_layers.size() == 1);
layer = net.node(next_layers[0].second);
layers_to_ignore.insert(next_layers[0].first);

@ -124,6 +124,7 @@ TEST_P(Test_TensorFlow_layers, conv)
runTensorFlowNet("atrous_conv2d_valid", targetId);
runTensorFlowNet("atrous_conv2d_same", targetId);
runTensorFlowNet("depthwise_conv2d", targetId);
runTensorFlowNet("keras_atrous_conv2d_same", targetId);
}
TEST_P(Test_TensorFlow_layers, padding)

Loading…
Cancel
Save