Remove extra weights cloning from TensorFlow importer

This commit is contained in:
Dmitry Kurtaev 2019-04-30 19:18:41 +03:00
parent 77fa59c3da
commit a6ed8f268a
3 changed files with 19 additions and 15 deletions

View File

@ -747,43 +747,47 @@ void RemoveIdentityOps(tensorflow::GraphDef& net)
}
}
Mat getTensorContent(const tensorflow::TensorProto &tensor)
Mat getTensorContent(const tensorflow::TensorProto &tensor, bool copy)
{
const std::string& content = tensor.tensor_content();
Mat m;
switch (tensor.dtype())
{
case tensorflow::DT_FLOAT:
{
if (!content.empty())
return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone();
m = Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str());
else
{
const RepeatedField<float>& field = tensor.float_val();
CV_Assert(!field.empty());
return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone();
m = Mat(1, field.size(), CV_32FC1, (void*)field.data());
}
break;
}
case tensorflow::DT_DOUBLE:
{
if (!content.empty())
return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone();
m = Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str());
else
{
const RepeatedField<double>& field = tensor.double_val();
CV_Assert(!field.empty());
return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone();
m = Mat(1, field.size(), CV_64FC1, (void*)field.data());
}
break;
}
case tensorflow::DT_INT32:
{
if (!content.empty())
return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone();
m = Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str());
else
{
const RepeatedField<int32_t>& field = tensor.int_val();
CV_Assert(!field.empty());
return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone();
m = Mat(1, field.size(), CV_32SC1, (void*)field.data());
}
break;
}
case tensorflow::DT_HALF:
{
@ -802,20 +806,20 @@ Mat getTensorContent(const tensorflow::TensorProto &tensor)
}
// Reinterpret as a signed shorts just for a convertFp16 call.
Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data);
Mat floats(halfs.size(), CV_32FC1);
convertFp16(halfsSigned, floats);
return floats;
convertFp16(halfsSigned, m);
break;
}
case tensorflow::DT_QUINT8:
{
CV_Assert(!content.empty());
return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone();
m = Mat(1, content.size(), CV_8UC1, (void*)content.c_str());
break;
}
default:
CV_Error(Error::StsError, "Tensor's data type is not supported");
break;
}
return Mat();
return copy ? m.clone() : m;
}
void releaseTensor(tensorflow::TensorProto* tensor)

View File

@ -21,7 +21,7 @@ void RemoveIdentityOps(tensorflow::GraphDef& net);
void simplifySubgraphs(tensorflow::GraphDef& net);
Mat getTensorContent(const tensorflow::TensorProto &tensor);
Mat getTensorContent(const tensorflow::TensorProto &tensor, bool copy = true);
void releaseTensor(tensorflow::TensorProto* tensor);

View File

@ -109,7 +109,7 @@ void parseTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
dstBlob.create(shape, CV_32F);
Mat tensorContent = getTensorContent(tensor);
Mat tensorContent = getTensorContent(tensor, /*no copy*/false);
int size = tensorContent.total();
CV_Assert(size == (int)dstBlob.total());
@ -509,7 +509,7 @@ void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &ds
dstBlob.create(shape, CV_32F);
Mat tensorContent = getTensorContent(tensor);
Mat tensorContent = getTensorContent(tensor, /*no copy*/false);
int size = tensorContent.total();
CV_Assert(size == (int)dstBlob.total());