From 54f0616a13b9bde2c17d71096303ba3faab0f555 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Fri, 11 Aug 2017 16:23:41 +0300 Subject: [PATCH] Deconvolution layer from TensorFlow --- modules/dnn/src/tensorflow/tf_importer.cpp | 44 ++++++++++++++++++++++ modules/dnn/test/test_tf_importer.cpp | 5 +++ 2 files changed, 49 insertions(+) diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 0f07f33d3d..6c5faa4caa 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -863,6 +863,50 @@ void TFImporter::populateNet(Net dstNet) // one input only connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); } + else if (type == "Conv2DBackpropInput") + { + // op: "Conv2DBackpropInput" + // input: "conv2d_transpose/output_shape" + // input: "weights" + // input: "input" + if (layer.input_size() != 3) + CV_Error(Error::StsNotImplemented, + "Expected output shape, weights and input nodes"); + + layerParams.set("bias_term", false); + layerParams.blobs.resize(1); + + StrIntVector next_layers = getNextLayers(net, name, "BiasAdd"); + if (next_layers.size() == 1) + { + layerParams.set("bias_term", true); + layerParams.blobs.resize(2); + + int weights_layer_index = next_layers[0].second; + + blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]); + ExcludeLayer(net, weights_layer_index, 0, false); + layers_to_ignore[weights_layer_index] = next_layers[0].first; + } + + kernelFromTensor(getConstBlob(layer, value_id, 1), layerParams.blobs[0]); + // Swap just numbers of input and output channels. + std::swap(layerParams.blobs[0].size[0], layerParams.blobs[0].size[1]); + + const int* kshape = layerParams.blobs[0].size.p; + layerParams.set("kernel_h", kshape[2]); + layerParams.set("kernel_w", kshape[3]); + layerParams.set("num_output", kshape[0]); + + setStrides(layerParams, layer); + setPadding(layerParams, layer); + + int id = dstNet.addLayer(name, "Deconvolution", layerParams); + layer_id[name] = id; + + // one input only + connect(layer_id, dstNet, parsePin(layer.input(2)), id, 0); + } else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" || type == "Relu" || type == "Elu" || type == "Softmax" || type == "Identity") diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 8a6d495584..9df211cfbc 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -125,4 +125,9 @@ TEST(Test_TensorFlow, pooling) runTensorFlowNet("max_pool_odd_same"); } +TEST(Test_TensorFlow, deconvolution) +{ + runTensorFlowNet("deconvolution"); +} + }