|
|
|
@ -770,43 +770,47 @@ void RemoveIdentityOps(tensorflow::GraphDef& net) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
Mat getTensorContent(const tensorflow::TensorProto &tensor) |
|
|
|
|
Mat getTensorContent(const tensorflow::TensorProto &tensor, bool copy) |
|
|
|
|
{ |
|
|
|
|
const std::string& content = tensor.tensor_content(); |
|
|
|
|
Mat m; |
|
|
|
|
switch (tensor.dtype()) |
|
|
|
|
{ |
|
|
|
|
case tensorflow::DT_FLOAT: |
|
|
|
|
{ |
|
|
|
|
if (!content.empty()) |
|
|
|
|
return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone(); |
|
|
|
|
m = Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()); |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
const RepeatedField<float>& field = tensor.float_val(); |
|
|
|
|
CV_Assert(!field.empty()); |
|
|
|
|
return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone(); |
|
|
|
|
m = Mat(1, field.size(), CV_32FC1, (void*)field.data()); |
|
|
|
|
} |
|
|
|
|
break; |
|
|
|
|
} |
|
|
|
|
case tensorflow::DT_DOUBLE: |
|
|
|
|
{ |
|
|
|
|
if (!content.empty()) |
|
|
|
|
return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone(); |
|
|
|
|
m = Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()); |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
const RepeatedField<double>& field = tensor.double_val(); |
|
|
|
|
CV_Assert(!field.empty()); |
|
|
|
|
return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone(); |
|
|
|
|
m = Mat(1, field.size(), CV_64FC1, (void*)field.data()); |
|
|
|
|
} |
|
|
|
|
break; |
|
|
|
|
} |
|
|
|
|
case tensorflow::DT_INT32: |
|
|
|
|
{ |
|
|
|
|
if (!content.empty()) |
|
|
|
|
return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone(); |
|
|
|
|
m = Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()); |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
const RepeatedField<int32_t>& field = tensor.int_val(); |
|
|
|
|
CV_Assert(!field.empty()); |
|
|
|
|
return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone(); |
|
|
|
|
m = Mat(1, field.size(), CV_32SC1, (void*)field.data()); |
|
|
|
|
} |
|
|
|
|
break; |
|
|
|
|
} |
|
|
|
|
case tensorflow::DT_HALF: |
|
|
|
|
{ |
|
|
|
@ -825,20 +829,20 @@ Mat getTensorContent(const tensorflow::TensorProto &tensor) |
|
|
|
|
} |
|
|
|
|
// Reinterpret as a signed shorts just for a convertFp16 call.
|
|
|
|
|
Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data); |
|
|
|
|
Mat floats(halfs.size(), CV_32FC1); |
|
|
|
|
convertFp16(halfsSigned, floats); |
|
|
|
|
return floats; |
|
|
|
|
convertFp16(halfsSigned, m); |
|
|
|
|
break; |
|
|
|
|
} |
|
|
|
|
case tensorflow::DT_QUINT8: |
|
|
|
|
{ |
|
|
|
|
CV_Assert(!content.empty()); |
|
|
|
|
return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone(); |
|
|
|
|
m = Mat(1, content.size(), CV_8UC1, (void*)content.c_str()); |
|
|
|
|
break; |
|
|
|
|
} |
|
|
|
|
default: |
|
|
|
|
CV_Error(Error::StsError, "Tensor's data type is not supported"); |
|
|
|
|
break; |
|
|
|
|
} |
|
|
|
|
return Mat(); |
|
|
|
|
return copy ? m.clone() : m; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void releaseTensor(tensorflow::TensorProto* tensor) |
|
|
|
|