mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 22:44:02 +08:00
Remove extra weights cloning from TensorFlow importer
This commit is contained in:
parent
77fa59c3da
commit
a6ed8f268a
@ -747,43 +747,47 @@ void RemoveIdentityOps(tensorflow::GraphDef& net)
|
||||
}
|
||||
}
|
||||
|
||||
Mat getTensorContent(const tensorflow::TensorProto &tensor)
|
||||
Mat getTensorContent(const tensorflow::TensorProto &tensor, bool copy)
|
||||
{
|
||||
const std::string& content = tensor.tensor_content();
|
||||
Mat m;
|
||||
switch (tensor.dtype())
|
||||
{
|
||||
case tensorflow::DT_FLOAT:
|
||||
{
|
||||
if (!content.empty())
|
||||
return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone();
|
||||
m = Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str());
|
||||
else
|
||||
{
|
||||
const RepeatedField<float>& field = tensor.float_val();
|
||||
CV_Assert(!field.empty());
|
||||
return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone();
|
||||
m = Mat(1, field.size(), CV_32FC1, (void*)field.data());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case tensorflow::DT_DOUBLE:
|
||||
{
|
||||
if (!content.empty())
|
||||
return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone();
|
||||
m = Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str());
|
||||
else
|
||||
{
|
||||
const RepeatedField<double>& field = tensor.double_val();
|
||||
CV_Assert(!field.empty());
|
||||
return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone();
|
||||
m = Mat(1, field.size(), CV_64FC1, (void*)field.data());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case tensorflow::DT_INT32:
|
||||
{
|
||||
if (!content.empty())
|
||||
return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone();
|
||||
m = Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str());
|
||||
else
|
||||
{
|
||||
const RepeatedField<int32_t>& field = tensor.int_val();
|
||||
CV_Assert(!field.empty());
|
||||
return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone();
|
||||
m = Mat(1, field.size(), CV_32SC1, (void*)field.data());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case tensorflow::DT_HALF:
|
||||
{
|
||||
@ -802,20 +806,20 @@ Mat getTensorContent(const tensorflow::TensorProto &tensor)
|
||||
}
|
||||
// Reinterpret as a signed shorts just for a convertFp16 call.
|
||||
Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data);
|
||||
Mat floats(halfs.size(), CV_32FC1);
|
||||
convertFp16(halfsSigned, floats);
|
||||
return floats;
|
||||
convertFp16(halfsSigned, m);
|
||||
break;
|
||||
}
|
||||
case tensorflow::DT_QUINT8:
|
||||
{
|
||||
CV_Assert(!content.empty());
|
||||
return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone();
|
||||
m = Mat(1, content.size(), CV_8UC1, (void*)content.c_str());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
CV_Error(Error::StsError, "Tensor's data type is not supported");
|
||||
break;
|
||||
}
|
||||
return Mat();
|
||||
return copy ? m.clone() : m;
|
||||
}
|
||||
|
||||
void releaseTensor(tensorflow::TensorProto* tensor)
|
||||
|
@ -21,7 +21,7 @@ void RemoveIdentityOps(tensorflow::GraphDef& net);
|
||||
|
||||
void simplifySubgraphs(tensorflow::GraphDef& net);
|
||||
|
||||
Mat getTensorContent(const tensorflow::TensorProto &tensor);
|
||||
Mat getTensorContent(const tensorflow::TensorProto &tensor, bool copy = true);
|
||||
|
||||
void releaseTensor(tensorflow::TensorProto* tensor);
|
||||
|
||||
|
@ -109,7 +109,7 @@ void parseTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
|
||||
|
||||
dstBlob.create(shape, CV_32F);
|
||||
|
||||
Mat tensorContent = getTensorContent(tensor);
|
||||
Mat tensorContent = getTensorContent(tensor, /*no copy*/false);
|
||||
int size = tensorContent.total();
|
||||
CV_Assert(size == (int)dstBlob.total());
|
||||
|
||||
@ -509,7 +509,7 @@ void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &ds
|
||||
|
||||
dstBlob.create(shape, CV_32F);
|
||||
|
||||
Mat tensorContent = getTensorContent(tensor);
|
||||
Mat tensorContent = getTensorContent(tensor, /*no copy*/false);
|
||||
int size = tensorContent.total();
|
||||
CV_Assert(size == (int)dstBlob.total());
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user