Remove extra weights cloning from TensorFlow importer

pull/14460/head
Dmitry Kurtaev 6 years ago
parent 77fa59c3da
commit a6ed8f268a
  1. 28
      modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
  2. 2
      modules/dnn/src/tensorflow/tf_graph_simplifier.hpp
  3. 4
      modules/dnn/src/tensorflow/tf_importer.cpp

@ -747,43 +747,47 @@ void RemoveIdentityOps(tensorflow::GraphDef& net)
} }
} }
Mat getTensorContent(const tensorflow::TensorProto &tensor) Mat getTensorContent(const tensorflow::TensorProto &tensor, bool copy)
{ {
const std::string& content = tensor.tensor_content(); const std::string& content = tensor.tensor_content();
Mat m;
switch (tensor.dtype()) switch (tensor.dtype())
{ {
case tensorflow::DT_FLOAT: case tensorflow::DT_FLOAT:
{ {
if (!content.empty()) if (!content.empty())
return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone(); m = Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str());
else else
{ {
const RepeatedField<float>& field = tensor.float_val(); const RepeatedField<float>& field = tensor.float_val();
CV_Assert(!field.empty()); CV_Assert(!field.empty());
return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone(); m = Mat(1, field.size(), CV_32FC1, (void*)field.data());
} }
break;
} }
case tensorflow::DT_DOUBLE: case tensorflow::DT_DOUBLE:
{ {
if (!content.empty()) if (!content.empty())
return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone(); m = Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str());
else else
{ {
const RepeatedField<double>& field = tensor.double_val(); const RepeatedField<double>& field = tensor.double_val();
CV_Assert(!field.empty()); CV_Assert(!field.empty());
return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone(); m = Mat(1, field.size(), CV_64FC1, (void*)field.data());
} }
break;
} }
case tensorflow::DT_INT32: case tensorflow::DT_INT32:
{ {
if (!content.empty()) if (!content.empty())
return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone(); m = Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str());
else else
{ {
const RepeatedField<int32_t>& field = tensor.int_val(); const RepeatedField<int32_t>& field = tensor.int_val();
CV_Assert(!field.empty()); CV_Assert(!field.empty());
return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone(); m = Mat(1, field.size(), CV_32SC1, (void*)field.data());
} }
break;
} }
case tensorflow::DT_HALF: case tensorflow::DT_HALF:
{ {
@ -802,20 +806,20 @@ Mat getTensorContent(const tensorflow::TensorProto &tensor)
} }
// Reinterpret as a signed shorts just for a convertFp16 call. // Reinterpret as a signed shorts just for a convertFp16 call.
Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data); Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data);
Mat floats(halfs.size(), CV_32FC1); convertFp16(halfsSigned, m);
convertFp16(halfsSigned, floats); break;
return floats;
} }
case tensorflow::DT_QUINT8: case tensorflow::DT_QUINT8:
{ {
CV_Assert(!content.empty()); CV_Assert(!content.empty());
return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone(); m = Mat(1, content.size(), CV_8UC1, (void*)content.c_str());
break;
} }
default: default:
CV_Error(Error::StsError, "Tensor's data type is not supported"); CV_Error(Error::StsError, "Tensor's data type is not supported");
break; break;
} }
return Mat(); return copy ? m.clone() : m;
} }
void releaseTensor(tensorflow::TensorProto* tensor) void releaseTensor(tensorflow::TensorProto* tensor)

@ -21,7 +21,7 @@ void RemoveIdentityOps(tensorflow::GraphDef& net);
void simplifySubgraphs(tensorflow::GraphDef& net); void simplifySubgraphs(tensorflow::GraphDef& net);
Mat getTensorContent(const tensorflow::TensorProto &tensor); Mat getTensorContent(const tensorflow::TensorProto &tensor, bool copy = true);
void releaseTensor(tensorflow::TensorProto* tensor); void releaseTensor(tensorflow::TensorProto* tensor);

@ -109,7 +109,7 @@ void parseTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
dstBlob.create(shape, CV_32F); dstBlob.create(shape, CV_32F);
Mat tensorContent = getTensorContent(tensor); Mat tensorContent = getTensorContent(tensor, /*no copy*/false);
int size = tensorContent.total(); int size = tensorContent.total();
CV_Assert(size == (int)dstBlob.total()); CV_Assert(size == (int)dstBlob.total());
@ -509,7 +509,7 @@ void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &ds
dstBlob.create(shape, CV_32F); dstBlob.create(shape, CV_32F);
Mat tensorContent = getTensorContent(tensor); Mat tensorContent = getTensorContent(tensor, /*no copy*/false);
int size = tensorContent.total(); int size = tensorContent.total();
CV_Assert(size == (int)dstBlob.total()); CV_Assert(size == (int)dstBlob.total());

Loading…
Cancel
Save