Merge pull request #9349 from dkurt:tf_deconv

pull/9401/merge
Alexander Alekhin 8 years ago
commit 3202bbe17c
  1. 44
      modules/dnn/src/tensorflow/tf_importer.cpp
  2. 5
      modules/dnn/test/test_tf_importer.cpp

@ -863,6 +863,50 @@ void TFImporter::populateNet(Net dstNet)
// one input only // one input only
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
} }
else if (type == "Conv2DBackpropInput")
{
// op: "Conv2DBackpropInput"
// input: "conv2d_transpose/output_shape"
// input: "weights"
// input: "input"
if (layer.input_size() != 3)
CV_Error(Error::StsNotImplemented,
"Expected output shape, weights and input nodes");
layerParams.set("bias_term", false);
layerParams.blobs.resize(1);
StrIntVector next_layers = getNextLayers(net, name, "BiasAdd");
if (next_layers.size() == 1)
{
layerParams.set("bias_term", true);
layerParams.blobs.resize(2);
int weights_layer_index = next_layers[0].second;
blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]);
ExcludeLayer(net, weights_layer_index, 0, false);
layers_to_ignore[weights_layer_index] = next_layers[0].first;
}
kernelFromTensor(getConstBlob(layer, value_id, 1), layerParams.blobs[0]);
// Swap just numbers of input and output channels.
std::swap(layerParams.blobs[0].size[0], layerParams.blobs[0].size[1]);
const int* kshape = layerParams.blobs[0].size.p;
layerParams.set("kernel_h", kshape[2]);
layerParams.set("kernel_w", kshape[3]);
layerParams.set("num_output", kshape[0]);
setStrides(layerParams, layer);
setPadding(layerParams, layer);
int id = dstNet.addLayer(name, "Deconvolution", layerParams);
layer_id[name] = id;
// one input only
connect(layer_id, dstNet, parsePin(layer.input(2)), id, 0);
}
else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" || else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
type == "Relu" || type == "Elu" || type == "Softmax" || type == "Relu" || type == "Elu" || type == "Softmax" ||
type == "Identity") type == "Identity")

@ -125,4 +125,9 @@ TEST(Test_TensorFlow, pooling)
runTensorFlowNet("max_pool_odd_same"); runTensorFlowNet("max_pool_odd_same");
} }
TEST(Test_TensorFlow, deconvolution)
{
runTensorFlowNet("deconvolution");
}
} }

Loading…
Cancel
Save