mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 06:26:29 +08:00
Merge pull request #24039 from dkurt:tflite_test_backends
TFLite models on different backends (tests and improvements) #24039 ### Pull Request Readiness Checklist * MaxUnpooling with OpenVINO * Fully connected with transposed inputs/weights with OpenVINO * Enable backends tests for TFLite (related to https://github.com/opencv/opencv/issues/23992#issuecomment-1640691722) * Increase existing tests thresholds See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
96f23e3da1
commit
4b8aeb1129
@ -180,15 +180,12 @@ public:
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||
{
|
||||
bool tranAorB = transA || transB;
|
||||
#ifdef HAVE_INF_ENGINE
|
||||
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
|
||||
return axis == 1 && !tranAorB;
|
||||
#endif
|
||||
return backendId == DNN_BACKEND_OPENCV ||
|
||||
backendId == DNN_BACKEND_CUDA ||
|
||||
(backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1 && !tranAorB) ||
|
||||
(backendId == DNN_BACKEND_WEBNN && axis == 1 && !tranAorB) ||
|
||||
backendId == DNN_BACKEND_CANN ||
|
||||
backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH ||
|
||||
(backendId == DNN_BACKEND_VKCOM && haveVulkan() && !tranAorB);
|
||||
}
|
||||
|
||||
@ -802,17 +799,26 @@ public:
|
||||
if (nodes.size() == 2)
|
||||
{
|
||||
auto& inp2 = nodes[1].dynamicCast<InfEngineNgraphNode>()->node;
|
||||
matmul = std::make_shared<ngraph::op::MatMul>(ieInpNode, inp2, false, false);
|
||||
matmul = std::make_shared<ngraph::op::MatMul>(ieInpNode, inp2, transA, transB);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<int64_t> data = {(int64_t)ieInpNode->get_shape()[0], (int64_t)blobs[0].size[1]};
|
||||
auto new_shape = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{2}, data.data());
|
||||
auto inp = std::make_shared<ngraph::op::v1::Reshape>(ieInpNode, new_shape, true);
|
||||
std::vector<int> shape(1 + normalize_axis(axis, ieInpNode->get_shape().size()), 0);
|
||||
shape[shape.size() - 1] = -1;
|
||||
auto inp = std::make_shared<ngraph::op::v1::Reshape>(
|
||||
ieInpNode,
|
||||
std::make_shared<ngraph::op::Constant>(ngraph::element::i32, ngraph::Shape{shape.size()}, shape.data()),
|
||||
true
|
||||
);
|
||||
|
||||
std::vector<size_t> weight_shape{(size_t)blobs[0].size[0], (size_t)blobs[0].size[1]};
|
||||
std::vector<size_t> weight_shape;
|
||||
if (isMatMul) {
|
||||
weight_shape = getShape<size_t>(oriMat);
|
||||
} else {
|
||||
weight_shape = {(size_t)blobs[0].size[0], (size_t)blobs[0].size[1]};
|
||||
}
|
||||
auto ieWeights = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, weight_shape, blobs[0].data);
|
||||
matmul = std::make_shared<ngraph::op::MatMul>(inp, ieWeights, false, true);
|
||||
matmul = std::make_shared<ngraph::op::MatMul>(inp, ieWeights, transA, transB);
|
||||
}
|
||||
|
||||
if (bias) {
|
||||
|
@ -13,6 +13,7 @@ Implementation of Batch Normalization layer.
|
||||
#include "layers_common.hpp"
|
||||
#include "../op_cuda.hpp"
|
||||
#include "../op_halide.hpp"
|
||||
#include "../ie_ngraph.hpp"
|
||||
#include <opencv2/dnn/shape_utils.hpp>
|
||||
#include <opencv2/core/utils/logger.hpp>
|
||||
|
||||
@ -41,6 +42,7 @@ public:
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV ||
|
||||
backendId == DNN_BACKEND_CUDA ||
|
||||
backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH ||
|
||||
(backendId == DNN_BACKEND_HALIDE && haveHalide() && !poolPad.width && !poolPad.height);
|
||||
}
|
||||
|
||||
@ -181,6 +183,50 @@ public:
|
||||
#endif // HAVE_HALIDE
|
||||
return Ptr<BackendNode>();
|
||||
}
|
||||
|
||||
#ifdef HAVE_DNN_NGRAPH
|
||||
virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inputs,
|
||||
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE
|
||||
{
|
||||
auto features = nodes[0].dynamicCast<InfEngineNgraphNode>()->node;
|
||||
auto indices = nodes[1].dynamicCast<InfEngineNgraphNode>()->node;
|
||||
|
||||
std::vector<MatShape> inpShapes(nodes.size());
|
||||
std::vector<MatShape> outShapes, internals;
|
||||
for (int i = 0; i < nodes.size(); ++i) {
|
||||
std::vector<size_t> shape = nodes[i].dynamicCast<InfEngineNgraphNode>()->node->get_shape();
|
||||
inpShapes[i] = std::vector<int>(shape.begin(), shape.end());
|
||||
}
|
||||
getMemoryShapes(inpShapes, 1, outShapes, internals);
|
||||
|
||||
Mat zeros = Mat::zeros(1, total(outShapes[0]), CV_32F);
|
||||
auto zeroInp = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, ngraph::Shape{zeros.total()}, zeros.data);
|
||||
|
||||
int newShape = -1;
|
||||
features = std::make_shared<ngraph::op::v1::Reshape>(
|
||||
features,
|
||||
std::make_shared<ngraph::op::Constant>(ngraph::element::i32, ngraph::Shape{1}, &newShape),
|
||||
true
|
||||
);
|
||||
indices = std::make_shared<ngraph::op::v1::Reshape>(
|
||||
indices,
|
||||
std::make_shared<ngraph::op::Constant>(ngraph::element::i32, ngraph::Shape{1}, &newShape),
|
||||
true
|
||||
);
|
||||
if (indices->get_element_type() != ngraph::element::i32 && indices->get_element_type() != ngraph::element::i64) {
|
||||
indices = std::make_shared<ngraph::op::Convert>(indices, ngraph::element::i64);
|
||||
}
|
||||
|
||||
int axis = 0;
|
||||
std::shared_ptr<ngraph::Node> unpool = std::make_shared<ngraph::op::ScatterElementsUpdate>(zeroInp, indices, features,
|
||||
std::make_shared<ngraph::op::Constant>(ngraph::element::i32, ngraph::Shape{1}, &axis));
|
||||
|
||||
auto shape = std::make_shared<ngraph::op::Constant>(ngraph::element::i32, ngraph::Shape{outShapes[0].size()}, outShapes[0].data());
|
||||
unpool = std::make_shared<ngraph::op::v1::Reshape>(unpool, shape, true);
|
||||
|
||||
return Ptr<BackendNode>(new InfEngineNgraphNode(unpool));
|
||||
}
|
||||
#endif // HAVE_DNN_NGRAPH
|
||||
};
|
||||
|
||||
Ptr<MaxUnpoolLayer> MaxUnpoolLayer::create(const LayerParams& params)
|
||||
|
@ -209,7 +209,7 @@ public:
|
||||
#ifdef HAVE_INF_ENGINE
|
||||
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
|
||||
{
|
||||
return !computeMaxIdx && type != STOCHASTIC && kernel_size.size() > 1 && (kernel_size.size() != 3 || !isArmComputePlugin());
|
||||
return type != STOCHASTIC && kernel_size.size() > 1 && (kernel_size.size() != 3 || !isArmComputePlugin());
|
||||
}
|
||||
#endif
|
||||
if (backendId == DNN_BACKEND_OPENCV)
|
||||
@ -613,9 +613,17 @@ public:
|
||||
return Ptr<BackendNode>(new InfEngineNgraphNode(reduce_sum));
|
||||
}
|
||||
else if (type == MAX) {
|
||||
auto max_pool = std::make_shared<ngraph::op::v1::MaxPool>(ieInpNode, ngraph::Strides(strides),
|
||||
std::shared_ptr<ngraph::Node> max_pool;
|
||||
if (computeMaxIdx) {
|
||||
std::vector<size_t> dilations(kernel_size.size(), 1);
|
||||
max_pool = std::make_shared<ngraph::op::v8::MaxPool>(ieInpNode, ngraph::Strides(strides), ngraph::Strides(dilations),
|
||||
ngraph::Shape(pads_begin), ngraph::Shape(pads_end), ngraph::Shape(kernel_size),
|
||||
rounding_type, pad_type);
|
||||
} else {
|
||||
max_pool = std::make_shared<ngraph::op::v1::MaxPool>(ieInpNode, ngraph::Strides(strides),
|
||||
ngraph::Shape(pads_begin), ngraph::Shape(pads_end), ngraph::Shape(kernel_size),
|
||||
rounding_type, pad_type);
|
||||
}
|
||||
return Ptr<BackendNode>(new InfEngineNgraphNode(max_pool));
|
||||
}
|
||||
else if (type == ROI) {
|
||||
|
@ -410,7 +410,10 @@ public:
|
||||
}
|
||||
attrs.shape_calculation_mode = ngraph::op::v4::Interpolate::ShapeCalcMode::SIZES;
|
||||
|
||||
if (alignCorners) {
|
||||
CV_Assert(!halfPixelCenters || !alignCorners);
|
||||
if (halfPixelCenters) {
|
||||
attrs.coordinate_transformation_mode = ngraph::op::v4::Interpolate::CoordinateTransformMode::HALF_PIXEL;
|
||||
} else if (alignCorners) {
|
||||
attrs.coordinate_transformation_mode = ngraph::op::v4::Interpolate::CoordinateTransformMode::ALIGN_CORNERS;
|
||||
}
|
||||
|
||||
@ -427,7 +430,10 @@ public:
|
||||
}
|
||||
attrs.shape_calculation_mode = ngraph::op::v4::Interpolate::ShapeCalcMode::sizes;
|
||||
|
||||
if (alignCorners) {
|
||||
CV_Assert(!halfPixelCenters || !alignCorners);
|
||||
if (halfPixelCenters) {
|
||||
attrs.coordinate_transformation_mode = ngraph::op::v4::Interpolate::CoordinateTransformMode::half_pixel;
|
||||
} else if (alignCorners) {
|
||||
attrs.coordinate_transformation_mode = ngraph::op::v4::Interpolate::CoordinateTransformMode::align_corners;
|
||||
}
|
||||
|
||||
|
@ -476,13 +476,14 @@ void NetImplOpenVINO::initBackend(const std::vector<LayerPin>& blobsToKeep_)
|
||||
{
|
||||
int lid = ld.inputBlobsId[i].lid;
|
||||
int oid = ld.inputBlobsId[i].oid;
|
||||
if (oid == 0 || lid == 0)
|
||||
continue;
|
||||
|
||||
auto ieInpNode = inputNodes[i].dynamicCast<InfEngineNgraphNode>();
|
||||
const auto& ngraph_input_node = ieInpNode->node;
|
||||
CV_LOG_DEBUG(NULL, "DNN/IE: bind output port " << lid << ":" << oid << " (" << ngraph_input_node->get_friendly_name() << ":" << ngraph_input_node->get_type_info().name << ")");
|
||||
|
||||
if ((oid == 0 && ngraph_input_node->get_output_size() == 1) || lid == 0)
|
||||
continue;
|
||||
|
||||
// Handle parameters from other subnets. Output port is not used in this case
|
||||
#if INF_ENGINE_VER_MAJOR_GT(INF_ENGINE_RELEASE_2020_4)
|
||||
if ((ngraph::op::is_parameter(ngraph_input_node) || ngraph::op::is_constant(ngraph_input_node)) &&
|
||||
|
@ -732,16 +732,9 @@ TEST_P(Test_Caffe_nets, FasterRCNN_vgg16)
|
||||
#endif
|
||||
|
||||
double scoreDiff = 0.0, iouDiff = 0.0;
|
||||
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_EQ(2022010000)
|
||||
// Check 'backward_compatible_check || in_out_elements_equal' failed at core/src/op/reshape.cpp:427:
|
||||
// While validating node 'v1::Reshape bbox_pred_reshape (bbox_pred[0]:f32{1,84}, Constant_265242[0]:i64{4}) -> (f32{?,?,?,?})' with friendly_name 'bbox_pred_reshape':
|
||||
// Requested output shape {1,6300,4,1} is incompatible with input shape {1, 84}
|
||||
#if defined(INF_ENGINE_RELEASE)
|
||||
if (target == DNN_TARGET_MYRIAD)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH, CV_TEST_TAG_DNN_SKIP_IE_VERSION);
|
||||
if (target == DNN_TARGET_OPENCL_FP16)
|
||||
scoreDiff = 0.02;
|
||||
#endif
|
||||
#if defined(INF_ENGINE_RELEASE)
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) {
|
||||
iouDiff = 0.02;
|
||||
if (target == DNN_TARGET_OPENCL || target == DNN_TARGET_OPENCL_FP16) {
|
||||
|
@ -102,11 +102,14 @@ TEST(Test_Darknet, read_yolo_voc_stream)
|
||||
class Test_Darknet_layers : public DNNTestLayer
|
||||
{
|
||||
public:
|
||||
void testDarknetLayer(const std::string& name, bool hasWeights = false, bool testBatchProcessing = true)
|
||||
void testDarknetLayer(const std::string& name, bool hasWeights = false, bool testBatchProcessing = true,
|
||||
double l1 = 0.0, double lInf = 0.0)
|
||||
{
|
||||
SCOPED_TRACE(name);
|
||||
Mat inp = blobFromNPY(findDataFile("dnn/darknet/" + name + "_in.npy"));
|
||||
Mat ref = blobFromNPY(findDataFile("dnn/darknet/" + name + "_out.npy"));
|
||||
l1 = l1 ? l1 : default_l1;
|
||||
lInf = lInf ? lInf : default_lInf;
|
||||
|
||||
std::string cfg = findDataFile("dnn/darknet/" + name + ".cfg");
|
||||
std::string model = "";
|
||||
@ -120,7 +123,7 @@ public:
|
||||
net.setPreferableTarget(target);
|
||||
net.setInput(inp);
|
||||
Mat out = net.forward();
|
||||
normAssert(out, ref, "", default_l1, default_lInf);
|
||||
normAssert(out, ref, "", l1, lInf);
|
||||
|
||||
if (inp.size[0] == 1 && testBatchProcessing) // test handling of batch size
|
||||
{
|
||||
@ -166,8 +169,8 @@ public:
|
||||
}*/
|
||||
ASSERT_EQ(out2.dims, ref2.dims) << ref.dims;
|
||||
|
||||
normAssert(out2(ranges0), ref2, "", default_l1, default_lInf);
|
||||
normAssert(out2(ranges1), ref2, "", default_l1, default_lInf);
|
||||
normAssert(out2(ranges0), ref2, "", l1, lInf);
|
||||
normAssert(out2(ranges1), ref2, "", l1, lInf);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -1116,7 +1119,12 @@ TEST_P(Test_Darknet_layers, connected)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
|
||||
if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_CPU_FP16)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_CPU_FP16);
|
||||
testDarknetLayer("connected", true);
|
||||
double l1 = 0.0;
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && target == DNN_TARGET_OPENCL)
|
||||
{
|
||||
l1 = 3e-5;
|
||||
}
|
||||
testDarknetLayer("connected", true, true, l1);
|
||||
}
|
||||
|
||||
TEST_P(Test_Darknet_layers, relu)
|
||||
|
@ -361,22 +361,9 @@ TEST_P(MaxPooling, Accuracy)
|
||||
Backend backendId = get<0>(get<5>(GetParam()));
|
||||
Target targetId = get<1>(get<5>(GetParam()));
|
||||
|
||||
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LE(2018050000)
|
||||
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && targetId == DNN_TARGET_MYRIAD
|
||||
&& inSize == Size(7, 6) && kernel == Size(3, 2)
|
||||
&& (stride == Size(1, 1) || stride == Size(2, 2))
|
||||
&& (pad == Size(0, 1) || pad == Size(1, 1))
|
||||
)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_VERSION);
|
||||
#endif
|
||||
|
||||
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_EQ(2018050000)
|
||||
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && targetId == DNN_TARGET_MYRIAD
|
||||
&& (kernel == Size(2, 2) || kernel == Size(3, 2))
|
||||
&& stride == Size(1, 1) && (pad == Size(0, 0) || pad == Size(0, 1))
|
||||
)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_VERSION);
|
||||
#endif
|
||||
// https://github.com/openvinotoolkit/openvino/issues/18731
|
||||
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && stride != Size(1, 1))
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
|
||||
|
||||
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_GE(2019010000)
|
||||
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && targetId == DNN_TARGET_MYRIAD
|
||||
@ -467,6 +454,11 @@ TEST_P(FullyConnected, Accuracy)
|
||||
{
|
||||
l1 = 0.01;
|
||||
}
|
||||
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && targetId == DNN_TARGET_OPENCL)
|
||||
{
|
||||
l1 = 5e-3;
|
||||
lInf = 7e-3;
|
||||
}
|
||||
#endif
|
||||
if (targetId == DNN_TARGET_CUDA_FP16)
|
||||
l1 = 0.015;
|
||||
|
@ -215,7 +215,13 @@ TEST_P(Test_Caffe_layers, InnerProduct)
|
||||
if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_CPU_FP16)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_CPU_FP16);
|
||||
|
||||
testLayerUsingCaffeModels("layer_inner_product", true);
|
||||
double l1 = 0.0, lInf = 0.0;
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && (target == DNN_TARGET_OPENCL || target == DNN_TARGET_OPENCL_FP16))
|
||||
{
|
||||
l1 = 5e-3;
|
||||
lInf = 2e-2;
|
||||
}
|
||||
testLayerUsingCaffeModels("layer_inner_product", true, true, l1, lInf);
|
||||
}
|
||||
|
||||
TEST_P(Test_Caffe_layers, Pooling_max)
|
||||
|
@ -52,7 +52,7 @@ public:
|
||||
}
|
||||
|
||||
void testONNXModels(const String& basename, const Extension ext = npy,
|
||||
const double l1 = 0, const float lInf = 0, const bool useSoftmax = false,
|
||||
double l1 = 0, double lInf = 0, const bool useSoftmax = false,
|
||||
bool checkNoFallbacks = true, int numInps = 1)
|
||||
{
|
||||
String onnxmodel = _tf("models/" + basename + ".onnx", required);
|
||||
@ -102,6 +102,11 @@ public:
|
||||
netSoftmax.setInput(ref);
|
||||
ref = netSoftmax.forward();
|
||||
}
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && target == DNN_TARGET_OPENCL)
|
||||
{
|
||||
l1 = std::max(l1, 1.4e-3);
|
||||
lInf = std::max(lInf, 8e-3);
|
||||
}
|
||||
normAssert(ref, out, basename.c_str(), l1 ? l1 : default_l1, lInf ? lInf : default_lInf);
|
||||
if (checkNoFallbacks)
|
||||
expectNoFallbacksFromIE(net);
|
||||
|
@ -1818,8 +1818,8 @@ TEST_P(Test_TensorFlow_nets, Mask_RCNN)
|
||||
double iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD || target == DNN_TARGET_CPU_FP16) ? 0.018 : default_lInf;
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
|
||||
{
|
||||
scoreDiff = std::max(scoreDiff, 0.02);
|
||||
iouDiff = std::max(iouDiff, 0.009);
|
||||
scoreDiff = std::max(scoreDiff, 0.06);
|
||||
iouDiff = std::max(iouDiff, 0.01);
|
||||
}
|
||||
normAssertDetections(refDetections, outDetections, "", /*threshold for zero confidence*/1e-5, scoreDiff, iouDiff);
|
||||
|
||||
|
@ -20,6 +20,14 @@ namespace opencv_test { namespace {
|
||||
using namespace cv;
|
||||
using namespace cv::dnn;
|
||||
|
||||
class Test_TFLite : public DNNTestLayer {
|
||||
public:
|
||||
void testModel(Net& net, const std::string& modelName, const Mat& input, double l1 = 0, double lInf = 0);
|
||||
void testModel(const std::string& modelName, const Mat& input, double l1 = 0, double lInf = 0);
|
||||
void testModel(const std::string& modelName, const Size& inpSize, double l1 = 0, double lInf = 0);
|
||||
void testLayer(const std::string& modelName, double l1 = 0, double lInf = 0);
|
||||
};
|
||||
|
||||
void testInputShapes(const Net& net, const std::vector<Mat>& inps) {
|
||||
std::vector<MatShape> inLayerShapes;
|
||||
std::vector<MatShape> outLayerShapes;
|
||||
@ -31,8 +39,14 @@ void testInputShapes(const Net& net, const std::vector<Mat>& inps) {
|
||||
}
|
||||
}
|
||||
|
||||
void testModel(Net& net, const std::string& modelName, const Mat& input, double l1 = 1e-5, double lInf = 1e-4)
|
||||
void Test_TFLite::testModel(Net& net, const std::string& modelName, const Mat& input, double l1, double lInf)
|
||||
{
|
||||
l1 = l1 ? l1 : default_l1;
|
||||
lInf = lInf ? lInf : default_lInf;
|
||||
|
||||
net.setPreferableBackend(backend);
|
||||
net.setPreferableTarget(target);
|
||||
|
||||
testInputShapes(net, {input});
|
||||
net.setInput(input);
|
||||
|
||||
@ -48,20 +62,20 @@ void testModel(Net& net, const std::string& modelName, const Mat& input, double
|
||||
}
|
||||
}
|
||||
|
||||
void testModel(const std::string& modelName, const Mat& input, double l1 = 1e-5, double lInf = 1e-4)
|
||||
void Test_TFLite::testModel(const std::string& modelName, const Mat& input, double l1, double lInf)
|
||||
{
|
||||
Net net = readNet(findDataFile("dnn/tflite/" + modelName + ".tflite", false));
|
||||
testModel(net, modelName, input, l1, lInf);
|
||||
}
|
||||
|
||||
void testModel(const std::string& modelName, const Size& inpSize, double l1 = 1e-5, double lInf = 1e-4)
|
||||
void Test_TFLite::testModel(const std::string& modelName, const Size& inpSize, double l1, double lInf)
|
||||
{
|
||||
Mat input = imread(findDataFile("cv/shared/lena.png"));
|
||||
input = blobFromImage(input, 1.0 / 255, inpSize, 0, true);
|
||||
testModel(modelName, input, l1, lInf);
|
||||
}
|
||||
|
||||
void testLayer(const std::string& modelName, double l1 = 1e-5, double lInf = 1e-4)
|
||||
void Test_TFLite::testLayer(const std::string& modelName, double l1, double lInf)
|
||||
{
|
||||
Mat inp = blobFromNPY(findDataFile("dnn/tflite/" + modelName + "_inp.npy"));
|
||||
Net net = readNet(findDataFile("dnn/tflite/" + modelName + ".tflite"));
|
||||
@ -69,29 +83,66 @@ void testLayer(const std::string& modelName, double l1 = 1e-5, double lInf = 1e-
|
||||
}
|
||||
|
||||
// https://google.github.io/mediapipe/solutions/face_mesh
|
||||
TEST(Test_TFLite, face_landmark)
|
||||
TEST_P(Test_TFLite, face_landmark)
|
||||
{
|
||||
testModel("face_landmark", Size(192, 192), 2e-5, 2e-4);
|
||||
if (backend == DNN_BACKEND_CUDA && target == DNN_TARGET_CUDA_FP16)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA_FP16);
|
||||
double l1 = 2e-5, lInf = 2e-4;
|
||||
if (target == DNN_TARGET_CPU_FP16 || target == DNN_TARGET_CUDA_FP16 || target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD ||
|
||||
(backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && target == DNN_TARGET_OPENCL))
|
||||
{
|
||||
l1 = 0.15;
|
||||
lInf = 0.82;
|
||||
}
|
||||
testModel("face_landmark", Size(192, 192), l1, lInf);
|
||||
}
|
||||
|
||||
// https://google.github.io/mediapipe/solutions/face_detection
|
||||
TEST(Test_TFLite, face_detection_short_range)
|
||||
TEST_P(Test_TFLite, face_detection_short_range)
|
||||
{
|
||||
testModel("face_detection_short_range", Size(128, 128));
|
||||
double l1 = 0, lInf = 2e-4;
|
||||
if (target == DNN_TARGET_CPU_FP16 || target == DNN_TARGET_CUDA_FP16 || target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD ||
|
||||
(backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && target == DNN_TARGET_OPENCL))
|
||||
{
|
||||
l1 = 0.04;
|
||||
lInf = 0.8;
|
||||
}
|
||||
testModel("face_detection_short_range", Size(128, 128), l1, lInf);
|
||||
}
|
||||
|
||||
// https://google.github.io/mediapipe/solutions/selfie_segmentation
|
||||
TEST(Test_TFLite, selfie_segmentation)
|
||||
TEST_P(Test_TFLite, selfie_segmentation)
|
||||
{
|
||||
testModel("selfie_segmentation", Size(256, 256));
|
||||
double l1 = 0, lInf = 0;
|
||||
if (target == DNN_TARGET_CPU_FP16 || target == DNN_TARGET_CUDA_FP16 || target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD ||
|
||||
(backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && target == DNN_TARGET_OPENCL))
|
||||
{
|
||||
l1 = 0.01;
|
||||
lInf = 0.48;
|
||||
}
|
||||
testModel("selfie_segmentation", Size(256, 256), l1, lInf);
|
||||
}
|
||||
|
||||
TEST(Test_TFLite, max_unpooling)
|
||||
TEST_P(Test_TFLite, max_unpooling)
|
||||
{
|
||||
if (backend == DNN_BACKEND_CUDA)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA);
|
||||
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && target != DNN_TARGET_CPU) {
|
||||
if (target == DNN_TARGET_OPENCL_FP16) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL_FP16, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
|
||||
if (target == DNN_TARGET_OPENCL) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
|
||||
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
|
||||
}
|
||||
|
||||
if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
|
||||
|
||||
// Due Max Unpoling is a numerically unstable operation and small difference between frameworks
|
||||
// might lead to positional difference of maximal elements in the tensor, this test checks
|
||||
// behavior of Max Unpooling layer only.
|
||||
Net net = readNet(findDataFile("dnn/tflite/hair_segmentation.tflite", false));
|
||||
net.setPreferableBackend(backend);
|
||||
net.setPreferableTarget(target);
|
||||
|
||||
Mat input = imread(findDataFile("cv/shared/lena.png"));
|
||||
cvtColor(input, input, COLOR_BGR2RGBA);
|
||||
@ -101,7 +152,15 @@ TEST(Test_TFLite, max_unpooling)
|
||||
net.setInput(input);
|
||||
|
||||
std::vector<std::vector<Mat> > outs;
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) {
|
||||
// TODO: seems like a bug with a retrieving intermediate tensors
|
||||
net.forward(outs, {"conv2d_transpose_4", "p_re_lu_1", "max_pooling_with_argmax2d", "conv2d_86", "max_unpooling2d_2"});
|
||||
outs.erase(outs.begin());
|
||||
}
|
||||
else {
|
||||
net.forward(outs, {"p_re_lu_1", "max_pooling_with_argmax2d", "conv2d_86", "max_unpooling2d_2"});
|
||||
}
|
||||
|
||||
ASSERT_EQ(outs.size(), 4);
|
||||
ASSERT_EQ(outs[0].size(), 1);
|
||||
ASSERT_EQ(outs[1].size(), 2);
|
||||
@ -117,6 +176,8 @@ TEST(Test_TFLite, max_unpooling)
|
||||
ASSERT_EQ(poolOut.size, poolIds.size);
|
||||
ASSERT_EQ(poolOut.size, unpoolInp.size);
|
||||
|
||||
ASSERT_EQ(countNonZero(poolInp), poolInp.total());
|
||||
|
||||
for (int c = 0; c < 32; ++c) {
|
||||
float *poolInpData = poolInp.ptr<float>(0, c);
|
||||
float *poolOutData = poolOut.ptr<float>(0, c);
|
||||
@ -135,15 +196,19 @@ TEST(Test_TFLite, max_unpooling)
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(poolInpData[maxIdx], poolOutData[y * 64 + x]) << errMsg;
|
||||
if (backend != DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) {
|
||||
EXPECT_EQ(poolIdsData[y * 64 + x], (float)maxIdx) << errMsg;
|
||||
}
|
||||
EXPECT_EQ(unpoolOutData[maxIdx], unpoolInpData[y * 64 + x]) << errMsg;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Test_TFLite, EfficientDet_int8) {
|
||||
TEST_P(Test_TFLite, EfficientDet_int8) {
|
||||
Net net = readNet(findDataFile("dnn/tflite/coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", false));
|
||||
net.setPreferableBackend(backend);
|
||||
net.setPreferableTarget(target);
|
||||
|
||||
Mat img = imread(findDataFile("dnn/dog416.png"));
|
||||
Mat blob = blobFromImage(img, 1.0, Size(320, 320));
|
||||
@ -158,10 +223,18 @@ TEST(Test_TFLite, EfficientDet_int8) {
|
||||
normAssertDetections(ref, out, "", 0.5, 0.05, 0.1);
|
||||
}
|
||||
|
||||
TEST(Test_TFLite, replicate_by_pack) {
|
||||
testLayer("replicate_by_pack");
|
||||
TEST_P(Test_TFLite, replicate_by_pack) {
|
||||
double l1 = 0, lInf = 0;
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && target == DNN_TARGET_OPENCL)
|
||||
{
|
||||
l1 = 4e-4;
|
||||
lInf = 2e-3;
|
||||
}
|
||||
testLayer("replicate_by_pack", l1, lInf);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets());
|
||||
|
||||
}} // namespace
|
||||
|
||||
#endif // OPENCV_TEST_DNN_TFLITE
|
||||
|
Loading…
Reference in New Issue
Block a user