From a6ed8f268a80be33aa0de10a7e2812536bcab0f9 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Tue, 30 Apr 2019 19:18:41 +0300 Subject: [PATCH] Remove extra weights cloning from TensorFlow importer --- .../src/tensorflow/tf_graph_simplifier.cpp | 28 +++++++++++-------- .../src/tensorflow/tf_graph_simplifier.hpp | 2 +- modules/dnn/src/tensorflow/tf_importer.cpp | 4 +-- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index 7f1001888a..a40da7e174 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -747,43 +747,47 @@ void RemoveIdentityOps(tensorflow::GraphDef& net) } } -Mat getTensorContent(const tensorflow::TensorProto &tensor) +Mat getTensorContent(const tensorflow::TensorProto &tensor, bool copy) { const std::string& content = tensor.tensor_content(); + Mat m; switch (tensor.dtype()) { case tensorflow::DT_FLOAT: { if (!content.empty()) - return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone(); + m = Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()); else { const RepeatedField& field = tensor.float_val(); CV_Assert(!field.empty()); - return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone(); + m = Mat(1, field.size(), CV_32FC1, (void*)field.data()); } + break; } case tensorflow::DT_DOUBLE: { if (!content.empty()) - return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone(); + m = Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()); else { const RepeatedField& field = tensor.double_val(); CV_Assert(!field.empty()); - return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone(); + m = Mat(1, field.size(), CV_64FC1, (void*)field.data()); } + break; } case tensorflow::DT_INT32: { if (!content.empty()) - return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone(); + m = Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()); else { const RepeatedField& field = tensor.int_val(); CV_Assert(!field.empty()); - return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone(); + m = Mat(1, field.size(), CV_32SC1, (void*)field.data()); } + break; } case tensorflow::DT_HALF: { @@ -802,20 +806,20 @@ Mat getTensorContent(const tensorflow::TensorProto &tensor) } // Reinterpret as a signed shorts just for a convertFp16 call. Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data); - Mat floats(halfs.size(), CV_32FC1); - convertFp16(halfsSigned, floats); - return floats; + convertFp16(halfsSigned, m); + break; } case tensorflow::DT_QUINT8: { CV_Assert(!content.empty()); - return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone(); + m = Mat(1, content.size(), CV_8UC1, (void*)content.c_str()); + break; } default: CV_Error(Error::StsError, "Tensor's data type is not supported"); break; } - return Mat(); + return copy ? m.clone() : m; } void releaseTensor(tensorflow::TensorProto* tensor) diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.hpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.hpp index 5929d1f857..55f36cdb44 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.hpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.hpp @@ -21,7 +21,7 @@ void RemoveIdentityOps(tensorflow::GraphDef& net); void simplifySubgraphs(tensorflow::GraphDef& net); -Mat getTensorContent(const tensorflow::TensorProto &tensor); +Mat getTensorContent(const tensorflow::TensorProto &tensor, bool copy = true); void releaseTensor(tensorflow::TensorProto* tensor); diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index a7a681c140..8acaaf8f7a 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -109,7 +109,7 @@ void parseTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob) dstBlob.create(shape, CV_32F); - Mat tensorContent = getTensorContent(tensor); + Mat tensorContent = getTensorContent(tensor, /*no copy*/false); int size = tensorContent.total(); CV_Assert(size == (int)dstBlob.total()); @@ -509,7 +509,7 @@ void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &ds dstBlob.create(shape, CV_32F); - Mat tensorContent = getTensorContent(tensor); + Mat tensorContent = getTensorContent(tensor, /*no copy*/false); int size = tensorContent.total(); CV_Assert(size == (int)dstBlob.total());