Merge pull request #18716 from dmatveev:dm/upstream_onnx

* G-API: Introduce ONNX backend for Inference

- Basic operations are implemented (Infer, -ROI, -List, -List2);
- Implemented automatic preprocessing for ONNX models;
- Test suite is extended with `OPENCV_GAPI_ONNX_MODEL_PATH` env for test data
  (test data is an ONNX Model Zoo repo snapshot);
- Fixed kernel lookup logic in core G-API:
  - Lookup NN kernels not in the default package, but in the associated
    backend's aux package. Now two NN backends can work in the same graph.
- Added Infer SSD demo and a combined ONNX/IE demo;

* G-API/ONNX: Fix some of CMake issues

Co-authored-by: Pashchenkov, Maxim <maxim.pashchenkov@intel.com>
This commit is contained in:
Dmitry Matveev 2020-11-03 21:39:16 +03:00 committed by GitHub
parent 2a3cdba724
commit a110ede0a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1920 additions and 6 deletions

View File

@ -439,6 +439,9 @@ OCV_OPTION(WITH_ANDROID_MEDIANDK "Use Android Media NDK for Video I/O (Android)"
OCV_OPTION(WITH_TENGINE "Include Arm Inference Tengine support" OFF
VISIBLE_IF (ARM OR AARCH64) AND (UNIX OR ANDROID) AND NOT IOS
VERIFY HAVE_TENGINE)
OCV_OPTION(WITH_ONNX "Include Microsoft ONNX Runtime support" OFF
VISIBLE_IF TRUE
VERIFY HAVE_ONNX)
# OpenCV build components
# ===================================================
@ -775,6 +778,11 @@ if(WITH_QUIRC)
add_subdirectory(3rdparty/quirc)
set(HAVE_QUIRC TRUE)
endif()
if(WITH_ONNX)
include(cmake/FindONNX.cmake)
endif()
# ----------------------------------------------------------------------------
# OpenCV HAL
# ----------------------------------------------------------------------------
@ -1556,6 +1564,15 @@ if(WITH_OPENCL OR HAVE_OPENCL)
endif()
endif()
if(WITH_ONNX OR HAVE_ONNX)
status("")
status(" ONNX:" HAVE_ONNX THEN "YES" ELSE "NO")
if(HAVE_ONNX)
status(" Include path:" ONNX_INCLUDE_DIR THEN "${ONNX_INCLUDE_DIR}" ELSE "NO")
status(" Link libraries:" ONNX_LIBRARIES THEN "${ONNX_LIBRARIES}" ELSE "NO")
endif()
endif()
# ========================== python ==========================
if(BUILD_opencv_python2)
status("")

36
cmake/FindONNX.cmake Normal file
View File

@ -0,0 +1,36 @@
ocv_clear_vars(HAVE_ONNX)
set(ONNXRT_ROOT_DIR "" CACHE PATH "ONNX Runtime install directory")
# For now, check the old name ORT_INSTALL_DIR
if(ORT_INSTALL_DIR AND NOT ONNXRT_ROOT_DIR)
set(ONNXRT_ROOT_DIR ORT_INSTALL_DIR)
endif()
if(ONNXRT_ROOT_DIR)
find_library(ORT_LIB onnxruntime
${ONNXRT_ROOT_DIR}/lib
CMAKE_FIND_ROOT_PATH_BOTH)
find_path(ORT_INCLUDE onnxruntime_cxx_api.h
${ONNXRT_ROOT_DIR}/include/onnxruntime/core/session
CMAKE_FIND_ROOT_PATH_BOTH)
endif()
if(ORT_LIB AND ORT_INCLUDE)
set(HAVE_ONNX TRUE)
# For CMake output only
set(ONNX_LIBRARIES "${ORT_LIB}" CACHE STRING "ONNX Runtime libraries")
set(ONNX_INCLUDE_DIR "${ORT_INCLUDE}" CACHE STRING "ONNX Runtime include path")
# Link target with associated interface headers
set(ONNX_LIBRARY "onnxruntime" CACHE STRING "ONNX Link Target")
ocv_add_library(${ONNX_LIBRARY} SHARED IMPORTED)
set_target_properties(${ONNX_LIBRARY} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${ORT_INCLUDE}
IMPORTED_LOCATION ${ORT_LIB}
IMPORTED_IMPLIB ${ORT_LIB})
endif()
if(NOT HAVE_ONNX)
ocv_clear_vars(HAVE_ONNX ORT_LIB ORT_INCLUDE_DIR)
endif()

View File

@ -131,6 +131,9 @@ set(gapi_srcs
src/backends/ie/giebackend.cpp
src/backends/ie/giebackend/giewrapper.cpp
# ONNX Backend.
src/backends/onnx/gonnxbackend.cpp
# Render Backend.
src/backends/render/grenderocv.cpp
src/backends/render/ft_render.cpp
@ -205,10 +208,20 @@ if(HAVE_PLAIDML)
ocv_target_include_directories(${the_module} SYSTEM PRIVATE ${PLAIDML_INCLUDE_DIRS})
endif()
if(WIN32)
# Required for htonl/ntohl on Windows
ocv_target_link_libraries(${the_module} PRIVATE wsock32 ws2_32)
endif()
if(HAVE_ONNX)
ocv_target_link_libraries(${the_module} PRIVATE ${ONNX_LIBRARY})
ocv_target_compile_definitions(${the_module} PRIVATE HAVE_ONNX=1)
if(TARGET opencv_test_gapi)
ocv_target_compile_definitions(opencv_test_gapi PRIVATE HAVE_ONNX=1)
ocv_target_link_libraries(opencv_test_gapi PRIVATE ${ONNX_LIBRARY})
endif()
endif()
ocv_add_perf_tests()
ocv_add_samples()

View File

@ -0,0 +1,138 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2020 Intel Corporation
#ifndef OPENCV_GAPI_INFER_ONNX_HPP
#define OPENCV_GAPI_INFER_ONNX_HPP
#include <unordered_map>
#include <string>
#include <array>
#include <tuple> // tuple, tuple_size
#include <opencv2/gapi/opencv_includes.hpp>
#include <opencv2/gapi/util/any.hpp>
#include <opencv2/core/cvdef.h> // GAPI_EXPORTS
#include <opencv2/gapi/gkernel.hpp> // GKernelPackage
namespace cv {
namespace gapi {
namespace onnx {
GAPI_EXPORTS cv::gapi::GBackend backend();
enum class TraitAs: int {
TENSOR, //!< G-API traits an associated cv::Mat as a raw tensor
// and passes dimensions as-is
IMAGE //!< G-API traits an associated cv::Mat as an image so
// creates an "image" blob (NCHW/NHWC, etc)
};
using PostProc = std::function<void(const std::unordered_map<std::string, cv::Mat> &,
std::unordered_map<std::string, cv::Mat> &)>;
namespace detail {
struct ParamDesc {
std::string model_path;
// NB: nun_* may differ from topology's real input/output port numbers
// (e.g. topology's partial execution)
std::size_t num_in; // How many inputs are defined in the operation
std::size_t num_out; // How many outputs are defined in the operation
// NB: Here order follows the `Net` API
std::vector<std::string> input_names;
std::vector<std::string> output_names;
using ConstInput = std::pair<cv::Mat, TraitAs>;
std::unordered_map<std::string, ConstInput> const_inputs;
std::vector<cv::Scalar> mean;
std::vector<cv::Scalar> stdev;
std::vector<cv::GMatDesc> out_metas;
PostProc custom_post_proc;
std::vector<bool> normalize;
};
} // namespace detail
template<typename Net>
struct PortCfg {
using In = std::array
< std::string
, std::tuple_size<typename Net::InArgs>::value >;
using Out = std::array
< std::string
, std::tuple_size<typename Net::OutArgs>::value >;
using NormCoefs = std::array
< cv::Scalar
, std::tuple_size<typename Net::InArgs>::value >;
using Normalize = std::array
< bool
, std::tuple_size<typename Net::InArgs>::value >;
};
template<typename Net> class Params {
public:
Params(const std::string &model) {
desc.model_path = model;
desc.num_in = std::tuple_size<typename Net::InArgs>::value;
desc.num_out = std::tuple_size<typename Net::OutArgs>::value;
};
// BEGIN(G-API's network parametrization API)
GBackend backend() const { return cv::gapi::onnx::backend(); }
std::string tag() const { return Net::tag(); }
cv::util::any params() const { return { desc }; }
// END(G-API's network parametrization API)
Params<Net>& cfgInputLayers(const typename PortCfg<Net>::In &ll) {
desc.input_names.assign(ll.begin(), ll.end());
return *this;
}
Params<Net>& cfgOutputLayers(const typename PortCfg<Net>::Out &ll) {
desc.output_names.assign(ll.begin(), ll.end());
return *this;
}
Params<Net>& constInput(const std::string &layer_name,
const cv::Mat &data,
TraitAs hint = TraitAs::TENSOR) {
desc.const_inputs[layer_name] = {data, hint};
return *this;
}
Params<Net>& cfgMeanStd(const typename PortCfg<Net>::NormCoefs &m,
const typename PortCfg<Net>::NormCoefs &s) {
desc.mean.assign(m.begin(), m.end());
desc.stdev.assign(s.begin(), s.end());
return *this;
}
Params<Net>& cfgPostProc(const std::vector<cv::GMatDesc> &outs,
const PostProc &pp) {
desc.out_metas = outs;
desc.custom_post_proc = pp;
return *this;
}
Params<Net>& cfgNormalize(const typename PortCfg<Net>::Normalize &n) {
desc.normalize.assign(n.begin(), n.end());
return *this;
}
protected:
detail::ParamDesc desc;
};
} // namespace onnx
} // namespace gapi
} // namespace cv
#endif // OPENCV_GAPI_INFER_HPP

View File

@ -0,0 +1,195 @@
#include <chrono>
#include <iomanip>
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/gapi.hpp"
#include "opencv2/gapi/core.hpp"
#include "opencv2/gapi/imgproc.hpp"
#include "opencv2/gapi/infer.hpp"
#include "opencv2/gapi/infer/ie.hpp"
#include "opencv2/gapi/infer/onnx.hpp"
#include "opencv2/gapi/cpu/gcpukernel.hpp"
#include "opencv2/gapi/streaming/cap.hpp"
namespace {
const std::string keys =
"{ h help | | print this help message }"
"{ input | | Path to an input video file }"
"{ fdm | | IE face detection model IR }"
"{ fdw | | IE face detection model weights }"
"{ fdd | | IE face detection device }"
"{ emom | | ONNX emotions recognition model }"
"{ output | | (Optional) Path to an output video file }"
;
} // namespace
namespace custom {
G_API_NET(Faces, <cv::GMat(cv::GMat)>, "face-detector");
G_API_NET(Emotions, <cv::GMat(cv::GMat)>, "emotions-recognition");
G_API_OP(PostProc, <cv::GArray<cv::Rect>(cv::GMat, cv::GMat)>, "custom.fd_postproc") {
static cv::GArrayDesc outMeta(const cv::GMatDesc &, const cv::GMatDesc &) {
return cv::empty_array_desc();
}
};
GAPI_OCV_KERNEL(OCVPostProc, PostProc) {
static void run(const cv::Mat &in_ssd_result,
const cv::Mat &in_frame,
std::vector<cv::Rect> &out_faces) {
const int MAX_PROPOSALS = 200;
const int OBJECT_SIZE = 7;
const cv::Size upscale = in_frame.size();
const cv::Rect surface({0,0}, upscale);
out_faces.clear();
const float *data = in_ssd_result.ptr<float>();
for (int i = 0; i < MAX_PROPOSALS; i++) {
const float image_id = data[i * OBJECT_SIZE + 0]; // batch id
const float confidence = data[i * OBJECT_SIZE + 2];
const float rc_left = data[i * OBJECT_SIZE + 3];
const float rc_top = data[i * OBJECT_SIZE + 4];
const float rc_right = data[i * OBJECT_SIZE + 5];
const float rc_bottom = data[i * OBJECT_SIZE + 6];
if (image_id < 0.f) { // indicates end of detections
break;
}
if (confidence < 0.5f) {
continue;
}
cv::Rect rc;
rc.x = static_cast<int>(rc_left * upscale.width);
rc.y = static_cast<int>(rc_top * upscale.height);
rc.width = static_cast<int>(rc_right * upscale.width) - rc.x;
rc.height = static_cast<int>(rc_bottom * upscale.height) - rc.y;
out_faces.push_back(rc & surface);
}
}
};
//! [Postproc]
} // namespace custom
namespace labels {
// Labels as defined in
// https://github.com/onnx/models/tree/master/vision/body_analysis/emotion_ferplus
//
const std::string emotions[] = {
"neutral", "happiness", "surprise", "sadness", "anger", "disgust", "fear", "contempt"
};
namespace {
template<typename Iter>
std::vector<float> softmax(Iter begin, Iter end) {
std::vector<float> prob(end - begin, 0.f);
std::transform(begin, end, prob.begin(), [](float x) { return std::exp(x); });
float sum = std::accumulate(prob.begin(), prob.end(), 0.0f);
for (int i = 0; i < static_cast<int>(prob.size()); i++)
prob[i] /= sum;
return prob;
}
void DrawResults(cv::Mat &frame,
const std::vector<cv::Rect> &faces,
const std::vector<cv::Mat> &out_emotions) {
CV_Assert(faces.size() == out_emotions.size());
for (auto it = faces.begin(); it != faces.end(); ++it) {
const auto idx = std::distance(faces.begin(), it);
const auto &rc = *it;
const float *emotions_data = out_emotions[idx].ptr<float>();
auto sm = softmax(emotions_data, emotions_data + 8);
const auto emo_id = std::max_element(sm.begin(), sm.end()) - sm.begin();
const int ATTRIB_OFFSET = 15;
cv::rectangle(frame, rc, {0, 255, 0}, 4);
cv::putText(frame, emotions[emo_id],
cv::Point(rc.x, rc.y - ATTRIB_OFFSET),
cv::FONT_HERSHEY_COMPLEX_SMALL,
1,
cv::Scalar(0, 0, 255));
std::cout << emotions[emo_id] << " at " << rc << std::endl;
}
}
} // anonymous namespace
} // namespace labels
int main(int argc, char *argv[])
{
cv::CommandLineParser cmd(argc, argv, keys);
if (cmd.has("help")) {
cmd.printMessage();
return 0;
}
const std::string input = cmd.get<std::string>("input");
const std::string output = cmd.get<std::string>("output");
// OpenVINO FD parameters here
auto det_net = cv::gapi::ie::Params<custom::Faces> {
cmd.get<std::string>("fdm"), // read cmd args: path to topology IR
cmd.get<std::string>("fdw"), // read cmd args: path to weights
cmd.get<std::string>("fdd"), // read cmd args: device specifier
};
// ONNX Emotions parameters here
auto emo_net = cv::gapi::onnx::Params<custom::Emotions> {
cmd.get<std::string>("emom"), // read cmd args: path to the ONNX model
}.cfgNormalize({false}); // model accepts 0..255 range in FP32
auto kernels = cv::gapi::kernels<custom::OCVPostProc>();
auto networks = cv::gapi::networks(det_net, emo_net);
cv::GMat in;
cv::GMat bgr = cv::gapi::copy(in);
cv::GMat frame = cv::gapi::streaming::desync(bgr);
cv::GMat detections = cv::gapi::infer<custom::Faces>(frame);
cv::GArray<cv::Rect> faces = custom::PostProc::on(detections, frame);
cv::GArray<cv::GMat> emotions = cv::gapi::infer<custom::Emotions>(faces, frame);
auto pipeline = cv::GComputation(cv::GIn(in), cv::GOut(bgr, faces, emotions))
.compileStreaming(cv::compile_args(kernels, networks));
auto in_src = cv::gapi::wip::make_src<cv::gapi::wip::GCaptureSource>(input);
pipeline.setSource(cv::gin(in_src));
pipeline.start();
cv::util::optional<cv::Mat> out_frame;
cv::util::optional<std::vector<cv::Rect>> out_faces;
cv::util::optional<std::vector<cv::Mat>> out_emotions;
cv::Mat last_mat;
std::vector<cv::Rect> last_faces;
std::vector<cv::Mat> last_emotions;
cv::VideoWriter writer;
while (pipeline.pull(cv::gout(out_frame, out_faces, out_emotions))) {
if (out_faces && out_emotions) {
last_faces = *out_faces;
last_emotions = *out_emotions;
}
if (out_frame) {
last_mat = *out_frame;
labels::DrawResults(last_mat, last_faces, last_emotions);
if (!output.empty()) {
if (!writer.isOpened()) {
const auto sz = cv::Size{last_mat.cols, last_mat.rows};
writer.open(output, cv::VideoWriter::fourcc('M','J','P','G'), 25.0, sz);
CV_Assert(writer.isOpened());
}
writer << last_mat;
}
}
if (!last_mat.empty()) {
cv::imshow("Out", last_mat);
cv::waitKey(1);
}
}
return 0;
}

View File

@ -0,0 +1,213 @@
#include <algorithm>
#include <iostream>
#include <sstream>
#include <opencv2/imgproc.hpp>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/gapi.hpp>
#include <opencv2/gapi/core.hpp>
#include <opencv2/gapi/imgproc.hpp>
#include <opencv2/gapi/infer.hpp>
#include <opencv2/gapi/render.hpp>
#include <opencv2/gapi/infer/onnx.hpp>
#include <opencv2/gapi/cpu/gcpukernel.hpp>
#include <opencv2/gapi/streaming/cap.hpp>
#include <opencv2/highgui.hpp>
namespace custom {
G_API_NET(ObjDetector, <cv::GMat(cv::GMat)>, "object-detector");
using GDetections = cv::GArray<cv::Rect>;
using GSize = cv::GOpaque<cv::Size>;
using GPrims = cv::GArray<cv::gapi::wip::draw::Prim>;
G_API_OP(GetSize, <GSize(cv::GMat)>, "sample.custom.get-size") {
static cv::GOpaqueDesc outMeta(const cv::GMatDesc &) {
return cv::empty_gopaque_desc();
}
};
G_API_OP(ParseSSD, <GDetections(cv::GMat, GSize)>, "sample.custom.parse-ssd") {
static cv::GArrayDesc outMeta(const cv::GMatDesc &, const cv::GOpaqueDesc &) {
return cv::empty_array_desc();
}
};
G_API_OP(BBoxes, <GPrims(GDetections)>, "sample.custom.b-boxes") {
static cv::GArrayDesc outMeta(const cv::GArrayDesc &) {
return cv::empty_array_desc();
}
};
GAPI_OCV_KERNEL(OCVGetSize, GetSize) {
static void run(const cv::Mat &in, cv::Size &out) {
out = {in.cols, in.rows};
}
};
GAPI_OCV_KERNEL(OCVParseSSD, ParseSSD) {
static void run(const cv::Mat &in_ssd_result,
const cv::Size &in_parent_size,
std::vector<cv::Rect> &out_objects) {
const auto &in_ssd_dims = in_ssd_result.size;
CV_Assert(in_ssd_dims.dims() == 4u);
const int MAX_PROPOSALS = in_ssd_dims[2];
const int OBJECT_SIZE = in_ssd_dims[3];
CV_Assert(OBJECT_SIZE == 7); // fixed SSD object size
const cv::Rect surface({0,0}, in_parent_size);
out_objects.clear();
const float *data = in_ssd_result.ptr<float>();
for (int i = 0; i < MAX_PROPOSALS; i++) {
const float image_id = data[i * OBJECT_SIZE + 0];
const float label = data[i * OBJECT_SIZE + 1];
const float confidence = data[i * OBJECT_SIZE + 2];
const float rc_left = data[i * OBJECT_SIZE + 3];
const float rc_top = data[i * OBJECT_SIZE + 4];
const float rc_right = data[i * OBJECT_SIZE + 5];
const float rc_bottom = data[i * OBJECT_SIZE + 6];
(void) label; // unused
if (image_id < 0.f) {
break; // marks end-of-detections
}
if (confidence < 0.5f) {
continue; // skip objects with low confidence
}
// map relative coordinates to the original image scale
cv::Rect rc;
rc.x = static_cast<int>(rc_left * in_parent_size.width);
rc.y = static_cast<int>(rc_top * in_parent_size.height);
rc.width = static_cast<int>(rc_right * in_parent_size.width) - rc.x;
rc.height = static_cast<int>(rc_bottom * in_parent_size.height) - rc.y;
out_objects.emplace_back(rc & surface);
}
}
};
GAPI_OCV_KERNEL(OCVBBoxes, BBoxes) {
// This kernel converts the rectangles into G-API's
// rendering primitives
static void run(const std::vector<cv::Rect> &in_obj_rcs,
std::vector<cv::gapi::wip::draw::Prim> &out_prims) {
out_prims.clear();
const auto cvt = [](const cv::Rect &rc, const cv::Scalar &clr) {
return cv::gapi::wip::draw::Rect(rc, clr, 2);
};
for (auto &&rc : in_obj_rcs) {
out_prims.emplace_back(cvt(rc, CV_RGB(0,255,0))); // green
}
std::cout << "Detections:";
for (auto &&rc : in_obj_rcs) std::cout << ' ' << rc;
std::cout << std::endl;
}
};
} // namespace custom
namespace {
void remap_ssd_ports(const std::unordered_map<std::string, cv::Mat> &onnx,
std::unordered_map<std::string, cv::Mat> &gapi) {
// Assemble ONNX-processed outputs back to a single 1x1x200x7 blob
// to preserve compatibility with OpenVINO-based SSD pipeline
const cv::Mat &num_detections = onnx.at("num_detections:0");
const cv::Mat &detection_boxes = onnx.at("detection_boxes:0");
const cv::Mat &detection_scores = onnx.at("detection_scores:0");
const cv::Mat &detection_classes = onnx.at("detection_classes:0");
GAPI_Assert(num_detections.depth() == CV_32F);
GAPI_Assert(detection_boxes.depth() == CV_32F);
GAPI_Assert(detection_scores.depth() == CV_32F);
GAPI_Assert(detection_classes.depth() == CV_32F);
cv::Mat &ssd_output = gapi.at("detection_output");
const int num_objects = static_cast<int>(num_detections.ptr<float>()[0]);
const float *in_boxes = detection_boxes.ptr<float>();
const float *in_scores = detection_scores.ptr<float>();
const float *in_classes = detection_classes.ptr<float>();
float *ptr = ssd_output.ptr<float>();
for (int i = 0; i < num_objects; i++) {
ptr[0] = 0.f; // "image_id"
ptr[1] = in_classes[i]; // "label"
ptr[2] = in_scores[i]; // "confidence"
ptr[3] = in_boxes[4*i + 1]; // left
ptr[4] = in_boxes[4*i + 0]; // top
ptr[5] = in_boxes[4*i + 3]; // right
ptr[6] = in_boxes[4*i + 2]; // bottom
ptr += 7;
in_boxes += 4;
}
if (num_objects < ssd_output.size[2]-1) {
// put a -1 mark at the end of output blob if there is space left
ptr[0] = -1.f;
}
}
} // anonymous namespace
const std::string keys =
"{ h help | | Print this help message }"
"{ input | | Path to the input video file }"
"{ output | | (Optional) path to output video file }"
"{ detm | | Path to an ONNX SSD object detection model (.onnx) }"
;
int main(int argc, char *argv[])
{
cv::CommandLineParser cmd(argc, argv, keys);
if (cmd.has("help")) {
cmd.printMessage();
return 0;
}
// Prepare parameters first
const std::string input = cmd.get<std::string>("input");
const std::string output = cmd.get<std::string>("output");
const auto obj_model_path = cmd.get<std::string>("detm");
auto obj_net = cv::gapi::onnx::Params<custom::ObjDetector>{obj_model_path}
.cfgOutputLayers({"detection_output"})
.cfgPostProc({cv::GMatDesc{CV_32F, {1,1,200,7}}}, remap_ssd_ports);
auto kernels = cv::gapi::kernels< custom::OCVGetSize
, custom::OCVParseSSD
, custom::OCVBBoxes>();
auto networks = cv::gapi::networks(obj_net);
// Now build the graph
cv::GMat in;
auto blob = cv::gapi::infer<custom::ObjDetector>(in);
auto rcs = custom::ParseSSD::on(blob, custom::GetSize::on(in));
auto out = cv::gapi::wip::draw::render3ch(in, custom::BBoxes::on(rcs));
cv::GStreamingCompiled pipeline = cv::GComputation(cv::GIn(in), cv::GOut(out))
.compileStreaming(cv::compile_args(kernels, networks));
auto inputs = cv::gin(cv::gapi::wip::make_src<cv::gapi::wip::GCaptureSource>(input));
// The execution part
pipeline.setSource(std::move(inputs));
pipeline.start();
cv::VideoWriter writer;
cv::Mat outMat;
while (pipeline.pull(cv::gout(outMat))) {
cv::imshow("Out", outMat);
cv::waitKey(1);
if (!output.empty()) {
if (!writer.isOpened()) {
const auto sz = cv::Size{outMat.cols, outMat.rows};
writer.open(output, cv::VideoWriter::fourcc('M','J','P','G'), 25.0, sz);
CV_Assert(writer.isOpened());
}
writer << outMat;
}
}
return 0;
}

View File

@ -0,0 +1,955 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2020 Intel Corporation
#include "precomp.hpp"
#include "backends/onnx/gonnxbackend.hpp"
#ifdef HAVE_ONNX
#include <ade/util/algorithm.hpp> // any_of
#include <ade/util/zip_range.hpp>
#include <opencv2/gapi/infer.hpp>
#include <opencv2/gapi/own/convert.hpp>
#include "api/gbackend_priv.hpp" // FIXME: Make it part of Backend SDK!
namespace cv {
namespace gimpl {
namespace onnx {
enum TensorPosition : int {
INPUT,
OUTPUT
};
struct TensorInfo {
TensorInfo() = default;
explicit TensorInfo(const Ort::TensorTypeAndShapeInfo& info)
: dims(info.GetShape())
, type(info.GetElementType())
, is_dynamic(std::find(dims.begin(), dims.end(), -1) != dims.end()) {
if (!is_dynamic) {
size = std::accumulate(dims.begin(),
dims.end(),
static_cast<int64_t>(1),
std::multiplies<int64_t>());
}
// Heuristic: check if the tensor is grayscale input
if (dims.size() == 4u
&& dims[0] == 1
&& dims[1] == 1
&& dims[2] > 1
&& dims[3] > 1) {
is_grayscale = true;
}
}
std::string name;
std::vector<int64_t> dims;
ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
int64_t size = -1;
bool normalize = true;
bool is_dynamic = false;
bool is_grayscale = false;
struct MeanStdev {
cv::Scalar mean;
cv::Scalar stdev;
};
cv::util::optional<MeanStdev> mstd;
};
class ONNXCompiled {
// ONNX Resources
// NOTE: Env must live with the session, otherwise segfaults.
Ort::Env this_env{nullptr};
Ort::Session this_session{nullptr};
Ort::MemoryInfo this_memory_info{nullptr};
std::vector<TensorInfo> in_tensor_info;
std::vector<TensorInfo> out_tensor_info;
bool is_dynamic = false;
// G-API <Net> description
gapi::onnx::detail::ParamDesc params;
// Input/output tensor information
std::vector<TensorInfo> getTensorInfo(TensorPosition pos);
// Run-time data structures
std::vector<cv::Mat> in_data;
std::vector<cv::Mat> out_data;
void Run(const std::vector<cv::Mat>& ins,
const std::vector<cv::Mat>& outs);
public:
explicit ONNXCompiled(const gapi::onnx::detail::ParamDesc &pp);
// Extract the information about output layer #i
cv::GMatDesc outMeta(int i) const;
// Assign input/output info
std::size_t numInputs() const { return params.num_in; }
std::size_t numOutputs() const { return params.num_out; }
void setInput(int i, const cv::Mat &m);
void setOutput(int i, cv::Mat &m);
cv::Mat allocOutput(int i) const;
// Run with the assigned inputs/outputs
void run();
};
} // namespace onnx
} // namespace gimpl
} // namespace cv
namespace {
inline std::vector<const char*> getCharNames(const std::vector<std::string>& names) {
std::vector<const char*> out_vec;
for (const auto& el : names) {
out_vec.push_back(el.data());
}
return out_vec;
}
inline int getIdxByName(const std::vector<cv::gimpl::onnx::TensorInfo>& info, const std::string& name) {
// FIXME: Cache the ordering
const auto it = std::find_if(info.begin(), info.end(), [&](const cv::gimpl::onnx::TensorInfo &i) {
return i.name == name;
});
GAPI_Assert(it != info.end());
return std::distance(info.begin(), it);
}
inline int toCV(ONNXTensorElementDataType prec) {
switch (prec) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: return CV_8U;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return CV_32F;
default: GAPI_Assert(false && "Unsupported data type");
}
return -1;
}
inline std::vector<int> toCV(const std::vector<int64_t> &vsz) {
std::vector<int> result;
result.reserve(vsz.size());
for (auto sz : vsz) {
result.push_back(ade::util::checked_cast<int>(sz));
}
return result;
}
inline cv::Mat toCV(Ort::Value &v) {
auto info = v.GetTensorTypeAndShapeInfo();
return cv::Mat(toCV(info.GetShape()),
toCV(info.GetElementType()),
reinterpret_cast<void*>(v.GetTensorMutableData<uint8_t*>()));
}
inline std::vector<int64_t> toORT(const cv::MatSize &sz) {
return cv::to_own<int64_t>(sz);
}
inline void preprocess(const cv::Mat& src,
const cv::gimpl::onnx::TensorInfo& ti,
cv::Mat& dst) {
GAPI_Assert(src.depth() == CV_32F || src.depth() == CV_8U);
if (src.depth() == CV_32F) {
// Just pass the tensor as-is.
// No layout or dimension transformations done here!
// TODO: This needs to be aligned across all NN backends.
GAPI_Assert(toCV(ti.type) == CV_32F && "Only 32F model input is supported for 32F data");
GAPI_Assert(toORT(src.size) == ti.dims && "32F tensor dimensions should match with NN input");
GAPI_Assert(!ti.is_dynamic && "Dynamic inputs are not supported for this case");
dst = src;
} else {
// 8U input: full preprocessing path
GAPI_Assert(src.depth() == CV_8U && "Only 8U data type is supported for preproc");
GAPI_Assert(ti.dims.size() == 4u && "Only NCHW/NHWC layouts are supported for preproc");
const auto ddepth = toCV(ti.type);
GAPI_Assert((ddepth == CV_8U || ddepth == CV_32F)
&& "Only 8U and 32F model input is supported for 8U data");
// Assess the expected input layout
const bool is_hwc = [&](int ch) {
if (ti.is_grayscale) return false; // 1,1,h,w
else if (ti.dims[3] == ch) return true; // _,_,_,c
else if (ti.dims[1] == ch) return false; // _,c,_,_
else cv::util::throw_error(std::logic_error("Couldn't identify input tensor layout"));
} (src.channels());
int new_c = src.channels();
cv::Mat csc;
if (ti.is_grayscale && new_c == 3) {
cv::cvtColor(src, csc, cv::COLOR_BGR2GRAY);
new_c = 1;
} else {
csc = src;
}
// NHWC vs NCHW
int new_h = -1, new_w = -1;
if (ti.is_dynamic) {
// reuse h & w from the input image
new_h = src.rows;
new_w = src.cols;
} else {
// take h & w from the ONNX tensor info
new_h = ti.dims[is_hwc ? 1 : 2];
new_w = ti.dims[is_hwc ? 2 : 3];
}
GAPI_Assert(new_h != -1 && new_w != -1);
cv::Mat rsz, pp;
cv::resize(csc, rsz, cv::Size(new_w, new_h));
if (src.depth() == CV_8U && ddepth == CV_32F) {
rsz.convertTo(pp, ddepth, ti.normalize ? 1.f / 255 : 1.f);
if (ti.mstd.has_value()) {
pp -= ti.mstd->mean;
pp /= ti.mstd->stdev;
}
} else {
pp = rsz;
}
if (!is_hwc && new_c > 1) {
// Convert to CHW
dst.create(cv::Size(new_w, new_h * new_c), ddepth);
std::vector<cv::Mat> planes(new_c);
for (int ch = 0; ch < new_c; ++ch) {
planes[ch] = dst.rowRange(ch * new_h, (ch + 1) * new_h);
}
cv::split(pp, planes);
} else {
// Keep HWC
dst = pp;
}
// Ensure dst is a tensor shape (not a 2D image)
if (ti.is_dynamic) {
// Reshape to input dimensions
const std::vector<int> out_dims = is_hwc
? std::vector<int>{1, new_h, new_w, new_c}
: std::vector<int>{1, new_c, new_h, new_w};
dst = dst.reshape(1, out_dims);
} else {
// Reshape to ONNX dimensions (no -1s there!)
dst = dst.reshape(1, toCV(ti.dims));
}
}
}
template <typename T>
inline Ort::Value createTensor(const Ort::MemoryInfo& memory_info,
const cv::gimpl::onnx::TensorInfo& tensor_params,
const cv::Mat& data) {
(void) tensor_params;
auto ort_dims = toORT(data.size);
return Ort::Value::CreateTensor<T>(memory_info,
const_cast<T*>(data.ptr<T>()),
data.total(),
ort_dims.data(),
ort_dims.size());
}
inline Ort::Value createTensor(const Ort::MemoryInfo& memory_info,
const cv::gimpl::onnx::TensorInfo& tensor_params,
const cv::Mat& data) {
GAPI_Assert(data.isContinuous ());
switch (tensor_params.type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return createTensor<uint8_t>(memory_info, tensor_params, data);
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return createTensor<float>(memory_info, tensor_params, data);
default:
GAPI_Assert(false && "Unsupported data type");
}
return Ort::Value{nullptr};
}
struct ONNXUnit {
static const char *name() { return "ONNXModelConfig"; }
std::shared_ptr<cv::gimpl::onnx::ONNXCompiled> oc;
explicit ONNXUnit(const cv::gapi::onnx::detail::ParamDesc &pp)
: oc(new cv::gimpl::onnx::ONNXCompiled(pp)) {
}
};
struct ONNXCallContext {
// Input parameters passed to an inference operation.
std::vector<cv::GArg> args;
//FIXME: avoid conversion of arguments from internal representation to OpenCV one on each call
//to OCV kernel. (This can be achieved by a two single time conversions in GCPUExecutable::run,
//once on enter for input and output arguments, and once before return for output arguments only
//FIXME: check if the above applies to this backend (taken from CPU)
std::unordered_map<std::size_t, cv::GRunArgP> results;
// Generic accessor API
template<typename T>
const T& inArg(std::size_t input) { return args.at(input).get<T>(); }
// Syntax sugar
const cv::Mat& inMat(std::size_t input) {
return inArg<cv::Mat>(input);
}
cv::Mat& outMatR(std::size_t output) {
return *cv::util::get<cv::Mat*>(results.at(output));
}
template<typename T> std::vector<T>& outVecR(std::size_t output) { // FIXME: the same issue
return outVecRef(output).wref<T>();
}
cv::detail::VectorRef& outVecRef(std::size_t output) {
return cv::util::get<cv::detail::VectorRef>(results.at(output));
}
};
struct ONNXCallable {
static const char *name() { return "ONNXRequestCallable"; }
using Run = std::function<void(const ONNXUnit &, ONNXCallContext &)>;
Run run;
};
struct KImpl {
cv::gimpl::CustomMetaFunction::CM customMetaFunc;
ONNXCallable::Run run;
};
// FIXME: Is there a way to take a typed graph (our GModel),
// and create a new typed graph _ATOP_ of that (by extending with a couple of
// new types?).
// Alternatively, is there a way to compose types graphs?
//
// If not, we need to introduce that!
using GONNXModel = ade::TypedGraph
< cv::gimpl::Protocol
, cv::gimpl::Op
, cv::gimpl::NetworkParams
, cv::gimpl::CustomMetaFunction
, ONNXUnit
, ONNXCallable
>;
// FIXME: Same issue with Typed and ConstTyped
using GConstGONNXModel = ade::ConstTypedGraph
< cv::gimpl::Protocol
, cv::gimpl::Op
, cv::gimpl::NetworkParams
, cv::gimpl::CustomMetaFunction
, ONNXUnit
, ONNXCallable
>;
} // anonymous namespace
// GCPUExcecutable implementation //////////////////////////////////////////////
cv::gimpl::onnx::GONNXExecutable::GONNXExecutable(const ade::Graph &g,
const std::vector<ade::NodeHandle> &nodes)
: m_g(g), m_gm(m_g) {
// FIXME: Currently this backend is capable to run a single inference node only.
// Need to extend our island fusion with merge/not-to-merge decision making parametrization
GConstGONNXModel iem(g);
for (auto &nh : nodes) {
switch (m_gm.metadata(nh).get<NodeType>().t) {
case NodeType::OP:
if (this_nh == nullptr) {
this_nh = nh;
}
else {
util::throw_error(std::logic_error("Multi-node inference is not supported!"));
}
break;
case NodeType::DATA: {
m_dataNodes.push_back(nh);
const auto &desc = m_gm.metadata(nh).get<Data>();
if (desc.storage == Data::Storage::CONST_VAL) {
util::throw_error(std::logic_error("No const data supported in backend!"));
}
if (desc.storage == Data::Storage::INTERNAL) {
util::throw_error(std::logic_error("No internal data supported in backend!"));
}
break;
}
default: util::throw_error(std::logic_error("Unsupported NodeType"));
}
}
}
// FIXME: Document what it does
cv::GArg cv::gimpl::onnx::GONNXExecutable::packArg(const cv::GArg &arg) {
// No API placeholders allowed at this point
// FIXME: this check has to be done somewhere in compilation stage.
GAPI_Assert( arg.kind != cv::detail::ArgKind::GMAT
&& arg.kind != cv::detail::ArgKind::GSCALAR
&& arg.kind != cv::detail::ArgKind::GARRAY
&& arg.kind != cv::detail::ArgKind::GOPAQUE);
if (arg.kind != cv::detail::ArgKind::GOBJREF) {
util::throw_error(std::logic_error("Inference supports G-types ONLY!"));
}
GAPI_Assert(arg.kind == cv::detail::ArgKind::GOBJREF);
// Wrap associated CPU object (either host or an internal one)
// FIXME: object can be moved out!!! GExecutor faced that.
const cv::gimpl::RcDesc &ref = arg.get<cv::gimpl::RcDesc>();
switch (ref.shape)
{
case GShape::GMAT: return GArg(m_res.slot<cv::Mat>()[ref.id]);
// Note: .at() is intentional for GArray as object MUST be already there
// (and constructed by either bindIn/Out or resetInternal)
case GShape::GARRAY: return GArg(m_res.slot<cv::detail::VectorRef>().at(ref.id));
// Note: .at() is intentional for GOpaque as object MUST be already there
// (and constructed by either bindIn/Out or resetInternal)
case GShape::GOPAQUE: return GArg(m_res.slot<cv::detail::OpaqueRef>().at(ref.id));
default:
util::throw_error(std::logic_error("Unsupported GShape type"));
break;
}
}
void cv::gimpl::onnx::GONNXExecutable::run(std::vector<InObj> &&input_objs,
std::vector<OutObj> &&output_objs) {
// Update resources with run-time information - what this Island
// has received from user (or from another Island, or mix...)
// FIXME: Check input/output objects against GIsland protocol
for (auto& it : input_objs) magazine::bindInArg (m_res, it.first, it.second);
for (auto& it : output_objs) magazine::bindOutArg(m_res, it.first, it.second);
// FIXME: Running just a single node now.
// Not sure if need to support many of them, though
// FIXME: Make this island-unmergeable?
const auto &op = m_gm.metadata(this_nh).get<Op>();
// Initialize kernel's execution context:
// - Input parameters
ONNXCallContext context;
context.args.reserve(op.args.size());
using namespace std::placeholders;
ade::util::transform(op.args,
std::back_inserter(context.args),
std::bind(&GONNXExecutable::packArg, this, _1));
// - Output parameters.
for (const auto &out_it : ade::util::indexed(op.outs)) {
// FIXME: Can the same GArg type resolution mechanism be reused here?
const auto out_port = ade::util::index(out_it);
const auto out_desc = ade::util::value(out_it);
context.results[out_port] = magazine::getObjPtr(m_res, out_desc);
}
// And now trigger the execution
GConstGONNXModel giem(m_g);
const auto &uu = giem.metadata(this_nh).get<ONNXUnit>();
const auto &kk = giem.metadata(this_nh).get<ONNXCallable>();
kk.run(uu, context);
for (auto &it : output_objs) magazine::writeBack(m_res, it.first, it.second);
}
namespace cv {
namespace gimpl {
namespace onnx {
ONNXCompiled::ONNXCompiled(const gapi::onnx::detail::ParamDesc &pp)
: params(pp) {
// Validate input parameters before allocating any resources
if (params.num_in > 1u && params.num_in != params.input_names.size()) {
cv::util::throw_error(std::logic_error("Please specify input layer names for "
+ params.model_path));
}
if (params.num_out > 1u && params.num_out != params.output_names.size()) {
cv::util::throw_error(std::logic_error("Please specify output layer names for "
+ params.model_path));
}
// Create and initialize the ONNX session
Ort::SessionOptions session_options;
this_env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "");
this_session = Ort::Session(this_env, params.model_path.data(), session_options);
this_memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
in_tensor_info = getTensorInfo(INPUT);
out_tensor_info = getTensorInfo(OUTPUT);
const auto is_dyn = [](const TensorInfo &ti) {
return ti.is_dynamic;
};
is_dynamic = ade::util::any_of(in_tensor_info, is_dyn)
|| ade::util::any_of(out_tensor_info, is_dyn);
if (is_dynamic && !params.custom_post_proc) {
util::throw_error(std::logic_error("This network has dynamic shapes. "
"Please provide a custom post-processing function "
"(.cfgPostProc) in network parameters"));
}
// Update parameters based on session information
if (params.num_in == 1u && params.input_names.empty()) {
params.input_names = { in_tensor_info.front().name };
}
if (params.num_out == 1u && params.output_names.empty()) {
params.output_names = { out_tensor_info.front().name };
}
// Validate what is supported currently
GAPI_Assert(params.const_inputs.empty()
&& "Const inputs are not currently supported");
GAPI_Assert(std::all_of(in_tensor_info.begin(),
in_tensor_info.end(),
[](const cv::gimpl::onnx::TensorInfo &p) {
return p.type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
|| p.type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
})
&& "Only FP32 and U8 inputs for NN are supported");
// Put mean and std in appropriate tensor params
if (!params.mean.empty() || !params.stdev.empty()) {
GAPI_Assert(params.mean.size() == params.stdev.size() &&
params.mean.size() == params.input_names.size());
for (auto idx : ade::util::iota(params.num_in)) {
const auto ort_idx = getIdxByName(in_tensor_info, params.input_names[idx]);
using M = TensorInfo::MeanStdev;
in_tensor_info[ort_idx].mstd = util::make_optional(M{ params.mean[idx]
, params.stdev[idx] });
}
}
// Update normalize flags for input tensors
if (!params.normalize.empty()) {
for (auto idx : ade::util::iota(params.num_in)) {
const auto ort_idx = getIdxByName(in_tensor_info, params.input_names[idx]);
in_tensor_info[ort_idx].normalize = params.normalize[idx];
}
}
// Pre-allocate vectors (not buffers) for runtime info
in_data.resize(params.num_in);
out_data.resize(params.num_out);
}
std::vector<TensorInfo> ONNXCompiled::getTensorInfo(TensorPosition pos) {
GAPI_Assert(pos == INPUT || pos == OUTPUT);
const auto num_nodes = pos == INPUT
? this_session.GetInputCount()
: this_session.GetOutputCount();
std::vector<TensorInfo> tensor_info;
tensor_info.reserve(num_nodes);
Ort::AllocatorWithDefaultOptions allocator;
for (auto i : ade::util::iota(num_nodes)) {
const auto info = pos == INPUT
? this_session.GetInputTypeInfo(i)
: this_session.GetOutputTypeInfo(i);
tensor_info.emplace_back(info.GetTensorTypeAndShapeInfo());
char *name_p = pos == INPUT
? this_session.GetInputName(i, allocator)
: this_session.GetOutputName(i, allocator);
tensor_info.back().name = name_p;
allocator.Free(name_p);
}
return tensor_info;
}
cv::GMatDesc ONNXCompiled::outMeta(int idx) const {
if (is_dynamic) {
GAPI_Assert(!params.out_metas.empty()
&& "Metadata must be specified if NN has dynamic inputs!");
return params.out_metas.at(idx);
}
const auto ort_idx = getIdxByName(out_tensor_info, params.output_names[idx]);
return cv::GMatDesc(toCV(out_tensor_info[ort_idx].type),
toCV(out_tensor_info[ort_idx].dims));
}
void ONNXCompiled::setInput(int i, const cv::Mat &m) {
const auto in_idx = i;
const auto in_name = params.input_names[in_idx];
const auto ort_idx = getIdxByName(in_tensor_info, in_name);
preprocess(m, in_tensor_info[ort_idx], in_data[in_idx]);
}
void ONNXCompiled::setOutput(int i, cv::Mat &m) {
// FIXME: No need in double-indexing?
out_data[i] = m;
}
cv::Mat ONNXCompiled::allocOutput(int i) const {
cv::Mat m;
m.create(toCV(out_tensor_info[i].dims),
toCV(out_tensor_info[i].type));
return m;
}
void ONNXCompiled::Run(const std::vector<cv::Mat>& ins,
const std::vector<cv::Mat>& outs) {
std::vector<Ort::Value> in_tensors, out_tensors;
auto in_run_names = getCharNames(params.input_names);
for (const auto it : ade::util::indexed(params.input_names)) {
auto i = ade::util::index(it);
auto in_name = ade::util::value(it);
const auto idx = getIdxByName(in_tensor_info, in_name);
in_tensors.emplace_back(createTensor(this_memory_info,
in_tensor_info[idx],
ins[i]));
}
if (!is_dynamic) {
// Easy path - just run the session which is bound to G-API's
// internal data
for (auto i : ade::util::iota(params.output_names.size())) {
out_tensors.emplace_back(createTensor(this_memory_info,
out_tensor_info[i],
outs[i]));
}
auto out_run_names = getCharNames(params.output_names);
this_session.Run(Ort::RunOptions{nullptr},
in_run_names.data(),
&in_tensors.front(),
params.input_names.size(),
out_run_names.data(),
&out_tensors.front(),
params.output_names.size());
} else {
// Hard path - run session & user-defined post-processing
// NOTE: use another list of output names here
std::vector<const char*> out_names;
for (auto &&ti : out_tensor_info) {
out_names.push_back(ti.name.c_str());
}
auto outputs = this_session.Run(Ort::RunOptions{nullptr},
in_run_names.data(),
&in_tensors.front(),
params.input_names.size(),
out_names.data(),
out_names.size());
std::unordered_map<std::string, cv::Mat> onnx_outputs;
std::unordered_map<std::string, cv::Mat> gapi_outputs;
GAPI_Assert(outputs.size() == out_names.size());
// Fill in ONNX tensors
for (auto &&iter : ade::util::zip(ade::util::toRange(out_tensor_info),
ade::util::toRange(outputs))) {
const auto &out_name = std::get<0>(iter).name;
auto &out_tensor = std::get<1>(iter);
onnx_outputs[out_name] = toCV(out_tensor);
}
// Fill in G-API outputs
for (auto &&it: ade::util::indexed(params.output_names)) {
gapi_outputs[ade::util::value(it)] = outs[ade::util::index(it)];
}
params.custom_post_proc(onnx_outputs, gapi_outputs);
}
}
void ONNXCompiled::run() {
Run(in_data, out_data);
}
struct Infer: public cv::detail::KernelTag {
using API = cv::GInferBase;
static cv::gapi::GBackend backend() { return cv::gapi::onnx::backend(); }
static KImpl kernel() { return KImpl{outMeta, run}; }
static cv::GMetaArgs outMeta(const ade::Graph &gr,
const ade::NodeHandle &nh,
const cv::GMetaArgs &in_metas,
const cv::GArgs &/*in_args*/) {
cv::GMetaArgs result;
GConstGONNXModel gm(gr);
const auto &uu = gm.metadata(nh).get<ONNXUnit>();
GAPI_Assert(uu.oc->numInputs() == in_metas.size()
&& "Known input layers count doesn't match input meta count");
for (auto &&mm : in_metas) {
GAPI_Assert(util::holds_alternative<cv::GMatDesc>(mm)
&& "Non-GMat inputs are not supported");
}
for (auto &&idx : ade::util::iota(uu.oc->numOutputs())) {
result.emplace_back(uu.oc->outMeta(idx));
}
return result;
}
static void run(const ONNXUnit &uu, ONNXCallContext &ctx) {
for (auto &&idx : ade::util::iota(uu.oc->numInputs())) {
uu.oc->setInput(idx, ctx.inMat(idx));
}
for (auto &&idx : ade::util::iota(uu.oc->numOutputs())) {
uu.oc->setOutput(idx, ctx.outMatR(idx));
}
uu.oc->run();
}
};
struct InferROI: public cv::detail::KernelTag {
using API = cv::GInferROIBase;
static cv::gapi::GBackend backend() { return cv::gapi::onnx::backend(); }
static KImpl kernel() { return KImpl{outMeta, run}; }
static cv::GMetaArgs outMeta(const ade::Graph &gr,
const ade::NodeHandle &nh,
const cv::GMetaArgs &in_metas,
const cv::GArgs &/*in_args*/) {
cv::GMetaArgs result;
GConstGONNXModel gm(gr);
const auto &uu = gm.metadata(nh).get<ONNXUnit>();
GAPI_Assert(1u == uu.oc->numInputs());
GAPI_Assert(2u == in_metas.size());
for (auto &&idx : ade::util::iota(uu.oc->numOutputs())) {
result.emplace_back(uu.oc->outMeta(idx));
}
return result;
}
static void run(const ONNXUnit &uu, ONNXCallContext &ctx) {
// non-generic version for now, per the InferROI's definition
GAPI_Assert(uu.oc->numInputs() == 1u);
const auto& this_roi = ctx.inArg<cv::detail::OpaqueRef>(0).rref<cv::Rect>();
const auto this_mat = ctx.inMat(1);
uu.oc->setInput(0, this_mat(this_roi));
for (auto &&idx : ade::util::iota(uu.oc->numOutputs())) {
uu.oc->setOutput(idx, ctx.outMatR(idx));
}
uu.oc->run();
}
};
struct InferList: public cv::detail::KernelTag {
using API = cv::GInferListBase;
static cv::gapi::GBackend backend() { return cv::gapi::onnx::backend(); }
static KImpl kernel() { return KImpl{outMeta, run}; }
static cv::GMetaArgs outMeta(const ade::Graph &gr,
const ade::NodeHandle &nh,
const cv::GMetaArgs &in_metas,
const cv::GArgs &/*in_args*/) {
GConstGONNXModel gm(gr);
const auto &uu = gm.metadata(nh).get<ONNXUnit>();
// Note our input layers list order matches the API order and so
// meta order.
GAPI_Assert(uu.oc->numInputs() == (in_metas.size() - 1u)
&& "Known input layers count doesn't match input meta count");
for (auto i : ade::util::iota(uu.oc->numInputs())) {
const auto & mm = in_metas[i + 1];
GAPI_Assert(util::holds_alternative<cv::GMatDesc>(mm)
&& "Non-GMat inputs are not supported");
}
// roi-list version is much easier at the moment.
// All our outputs are vectors which don't have
// metadata at the moment - so just create a vector of
// "empty" array metadatas of the required size.
return cv::GMetaArgs(uu.oc->numOutputs(),
cv::GMetaArg{cv::empty_array_desc()});
}
static void run(const ONNXUnit &uu, ONNXCallContext &ctx) {
// non-generic version for now:
// - assumes input 0 is always ROI list
// - assumes all inputs/outputs are always Mats
GAPI_Assert(uu.oc->numInputs() == 1); // roi list is not counted in net's inputs
const auto& in_roi_vec = ctx.inArg<cv::detail::VectorRef>(0u).rref<cv::Rect>();
const cv::Mat this_mat = ctx.inMat(1u);
for (auto i : ade::util::iota(uu.oc->numOutputs())) {
ctx.outVecR<cv::Mat>(i).clear();
}
for (const auto &rc : in_roi_vec) {
uu.oc->setInput(0, this_mat(rc));
std::vector<cv::Mat> out_mats(uu.oc->numOutputs());
for (auto i : ade::util::iota(uu.oc->numOutputs())) {
out_mats[i] = uu.oc->allocOutput(i);
uu.oc->setOutput(i, out_mats[i]);
}
uu.oc->run();
for (auto i : ade::util::iota(uu.oc->numOutputs())) {
std::vector<cv::Mat> &out_vec = ctx.outVecR<cv::Mat>(i);
out_vec.push_back(std::move(out_mats[i]));
}
}
}
};
struct InferList2: public cv::detail::KernelTag {
using API = cv::GInferList2Base;
static cv::gapi::GBackend backend() { return cv::gapi::onnx::backend(); }
static KImpl kernel() { return KImpl{outMeta, run}; }
static cv::GMetaArgs outMeta(const ade::Graph &gr,
const ade::NodeHandle &nh,
const cv::GMetaArgs &in_metas,
const cv::GArgs &/*in_args*/) {
GConstGONNXModel gm(gr);
const auto &uu = gm.metadata(nh).get<ONNXUnit>();
// Note our input layers list order matches the API order and so
// meta order.
GAPI_Assert(uu.oc->numInputs() == (in_metas.size() - 1u)
&& "Known input layers count doesn't match input meta count");
// In contrast to InferList, the InferList2 has only one
// "full-frame" image argument, and all the rest are arrays of
// ether ROI or blobs. So here we set the 0th arg image format
// to all inputs which are ROI-based (skipping the
// "blob"-based ones)
// FIXME: this is filtering not done, actually! GArrayDesc has
// no hint for type!
const auto &mm_0 = in_metas[0u];
const auto &meta_0 = util::get<cv::GMatDesc>(mm_0);
GAPI_Assert( !meta_0.isND()
&& !meta_0.planar
&& "Only images are supported as the 0th argument");
for (auto i : ade::util::iota(uu.oc->numInputs())) {
const auto &mm = in_metas[i + 1];
GAPI_Assert(util::holds_alternative<cv::GArrayDesc>(mm)
&& "Non-array inputs are not supported");
}
// roi-list version is much easier at the moment.
// All our outputs are vectors which don't have
// metadata at the moment - so just create a vector of
// "empty" array metadatas of the required size.
return cv::GMetaArgs(uu.oc->numOutputs(),
cv::GMetaArg{cv::empty_array_desc()});
}
static void run(const ONNXUnit &uu, ONNXCallContext &ctx) {
GAPI_Assert(ctx.args.size() > 1u
&& "This operation must have at least two arguments");
// Since we do a ROI list inference, always assume our input buffer is image
const cv::Mat mat_0 = ctx.inMat(0u);
// Take the next argument, which must be vector (of any kind).
// Use this only to obtain the ROI list size (sizes of all
// other vectors must be equal to this one)
const auto list_size = ctx.inArg<cv::detail::VectorRef>(1u).size();
for (auto i : ade::util::iota(uu.oc->numOutputs())) {
ctx.outVecR<cv::Mat>(i).clear();
}
// For every ROI in the list {{{
for (const auto &list_idx : ade::util::iota(list_size)) {
std::vector<Ort::Value> in_tensors, out_tensors;
std::vector<cv::Mat> in_mats(uu.oc->numInputs());
// For every input of the net {{{
for (auto in_idx : ade::util::iota(uu.oc->numInputs())) {
const auto &this_vec = ctx.inArg<cv::detail::VectorRef>(in_idx+1u);
GAPI_Assert(this_vec.size() == list_size);
// Prepare input {{{
// FIXME: Terrible run-time logic based on RTTI!
// FIXME: Will never work on non-RTTI systems!
// FIXME: Need to replace with a static type tags
// (like with serialization) instead!
if (this_vec.holds<cv::Rect>()) {
// ROI case - create an ROI blob
const auto &vec = this_vec.rref<cv::Rect>();
uu.oc->setInput(in_idx, mat_0(vec[list_idx]));
} else if (this_vec.holds<cv::Mat>()) {
// Mat case - create a regular blob
// FIXME: NOW Assume Mats are always BLOBS (not
// images)
const auto &vec = this_vec.rref<cv::Mat>();
uu.oc->setInput(in_idx, vec[list_idx]);
} else {
GAPI_Assert(false && "Only Rect and Mat types are supported for infer list 2!");
}
// }}} (Preapre input)
} // }}} (For every input of the net)
std::vector<cv::Mat> out_mats(uu.oc->numOutputs());
for (auto i : ade::util::iota(uu.oc->numOutputs())) {
out_mats[i] = uu.oc->allocOutput(i);
uu.oc->setOutput(i, out_mats[i]);
}
uu.oc->run();
for (auto i : ade::util::iota(uu.oc->numOutputs())) {
std::vector<cv::Mat> &out_vec = ctx.outVecR<cv::Mat>(i);
out_vec.push_back(std::move(out_mats[i]));
}
} // }}} (For every ROI in the list)
}
};
} // namespace onnx
} // namespace gapi
} // namespace cv
namespace {
class GONNXBackendImpl final: public cv::gapi::GBackend::Priv {
virtual void unpackKernel(ade::Graph &gr,
const ade::NodeHandle &nh,
const cv::GKernelImpl &ii) override {
using namespace cv::gimpl;
// FIXME: Introduce a DNNBackend interface which'd specify
// the framework for this???
GONNXModel gm(gr);
const auto &np = gm.metadata(nh).get<NetworkParams>();
const auto &pp = cv::util::any_cast<cv::gapi::onnx::detail::ParamDesc>(np.opaque);
const auto &ki = cv::util::any_cast<KImpl>(ii.opaque);
gm.metadata(nh).set(ONNXUnit{pp});
gm.metadata(nh).set(ONNXCallable{ki.run});
gm.metadata(nh).set(CustomMetaFunction{ki.customMetaFunc});
}
virtual EPtr compile(const ade::Graph &graph,
const cv::GCompileArgs &,
const std::vector<ade::NodeHandle> &nodes) const override {
return EPtr{new cv::gimpl::onnx::GONNXExecutable(graph, nodes)};
}
virtual cv::gapi::GKernelPackage auxiliaryKernels() const override {
return cv::gapi::kernels< cv::gimpl::onnx::Infer
, cv::gimpl::onnx::InferROI
, cv::gimpl::onnx::InferList
, cv::gimpl::onnx::InferList2
>();
}
};
}
cv::gapi::GBackend cv::gapi::onnx::backend() {
static cv::gapi::GBackend this_backend(std::make_shared<GONNXBackendImpl>());
return this_backend;
}
#else // HAVE_ONNX
cv::gapi::GBackend cv::gapi::onnx::backend() {
// Still provide this symbol to avoid linking issues
util::throw_error(std::runtime_error("G-API has been compiled without ONNX support"));
}
#endif // HAVE_ONNX

View File

@ -0,0 +1,56 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2020 Intel Corporation
#ifndef OPENCV_GAPI_GONNXBACKEND_HPP
#define OPENCV_GAPI_GONNXBACKEND_HPP
#include "opencv2/gapi/infer/onnx.hpp"
#ifdef HAVE_ONNX
#include <onnxruntime_cxx_api.h>
#include <ade/util/algorithm.hpp> // type_list_index
#include "backends/common/gbackend.hpp"
namespace cv {
namespace gimpl {
namespace onnx {
class GONNXExecutable final: public GIslandExecutable
{
const ade::Graph &m_g;
GModel::ConstGraph m_gm;
// The only executable stuff in this graph
// (assuming it is always single-op)
ade::NodeHandle this_nh;
// List of all resources in graph (both internal and external)
std::vector<ade::NodeHandle> m_dataNodes;
// Actual data of all resources in graph (both internal and external)
Mag m_res;
// Execution helpers
GArg packArg(const GArg &arg);
public:
GONNXExecutable(const ade::Graph &graph,
const std::vector<ade::NodeHandle> &nodes);
virtual inline bool canReshape() const override { return false; }
virtual inline void reshape(ade::Graph&, const GCompileArgs&) override {
GAPI_Assert(false); // Not implemented yet
}
virtual void run(std::vector<InObj> &&input_objs,
std::vector<OutObj> &&output_objs) override;
};
}}} // namespace cv::gimpl::onnx
#endif // HAVE_ONNX
#endif // OPENCV_GAPI_GONNXBACKEND_HPP

View File

@ -141,6 +141,7 @@ void cv::gimpl::passes::bindNetParams(ade::passes::PassContext &ctx,
continue;
pgr.metadata(nh).set(NetworkParams{it->params});
op.backend = it->backend;
}
}
}
@ -181,13 +182,25 @@ void cv::gimpl::passes::resolveKernels(ade::passes::PassContext &ctx,
// of the same kernel to be presented in the kernel
// package (as it was designed originally).
cv::gapi::GBackend selected_backend;
cv::GKernelImpl selected_impl;
std::tie(selected_backend, selected_impl) = kernels.lookup(op.k.name);
cv::GKernelImpl selected_impl;
selected_backend.priv().unpackKernel(ctx.graph, nh, selected_impl);
op.backend = selected_backend;
active_backends.insert(selected_backend);
if (op.backend == cv::gapi::GBackend()) {
std::tie(op.backend, selected_impl) = kernels.lookup(op.k.name);
} else {
// FIXME: This needs to be reworked properly
// Lookup for implementation from the pre-assinged backend
cv::gapi::GBackend dummy;
std::tie(dummy, selected_impl) = op.backend.priv()
.auxiliaryKernels().lookup(op.k.name);
// FIXME: Warning here!
// This situation may happen when NN (infer) backend was assigned
// by tag in bindNetParams (see above) but at this stage the operation
// lookup resulted in another backend (and it is perfectly valid when
// we have multiple NN backends available).
}
op.backend.priv().unpackKernel(ctx.graph, nh, selected_impl);
active_backends.insert(op.backend);
if (gr.metadata().contains<Deserialized>())
{

View File

@ -0,0 +1,278 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2020 Intel Corporation
#include "../test_precomp.hpp"
#ifdef HAVE_ONNX
#include <stdexcept>
#include <onnxruntime_cxx_api.h>
#include <ade/util/iota_range.hpp>
#include <opencv2/gapi/infer/onnx.hpp>
namespace {
struct ONNXInitPath {
ONNXInitPath() {
const char* env_path = getenv("OPENCV_GAPI_ONNX_MODEL_PATH");
if (env_path)
cvtest::addDataSearchPath(env_path);
}
};
static ONNXInitPath g_init_path;
cv::Mat initMatrixRandU(int type, cv::Size sz_in)
{
cv::Mat in_mat1 = cv::Mat(sz_in, type);
if (CV_MAT_DEPTH(type) < CV_32F)
{
cv::randu(in_mat1, cv::Scalar::all(0), cv::Scalar::all(255));
}
else
{
const int fscale = 256; // avoid bits near ULP, generate stable test input
cv::Mat in_mat32s(in_mat1.size(), CV_MAKE_TYPE(CV_32S, CV_MAT_CN(type)));
cv::randu(in_mat32s, cv::Scalar::all(0), cv::Scalar::all(255 * fscale));
in_mat32s.convertTo(in_mat1, type, 1.0f / fscale, 0);
}
return in_mat1;
}
}
namespace opencv_test
{
namespace {
// FIXME: taken from the DNN module
void normAssert(cv::InputArray ref, cv::InputArray test,
const char *comment /*= ""*/,
double l1 = 0.00001, double lInf = 0.0001)
{
double normL1 = cvtest::norm(ref, test, cv::NORM_L1) / ref.getMat().total();
EXPECT_LE(normL1, l1) << comment;
double normInf = cvtest::norm(ref, test, cv::NORM_INF);
EXPECT_LE(normInf, lInf) << comment;
}
std::string findModel(const std::string &model_name)
{
return findDataFile("vision/classification/squeezenet/model/" + model_name + ".onnx", false);
}
inline void preprocess(const cv::Mat& src,
cv::Mat& dst,
const cv::Scalar& mean,
const cv::Scalar& std) {
int new_h = 224;
int new_w = 224;
cv::Mat tmp, nmat, cvt;
cv::resize(src, dst, cv::Size(new_w, new_h));
dst.convertTo(cvt, CV_32F, 1.f / 255);
nmat = cvt - mean;
tmp = nmat / std;
dst.create(cv::Size(new_w, new_h * src.channels()), CV_32F);
std::vector<cv::Mat> planes;
for (int i = 0; i < src.channels(); ++i) {
planes.push_back(dst.rowRange(i * new_h, (i + 1) * new_h));
}
cv::split(tmp, planes);
}
void InferONNX(const std::string& model_path,
const cv::Mat& in,
cv::Mat& out,
const cv::Scalar& mean,
const cv::Scalar& std)
{
// FIXME: It must be a FIXTURE test!
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
Ort::SessionOptions session_options;
Ort::Session session(env, model_path.data(), session_options);
auto input_node_dims = // 0 - one input
session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
auto output_node_dims = // 0 - one output
session.GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
Ort::AllocatorWithDefaultOptions allocator;
char* in_node_name_p = session.GetInputName(0, allocator);
char* out_node_name_p = session.GetOutputName(0, allocator);
std::string in_node_name(in_node_name_p);
std::string out_node_name(out_node_name_p);
allocator.Free(in_node_name_p);
allocator.Free(out_node_name_p);
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
cv::Mat dst;
preprocess(in, dst, mean, std);
out.create(std::vector<int>(output_node_dims.begin(),
output_node_dims.end()), CV_32F); // empty output Mat
auto in_tensor = Ort::Value::CreateTensor<float>(memory_info,
dst.ptr<float>(),
dst.total(),
input_node_dims.data(),
input_node_dims.size());
auto out_tensor = Ort::Value::CreateTensor<float>(memory_info,
out.ptr<float>(),
out.total(),
output_node_dims.data(),
output_node_dims.size());
std::vector<const char *> in_names = {in_node_name.data()};
std::vector<const char *> out_names = {out_node_name.data()};
session.Run(Ort::RunOptions{nullptr},
in_names.data(),
&in_tensor,
session.GetInputCount(),
out_names.data(),
&out_tensor,
session.GetOutputCount());
}
} // anonymous namespace
TEST(ONNX, Infer)
{
cv::Mat in_mat1, out_gapi, out_onnx;
std::string model_path = findModel("squeezenet1.0-9");
// NOTE: All tests chek "random" image
// Ideally it should be a real image
in_mat1 = initMatrixRandU(CV_8UC3, cv::Size{640, 480});
cv::Scalar mean = { 0.485, 0.456, 0.406 };
cv::Scalar std = { 0.229, 0.224, 0.225 };
// ONNX_API code
InferONNX(model_path, in_mat1, out_onnx, mean, std);
// G_API code
G_API_NET(SqueezNet, <cv::GMat(cv::GMat)>, "squeeznet");
cv::GMat in;
cv::GMat out = cv::gapi::infer<SqueezNet>(in);
cv::GComputation comp(cv::GIn(in), cv::GOut(out));
// NOTE: We have to normalize U8 tensor
// so cfgMeanStd() is here
auto net = cv::gapi::onnx::Params<SqueezNet> { model_path }.cfgMeanStd({mean},{std});
comp.apply(cv::gin(in_mat1),
cv::gout(out_gapi),
cv::compile_args(cv::gapi::networks(net)));
// Validate
ASSERT_EQ(1000u, out_onnx.total());
ASSERT_EQ(1000u, out_gapi.total());
normAssert(out_onnx, out_gapi, "Test classification output");
}
TEST(ONNX, InferROI)
{
cv::Mat in_mat1, out_gapi, out_onnx;
std::string model_path = findModel("squeezenet1.0-9");
in_mat1 = initMatrixRandU(CV_8UC3, cv::Size{640, 480});
cv::Scalar mean = { 0.485, 0.456, 0.406 }; // squeeznet mean
cv::Scalar std = { 0.229, 0.224, 0.225 }; // squeeznet std
cv::Rect ROI(cv::Point{0, 0}, cv::Size{250, 250});
// ONNX_API code
InferONNX(model_path, in_mat1(ROI), out_onnx, mean, std);
// G_API code
G_API_NET(SqueezNet, <cv::GMat(cv::GMat)>, "squeeznet");
cv::GMat in;
cv::GOpaque<cv::Rect> rect;
cv::GMat out = cv::gapi::infer<SqueezNet>(rect, in);
cv::GComputation comp(cv::GIn(in, rect), cv::GOut(out));
auto net = cv::gapi::onnx::Params<SqueezNet> { model_path }.cfgMeanStd({mean},{std});
comp.apply(cv::gin(in_mat1, ROI),
cv::gout(out_gapi),
cv::compile_args(cv::gapi::networks(net)));
// Validate
ASSERT_EQ(1000u, out_onnx.total());
ASSERT_EQ(1000u, out_gapi.total());
normAssert(out_onnx, out_gapi, "Test classification output");
}
TEST(ONNX, InferROIList)
{
cv::Mat in_mat1;
std::string model_path = findModel("squeezenet1.0-9");
in_mat1 = initMatrixRandU(CV_8UC3, cv::Size{640, 480});
cv::Scalar mean = { 0.485, 0.456, 0.406 }; // squeeznet mean
cv::Scalar std = { 0.229, 0.224, 0.225 }; // squeeznet std
std::vector<cv::Rect> rois = {
cv::Rect(cv::Point{ 0, 0}, cv::Size{80, 120}),
cv::Rect(cv::Point{50, 100}, cv::Size{250, 360}),
};
std::vector<cv::Mat> out_gapi;
std::vector<cv::Mat> out_onnx(rois.size());
// ONNX_API code
for (size_t i = 0; i < rois.size(); ++i) {
InferONNX(model_path, in_mat1(rois[i]), out_onnx[i], mean, std);
}
// G_API code
G_API_NET(SqueezNet, <cv::GMat(cv::GMat)>, "squeeznet");
cv::GMat in;
cv::GArray<cv::Rect> rr;
cv::GArray<cv::GMat> out = cv::gapi::infer<SqueezNet>(rr, in);
cv::GComputation comp(cv::GIn(in, rr), cv::GOut(out));
auto net = cv::gapi::onnx::Params<SqueezNet> { model_path }.cfgMeanStd({mean},{std});
comp.apply(cv::gin(in_mat1, rois),
cv::gout(out_gapi),
cv::compile_args(cv::gapi::networks(net)));
// Validate
for (size_t i = 0; i < rois.size(); ++i) {
ASSERT_EQ(1000u, out_onnx[i].total());
ASSERT_EQ(1000u, out_gapi[i].total());
normAssert(out_onnx[i], out_gapi[i], "Test classification output");
}
}
TEST(ONNX, Infer2ROIList)
{
cv::Mat in_mat1;
std::string model_path = findModel("squeezenet1.0-9");
in_mat1 = initMatrixRandU(CV_8UC3, cv::Size{640, 480});
cv::Scalar mean = { 0.485, 0.456, 0.406 }; // squeeznet mean
cv::Scalar std = { 0.229, 0.224, 0.225 }; // squeeznet std
std::vector<cv::Rect> rois = {
cv::Rect(cv::Point{ 0, 0}, cv::Size{80, 120}),
cv::Rect(cv::Point{50, 100}, cv::Size{250, 360}),
};
std::vector<cv::Mat> out_gapi;
std::vector<cv::Mat> out_onnx(rois.size());
// ONNX_API code
for (size_t i = 0; i < rois.size(); ++i) {
InferONNX(model_path, in_mat1(rois[i]), out_onnx[i], mean, std);
}
// G_API code
G_API_NET(SqueezNet, <cv::GMat(cv::GMat)>, "squeeznet");
cv::GMat in;
cv::GArray<cv::Rect> rr;
cv::GArray<cv::GMat> out = cv::gapi::infer2<SqueezNet>(in,rr);
cv::GComputation comp(cv::GIn(in, rr), cv::GOut(out));
auto net = cv::gapi::onnx::Params<SqueezNet> { model_path }.cfgMeanStd({mean},{std});
comp.apply(cv::gin(in_mat1, rois),
cv::gout(out_gapi),
cv::compile_args(cv::gapi::networks(net)));
// Validate
for (size_t i = 0; i < rois.size(); ++i) {
ASSERT_EQ(1000u, out_onnx[i].total());
ASSERT_EQ(1000u, out_gapi[i].total());
normAssert(out_onnx[i], out_gapi[i], "Test classification output");
}
}
} // namespace opencv_test
#endif // HAVE_ONNX