diff --git a/modules/dnn/include/opencv2/dnn/dnn.hpp b/modules/dnn/include/opencv2/dnn/dnn.hpp index b08dd8acb7..6331353140 100644 --- a/modules/dnn/include/opencv2/dnn/dnn.hpp +++ b/modules/dnn/include/opencv2/dnn/dnn.hpp @@ -1051,17 +1051,24 @@ CV__DNN_INLINE_NS_BEGIN /** @brief Reads a network model stored in Caffe framework's format. * @param prototxt path to the .prototxt file with text description of the network architecture. * @param caffeModel path to the .caffemodel file with learned network. + * @param engine select DNN engine to be used. With auto selection the new engine is used. + * Please pay attention that the new DNN does not support non-CPU back-ends for now. * @returns Net object. */ - CV_EXPORTS_W Net readNetFromCaffe(CV_WRAP_FILE_PATH const String &prototxt, CV_WRAP_FILE_PATH const String &caffeModel = String()); + CV_EXPORTS_W Net readNetFromCaffe(CV_WRAP_FILE_PATH const String &prototxt, + CV_WRAP_FILE_PATH const String &caffeModel = String(), + int engine = ENGINE_AUTO); /** @brief Reads a network model stored in Caffe model in memory. * @param bufferProto buffer containing the content of the .prototxt file * @param bufferModel buffer containing the content of the .caffemodel file + * @param engine select DNN engine to be used. With auto selection the new engine is used. + * Please pay attention that the new DNN does not support non-CPU back-ends for now. * @returns Net object. */ CV_EXPORTS_W Net readNetFromCaffe(const std::vector& bufferProto, - const std::vector& bufferModel = std::vector()); + const std::vector& bufferModel = std::vector(), + int engine = ENGINE_AUTO); /** @brief Reads a network model stored in Caffe model in memory. * @details This is an overloaded member function, provided for convenience. @@ -1070,10 +1077,13 @@ CV__DNN_INLINE_NS_BEGIN * @param lenProto length of bufferProto * @param bufferModel buffer containing the content of the .caffemodel file * @param lenModel length of bufferModel + * @param engine select DNN engine to be used. With auto selection the new engine is used. + * Please pay attention that the new DNN does not support non-CPU back-ends for now. * @returns Net object. */ CV_EXPORTS Net readNetFromCaffe(const char *bufferProto, size_t lenProto, - const char *bufferModel = NULL, size_t lenModel = 0); + const char *bufferModel = NULL, size_t lenModel = 0, + int engine = ENGINE_AUTO); /** @brief Reads a network model stored in TensorFlow framework's format. * @param model path to the .pb file with binary protobuf description of the network architecture diff --git a/modules/dnn/misc/objc/gen_dict.json b/modules/dnn/misc/objc/gen_dict.json index 166a544735..fc426efab5 100644 --- a/modules/dnn/misc/objc/gen_dict.json +++ b/modules/dnn/misc/objc/gen_dict.json @@ -1,8 +1,8 @@ { "func_arg_fix" : { "Dnn": { - "(Net*)readNetFromCaffe:(NSString*)prototxt caffeModel:(NSString*)caffeModel" : { "readNetFromCaffe" : {"name" : "readNetFromCaffeFile"} }, - "(Net*)readNetFromCaffe:(ByteVector*)bufferProto bufferModel:(ByteVector*)bufferModel" : { "readNetFromCaffe" : {"name" : "readNetFromCaffeBuffer"} }, + "(Net*)readNetFromCaffe:(NSString*)prototxt caffeModel:(NSString*)caffeModel engine:(int)engine" : { "readNetFromCaffe" : {"name" : "readNetFromCaffeFile"} }, + "(Net*)readNetFromCaffe:(ByteVector*)bufferProto bufferModel:(ByteVector*)bufferModel engine:(int)engine" : { "readNetFromCaffe" : {"name" : "readNetFromCaffeBuffer"} }, "(Net*)readNetFromDarknet:(NSString*)cfgFile darknetModel:(NSString*)darknetModel" : { "readNetFromDarknet" : {"name" : "readNetFromDarknetFile"} }, "(Net*)readNetFromDarknet:(ByteVector*)bufferCfg bufferModel:(ByteVector*)bufferModel" : { "readNetFromDarknet" : {"name" : "readNetFromDarknetBuffer"} }, "(Net*)readNetFromONNX:(NSString*)onnxFile engine:(int)engine" : { "readNetFromONNX" : {"name" : "readNetFromONNXFile"} }, diff --git a/modules/dnn/misc/python/test/test_dnn.py b/modules/dnn/misc/python/test/test_dnn.py index d55b810e9b..7343b5b6be 100755 --- a/modules/dnn/misc/python/test/test_dnn.py +++ b/modules/dnn/misc/python/test/test_dnn.py @@ -112,7 +112,7 @@ class dnn_test(NewOpenCVTests): def checkIETarget(self, backend, target): proto = self.find_dnn_file('dnn/layers/layer_convolution.prototxt') model = self.find_dnn_file('dnn/layers/layer_convolution.caffemodel') - net = cv.dnn.readNet(proto, model) + net = cv.dnn.readNet(proto, model, engine=cv.dnn.ENGINE_CLASSIC) try: net.setPreferableBackend(backend) net.setPreferableTarget(target) @@ -324,6 +324,9 @@ class dnn_test(NewOpenCVTests): testScores, testBoxes, 0.5) def test_async(self): + # bug: https://github.com/opencv/opencv/issues/26376 + raise unittest.SkipTest("The new dnn engine does not support async inference") + timeout = 10*1000*10**6 # in nanoseconds (10 sec) proto = self.find_dnn_file('dnn/layers/layer_convolution.prototxt') model = self.find_dnn_file('dnn/layers/layer_convolution.caffemodel') @@ -337,7 +340,7 @@ class dnn_test(NewOpenCVTests): printParams(backend, target) - netSync = cv.dnn.readNet(proto, model) + netSync = cv.dnn.readNet(proto, model, engine=cv.dnn.ENGINE_CLASSIC) netSync.setPreferableBackend(backend) netSync.setPreferableTarget(target) @@ -463,7 +466,7 @@ class dnn_test(NewOpenCVTests): for backend, target in self.dnnBackendsAndTargets: printParams(backend, target) - net = cv.dnn.readNet(model) + net = cv.dnn.readNet(model, engine=cv.dnn.ENGINE_CLASSIC) net.setPreferableBackend(backend) net.setPreferableTarget(target) diff --git a/modules/dnn/src/caffe/caffe_importer.cpp b/modules/dnn/src/caffe/caffe_importer.cpp index 50e1fbe93f..eab8acbf64 100644 --- a/modules/dnn/src/caffe/caffe_importer.cpp +++ b/modules/dnn/src/caffe/caffe_importer.cpp @@ -40,6 +40,7 @@ //M*/ #include "../precomp.hpp" +#include "../net_impl.hpp" #ifdef HAVE_PROTOBUF #include @@ -53,8 +54,10 @@ #include "caffe_io.hpp" #endif +#include #include + namespace cv { namespace dnn { CV__DNN_INLINE_NS_BEGIN @@ -320,6 +323,30 @@ public: } } + Ptr addLayer(Net& dstNet, + const String& type, + const String& name, + LayerParams& layerParams, + const std::vector& inputs, + const std::vector& outputs) + { + layerParams.type = type; + layerParams.name = name; + Ptr layer = LayerFactory::createLayerInstance(type, layerParams); + if (!layer) { + CV_Error(Error::StsError, "Can't create layer " + name + " with type " + type); + return nullptr; + } + + for (const String& inputName : inputs) + layer->inputs.push_back(dstNet.getArg(inputName)); + for (const String& outputName : outputs) + layer->outputs.push_back(dstNet.getArg(outputName)); + layer->netimpl = dstNet.getImpl(); + CV_Assert(dstNet.getImpl()->dump_indent == 3); + return layer; + } + struct BlobNote { BlobNote(const std::string &_name, int _layerId, int _outNum) : @@ -332,24 +359,41 @@ public: std::vector addedBlobs; std::map layerCounter; - void populateNet(Net dstNet) + void populateNet(Net dstNet, bool newEngine) { CV_TRACE_FUNCTION(); int layersSize = net.layer_size(); layerCounter.clear(); - addedBlobs.clear(); - addedBlobs.reserve(layersSize + 1); - //setup input layer names + // OLD ENGINE + if(!newEngine) + { + addedBlobs.clear(); + addedBlobs.reserve(layersSize + 1); + } std::vector netInputs(net.input_size()); std::vector inp_shapes; + + // NEW ENGINE + Net::Impl* netImpl = dstNet.getImpl(); + std::vector> curr_prog; + std::vector modelInputs, modelOutputs; + { int net_input_size = net.input_size(); for (int inNum = 0; inNum < net_input_size; inNum++) { - addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum)); - netInputs[inNum] = net.input(inNum); + if (newEngine) + { + modelInputs.push_back(netImpl->newArg(net.input(inNum), DNN_ARG_INPUT)); + netImpl->args.at(modelInputs.back().idx).type = CV_32F; + } + else + { + addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum)); + netInputs[inNum] = net.input(inNum); + } } if (net.input_dim_size() > 0) // deprecated in Caffe proto @@ -365,7 +409,10 @@ public: shape[1] = net.input_dim(dim+1); shape[2] = net.input_dim(dim+2); shape[3] = net.input_dim(dim+3); - inp_shapes.push_back(shape); + if (newEngine) + netImpl->args.at(modelInputs[inp_id].idx).shape = shape; + else + inp_shapes.push_back(shape); } } else if (net.input_shape_size() > 0) // deprecated in Caffe proto @@ -375,7 +422,10 @@ public: for (int inp_id = 0; inp_id < net_input_shape_size; inp_id++) { MatShape shape = parseBlobShape(net.input_shape(inp_id)); - inp_shapes.push_back(shape); + if (newEngine) + netImpl->args.at(modelInputs[inp_id].idx).shape = shape; + else + inp_shapes.push_back(shape); } } else @@ -383,11 +433,20 @@ public: for (int inp_id = 0; inp_id < net_input_size; inp_id++) { MatShape shape; // empty - inp_shapes.push_back(shape); + if (newEngine) + netImpl->args.at(modelInputs[inp_id].idx).shape = shape; + else + inp_shapes.push_back(shape); } } } + if (newEngine && net.layer(layersSize - 1).type() == "Silence") + { + CV_LOG_WARNING(NULL, "Caffe parser: Silence layer was ignored"); + layersSize--; + } + for (int li = 0; li < layersSize; li++) { const caffe::LayerParameter &layer = net.layer(li); @@ -398,6 +457,12 @@ public: extractLayerParams(layer, layerParams); extractBinaryLayerParams(layer, layerParams); + if (newEngine && li == layersSize - 1) + { + for (int outNum = 0; outNum < layer.top_size(); outNum++) + modelOutputs.push_back(netImpl->newArg(layer.top(outNum), DNN_ARG_OUTPUT)); + } + int repetitions = layerCounter[name]++; if (repetitions) name += String("_") + toString(repetitions); @@ -406,9 +471,17 @@ public: { for (int outNum = 0; outNum < layer.top_size(); outNum++) { - addOutput(layer, 0, outNum); - addedBlobs.back().outNum = netInputs.size(); - netInputs.push_back(addedBlobs.back().name); + if (newEngine) + { + modelInputs.push_back(netImpl->newArg(layer.top(outNum), DNN_ARG_INPUT)); + netImpl->args.at(modelInputs.back().idx).type = CV_32F; + } + else + { + addOutput(layer, 0, outNum); + addedBlobs.back().outNum = netInputs.size(); + netInputs.push_back(addedBlobs.back().name); + } } if (layer.has_input_param()) { @@ -418,7 +491,15 @@ public: for (int inp_id = 0; inp_id < input_shape_size; inp_id++) { MatShape shape = parseBlobShape(inputParameter.shape(inp_id)); - inp_shapes.push_back(shape); + if (newEngine) + { + int inputIdx = modelInputs.size() - input_shape_size + inp_id; + netImpl->args.at(modelInputs[inputIdx].idx).shape = shape; + } + else + { + inp_shapes.push_back(shape); + } } } continue; @@ -437,12 +518,24 @@ public: if (repetitions) mvnName += String("_") + toString(repetitions); - int mvnId = dstNet.addLayer(mvnName, "MVN", mvnParams); - addInput(layer.bottom(0), mvnId, 0, dstNet); - addOutput(layer, mvnId, 0); - net.mutable_layer(li)->set_bottom(0, layer.top(0)); - layerParams.blobs[0].setTo(0); // mean - layerParams.blobs[1].setTo(1); // std + if (newEngine) + { + Ptr netLayer = addLayer( + dstNet, "MVN", mvnName, mvnParams, + {layer.bottom(0)}, + {layer.top(0)}); + curr_prog.push_back(netLayer); + continue; + } + else + { + int mvnId = dstNet.addLayer(mvnName, "MVN", mvnParams); + addInput(layer.bottom(0), mvnId, 0, dstNet); + addOutput(layer, mvnId, 0); + net.mutable_layer(li)->set_bottom(0, layer.top(0)); + layerParams.blobs[0].setTo(0); // mean + layerParams.blobs[1].setTo(1); // std + } } } else if (type == "Axpy") @@ -458,13 +551,34 @@ public: LayerParams scaleParams; scaleParams.set("axis", 1); scaleParams.set("has_bias", false); - int scaleId = dstNet.addLayer(scaleName, "Scale", scaleParams); - addInput(layer.bottom(2), scaleId, 0, dstNet); - addInput(layer.bottom(0), scaleId, 1, dstNet); - addOutput(layer, scaleId, 0); - net.mutable_layer(li)->set_bottom(0, layer.top(0)); - net.mutable_layer(li)->mutable_bottom()->RemoveLast(); - type = "Eltwise"; + + if (newEngine) + { + std::string intermediateTensor = scaleName + "_intermediate_output"; + Ptr netLayerScale= addLayer( + dstNet, "Scale", scaleName, scaleParams, + {layer.bottom(2), layer.bottom(0)}, + {intermediateTensor}); + curr_prog.push_back(netLayerScale); + + LayerParams eltwiseParams; + Ptr netLayerEltwise = addLayer( + dstNet, "Eltwise", name, eltwiseParams, + {intermediateTensor, layer.bottom(1)}, + {layer.top(0)}); + curr_prog.push_back(netLayerEltwise); + continue; + } + else + { + int scaleId = dstNet.addLayer(scaleName, "Scale", scaleParams); + addInput(layer.bottom(2), scaleId, 0, dstNet); + addInput(layer.bottom(0), scaleId, 1, dstNet); + addOutput(layer, scaleId, 0); + net.mutable_layer(li)->set_bottom(0, layer.top(0)); + net.mutable_layer(li)->mutable_bottom()->RemoveLast(); + type = "Eltwise"; + } } else if (type == "Resample") { @@ -489,9 +603,19 @@ public: CV_Assert(layer.bottom_size() == layer.top_size()); for (int i = 0; i < layer.bottom_size(); i++) { - int conv_id = dstNet.addLayer(layer.top(i), type, layerParams); - addInput(layer.bottom(i), conv_id, 0, dstNet); - addedBlobs.push_back(BlobNote(layer.top(i), conv_id, 0)); + if (newEngine) + { + Ptr netLayer = addLayer( + dstNet, type, layer.top(i), layerParams, + {layer.bottom(i)}, {layer.top(i)}); + curr_prog.push_back(netLayer); + } + else + { + int conv_id = dstNet.addLayer(layer.top(i), type, layerParams); + addInput(layer.bottom(i), conv_id, 0, dstNet); + addedBlobs.push_back(BlobNote(layer.top(i), conv_id, 0)); + } } continue; } @@ -504,25 +628,77 @@ public: if(!layerParams.has("axis")) layerParams.set("axis", 1); } + else if ("Proposal" == type) + { + if (newEngine && layer.top_size() == 1) + { + // Add unused optional second output and create the Proposal layer + std::vector layerInputs; + for (int inNum = 0; inNum < layer.bottom_size(); inNum++) + layerInputs.push_back(layer.bottom(inNum)); + Ptr netLayer = addLayer( + dstNet, type, name, layerParams, + layerInputs, {layer.top(0), name + "___output_scores"}); + curr_prog.push_back(netLayer); + continue; + } + } + else if ("Silence" == type) + { + if (newEngine) + { + CV_LOG_WARNING(NULL, "Caffe parser: Silence layer was ignored"); + continue; + } + } - int id = dstNet.addLayer(name, type, layerParams); + if (newEngine) + { + std::vector layerInputs, layerOutputs; + for (int inNum = 0; inNum < layer.bottom_size(); inNum++) + layerInputs.push_back(layer.bottom(inNum)); + for (int outNum = 0; outNum < layer.top_size(); outNum++) + layerOutputs.push_back(layer.top(outNum)); - for (int inNum = 0; inNum < layer.bottom_size(); inNum++) - addInput(layer.bottom(inNum), id, inNum, dstNet); - - for (int outNum = 0; outNum < layer.top_size(); outNum++) - addOutput(layer, id, outNum); + Ptr netLayer = addLayer( + dstNet, type, name, layerParams, + layerInputs, layerOutputs); + curr_prog.push_back(netLayer); + } + else + { + int id = dstNet.addLayer(name, type, layerParams); + for (int inNum = 0; inNum < layer.bottom_size(); inNum++) + addInput(layer.bottom(inNum), id, inNum, dstNet); + for (int outNum = 0; outNum < layer.top_size(); outNum++) + addOutput(layer, id, outNum); + } } - dstNet.setInputsNames(netInputs); - if (inp_shapes.size() > 0) + if (newEngine) { - CV_CheckEQ(inp_shapes.size(), netInputs.size(), ""); - for (int inp_id = 0; inp_id < inp_shapes.size(); inp_id++) - dstNet.setInputShape(netInputs[inp_id], inp_shapes[inp_id]); - } + Ptr curr_graph = netImpl->newGraph(net.name(), modelInputs, true); + curr_graph->setOutputs(modelOutputs); + curr_graph->setProg(curr_prog); - addedBlobs.clear(); + netImpl->mainGraph = curr_graph; + netImpl->modelFormat = DNN_MODEL_CAFFE; + netImpl->originalLayout = DATA_LAYOUT_NCHW; + netImpl->prepareForInference(); + } + else + { + dstNet.setInputsNames(netInputs); + + if (inp_shapes.size() > 0) + { + CV_CheckEQ(inp_shapes.size(), netInputs.size(), ""); + for (int inp_id = 0; inp_id < inp_shapes.size(); inp_id++) + dstNet.setInputShape(netInputs[inp_id], inp_shapes[inp_id]); + } + + addedBlobs.clear(); + } } void addOutput(const caffe::LayerParameter &layer, int layerId, int outNum) @@ -569,45 +745,59 @@ public: } -Net readNetFromCaffe(const String &prototxt, const String &caffeModel /*= String()*/) +Net readNetFromCaffe(const String &prototxt, + const String &caffeModel, /*= String()*/ + int engine) { + static const int engine_forced = (int)utils::getConfigurationParameterSizeT("OPENCV_FORCE_DNN_ENGINE", ENGINE_AUTO); + if(engine_forced != ENGINE_AUTO) + engine = engine_forced; + CaffeImporter caffeImporter(prototxt.c_str(), caffeModel.c_str()); Net net; - caffeImporter.populateNet(net); + caffeImporter.populateNet(net, engine == ENGINE_NEW || engine == ENGINE_AUTO); return net; } Net readNetFromCaffe(const char *bufferProto, size_t lenProto, - const char *bufferModel, size_t lenModel) + const char *bufferModel, size_t lenModel, + int engine) { + static const int engine_forced = (int)utils::getConfigurationParameterSizeT("OPENCV_FORCE_DNN_ENGINE", ENGINE_AUTO); + if(engine_forced != ENGINE_AUTO) + engine = engine_forced; + CaffeImporter caffeImporter(bufferProto, lenProto, bufferModel, lenModel); Net net; - caffeImporter.populateNet(net); + caffeImporter.populateNet(net, engine == ENGINE_NEW || engine == ENGINE_AUTO); return net; } -Net readNetFromCaffe(const std::vector& bufferProto, const std::vector& bufferModel) +Net readNetFromCaffe(const std::vector& bufferProto, + const std::vector& bufferModel, + int engine) { const char* bufferProtoPtr = reinterpret_cast(&bufferProto[0]); const char* bufferModelPtr = bufferModel.empty() ? NULL : reinterpret_cast(&bufferModel[0]); return readNetFromCaffe(bufferProtoPtr, bufferProto.size(), - bufferModelPtr, bufferModel.size()); + bufferModelPtr, bufferModel.size(), + engine); } #else // HAVE_PROTOBUF #define DNN_PROTOBUF_UNSUPPORTED() CV_Error(Error::StsError, "DNN/Caffe: Build OpenCV with Protobuf to import Caffe models") -Net readNetFromCaffe(const String &, const String &) { +Net readNetFromCaffe(const String &, const String &, int) { DNN_PROTOBUF_UNSUPPORTED(); } -Net readNetFromCaffe(const char *, size_t, const char *, size_t) { +Net readNetFromCaffe(const char *, size_t, const char *, size_t, int) { DNN_PROTOBUF_UNSUPPORTED(); } -Net readNetFromCaffe(const std::vector&, const std::vector&) { +Net readNetFromCaffe(const std::vector&, const std::vector&, int) { DNN_PROTOBUF_UNSUPPORTED(); } diff --git a/modules/dnn/src/dnn_read.cpp b/modules/dnn/src/dnn_read.cpp index ed42d57942..8aeb10ced8 100644 --- a/modules/dnn/src/dnn_read.cpp +++ b/modules/dnn/src/dnn_read.cpp @@ -21,7 +21,7 @@ Net readNet(const String& _model, const String& _config, const String& _framewor { if (modelExt == "prototxt" || configExt == "caffemodel") std::swap(model, config); - return readNetFromCaffe(config, model); + return readNetFromCaffe(config, model, engine); } if (framework == "tensorflow" || modelExt == "pb" || configExt == "pb" || modelExt == "pbtxt" || configExt == "pbtxt") { @@ -61,7 +61,7 @@ Net readNet(const String& _framework, const std::vector& bufferModel, if (framework == "onnx") return readNetFromONNX(bufferModel, engine); else if (framework == "caffe") - return readNetFromCaffe(bufferConfig, bufferModel); + return readNetFromCaffe(bufferConfig, bufferModel, engine); else if (framework == "tensorflow") return readNetFromTensorflow(bufferModel, bufferConfig); else if (framework == "darknet") diff --git a/modules/dnn/src/model.cpp b/modules/dnn/src/model.cpp index 951fa526d1..c36279e464 100644 --- a/modules/dnn/src/model.cpp +++ b/modules/dnn/src/model.cpp @@ -114,14 +114,18 @@ public: } Mat blob = dnn::blobFromImageWithParams(frame, param); // [1, 10, 10, 4] - net.setInput(blob); - // Faster-RCNN or R-FCN - if (!net.getMainGraph() && net.getLayer(0)->outputNameToIndex("im_info") != -1) + if ((net.getMainGraph() && net.haveArg("im_info") && net.argKind(net.getArg("im_info")) == DNN_ARG_INPUT) || + (!net.getMainGraph() && net.getLayer(0)->outputNameToIndex("im_info") != -1)) { + net.setInput(blob, "data"); Mat imInfo(Matx13f(size.height, size.width, 1.6f)); net.setInput(imInfo, "im_info"); } + else + { + net.setInput(blob); + } net.forward(outs, outNames); } @@ -507,7 +511,12 @@ void DetectionModel::detect(InputArray frame, CV_OUT std::vector& classIds, int frameWidth = frame.cols(); int frameHeight = frame.rows(); - if (getNetwork_().getLayer(0)->outputNameToIndex("im_info") != -1) + if ((getNetwork_().getMainGraph() && + getNetwork_().haveArg("im_info") && + getNetwork_().argKind(getNetwork_().getArg("im_info")) == DNN_ARG_INPUT) + || + (!getNetwork_().getMainGraph() && + getNetwork_().getLayer(0)->outputNameToIndex("im_info") != -1)) { frameWidth = impl->size.width; frameHeight = impl->size.height; diff --git a/modules/dnn/test/test_backends.cpp b/modules/dnn/test/test_backends.cpp index 65f4b80949..1bcc0c5db8 100644 --- a/modules/dnn/test/test_backends.cpp +++ b/modules/dnn/test/test_backends.cpp @@ -41,7 +41,13 @@ public: Net netDefault = readNet(weights, proto); netDefault.setPreferableBackend(DNN_BACKEND_OPENCV); netDefault.setInput(inp); - Mat outDefault = netDefault.forward(outputLayer).clone(); + + // BUG: https://github.com/opencv/opencv/issues/26349 + Mat outDefault; + if(netDefault.getMainGraph()) + outDefault = netDefault.forward().clone(); + else + outDefault = netDefault.forward(outputLayer).clone(); net = readNet(weights, proto); net.setInput(inp); @@ -51,7 +57,12 @@ public: if (target == DNN_TARGET_CPU_FP16) net.enableWinograd(false); - Mat out = net.forward(outputLayer).clone(); + // BUG: https://github.com/opencv/opencv/issues/26349 + Mat out; + if(net.getMainGraph()) + out = net.forward().clone(); + else + out = net.forward(outputLayer).clone(); check(outDefault, out, outputLayer, l1, lInf, detectionConfThresh, "First run"); @@ -65,8 +76,17 @@ public: } netDefault.setInput(inp); net.setInput(inp); - outDefault = netDefault.forward(outputLayer).clone(); - out = net.forward(outputLayer).clone(); + + if(netDefault.getMainGraph()) + outDefault = netDefault.forward().clone(); + else + outDefault = netDefault.forward(outputLayer).clone(); + + if(net.getMainGraph()) + out = net.forward().clone(); + else + out = net.forward(outputLayer).clone(); + check(outDefault, out, outputLayer, l1, lInf, detectionConfThresh, "Second run"); } @@ -514,9 +534,8 @@ TEST_P(DNNTestNetwork, FastNeuralStyle_eccv16) #if defined(HAVE_INF_ENGINE) && INF_ENGINE_VER_MAJOR_GE(2019010000) expectNoFallbacksFromIE(net); #endif - // BUG: https://github.com/opencv/opencv/issues/26306 - // Temporarily disabled check for no "fallbacks", since the new engine does not support CUDA yet - //expectNoFallbacksFromCUDA(net); + + expectNoFallbacksFromCUDA(net); } INSTANTIATE_TEST_CASE_P(/*nothing*/, DNNTestNetwork, dnnBackendsAndTargets(/* withInferenceEngine = */ true, diff --git a/modules/dnn/test/test_caffe_importer.cpp b/modules/dnn/test/test_caffe_importer.cpp index 015185379d..ed8f91f026 100644 --- a/modules/dnn/test/test_caffe_importer.cpp +++ b/modules/dnn/test/test_caffe_importer.cpp @@ -230,7 +230,14 @@ TEST_P(Reproducibility_AlexNet, Accuracy) ASSERT_TRUE(!sample.empty()); net.setInput(blobFromImage(sample, 1.0f, Size(227, 227), Scalar(), false), "data"); - Mat out = net.forward("prob"); + + Mat out; + // BUG: https://github.com/opencv/opencv/issues/26349 + if (net.getMainGraph()) + out = net.forward(); + else + out = net.forward("prob"); + Mat ref = blobFromNPY(_tf("caffe_alexnet_prob.npy")); normAssert(ref, out, "", l1, lInf); } @@ -259,7 +266,13 @@ TEST(Reproducibility_FCN, Accuracy) net.getMemoryConsumption(shape(1,3,227,227), CV_32F, layerIds, weights, blobs); net.setInput(blobFromImage(sample, 1.0f, Size(500, 500), Scalar(), false), "data"); - Mat out = net.forward("score"); + + Mat out; + // BUG: https://github.com/opencv/opencv/issues/26349 + if (net.getMainGraph()) + out = net.forward(); + else + out = net.forward("score"); Mat refData = imread(_tf("caffe_fcn8s_prob.png"), IMREAD_ANYDEPTH); int shape[] = {1, 21, 500, 500}; @@ -292,7 +305,13 @@ TEST(Reproducibility_SSD, Accuracy) Mat in_blob = blobFromImage(sample, 1.0f, Size(300, 300), Scalar(), false); net.setInput(in_blob, "data"); - Mat out = net.forward("detection_out"); + + // BUG: https://github.com/opencv/opencv/issues/26349 + Mat out; + if(net.getMainGraph()) + out = net.forward(); + else + out = net.forward("detection_out"); Mat ref = blobFromNPY(_tf("ssd_out.npy")); normAssertDetections(ref, out, "", 0.06); @@ -495,7 +514,13 @@ TEST(Reproducibility_GoogLeNet_fp16, Accuracy) ASSERT_TRUE(!inpMats[0].empty() && !inpMats[1].empty()); net.setInput(blobFromImages(inpMats, 1.0f, Size(), Scalar(), false), "data"); - Mat out = net.forward("prob"); + + // BUG: https://github.com/opencv/opencv/issues/26349 + Mat out; + if(net.getMainGraph()) + out = net.forward(); + else + out = net.forward("prob"); Mat ref = blobFromNPY(_tf("googlenet_prob.npy")); normAssert(out, ref, "", l1, lInf); diff --git a/modules/dnn/test/test_common.hpp b/modules/dnn/test/test_common.hpp index 0f9244324e..f58aae736b 100644 --- a/modules/dnn/test/test_common.hpp +++ b/modules/dnn/test/test_common.hpp @@ -195,6 +195,11 @@ public: void expectNoFallbacks(Net& net, bool raiseError = true) { + // The new DNN engine does not support back-ends for now + // bug: https://github.com/opencv/opencv/issues/26198 + if (net.getMainGraph()) + return; + // Check if all the layers are supported with current backend and target. // Some layers might be fused so their timings equal to zero. std::vector timings; diff --git a/modules/dnn/test/test_googlenet.cpp b/modules/dnn/test/test_googlenet.cpp index f911ff029f..2868cbbf64 100644 --- a/modules/dnn/test/test_googlenet.cpp +++ b/modules/dnn/test/test_googlenet.cpp @@ -80,7 +80,13 @@ TEST_P(Reproducibility_GoogLeNet, Batching) ASSERT_TRUE(!inpMats[0].empty() && !inpMats[1].empty()); net.setInput(blobFromImages(inpMats, 1.0f, Size(), Scalar(), false), "data"); - Mat out = net.forward("prob"); + + // BUG: https://github.com/opencv/opencv/issues/26349 + Mat out; + if(net.getMainGraph()) + out = net.forward(); + else + out = net.forward("prob"); Mat ref = blobFromNPY(_tf("googlenet_prob.npy")); normAssert(out, ref); @@ -93,8 +99,9 @@ TEST_P(Reproducibility_GoogLeNet, IntermediateBlobs) applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16); if (targetId == DNN_TARGET_CPU_FP16) applyTestTag(CV_TEST_TAG_DNN_SKIP_CPU_FP16); + // BUG: https://github.com/opencv/opencv/issues/26349 Net net = readNetFromCaffe(findDataFile("dnn/bvlc_googlenet.prototxt"), - findDataFile("dnn/bvlc_googlenet.caffemodel", false)); + findDataFile("dnn/bvlc_googlenet.caffemodel", false), ENGINE_CLASSIC); net.setPreferableBackend(DNN_BACKEND_OPENCV); net.setPreferableTarget(targetId); @@ -126,8 +133,9 @@ TEST_P(Reproducibility_GoogLeNet, SeveralCalls) applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16); if (targetId == DNN_TARGET_CPU_FP16) applyTestTag(CV_TEST_TAG_DNN_SKIP_CPU_FP16); + // BUG: https://github.com/opencv/opencv/issues/26349 Net net = readNetFromCaffe(findDataFile("dnn/bvlc_googlenet.prototxt"), - findDataFile("dnn/bvlc_googlenet.caffemodel", false)); + findDataFile("dnn/bvlc_googlenet.caffemodel", false), ENGINE_CLASSIC); net.setPreferableBackend(DNN_BACKEND_OPENCV); net.setPreferableTarget(targetId); diff --git a/modules/dnn/test/test_layers.cpp b/modules/dnn/test/test_layers.cpp index 11529e1562..55c2efe33c 100644 --- a/modules/dnn/test/test_layers.cpp +++ b/modules/dnn/test/test_layers.cpp @@ -408,7 +408,11 @@ TEST_P(Test_Caffe_layers, Reshape_Split_Slice) rng.fill(input, RNG::UNIFORM, -1, 1); net.setInput(input, "input"); - Mat output = net.forward("output"); + Mat output; + if (net.getMainGraph()) + output = net.forward(); + else + output = net.forward("output"); normAssert(input, output, "", default_l1, default_lInf); } @@ -864,7 +868,7 @@ TEST_P(Test_Caffe_layers, FasterRCNN_Proposal) std::vector outs; net.setPreferableBackend(backend); net.setPreferableTarget(target); - net.forward(outs, "output"); + net.forward(outs); for (int i = 0; i < 2; ++i) { diff --git a/modules/dnn/test/test_misc.cpp b/modules/dnn/test/test_misc.cpp index 652f876939..1b33aaa630 100644 --- a/modules/dnn/test/test_misc.cpp +++ b/modules/dnn/test/test_misc.cpp @@ -330,7 +330,10 @@ TEST_P(dump, Regression) Net net = readNet(findDataFile("dnn/squeezenet_v1.1.prototxt"), findDataFile("dnn/squeezenet_v1.1.caffemodel", false)); - ASSERT_EQ(net.getLayerInputs(net.getLayerId("fire2/concat")).size(), 2); + if (net.getMainGraph()) + ASSERT_EQ(net.getLayer(net.getLayerId("fire2/concat"))->inputs.size(), 2); + else + ASSERT_EQ(net.getLayerInputs(net.getLayerId("fire2/concat")).size(), 2); int size[] = {1, 3, 227, 227}; Mat input = cv::Mat::ones(4, size, CV_32F); @@ -602,7 +605,14 @@ TEST(Net, forwardAndRetrieve) outNames.push_back("testLayer"); std::vector > outBlobs; - net.forward(outBlobs, outNames); + if (net.getMainGraph()) + { + // Issue: https://github.com/opencv/opencv/issues/26349 + outBlobs.push_back({}); + net.forward(outBlobs[0]); + } + else + net.forward(outBlobs, outNames); EXPECT_EQ(outBlobs.size(), 1); EXPECT_EQ(outBlobs[0].size(), 2);