// This file is part of OpenCV project. // It is subject to the license terms in the LICENSE file found in the top-level directory // of this distribution and at http://opencv.org/license.html. // // Copyright (C) 2017, Intel Corporation, all rights reserved. // Third party copyrights are property of their respective owners. #include "npy_blob.hpp" namespace cv { static std::string getType(const std::string& header) { std::string field = "'descr':"; int idx = header.find(field); CV_Assert(idx != -1); int from = header.find('\'', idx + field.size()) + 1; int to = header.find('\'', from); return header.substr(from, to - from); } static std::string getFortranOrder(const std::string& header) { std::string field = "'fortran_order':"; int idx = header.find(field); CV_Assert(idx != -1); int from = header.find_last_of(' ', idx + field.size()) + 1; int to = header.find(',', from); return header.substr(from, to - from); } static std::vector getShape(const std::string& header) { std::string field = "'shape':"; int idx = header.find(field); CV_Assert(idx != -1); int from = header.find('(', idx + field.size()) + 1; int to = header.find(')', from); std::string shapeStr = header.substr(from, to - from); if (shapeStr.empty()) return std::vector(1, 1); // Remove all commas. shapeStr.erase(std::remove(shapeStr.begin(), shapeStr.end(), ','), shapeStr.end()); std::istringstream ss(shapeStr); int value; std::vector shape; while (ss >> value) { shape.push_back(value); } return shape; } Mat blobFromNPY(const std::string& path) { std::ifstream ifs(path.c_str(), std::ios::binary); CV_Assert(ifs.is_open()); std::string magic(6, '*'); ifs.read(&magic[0], magic.size()); CV_Assert(magic == "\x93NUMPY"); ifs.ignore(1); // Skip major version byte. ifs.ignore(1); // Skip minor version byte. unsigned short headerSize; ifs.read((char*)&headerSize, sizeof(headerSize)); std::string header(headerSize, '*'); ifs.read(&header[0], header.size()); // Extract data type. CV_Assert(getType(header) == " shape = getShape(header); Mat blob(shape, CV_32F); ifs.read((char*)blob.data, blob.total() * blob.elemSize()); CV_Assert(ifs.gcount() == blob.total() * blob.elemSize()); return blob; } } // namespace cv