mirror of
https://github.com/opencv/opencv.git
synced 2024-11-28 21:20:18 +08:00
add loading TensorFlow/Caffe net from memory buffer
add a corresponding test
This commit is contained in:
parent
6e4f9433d0
commit
f723cede2e
@ -634,11 +634,33 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
*/
|
||||
CV_EXPORTS_W Net readNetFromCaffe(const String &prototxt, const String &caffeModel = String());
|
||||
|
||||
/** @brief Reads a network model stored in Caffe model in memory.
|
||||
* @details This is an overloaded member function, provided for convenience.
|
||||
* It differs from the above function only in what argument(s) it accepts.
|
||||
* @param bufferProto buffer containing the content of the .prototxt file
|
||||
* @param lenProto length of bufferProto
|
||||
* @param bufferModel buffer containing the content of the .caffemodel file
|
||||
* @param lenModel length of bufferModel
|
||||
*/
|
||||
CV_EXPORTS Net readNetFromCaffe(const char *bufferProto, size_t lenProto,
|
||||
const char *bufferModel = NULL, size_t lenModel = 0);
|
||||
|
||||
/** @brief Reads a network model stored in Tensorflow model file.
|
||||
* @details This is shortcut consisting from createTensorflowImporter and Net::populateNet calls.
|
||||
*/
|
||||
CV_EXPORTS_W Net readNetFromTensorflow(const String &model, const String &config = String());
|
||||
|
||||
/** @brief Reads a network model stored in Tensorflow model in memory.
|
||||
* @details This is an overloaded member function, provided for convenience.
|
||||
* It differs from the above function only in what argument(s) it accepts.
|
||||
* @param bufferModel buffer containing the content of the pb file
|
||||
* @param lenModel length of bufferModel
|
||||
* @param bufferConfig buffer containing the content of the pbtxt file
|
||||
* @param lenConfig length of bufferConfig
|
||||
*/
|
||||
CV_EXPORTS Net readNetFromTensorflow(const char *bufferModel, size_t lenModel,
|
||||
const char *bufferConfig = NULL, size_t lenConfig = 0);
|
||||
|
||||
/** @brief Reads a network model stored in Torch model file.
|
||||
* @details This is shortcut consisting from createTorchImporter and Net::populateNet calls.
|
||||
*/
|
||||
|
@ -94,6 +94,17 @@ public:
|
||||
ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);
|
||||
}
|
||||
|
||||
CaffeImporter(const char *dataProto, size_t lenProto,
|
||||
const char *dataModel, size_t lenModel)
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
|
||||
ReadNetParamsFromTextBufferOrDie(dataProto, lenProto, &net);
|
||||
|
||||
if (dataModel != NULL && lenModel > 0)
|
||||
ReadNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBinary);
|
||||
}
|
||||
|
||||
void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams ¶ms)
|
||||
{
|
||||
const Reflection *refl = msg.GetReflection();
|
||||
@ -400,6 +411,15 @@ Net readNetFromCaffe(const String &prototxt, const String &caffeModel /*= String
|
||||
return net;
|
||||
}
|
||||
|
||||
Net readNetFromCaffe(const char *bufferProto, size_t lenProto,
|
||||
const char *bufferModel, size_t lenModel)
|
||||
{
|
||||
CaffeImporter caffeImporter(bufferProto, lenProto, bufferModel, lenModel);
|
||||
Net net;
|
||||
caffeImporter.populateNet(net);
|
||||
return net;
|
||||
}
|
||||
|
||||
#endif //HAVE_PROTOBUF
|
||||
|
||||
CV__DNN_EXPERIMENTAL_NS_END
|
||||
|
@ -1108,28 +1108,37 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) {
|
||||
|
||||
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
|
||||
|
||||
bool ReadProtoFromBinary(ZeroCopyInputStream* input, Message *proto) {
|
||||
CodedInputStream coded_input(input);
|
||||
coded_input.SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
|
||||
|
||||
return proto->ParseFromCodedStream(&coded_input);
|
||||
}
|
||||
|
||||
bool ReadProtoFromTextFile(const char* filename, Message* proto) {
|
||||
std::ifstream fs(filename, std::ifstream::in);
|
||||
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
|
||||
IstreamInputStream input(&fs);
|
||||
bool success = google::protobuf::TextFormat::Parse(&input, proto);
|
||||
fs.close();
|
||||
return success;
|
||||
return google::protobuf::TextFormat::Parse(&input, proto);
|
||||
}
|
||||
|
||||
bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
|
||||
std::ifstream fs(filename, std::ifstream::in | std::ifstream::binary);
|
||||
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
|
||||
ZeroCopyInputStream* raw_input = new IstreamInputStream(&fs);
|
||||
CodedInputStream* coded_input = new CodedInputStream(raw_input);
|
||||
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
|
||||
IstreamInputStream raw_input(&fs);
|
||||
|
||||
bool success = proto->ParseFromCodedStream(coded_input);
|
||||
return ReadProtoFromBinary(&raw_input, proto);
|
||||
}
|
||||
|
||||
delete coded_input;
|
||||
delete raw_input;
|
||||
fs.close();
|
||||
return success;
|
||||
bool ReadProtoFromTextBuffer(const char* data, size_t len, Message* proto) {
|
||||
ArrayInputStream input(data, len);
|
||||
return google::protobuf::TextFormat::Parse(&input, proto);
|
||||
}
|
||||
|
||||
|
||||
bool ReadProtoFromBinaryBuffer(const char* data, size_t len, Message* proto) {
|
||||
ArrayInputStream raw_input(data, len);
|
||||
return ReadProtoFromBinary(&raw_input, proto);
|
||||
}
|
||||
|
||||
void ReadNetParamsFromTextFileOrDie(const char* param_file,
|
||||
@ -1139,6 +1148,13 @@ void ReadNetParamsFromTextFileOrDie(const char* param_file,
|
||||
UpgradeNetAsNeeded(param_file, param);
|
||||
}
|
||||
|
||||
void ReadNetParamsFromTextBufferOrDie(const char* data, size_t len,
|
||||
NetParameter* param) {
|
||||
CHECK(ReadProtoFromTextBuffer(data, len, param))
|
||||
<< "Failed to parse NetParameter buffer";
|
||||
UpgradeNetAsNeeded("memory buffer", param);
|
||||
}
|
||||
|
||||
void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
|
||||
NetParameter* param) {
|
||||
CHECK(ReadProtoFromBinaryFile(param_file, param))
|
||||
@ -1146,6 +1162,13 @@ void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
|
||||
UpgradeNetAsNeeded(param_file, param);
|
||||
}
|
||||
|
||||
void ReadNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
|
||||
NetParameter* param) {
|
||||
CHECK(ReadProtoFromBinaryBuffer(data, len, param))
|
||||
<< "Failed to parse NetParameter buffer";
|
||||
UpgradeNetAsNeeded("memory buffer", param);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
@ -102,6 +102,18 @@ void ReadNetParamsFromTextFileOrDie(const char* param_file,
|
||||
void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
|
||||
caffe::NetParameter* param);
|
||||
|
||||
// Read parameters from a memory buffer into a NetParammeter proto message.
|
||||
void ReadNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
|
||||
caffe::NetParameter* param);
|
||||
void ReadNetParamsFromTextBufferOrDie(const char* data, size_t len,
|
||||
caffe::NetParameter* param);
|
||||
|
||||
// Utility functions used internally by Caffe and TensorFlow loaders
|
||||
bool ReadProtoFromTextFile(const char* filename, ::google::protobuf::Message* proto);
|
||||
bool ReadProtoFromBinaryFile(const char* filename, ::google::protobuf::Message* proto);
|
||||
bool ReadProtoFromTextBuffer(const char* data, size_t len, ::google::protobuf::Message* proto);
|
||||
bool ReadProtoFromBinaryBuffer(const char* data, size_t len, ::google::protobuf::Message* proto);
|
||||
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
@ -449,6 +449,9 @@ void ExcludeLayer(tensorflow::GraphDef& net, const int layer_index, const int in
|
||||
class TFImporter : public Importer {
|
||||
public:
|
||||
TFImporter(const char *model, const char *config = NULL);
|
||||
TFImporter(const char *dataModel, size_t lenModel,
|
||||
const char *dataConfig = NULL, size_t lenConfig = 0);
|
||||
|
||||
void populateNet(Net dstNet);
|
||||
~TFImporter() {}
|
||||
|
||||
@ -479,6 +482,15 @@ TFImporter::TFImporter(const char *model, const char *config)
|
||||
ReadTFNetParamsFromTextFileOrDie(config, &netTxt);
|
||||
}
|
||||
|
||||
TFImporter::TFImporter(const char *dataModel, size_t lenModel,
|
||||
const char *dataConfig, size_t lenConfig)
|
||||
{
|
||||
if (dataModel != NULL && lenModel > 0)
|
||||
ReadTFNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBin);
|
||||
if (dataConfig != NULL && lenConfig > 0)
|
||||
ReadTFNetParamsFromTextBufferOrDie(dataConfig, lenConfig, &netTxt);
|
||||
}
|
||||
|
||||
void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
|
||||
{
|
||||
MatShape shape;
|
||||
@ -1326,5 +1338,14 @@ Net readNetFromTensorflow(const String &model, const String &config)
|
||||
return net;
|
||||
}
|
||||
|
||||
Net readNetFromTensorflow(const char* bufferModel, size_t lenModel,
|
||||
const char* bufferConfig, size_t lenConfig)
|
||||
{
|
||||
TFImporter importer(bufferModel, lenModel, bufferConfig, lenConfig);
|
||||
Net net;
|
||||
importer.populateNet(net);
|
||||
return net;
|
||||
}
|
||||
|
||||
CV__DNN_EXPERIMENTAL_NS_END
|
||||
}} // namespace
|
||||
|
@ -23,6 +23,7 @@ Implementation of various functions which are related to Tensorflow models readi
|
||||
|
||||
#include "graph.pb.h"
|
||||
#include "tf_io.hpp"
|
||||
#include "../caffe/caffe_io.hpp"
|
||||
#include "../caffe/glog_emulator.hpp"
|
||||
|
||||
namespace cv {
|
||||
@ -36,41 +37,28 @@ using namespace ::google::protobuf::io;
|
||||
|
||||
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
|
||||
|
||||
// TODO: remove Caffe duplicate
|
||||
bool ReadProtoFromBinaryFileTF(const char* filename, Message* proto) {
|
||||
std::ifstream fs(filename, std::ifstream::in | std::ifstream::binary);
|
||||
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
|
||||
ZeroCopyInputStream* raw_input = new IstreamInputStream(&fs);
|
||||
CodedInputStream* coded_input = new CodedInputStream(raw_input);
|
||||
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
|
||||
|
||||
bool success = proto->ParseFromCodedStream(coded_input);
|
||||
|
||||
delete coded_input;
|
||||
delete raw_input;
|
||||
fs.close();
|
||||
return success;
|
||||
}
|
||||
|
||||
bool ReadProtoFromTextFileTF(const char* filename, Message* proto) {
|
||||
std::ifstream fs(filename, std::ifstream::in);
|
||||
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
|
||||
IstreamInputStream input(&fs);
|
||||
bool success = google::protobuf::TextFormat::Parse(&input, proto);
|
||||
fs.close();
|
||||
return success;
|
||||
}
|
||||
|
||||
void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
|
||||
tensorflow::GraphDef* param) {
|
||||
CHECK(ReadProtoFromBinaryFileTF(param_file, param))
|
||||
<< "Failed to parse GraphDef file: " << param_file;
|
||||
tensorflow::GraphDef* param) {
|
||||
CHECK(ReadProtoFromBinaryFile(param_file, param))
|
||||
<< "Failed to parse GraphDef file: " << param_file;
|
||||
}
|
||||
|
||||
void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
|
||||
tensorflow::GraphDef* param) {
|
||||
CHECK(ReadProtoFromBinaryBuffer(data, len, param))
|
||||
<< "Failed to parse GraphDef buffer";
|
||||
}
|
||||
|
||||
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
|
||||
tensorflow::GraphDef* param) {
|
||||
CHECK(ReadProtoFromTextFileTF(param_file, param))
|
||||
<< "Failed to parse GraphDef file: " << param_file;
|
||||
CHECK(ReadProtoFromTextFile(param_file, param))
|
||||
<< "Failed to parse GraphDef file: " << param_file;
|
||||
}
|
||||
|
||||
void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
|
||||
tensorflow::GraphDef* param) {
|
||||
CHECK(ReadProtoFromTextBuffer(data, len, param))
|
||||
<< "Failed to parse GraphDef buffer";
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -25,6 +25,13 @@ void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
|
||||
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
|
||||
tensorflow::GraphDef* param);
|
||||
|
||||
// Read parameters from a memory buffer into a GraphDef proto message.
|
||||
void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
|
||||
tensorflow::GraphDef* param);
|
||||
|
||||
void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
|
||||
tensorflow::GraphDef* param);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,6 +55,24 @@ static std::string _tf(TString filename)
|
||||
return (getOpenCVExtraDir() + "/dnn/") + filename;
|
||||
}
|
||||
|
||||
TEST(Test_Caffe, memory_read)
|
||||
{
|
||||
const string proto = findDataFile("dnn/bvlc_googlenet.prototxt", false);
|
||||
const string model = findDataFile("dnn/bvlc_googlenet.caffemodel", false);
|
||||
|
||||
string dataProto;
|
||||
ASSERT_TRUE(readFileInMemory(proto, dataProto));
|
||||
string dataModel;
|
||||
ASSERT_TRUE(readFileInMemory(model, dataModel));
|
||||
|
||||
Net net = readNetFromCaffe(dataProto.c_str(), dataProto.size());
|
||||
ASSERT_FALSE(net.empty());
|
||||
|
||||
Net net2 = readNetFromCaffe(dataProto.c_str(), dataProto.size(),
|
||||
dataModel.c_str(), dataModel.size());
|
||||
ASSERT_FALSE(net2.empty());
|
||||
}
|
||||
|
||||
TEST(Test_Caffe, read_gtsrb)
|
||||
{
|
||||
Net net = readNetFromCaffe(_tf("gtsrb.prototxt"));
|
||||
@ -67,13 +85,26 @@ TEST(Test_Caffe, read_googlenet)
|
||||
ASSERT_FALSE(net.empty());
|
||||
}
|
||||
|
||||
TEST(Reproducibility_AlexNet, Accuracy)
|
||||
typedef testing::TestWithParam<tuple<bool> > Reproducibility_AlexNet;
|
||||
TEST_P(Reproducibility_AlexNet, Accuracy)
|
||||
{
|
||||
bool readFromMemory = get<0>(GetParam());
|
||||
Net net;
|
||||
{
|
||||
const string proto = findDataFile("dnn/bvlc_alexnet.prototxt", false);
|
||||
const string model = findDataFile("dnn/bvlc_alexnet.caffemodel", false);
|
||||
net = readNetFromCaffe(proto, model);
|
||||
if (readFromMemory)
|
||||
{
|
||||
string dataProto;
|
||||
ASSERT_TRUE(readFileInMemory(proto, dataProto));
|
||||
string dataModel;
|
||||
ASSERT_TRUE(readFileInMemory(model, dataModel));
|
||||
|
||||
net = readNetFromCaffe(dataProto.c_str(), dataProto.size(),
|
||||
dataModel.c_str(), dataModel.size());
|
||||
}
|
||||
else
|
||||
net = readNetFromCaffe(proto, model);
|
||||
ASSERT_FALSE(net.empty());
|
||||
}
|
||||
|
||||
@ -86,6 +117,8 @@ TEST(Reproducibility_AlexNet, Accuracy)
|
||||
normAssert(ref, out);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(Test_Caffe, Reproducibility_AlexNet, testing::Values(true, false));
|
||||
|
||||
#if !defined(_WIN32) || defined(_WIN64)
|
||||
TEST(Reproducibility_FCN, Accuracy)
|
||||
{
|
||||
|
@ -57,4 +57,23 @@ inline void normAssert(cv::InputArray ref, cv::InputArray test, const char *comm
|
||||
EXPECT_LE(normInf, lInf) << comment;
|
||||
}
|
||||
|
||||
inline bool readFileInMemory(const std::string& filename, std::string& content)
|
||||
{
|
||||
std::ios::openmode mode = std::ios::in | std::ios::binary;
|
||||
std::ifstream ifs(filename.c_str(), mode);
|
||||
if (!ifs.is_open())
|
||||
return false;
|
||||
|
||||
content.clear();
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
content.reserve(ifs.tellg());
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
|
||||
content.assign((std::istreambuf_iterator<char>(ifs)),
|
||||
std::istreambuf_iterator<char>());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -75,14 +75,32 @@ static std::string path(const std::string& file)
|
||||
}
|
||||
|
||||
static void runTensorFlowNet(const std::string& prefix, bool hasText = false,
|
||||
double l1 = 1e-5, double lInf = 1e-4)
|
||||
double l1 = 1e-5, double lInf = 1e-4,
|
||||
bool memoryLoad = false)
|
||||
{
|
||||
std::string netPath = path(prefix + "_net.pb");
|
||||
std::string netConfig = (hasText ? path(prefix + "_net.pbtxt") : "");
|
||||
std::string inpPath = path(prefix + "_in.npy");
|
||||
std::string outPath = path(prefix + "_out.npy");
|
||||
|
||||
Net net = readNetFromTensorflow(netPath, netConfig);
|
||||
Net net;
|
||||
if (memoryLoad)
|
||||
{
|
||||
// Load files into a memory buffers
|
||||
string dataModel;
|
||||
ASSERT_TRUE(readFileInMemory(netPath, dataModel));
|
||||
|
||||
string dataConfig;
|
||||
if (hasText)
|
||||
ASSERT_TRUE(readFileInMemory(netConfig, dataConfig));
|
||||
|
||||
net = readNetFromTensorflow(dataModel.c_str(), dataModel.size(),
|
||||
dataConfig.c_str(), dataConfig.size());
|
||||
}
|
||||
else
|
||||
net = readNetFromTensorflow(netPath, netConfig);
|
||||
|
||||
ASSERT_FALSE(net.empty());
|
||||
|
||||
cv::Mat input = blobFromNPY(inpPath);
|
||||
cv::Mat target = blobFromNPY(outPath);
|
||||
@ -216,4 +234,15 @@ TEST(Test_TensorFlow, resize_nearest_neighbor)
|
||||
runTensorFlowNet("resize_nearest_neighbor");
|
||||
}
|
||||
|
||||
TEST(Test_TensorFlow, memory_read)
|
||||
{
|
||||
double l1 = 1e-5;
|
||||
double lInf = 1e-4;
|
||||
runTensorFlowNet("lstm", true, l1, lInf, true);
|
||||
|
||||
runTensorFlowNet("batch_norm", false, l1, lInf, true);
|
||||
runTensorFlowNet("fused_batch_norm", false, l1, lInf, true);
|
||||
runTensorFlowNet("batch_norm_text", true, l1, lInf, true);
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user