mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #9994 from r2d3:dnn_memory_load
This commit is contained in:
commit
f37f4cf3b4
@ -644,11 +644,33 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
|||||||
*/
|
*/
|
||||||
CV_EXPORTS_W Net readNetFromCaffe(const String &prototxt, const String &caffeModel = String());
|
CV_EXPORTS_W Net readNetFromCaffe(const String &prototxt, const String &caffeModel = String());
|
||||||
|
|
||||||
|
/** @brief Reads a network model stored in Caffe model in memory.
|
||||||
|
* @details This is an overloaded member function, provided for convenience.
|
||||||
|
* It differs from the above function only in what argument(s) it accepts.
|
||||||
|
* @param bufferProto buffer containing the content of the .prototxt file
|
||||||
|
* @param lenProto length of bufferProto
|
||||||
|
* @param bufferModel buffer containing the content of the .caffemodel file
|
||||||
|
* @param lenModel length of bufferModel
|
||||||
|
*/
|
||||||
|
CV_EXPORTS Net readNetFromCaffe(const char *bufferProto, size_t lenProto,
|
||||||
|
const char *bufferModel = NULL, size_t lenModel = 0);
|
||||||
|
|
||||||
/** @brief Reads a network model stored in Tensorflow model file.
|
/** @brief Reads a network model stored in Tensorflow model file.
|
||||||
* @details This is shortcut consisting from createTensorflowImporter and Net::populateNet calls.
|
* @details This is shortcut consisting from createTensorflowImporter and Net::populateNet calls.
|
||||||
*/
|
*/
|
||||||
CV_EXPORTS_W Net readNetFromTensorflow(const String &model, const String &config = String());
|
CV_EXPORTS_W Net readNetFromTensorflow(const String &model, const String &config = String());
|
||||||
|
|
||||||
|
/** @brief Reads a network model stored in Tensorflow model in memory.
|
||||||
|
* @details This is an overloaded member function, provided for convenience.
|
||||||
|
* It differs from the above function only in what argument(s) it accepts.
|
||||||
|
* @param bufferModel buffer containing the content of the pb file
|
||||||
|
* @param lenModel length of bufferModel
|
||||||
|
* @param bufferConfig buffer containing the content of the pbtxt file
|
||||||
|
* @param lenConfig length of bufferConfig
|
||||||
|
*/
|
||||||
|
CV_EXPORTS Net readNetFromTensorflow(const char *bufferModel, size_t lenModel,
|
||||||
|
const char *bufferConfig = NULL, size_t lenConfig = 0);
|
||||||
|
|
||||||
/** @brief Reads a network model stored in Torch model file.
|
/** @brief Reads a network model stored in Torch model file.
|
||||||
* @details This is shortcut consisting from createTorchImporter and Net::populateNet calls.
|
* @details This is shortcut consisting from createTorchImporter and Net::populateNet calls.
|
||||||
*/
|
*/
|
||||||
|
@ -92,6 +92,17 @@ public:
|
|||||||
ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);
|
ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CaffeImporter(const char *dataProto, size_t lenProto,
|
||||||
|
const char *dataModel, size_t lenModel)
|
||||||
|
{
|
||||||
|
CV_TRACE_FUNCTION();
|
||||||
|
|
||||||
|
ReadNetParamsFromTextBufferOrDie(dataProto, lenProto, &net);
|
||||||
|
|
||||||
|
if (dataModel != NULL && lenModel > 0)
|
||||||
|
ReadNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBinary);
|
||||||
|
}
|
||||||
|
|
||||||
void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams ¶ms)
|
void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams ¶ms)
|
||||||
{
|
{
|
||||||
const Reflection *refl = msg.GetReflection();
|
const Reflection *refl = msg.GetReflection();
|
||||||
@ -398,6 +409,15 @@ Net readNetFromCaffe(const String &prototxt, const String &caffeModel /*= String
|
|||||||
return net;
|
return net;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Net readNetFromCaffe(const char *bufferProto, size_t lenProto,
|
||||||
|
const char *bufferModel, size_t lenModel)
|
||||||
|
{
|
||||||
|
CaffeImporter caffeImporter(bufferProto, lenProto, bufferModel, lenModel);
|
||||||
|
Net net;
|
||||||
|
caffeImporter.populateNet(net);
|
||||||
|
return net;
|
||||||
|
}
|
||||||
|
|
||||||
#endif //HAVE_PROTOBUF
|
#endif //HAVE_PROTOBUF
|
||||||
|
|
||||||
CV__DNN_EXPERIMENTAL_NS_END
|
CV__DNN_EXPERIMENTAL_NS_END
|
||||||
|
@ -1107,28 +1107,37 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) {
|
|||||||
|
|
||||||
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
|
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
|
||||||
|
|
||||||
|
bool ReadProtoFromBinary(ZeroCopyInputStream* input, Message *proto) {
|
||||||
|
CodedInputStream coded_input(input);
|
||||||
|
coded_input.SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
|
||||||
|
|
||||||
|
return proto->ParseFromCodedStream(&coded_input);
|
||||||
|
}
|
||||||
|
|
||||||
bool ReadProtoFromTextFile(const char* filename, Message* proto) {
|
bool ReadProtoFromTextFile(const char* filename, Message* proto) {
|
||||||
std::ifstream fs(filename, std::ifstream::in);
|
std::ifstream fs(filename, std::ifstream::in);
|
||||||
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
|
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
|
||||||
IstreamInputStream input(&fs);
|
IstreamInputStream input(&fs);
|
||||||
bool success = google::protobuf::TextFormat::Parse(&input, proto);
|
return google::protobuf::TextFormat::Parse(&input, proto);
|
||||||
fs.close();
|
|
||||||
return success;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
|
bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
|
||||||
std::ifstream fs(filename, std::ifstream::in | std::ifstream::binary);
|
std::ifstream fs(filename, std::ifstream::in | std::ifstream::binary);
|
||||||
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
|
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
|
||||||
ZeroCopyInputStream* raw_input = new IstreamInputStream(&fs);
|
IstreamInputStream raw_input(&fs);
|
||||||
CodedInputStream* coded_input = new CodedInputStream(raw_input);
|
|
||||||
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
|
|
||||||
|
|
||||||
bool success = proto->ParseFromCodedStream(coded_input);
|
return ReadProtoFromBinary(&raw_input, proto);
|
||||||
|
}
|
||||||
|
|
||||||
delete coded_input;
|
bool ReadProtoFromTextBuffer(const char* data, size_t len, Message* proto) {
|
||||||
delete raw_input;
|
ArrayInputStream input(data, len);
|
||||||
fs.close();
|
return google::protobuf::TextFormat::Parse(&input, proto);
|
||||||
return success;
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool ReadProtoFromBinaryBuffer(const char* data, size_t len, Message* proto) {
|
||||||
|
ArrayInputStream raw_input(data, len);
|
||||||
|
return ReadProtoFromBinary(&raw_input, proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReadNetParamsFromTextFileOrDie(const char* param_file,
|
void ReadNetParamsFromTextFileOrDie(const char* param_file,
|
||||||
@ -1138,6 +1147,13 @@ void ReadNetParamsFromTextFileOrDie(const char* param_file,
|
|||||||
UpgradeNetAsNeeded(param_file, param);
|
UpgradeNetAsNeeded(param_file, param);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ReadNetParamsFromTextBufferOrDie(const char* data, size_t len,
|
||||||
|
NetParameter* param) {
|
||||||
|
CHECK(ReadProtoFromTextBuffer(data, len, param))
|
||||||
|
<< "Failed to parse NetParameter buffer";
|
||||||
|
UpgradeNetAsNeeded("memory buffer", param);
|
||||||
|
}
|
||||||
|
|
||||||
void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
|
void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
|
||||||
NetParameter* param) {
|
NetParameter* param) {
|
||||||
CHECK(ReadProtoFromBinaryFile(param_file, param))
|
CHECK(ReadProtoFromBinaryFile(param_file, param))
|
||||||
@ -1145,6 +1161,13 @@ void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
|
|||||||
UpgradeNetAsNeeded(param_file, param);
|
UpgradeNetAsNeeded(param_file, param);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ReadNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
|
||||||
|
NetParameter* param) {
|
||||||
|
CHECK(ReadProtoFromBinaryBuffer(data, len, param))
|
||||||
|
<< "Failed to parse NetParameter buffer";
|
||||||
|
UpgradeNetAsNeeded("memory buffer", param);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -102,6 +102,18 @@ void ReadNetParamsFromTextFileOrDie(const char* param_file,
|
|||||||
void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
|
void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
|
||||||
caffe::NetParameter* param);
|
caffe::NetParameter* param);
|
||||||
|
|
||||||
|
// Read parameters from a memory buffer into a NetParammeter proto message.
|
||||||
|
void ReadNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
|
||||||
|
caffe::NetParameter* param);
|
||||||
|
void ReadNetParamsFromTextBufferOrDie(const char* data, size_t len,
|
||||||
|
caffe::NetParameter* param);
|
||||||
|
|
||||||
|
// Utility functions used internally by Caffe and TensorFlow loaders
|
||||||
|
bool ReadProtoFromTextFile(const char* filename, ::google::protobuf::Message* proto);
|
||||||
|
bool ReadProtoFromBinaryFile(const char* filename, ::google::protobuf::Message* proto);
|
||||||
|
bool ReadProtoFromTextBuffer(const char* data, size_t len, ::google::protobuf::Message* proto);
|
||||||
|
bool ReadProtoFromBinaryBuffer(const char* data, size_t len, ::google::protobuf::Message* proto);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -449,6 +449,9 @@ void ExcludeLayer(tensorflow::GraphDef& net, const int layer_index, const int in
|
|||||||
class TFImporter : public Importer {
|
class TFImporter : public Importer {
|
||||||
public:
|
public:
|
||||||
TFImporter(const char *model, const char *config = NULL);
|
TFImporter(const char *model, const char *config = NULL);
|
||||||
|
TFImporter(const char *dataModel, size_t lenModel,
|
||||||
|
const char *dataConfig = NULL, size_t lenConfig = 0);
|
||||||
|
|
||||||
void populateNet(Net dstNet);
|
void populateNet(Net dstNet);
|
||||||
~TFImporter() {}
|
~TFImporter() {}
|
||||||
|
|
||||||
@ -479,6 +482,15 @@ TFImporter::TFImporter(const char *model, const char *config)
|
|||||||
ReadTFNetParamsFromTextFileOrDie(config, &netTxt);
|
ReadTFNetParamsFromTextFileOrDie(config, &netTxt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TFImporter::TFImporter(const char *dataModel, size_t lenModel,
|
||||||
|
const char *dataConfig, size_t lenConfig)
|
||||||
|
{
|
||||||
|
if (dataModel != NULL && lenModel > 0)
|
||||||
|
ReadTFNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBin);
|
||||||
|
if (dataConfig != NULL && lenConfig > 0)
|
||||||
|
ReadTFNetParamsFromTextBufferOrDie(dataConfig, lenConfig, &netTxt);
|
||||||
|
}
|
||||||
|
|
||||||
void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
|
void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
|
||||||
{
|
{
|
||||||
MatShape shape;
|
MatShape shape;
|
||||||
@ -1326,5 +1338,14 @@ Net readNetFromTensorflow(const String &model, const String &config)
|
|||||||
return net;
|
return net;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Net readNetFromTensorflow(const char* bufferModel, size_t lenModel,
|
||||||
|
const char* bufferConfig, size_t lenConfig)
|
||||||
|
{
|
||||||
|
TFImporter importer(bufferModel, lenModel, bufferConfig, lenConfig);
|
||||||
|
Net net;
|
||||||
|
importer.populateNet(net);
|
||||||
|
return net;
|
||||||
|
}
|
||||||
|
|
||||||
CV__DNN_EXPERIMENTAL_NS_END
|
CV__DNN_EXPERIMENTAL_NS_END
|
||||||
}} // namespace
|
}} // namespace
|
||||||
|
@ -23,6 +23,7 @@ Implementation of various functions which are related to Tensorflow models readi
|
|||||||
|
|
||||||
#include "graph.pb.h"
|
#include "graph.pb.h"
|
||||||
#include "tf_io.hpp"
|
#include "tf_io.hpp"
|
||||||
|
#include "../caffe/caffe_io.hpp"
|
||||||
#include "../caffe/glog_emulator.hpp"
|
#include "../caffe/glog_emulator.hpp"
|
||||||
|
|
||||||
namespace cv {
|
namespace cv {
|
||||||
@ -36,41 +37,28 @@ using namespace ::google::protobuf::io;
|
|||||||
|
|
||||||
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
|
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
|
||||||
|
|
||||||
// TODO: remove Caffe duplicate
|
|
||||||
bool ReadProtoFromBinaryFileTF(const char* filename, Message* proto) {
|
|
||||||
std::ifstream fs(filename, std::ifstream::in | std::ifstream::binary);
|
|
||||||
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
|
|
||||||
ZeroCopyInputStream* raw_input = new IstreamInputStream(&fs);
|
|
||||||
CodedInputStream* coded_input = new CodedInputStream(raw_input);
|
|
||||||
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
|
|
||||||
|
|
||||||
bool success = proto->ParseFromCodedStream(coded_input);
|
|
||||||
|
|
||||||
delete coded_input;
|
|
||||||
delete raw_input;
|
|
||||||
fs.close();
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ReadProtoFromTextFileTF(const char* filename, Message* proto) {
|
|
||||||
std::ifstream fs(filename, std::ifstream::in);
|
|
||||||
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
|
|
||||||
IstreamInputStream input(&fs);
|
|
||||||
bool success = google::protobuf::TextFormat::Parse(&input, proto);
|
|
||||||
fs.close();
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
|
void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
|
||||||
tensorflow::GraphDef* param) {
|
tensorflow::GraphDef* param) {
|
||||||
CHECK(ReadProtoFromBinaryFileTF(param_file, param))
|
CHECK(ReadProtoFromBinaryFile(param_file, param))
|
||||||
<< "Failed to parse GraphDef file: " << param_file;
|
<< "Failed to parse GraphDef file: " << param_file;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
|
||||||
|
tensorflow::GraphDef* param) {
|
||||||
|
CHECK(ReadProtoFromBinaryBuffer(data, len, param))
|
||||||
|
<< "Failed to parse GraphDef buffer";
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
|
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
|
||||||
tensorflow::GraphDef* param) {
|
tensorflow::GraphDef* param) {
|
||||||
CHECK(ReadProtoFromTextFileTF(param_file, param))
|
CHECK(ReadProtoFromTextFile(param_file, param))
|
||||||
<< "Failed to parse GraphDef file: " << param_file;
|
<< "Failed to parse GraphDef file: " << param_file;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
|
||||||
|
tensorflow::GraphDef* param) {
|
||||||
|
CHECK(ReadProtoFromTextBuffer(data, len, param))
|
||||||
|
<< "Failed to parse GraphDef buffer";
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,13 @@ void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
|
|||||||
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
|
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
|
||||||
tensorflow::GraphDef* param);
|
tensorflow::GraphDef* param);
|
||||||
|
|
||||||
|
// Read parameters from a memory buffer into a GraphDef proto message.
|
||||||
|
void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
|
||||||
|
tensorflow::GraphDef* param);
|
||||||
|
|
||||||
|
void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
|
||||||
|
tensorflow::GraphDef* param);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,6 +55,24 @@ static std::string _tf(TString filename)
|
|||||||
return (getOpenCVExtraDir() + "/dnn/") + filename;
|
return (getOpenCVExtraDir() + "/dnn/") + filename;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Test_Caffe, memory_read)
|
||||||
|
{
|
||||||
|
const string proto = findDataFile("dnn/bvlc_googlenet.prototxt", false);
|
||||||
|
const string model = findDataFile("dnn/bvlc_googlenet.caffemodel", false);
|
||||||
|
|
||||||
|
string dataProto;
|
||||||
|
ASSERT_TRUE(readFileInMemory(proto, dataProto));
|
||||||
|
string dataModel;
|
||||||
|
ASSERT_TRUE(readFileInMemory(model, dataModel));
|
||||||
|
|
||||||
|
Net net = readNetFromCaffe(dataProto.c_str(), dataProto.size());
|
||||||
|
ASSERT_FALSE(net.empty());
|
||||||
|
|
||||||
|
Net net2 = readNetFromCaffe(dataProto.c_str(), dataProto.size(),
|
||||||
|
dataModel.c_str(), dataModel.size());
|
||||||
|
ASSERT_FALSE(net2.empty());
|
||||||
|
}
|
||||||
|
|
||||||
TEST(Test_Caffe, read_gtsrb)
|
TEST(Test_Caffe, read_gtsrb)
|
||||||
{
|
{
|
||||||
Net net = readNetFromCaffe(_tf("gtsrb.prototxt"));
|
Net net = readNetFromCaffe(_tf("gtsrb.prototxt"));
|
||||||
@ -67,13 +85,26 @@ TEST(Test_Caffe, read_googlenet)
|
|||||||
ASSERT_FALSE(net.empty());
|
ASSERT_FALSE(net.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Reproducibility_AlexNet, Accuracy)
|
typedef testing::TestWithParam<tuple<bool> > Reproducibility_AlexNet;
|
||||||
|
TEST_P(Reproducibility_AlexNet, Accuracy)
|
||||||
{
|
{
|
||||||
|
bool readFromMemory = get<0>(GetParam());
|
||||||
Net net;
|
Net net;
|
||||||
{
|
{
|
||||||
const string proto = findDataFile("dnn/bvlc_alexnet.prototxt", false);
|
const string proto = findDataFile("dnn/bvlc_alexnet.prototxt", false);
|
||||||
const string model = findDataFile("dnn/bvlc_alexnet.caffemodel", false);
|
const string model = findDataFile("dnn/bvlc_alexnet.caffemodel", false);
|
||||||
net = readNetFromCaffe(proto, model);
|
if (readFromMemory)
|
||||||
|
{
|
||||||
|
string dataProto;
|
||||||
|
ASSERT_TRUE(readFileInMemory(proto, dataProto));
|
||||||
|
string dataModel;
|
||||||
|
ASSERT_TRUE(readFileInMemory(model, dataModel));
|
||||||
|
|
||||||
|
net = readNetFromCaffe(dataProto.c_str(), dataProto.size(),
|
||||||
|
dataModel.c_str(), dataModel.size());
|
||||||
|
}
|
||||||
|
else
|
||||||
|
net = readNetFromCaffe(proto, model);
|
||||||
ASSERT_FALSE(net.empty());
|
ASSERT_FALSE(net.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,6 +117,8 @@ TEST(Reproducibility_AlexNet, Accuracy)
|
|||||||
normAssert(ref, out);
|
normAssert(ref, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(Test_Caffe, Reproducibility_AlexNet, testing::Values(true, false));
|
||||||
|
|
||||||
#if !defined(_WIN32) || defined(_WIN64)
|
#if !defined(_WIN32) || defined(_WIN64)
|
||||||
TEST(Reproducibility_FCN, Accuracy)
|
TEST(Reproducibility_FCN, Accuracy)
|
||||||
{
|
{
|
||||||
|
@ -57,4 +57,23 @@ inline void normAssert(cv::InputArray ref, cv::InputArray test, const char *comm
|
|||||||
EXPECT_LE(normInf, lInf) << comment;
|
EXPECT_LE(normInf, lInf) << comment;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool readFileInMemory(const std::string& filename, std::string& content)
|
||||||
|
{
|
||||||
|
std::ios::openmode mode = std::ios::in | std::ios::binary;
|
||||||
|
std::ifstream ifs(filename.c_str(), mode);
|
||||||
|
if (!ifs.is_open())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
content.clear();
|
||||||
|
|
||||||
|
ifs.seekg(0, std::ios::end);
|
||||||
|
content.reserve(ifs.tellg());
|
||||||
|
ifs.seekg(0, std::ios::beg);
|
||||||
|
|
||||||
|
content.assign((std::istreambuf_iterator<char>(ifs)),
|
||||||
|
std::istreambuf_iterator<char>());
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -75,14 +75,32 @@ static std::string path(const std::string& file)
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void runTensorFlowNet(const std::string& prefix, bool hasText = false,
|
static void runTensorFlowNet(const std::string& prefix, bool hasText = false,
|
||||||
double l1 = 1e-5, double lInf = 1e-4)
|
double l1 = 1e-5, double lInf = 1e-4,
|
||||||
|
bool memoryLoad = false)
|
||||||
{
|
{
|
||||||
std::string netPath = path(prefix + "_net.pb");
|
std::string netPath = path(prefix + "_net.pb");
|
||||||
std::string netConfig = (hasText ? path(prefix + "_net.pbtxt") : "");
|
std::string netConfig = (hasText ? path(prefix + "_net.pbtxt") : "");
|
||||||
std::string inpPath = path(prefix + "_in.npy");
|
std::string inpPath = path(prefix + "_in.npy");
|
||||||
std::string outPath = path(prefix + "_out.npy");
|
std::string outPath = path(prefix + "_out.npy");
|
||||||
|
|
||||||
Net net = readNetFromTensorflow(netPath, netConfig);
|
Net net;
|
||||||
|
if (memoryLoad)
|
||||||
|
{
|
||||||
|
// Load files into a memory buffers
|
||||||
|
string dataModel;
|
||||||
|
ASSERT_TRUE(readFileInMemory(netPath, dataModel));
|
||||||
|
|
||||||
|
string dataConfig;
|
||||||
|
if (hasText)
|
||||||
|
ASSERT_TRUE(readFileInMemory(netConfig, dataConfig));
|
||||||
|
|
||||||
|
net = readNetFromTensorflow(dataModel.c_str(), dataModel.size(),
|
||||||
|
dataConfig.c_str(), dataConfig.size());
|
||||||
|
}
|
||||||
|
else
|
||||||
|
net = readNetFromTensorflow(netPath, netConfig);
|
||||||
|
|
||||||
|
ASSERT_FALSE(net.empty());
|
||||||
|
|
||||||
cv::Mat input = blobFromNPY(inpPath);
|
cv::Mat input = blobFromNPY(inpPath);
|
||||||
cv::Mat target = blobFromNPY(outPath);
|
cv::Mat target = blobFromNPY(outPath);
|
||||||
@ -216,4 +234,15 @@ TEST(Test_TensorFlow, resize_nearest_neighbor)
|
|||||||
runTensorFlowNet("resize_nearest_neighbor");
|
runTensorFlowNet("resize_nearest_neighbor");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Test_TensorFlow, memory_read)
|
||||||
|
{
|
||||||
|
double l1 = 1e-5;
|
||||||
|
double lInf = 1e-4;
|
||||||
|
runTensorFlowNet("lstm", true, l1, lInf, true);
|
||||||
|
|
||||||
|
runTensorFlowNet("batch_norm", false, l1, lInf, true);
|
||||||
|
runTensorFlowNet("fused_batch_norm", false, l1, lInf, true);
|
||||||
|
runTensorFlowNet("batch_norm_text", true, l1, lInf, true);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user