mirror of
https://github.com/opencv/opencv.git
synced 2024-11-24 03:00:14 +08:00
Add python bindings for G-API onnx
This commit is contained in:
parent
9da9e8244b
commit
ea2527c2d1
@ -176,6 +176,7 @@ set(gapi_srcs
|
||||
|
||||
# Python bridge
|
||||
src/backends/ie/bindings_ie.cpp
|
||||
src/backends/onnx/bindings_onnx.cpp
|
||||
src/backends/python/gpythonbackend.cpp
|
||||
|
||||
# OpenVPL Streaming source
|
||||
|
43
modules/gapi/include/opencv2/gapi/infer/bindings_onnx.hpp
Normal file
43
modules/gapi/include/opencv2/gapi/infer/bindings_onnx.hpp
Normal file
@ -0,0 +1,43 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level
|
||||
// directory of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_GAPI_INFER_BINDINGS_ONNX_HPP
|
||||
#define OPENCV_GAPI_INFER_BINDINGS_ONNX_HPP
|
||||
|
||||
#include <opencv2/gapi/gkernel.hpp> // GKernelPackage
|
||||
#include <opencv2/gapi/infer/onnx.hpp> // Params
|
||||
#include "opencv2/gapi/own/exports.hpp" // GAPI_EXPORTS
|
||||
#include <opencv2/gapi/util/any.hpp>
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace cv {
|
||||
namespace gapi {
|
||||
namespace onnx {
|
||||
|
||||
// NB: Used by python wrapper
|
||||
// This class can be marked as SIMPLE, because it's implemented as pimpl
|
||||
class GAPI_EXPORTS_W_SIMPLE PyParams {
|
||||
public:
|
||||
GAPI_WRAP
|
||||
PyParams() = default;
|
||||
|
||||
GAPI_WRAP
|
||||
PyParams(const std::string& tag, const std::string& model_path);
|
||||
|
||||
GBackend backend() const;
|
||||
std::string tag() const;
|
||||
cv::util::any params() const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<Params<cv::gapi::Generic>> m_priv;
|
||||
};
|
||||
|
||||
GAPI_EXPORTS_W PyParams params(const std::string& tag, const std::string& model_path);
|
||||
|
||||
} // namespace onnx
|
||||
} // namespace gapi
|
||||
} // namespace cv
|
||||
|
||||
#endif // OPENCV_GAPI_INFER_BINDINGS_ONNX_HPP
|
@ -17,6 +17,7 @@
|
||||
|
||||
#include <opencv2/core/cvdef.h> // GAPI_EXPORTS
|
||||
#include <opencv2/gapi/gkernel.hpp> // GKernelPackage
|
||||
#include <opencv2/gapi/infer.hpp> // Generic
|
||||
|
||||
namespace cv {
|
||||
namespace gapi {
|
||||
@ -67,6 +68,8 @@ struct ParamDesc {
|
||||
std::vector<bool> normalize; //!< Vector of bool values that enabled or disabled normalize of input data.
|
||||
|
||||
std::vector<std::string> names_to_remap; //!< Names of output layers that will be processed in PostProc function.
|
||||
|
||||
bool is_generic;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
@ -103,6 +106,7 @@ public:
|
||||
desc.model_path = model;
|
||||
desc.num_in = std::tuple_size<typename Net::InArgs>::value;
|
||||
desc.num_out = std::tuple_size<typename Net::OutArgs>::value;
|
||||
desc.is_generic = false;
|
||||
};
|
||||
|
||||
/** @brief Specifies sequence of network input layers names for inference.
|
||||
@ -277,6 +281,35 @@ protected:
|
||||
detail::ParamDesc desc;
|
||||
};
|
||||
|
||||
/*
|
||||
* @brief This structure provides functions for generic network type that
|
||||
* fill inference parameters.
|
||||
* @see struct Generic
|
||||
*/
|
||||
template<>
|
||||
class Params<cv::gapi::Generic> {
|
||||
public:
|
||||
/** @brief Class constructor.
|
||||
|
||||
Constructs Params based on input information and sets default values for other
|
||||
inference description parameters.
|
||||
|
||||
@param tag string tag of the network for which these parameters are intended.
|
||||
@param model_path path to model file (.onnx file).
|
||||
*/
|
||||
Params(const std::string& tag, const std::string& model_path)
|
||||
: desc{model_path, 0u, 0u, {}, {}, {}, {}, {}, {}, {}, {}, {}, true}, m_tag(tag) {}
|
||||
|
||||
// BEGIN(G-API's network parametrization API)
|
||||
GBackend backend() const { return cv::gapi::onnx::backend(); }
|
||||
std::string tag() const { return m_tag; }
|
||||
cv::util::any params() const { return { desc }; }
|
||||
// END(G-API's network parametrization API)
|
||||
protected:
|
||||
detail::ParamDesc desc;
|
||||
std::string m_tag;
|
||||
};
|
||||
|
||||
} // namespace onnx
|
||||
} // namespace gapi
|
||||
} // namespace cv
|
||||
|
@ -14,6 +14,7 @@
|
||||
using gapi_GKernelPackage = cv::GKernelPackage;
|
||||
using gapi_GNetPackage = cv::gapi::GNetPackage;
|
||||
using gapi_ie_PyParams = cv::gapi::ie::PyParams;
|
||||
using gapi_onnx_PyParams = cv::gapi::onnx::PyParams;
|
||||
using gapi_wip_IStreamSource_Ptr = cv::Ptr<cv::gapi::wip::IStreamSource>;
|
||||
using detail_ExtractArgsCallback = cv::detail::ExtractArgsCallback;
|
||||
using detail_ExtractMetaCallback = cv::detail::ExtractMetaCallback;
|
||||
|
@ -79,5 +79,6 @@ namespace streaming
|
||||
namespace detail
|
||||
{
|
||||
gapi::GNetParam GAPI_EXPORTS_W strip(gapi::ie::PyParams params);
|
||||
gapi::GNetParam GAPI_EXPORTS_W strip(gapi::onnx::PyParams params);
|
||||
} // namespace detail
|
||||
} // namespace cv
|
||||
|
74
modules/gapi/misc/python/test/test_gapi_infer_onnx.py
Normal file
74
modules/gapi/misc/python/test/test_gapi_infer_onnx.py
Normal file
@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import numpy as np
|
||||
import cv2 as cv
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from tests_common import NewOpenCVTests
|
||||
|
||||
|
||||
try:
|
||||
|
||||
if sys.version_info[:2] < (3, 0):
|
||||
raise unittest.SkipTest('Python 2.x is not supported')
|
||||
|
||||
CLASSIFICATION_MODEL_PATH = "onnx_models/vision/classification/squeezenet/model/squeezenet1.0-9.onnx"
|
||||
|
||||
testdata_required = bool(os.environ.get('OPENCV_DNN_TEST_REQUIRE_TESTDATA', False))
|
||||
|
||||
class test_gapi_infer(NewOpenCVTests):
|
||||
def find_dnn_file(self, filename, required=None):
|
||||
if not required:
|
||||
required = testdata_required
|
||||
return self.find_file(filename, [os.environ.get('OPENCV_DNN_TEST_DATA_PATH', os.getcwd()),
|
||||
os.environ['OPENCV_TEST_DATA_PATH']],
|
||||
required=required)
|
||||
|
||||
def test_onnx_classification(self):
|
||||
model_path = self.find_dnn_file(CLASSIFICATION_MODEL_PATH)
|
||||
|
||||
if model_path is None:
|
||||
raise unittest.SkipTest("Missing DNN test file")
|
||||
|
||||
in_mat = cv.imread(
|
||||
self.find_file("cv/dpm/cat.png",
|
||||
[os.environ.get('OPENCV_TEST_DATA_PATH')]))
|
||||
|
||||
g_in = cv.GMat()
|
||||
g_infer_inputs = cv.GInferInputs()
|
||||
g_infer_inputs.setInput("data_0", g_in)
|
||||
g_infer_out = cv.gapi.infer("squeeze-net", g_infer_inputs)
|
||||
g_out = g_infer_out.at("softmaxout_1")
|
||||
|
||||
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out))
|
||||
|
||||
net = cv.gapi.onnx.params("squeeze-net", model_path)
|
||||
try:
|
||||
out_gapi = comp.apply(cv.gin(in_mat), cv.gapi.compile_args(cv.gapi.networks(net)))
|
||||
except cv.error as err:
|
||||
if err.args[0] == "G-API has been compiled without ONNX support":
|
||||
raise unittest.SkipTest("G-API has been compiled without ONNX support")
|
||||
else:
|
||||
raise
|
||||
|
||||
self.assertEqual((1, 1000, 1, 1), out_gapi.shape)
|
||||
|
||||
|
||||
except unittest.SkipTest as e:
|
||||
|
||||
message = str(e)
|
||||
|
||||
class TestSkip(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.skipTest('Skip tests: ' + message)
|
||||
|
||||
def test_skip():
|
||||
pass
|
||||
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
NewOpenCVTests.bootstrap()
|
24
modules/gapi/src/backends/onnx/bindings_onnx.cpp
Normal file
24
modules/gapi/src/backends/onnx/bindings_onnx.cpp
Normal file
@ -0,0 +1,24 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level
|
||||
// directory of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include <opencv2/gapi/infer/bindings_onnx.hpp>
|
||||
|
||||
cv::gapi::onnx::PyParams::PyParams(const std::string& tag,
|
||||
const std::string& model_path)
|
||||
: m_priv(std::make_shared<Params<cv::gapi::Generic>>(tag, model_path)) {}
|
||||
|
||||
cv::gapi::GBackend cv::gapi::onnx::PyParams::backend() const {
|
||||
return m_priv->backend();
|
||||
}
|
||||
|
||||
std::string cv::gapi::onnx::PyParams::tag() const { return m_priv->tag(); }
|
||||
|
||||
cv::util::any cv::gapi::onnx::PyParams::params() const {
|
||||
return m_priv->params();
|
||||
}
|
||||
|
||||
cv::gapi::onnx::PyParams cv::gapi::onnx::params(
|
||||
const std::string& tag, const std::string& model_path) {
|
||||
return {tag, model_path};
|
||||
}
|
@ -735,7 +735,8 @@ void ONNXCompiled::extractMat(ONNXCallContext &ctx, const size_t in_idx, Views&
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXCompiled::setOutput(int i, cv::Mat &m) {
|
||||
void ONNXCompiled::setOutput(int i, cv::Mat &m)
|
||||
{
|
||||
// FIXME: No need in double-indexing?
|
||||
out_data[i] = m;
|
||||
}
|
||||
@ -1133,9 +1134,34 @@ namespace {
|
||||
// FIXME: Introduce a DNNBackend interface which'd specify
|
||||
// the framework for this???
|
||||
GONNXModel gm(gr);
|
||||
const auto &np = gm.metadata(nh).get<NetworkParams>();
|
||||
const auto &pp = cv::util::any_cast<cv::gapi::onnx::detail::ParamDesc>(np.opaque);
|
||||
auto &np = gm.metadata(nh).get<NetworkParams>();
|
||||
auto &pp = cv::util::any_cast<cv::gapi::onnx::detail::ParamDesc>(np.opaque);
|
||||
const auto &ki = cv::util::any_cast<KImpl>(ii.opaque);
|
||||
|
||||
GModel::Graph model(gr);
|
||||
auto& op = model.metadata(nh).get<Op>();
|
||||
if (pp.is_generic) {
|
||||
auto& info = cv::util::any_cast<cv::detail::InOutInfo>(op.params);
|
||||
|
||||
for (const auto& a : info.in_names)
|
||||
{
|
||||
pp.input_names.push_back(a);
|
||||
}
|
||||
// Adding const input is necessary because the definition of input_names
|
||||
// includes const input.
|
||||
for (const auto& a : pp.const_inputs)
|
||||
{
|
||||
pp.input_names.push_back(a.first);
|
||||
}
|
||||
pp.num_in = info.in_names.size();
|
||||
|
||||
for (const auto& a : info.out_names)
|
||||
{
|
||||
pp.output_names.push_back(a);
|
||||
}
|
||||
pp.num_out = info.out_names.size();
|
||||
}
|
||||
|
||||
gm.metadata(nh).set(ONNXUnit{pp});
|
||||
gm.metadata(nh).set(ONNXCallable{ki.run});
|
||||
gm.metadata(nh).set(CustomMetaFunction{ki.customMetaFunc});
|
||||
|
Loading…
Reference in New Issue
Block a user