diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index 086f0ae637..e56073788b 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -770,43 +770,47 @@ void RemoveIdentityOps(tensorflow::GraphDef& net) } } -Mat getTensorContent(const tensorflow::TensorProto &tensor) +Mat getTensorContent(const tensorflow::TensorProto &tensor, bool copy) { const std::string& content = tensor.tensor_content(); + Mat m; switch (tensor.dtype()) { case tensorflow::DT_FLOAT: { if (!content.empty()) - return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone(); + m = Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()); else { const RepeatedField& field = tensor.float_val(); CV_Assert(!field.empty()); - return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone(); + m = Mat(1, field.size(), CV_32FC1, (void*)field.data()); } + break; } case tensorflow::DT_DOUBLE: { if (!content.empty()) - return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone(); + m = Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()); else { const RepeatedField& field = tensor.double_val(); CV_Assert(!field.empty()); - return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone(); + m = Mat(1, field.size(), CV_64FC1, (void*)field.data()); } + break; } case tensorflow::DT_INT32: { if (!content.empty()) - return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone(); + m = Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()); else { const RepeatedField& field = tensor.int_val(); CV_Assert(!field.empty()); - return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone(); + m = Mat(1, field.size(), CV_32SC1, (void*)field.data()); } + break; } case tensorflow::DT_HALF: { @@ -825,20 +829,20 @@ Mat getTensorContent(const tensorflow::TensorProto &tensor) } // Reinterpret as a signed shorts just for a convertFp16 call. Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data); - Mat floats(halfs.size(), CV_32FC1); - convertFp16(halfsSigned, floats); - return floats; + convertFp16(halfsSigned, m); + break; } case tensorflow::DT_QUINT8: { CV_Assert(!content.empty()); - return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone(); + m = Mat(1, content.size(), CV_8UC1, (void*)content.c_str()); + break; } default: CV_Error(Error::StsError, "Tensor's data type is not supported"); break; } - return Mat(); + return copy ? m.clone() : m; } void releaseTensor(tensorflow::TensorProto* tensor) diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.hpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.hpp index 5929d1f857..55f36cdb44 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.hpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.hpp @@ -21,7 +21,7 @@ void RemoveIdentityOps(tensorflow::GraphDef& net); void simplifySubgraphs(tensorflow::GraphDef& net); -Mat getTensorContent(const tensorflow::TensorProto &tensor); +Mat getTensorContent(const tensorflow::TensorProto &tensor, bool copy = true); void releaseTensor(tensorflow::TensorProto* tensor); diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index b802c4e131..058dfe0403 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -109,7 +109,7 @@ void parseTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob) dstBlob.create(shape, CV_32F); - Mat tensorContent = getTensorContent(tensor); + Mat tensorContent = getTensorContent(tensor, /*no copy*/false); int size = tensorContent.total(); CV_Assert(size == (int)dstBlob.total()); @@ -509,7 +509,7 @@ void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &ds dstBlob.create(shape, CV_32F); - Mat tensorContent = getTensorContent(tensor); + Mat tensorContent = getTensorContent(tensor, /*no copy*/false); int size = tensorContent.total(); CV_Assert(size == (int)dstBlob.total());