From 38a49f92aba0212366e8ffe72458a5f6b29e75a7 Mon Sep 17 00:00:00 2001 From: Liubov Batanina Date: Fri, 22 Jan 2021 16:47:02 +0300 Subject: [PATCH] Added shared weights for MatMul --- modules/dnn/src/tensorflow/tf_importer.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 2484c8493c..057dc11289 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -1228,8 +1228,18 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_) int kernel_blob_index = -1; const tensorflow::TensorProto& kernelTensor = getConstBlob(layer, value_id, -1, &kernel_blob_index); - blobFromTensor(kernelTensor, layerParams.blobs[0]); - releaseTensor(const_cast(&kernelTensor)); + const String kernelTensorName = layer.input(kernel_blob_index); + std::map::iterator sharedWeightsIt = sharedWeights.find(kernelTensorName); + if (sharedWeightsIt == sharedWeights.end()) + { + blobFromTensor(kernelTensor, layerParams.blobs[0]); + releaseTensor(const_cast(&kernelTensor)); + sharedWeights[kernelTensorName] = layerParams.blobs[0]; + } + else + { + layerParams.blobs[0] = sharedWeightsIt->second; + } if (kernel_blob_index == 1) { // In this case output is computed by x*W formula - W should be transposed Mat data = layerParams.blobs[0].t();