Read raw floats data from Caffe models

This commit is contained in:
Dmitry Kurtaev 2019-02-11 20:08:17 +03:00
parent 5f57842575
commit 1606137df2

View File

@ -259,16 +259,25 @@ public:
Mat(dstBlob.dims, &dstBlob.size[0], CV_32F, (void*)pbBlob.data().data()).copyTo(dstBlob); Mat(dstBlob.dims, &dstBlob.size[0], CV_32F, (void*)pbBlob.data().data()).copyTo(dstBlob);
} }
else else
{
CV_Assert(pbBlob.has_raw_data());
const std::string& raw_data = pbBlob.raw_data();
if (pbBlob.raw_data_type() == caffe::FLOAT16)
{ {
// Half precision floats. // Half precision floats.
CV_Assert(pbBlob.raw_data_type() == caffe::FLOAT16);
std::string raw_data = pbBlob.raw_data();
CV_Assert(raw_data.size() / 2 == (int)dstBlob.total()); CV_Assert(raw_data.size() / 2 == (int)dstBlob.total());
Mat halfs((int)shape.size(), &shape[0], CV_16SC1, (void*)raw_data.c_str()); Mat halfs((int)shape.size(), &shape[0], CV_16SC1, (void*)raw_data.c_str());
convertFp16(halfs, dstBlob); convertFp16(halfs, dstBlob);
} }
else if (pbBlob.raw_data_type() == caffe::FLOAT)
{
CV_Assert(raw_data.size() / 4 == (int)dstBlob.total());
Mat((int)shape.size(), &shape[0], CV_32FC1, (void*)raw_data.c_str()).copyTo(dstBlob);
}
else
CV_Error(Error::StsNotImplemented, "Unexpected blob data type");
}
} }
void extractBinaryLayerParams(const caffe::LayerParameter& layer, LayerParams& layerParams) void extractBinaryLayerParams(const caffe::LayerParameter& layer, LayerParams& layerParams)