diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 7f8e919257..8154064e03 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -1166,6 +1166,13 @@ CV__DNN_INLINE_NS_BEGIN static Ptr create(const LayerParams ¶ms); }; + class CV_EXPORTS InstanceNormLayer : public Layer { + public: + float epsilon; + + static Ptr create(const LayerParams ¶ms); + }; + //! @} //! @} CV__DNN_INLINE_NS_END diff --git a/modules/dnn/perf/perf_layer.cpp b/modules/dnn/perf/perf_layer.cpp index e2d7cf2ff4..c26b7a1588 100644 --- a/modules/dnn/perf/perf_layer.cpp +++ b/modules/dnn/perf/perf_layer.cpp @@ -683,6 +683,62 @@ PERF_TEST_P_(Layer_GatherElements, GatherElements) test_layer({2700, 1, 2914}, {2700, 1, 81}, 2); } +struct Layer_InstanceNorm : public TestBaseWithParam > +{ + void test_layer(const std::vector& x_shape) + { + int backendId = get<0>(GetParam()); + int targetId = get<1>(GetParam()); + + Mat x(x_shape, CV_32FC1); + Mat scale(x_shape[1], 1, CV_32FC1); + Mat b(x_shape[1], 1, CV_32FC1); + + randu(x, 0.f, 1.f); + randu(scale, 0.f, 1.f); + randu(b, 0.f, 1.f); + + Net net; + LayerParams lp; + lp.type = "InstanceNormalization"; + lp.name = "testLayer"; + int id = net.addLayerToPrev(lp.name, lp.type, lp); + net.connect(0, 0, id, 0); + net.connect(0, 1, id, 1); + net.connect(0, 2, id, 2); + + // warmup + { + std::vector inpNames{"x", "scale", "b"}; + net.setInputsNames(inpNames); + net.setInput(x, inpNames[0]); + net.setInput(scale, inpNames[1]); + net.setInput(b, inpNames[2]); + + net.setPreferableBackend(backendId); + net.setPreferableTarget(targetId); + Mat out = net.forward(); + } + + TEST_CYCLE() + { + Mat res = net.forward(); + } + + SANITY_CHECK_NOTHING(); + } + + int N = 2; + int C = 64; + int H = 180; + int W = 240; +}; + +PERF_TEST_P_(Layer_InstanceNorm, InstanceNorm) +{ + test_layer({N, C, H, W}); +} + INSTANTIATE_TEST_CASE_P(/**/, Layer_Slice, dnnBackendsAndTargets(false, false)); INSTANTIATE_TEST_CASE_P(/**/, Layer_NaryEltwise, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); #ifdef HAVE_CUDA @@ -693,6 +749,7 @@ INSTANTIATE_TEST_CASE_P(/**/, Layer_ScatterND, testing::Values(std::make_tuple(D INSTANTIATE_TEST_CASE_P(/**/, Layer_LayerNorm, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); INSTANTIATE_TEST_CASE_P(/**/, Layer_LayerNormExpanded, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); INSTANTIATE_TEST_CASE_P(/**/, Layer_GatherElements, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); +INSTANTIATE_TEST_CASE_P(/**/, Layer_InstanceNorm, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); typedef TestBaseWithParam > > Layer_FullyConnected; diff --git a/modules/dnn/src/cuda/mvn.cu b/modules/dnn/src/cuda/mvn.cu index adf997c0b0..d4f0733676 100644 --- a/modules/dnn/src/cuda/mvn.cu +++ b/modules/dnn/src/cuda/mvn.cu @@ -66,6 +66,17 @@ namespace raw { output[idx] = (static_cast(input[idx]) - means[outer_idx]) * scale[outer_idx]; } } + + template + __global__ void normalize_mean_variance_channelwise(Span output, View input, View scale, View bias, View means, View stdev, size_type inner_size, size_type C) { + for (auto idx : grid_stride_range(output.size())) { + const index_type outer_idx = idx / inner_size; + const index_type c = outer_idx % C; + auto s = static_cast(scale[c]) * stdev[outer_idx]; + auto b = static_cast(bias[c]); + output[idx] = (static_cast(input[idx]) - means[outer_idx]) * s + b; + } + } } template @@ -142,4 +153,21 @@ template void normalize_mean_variance(const Stream&, Span<__half>, View<__half>, #endif template void normalize_mean_variance(const Stream&, Span, View, View, View, std::size_t); +template +void normalize_mean_variance_channelwise(const Stream& stream, Span output, View input, View scale, View bias, View means, View stdev, std::size_t inner_size, std::size_t C) +{ + CV_Assert(input.size() == output.size()); + CV_Assert(input.size() / inner_size == means.size()); + CV_Assert(means.size() == stdev.size()); + + auto kernel = raw::normalize_mean_variance_channelwise; + auto policy = make_policy(kernel, output.size(), 0, stream); + launch_kernel(kernel, policy, output, input, scale, bias, means, stdev, inner_size, C); +} + +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) +template void normalize_mean_variance_channelwise(const Stream&, Span<__half> /*output*/, View<__half> /*input*/, View<__half> /*scale*/, View<__half> /*bias*/, View /*means*/, View /*stdev*/, std::size_t, std::size_t); +#endif +template void normalize_mean_variance_channelwise(const Stream&, Span /*output*/, View /*input*/, View /*scale*/, View /*bias*/, View /*means*/, View /*stdev*/, std::size_t, std::size_t); + }}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda4dnn/kernels/mvn.hpp b/modules/dnn/src/cuda4dnn/kernels/mvn.hpp index b5a573e921..ebd7b9f659 100644 --- a/modules/dnn/src/cuda4dnn/kernels/mvn.hpp +++ b/modules/dnn/src/cuda4dnn/kernels/mvn.hpp @@ -26,6 +26,9 @@ void normalize_mean(const csl::Stream& stream, csl::Span output, csl::View template void normalize_mean_variance(const csl::Stream& stream, csl::Span output, csl::View input, csl::View means, csl::View scale, std::size_t inner_size); +template +void normalize_mean_variance_channelwise(const csl::Stream &stream, csl::Span output, csl::View input, csl::View scale, csl::View bias, csl::View means, csl::View stdev, std::size_t inner_size, std::size_t C); + }}}} /* namespace cv::dnn::cuda4dnn::kernels */ #endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_MVN_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/instance_norm.hpp b/modules/dnn/src/cuda4dnn/primitives/instance_norm.hpp new file mode 100644 index 0000000000..0a32e40fc0 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/instance_norm.hpp @@ -0,0 +1,86 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_INSTANCE_NORM_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_INSTANCE_NORM_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/span.hpp" +#include "../csl/tensor.hpp" +#include "../csl/workspace.hpp" + +#include "../kernels/fill_copy.hpp" +#include "../kernels/mvn.hpp" + +#include + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class InstanceNormOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + InstanceNormOp(csl::Stream stream_, float epsilon_, size_t loops) + : stream(std::move(stream_)), epsilon(epsilon_) { + csl::WorkspaceBuilder builder; + builder.require(loops); + builder.require(loops); + scratch_mem_in_bytes = builder.required_workspace_size(); + } + + void forward(const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override { + auto input_wrapper = inputs[0].dynamicCast(); + auto scale_wrapper = inputs[1].dynamicCast(); + auto bias_wrapper = inputs[2].dynamicCast(); + + auto input = input_wrapper->getView(); + auto scale = scale_wrapper->getView(); + auto bias = bias_wrapper->getView(); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + auto C = input.get_axis_size(1); + auto loops = input.size_range(0, 2); + auto norm_size = input.size_range(2, input.rank()); + if (norm_size == 1) { + kernels::fill(stream, output, 0.f); + return; + } else { + auto ws_allocator = csl::WorkspaceAllocator(workspace); + + auto mean = ws_allocator.get_span(loops); + kernels::fill(stream, mean, 0.f); + + auto stdev = ws_allocator.get_span(loops); + kernels::fill(stream, stdev, 0.f); + + kernels::reduce_mean_sqr_sum(stream, mean, stdev, input, norm_size); + kernels::compute_normalization_scale(stream, stdev, mean, stdev, norm_size, epsilon); + kernels::normalize_mean_variance_channelwise(stream, output, input, scale, bias, mean, stdev, norm_size, C); + } + } + + std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; } + + private: + csl::Stream stream; + + float epsilon; + + std::size_t scratch_mem_in_bytes; + }; + +}}} // cv::dnn::cuda4dnn + +#endif // OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_INSTANCE_NORM_HPP diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index e70d5dad47..961e6e5c9a 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -160,6 +160,7 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(GatherElements, GatherElementsLayer); CV_DNN_REGISTER_LAYER_CLASS(LayerNormalization, LayerNormLayer); CV_DNN_REGISTER_LAYER_CLASS(Expand, ExpandLayer); + CV_DNN_REGISTER_LAYER_CLASS(InstanceNormalization, InstanceNormLayer); CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer); CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer); diff --git a/modules/dnn/src/layers/cpu_kernels/fast_norm.cpp b/modules/dnn/src/layers/cpu_kernels/fast_norm.cpp index 60b503513f..ab9d8ee0af 100644 --- a/modules/dnn/src/layers/cpu_kernels/fast_norm.cpp +++ b/modules/dnn/src/layers/cpu_kernels/fast_norm.cpp @@ -118,10 +118,11 @@ void fastNorm(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, void fastNormChannel(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, float epsilon) { const auto input_shape = shape(input); + size_t N = input_shape[0], C = input_shape[1]; CV_CheckEQ(scale.total(), bias.total(), "fastNormChannel: scale and bias should have the same shape"); + CV_CheckEQ(scale.total(), C, "fastNormChannel: scale should be a 1d tensor and match the channel of input"); CV_CheckGE(input.dims, 3, "fastNormChannel: input dimension >= 3"); - size_t N = input_shape[0], C = input_shape[1]; size_t loops = N * C, norm_size = static_cast(total(input_shape, 2)); float inv_norm_size = 1.0 / norm_size; @@ -147,9 +148,9 @@ void fastNormChannel(const Mat &input, const Mat &scale, const Mat &bias, Mat &o float inv_stdev = 1.f / mean_square; size_t c = i % C; - float s = scale_data[c], b = bias_data[c]; + float s = scale_data[c] * inv_stdev, b = bias_data[c]; for (size_t j = 0; j < norm_size; j++) { - y[j] = s * (x[j] - mean) * inv_stdev + b; + y[j] = s * (x[j] - mean) + b; } } }; diff --git a/modules/dnn/src/layers/instance_norm_layer.cpp b/modules/dnn/src/layers/instance_norm_layer.cpp new file mode 100644 index 0000000000..fda0efdb94 --- /dev/null +++ b/modules/dnn/src/layers/instance_norm_layer.cpp @@ -0,0 +1,231 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "../precomp.hpp" +#include +#include "./cpu_kernels/fast_norm.hpp" + +// OpenVINO backend +#include "../op_inf_engine.hpp" +#include "../ie_ngraph.hpp" + +// CUDA backend +#include "../op_cuda.hpp" +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/instance_norm.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + +// OpenCL backend +#ifdef HAVE_OPENCL +#include "../ocl4dnn/include/math_functions.hpp" +#include "opencl_kernels_dnn.hpp" +#endif + +namespace cv { namespace dnn { + +// https://github.com/onnx/onnx/blob/main/docs/Operators.md#InstanceNormalization +class InstanceNormLayerImpl CV_FINAL : public InstanceNormLayer { +public: + InstanceNormLayerImpl(const LayerParams ¶ms) { + setParamsFrom(params); + + epsilon = params.get("epsilon", 1e-5); + } + + virtual bool supportBackend(int backendId) CV_OVERRIDE { +#ifdef HAVE_INF_ENGINE + if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) + return true; +#endif + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA; + } + + bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE { + const auto &input = inputs[0]; + const auto &scale = inputs[1]; + const auto &bias = inputs[2]; + CV_CheckGE(input.size(), static_cast(3), "DNN/InstanceNorm: input dimension >= 3 is required"); + + int C = input[1]; + int scale_dim = std::accumulate(scale.begin(), scale.end(), 1, std::multiplies()); + CV_CheckEQ(scale_dim, C, "DNN/InstanceNorm: scale must be a 1d tensor and match the channel of input"); + int bias_dim = std::accumulate(bias.begin(), bias.end(), 1, std::multiplies()); + CV_CheckEQ(bias_dim, C, "DNN/InstanceNorm: bias must be a 1d tensor and match the channel of input"); + + outputs.assign(1, inputs[0]); + return false; + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget), + forward_ocl(inputs_arr, outputs_arr, internals_arr)) + + if (inputs_arr.depth() == CV_16S) + { + forward_fallback(inputs_arr, outputs_arr, internals_arr); + return; + } + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + const auto &input = inputs[0]; + const auto &scale = inputs[1]; + const auto &bias = inputs[2]; + + fastNormChannel(input, scale, bias, outputs[0], epsilon); + } + +#ifdef HAVE_OPENCL + bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_) { + std::vector inputs; + std::vector outputs; + + inputs_.getUMatVector(inputs); + outputs_.getUMatVector(outputs); + + const auto &input = inputs[0], &scale = inputs[1], &bias = inputs[2]; + auto &output = outputs[0]; + + const auto input_shape = shape(input); + size_t N = input_shape[0], C = input_shape[1], + loops = N * C, norm_size = static_cast(total(input_shape, 2)); + float inv_norm_size = 1.f / norm_size; + + // no fp16 support + if (input.depth() == CV_16S) { + return false; + } + + String base_opts = format(" -DT=float -DT4=float4 -Dconvert_T=convert_float4"); + + // Calculate mean + UMat one = UMat::ones(norm_size, 1, CV_32F); + UMat mean = UMat(loops, 1, CV_32F); + UMat mean_square = UMat(loops, 1, CV_32F); + UMat tmp = UMat(loops, norm_size, CV_32F); + bool ret = ocl4dnn::ocl4dnnGEMV(ocl4dnn::CblasNoTrans, loops, norm_size, inv_norm_size, + input, 0, one, 0, 0.f, mean, 0); + if (!ret) { + return false; + } + // Calculate mean_square + int num_vector = (norm_size % 8 == 0) ? 8 : ((norm_size % 4 == 0) ? 4 : 1); + size_t global[] = {loops, static_cast(norm_size / num_vector)}; + String build_opt = format(" -DNUM=%d", num_vector) + base_opts; + String mean_square_kernel_name = format("calc_mean%d", num_vector); + ocl::Kernel mean_square_kernel(mean_square_kernel_name.c_str(), ocl::dnn::mvn_oclsrc, build_opt + " -DKERNEL_MEAN"); + if (mean_square_kernel.empty()) { + return false; + } + mean_square_kernel.set(0, ocl::KernelArg::PtrReadOnly(input)); + mean_square_kernel.set(1, (int)loops); + mean_square_kernel.set(2, (int)norm_size); + mean_square_kernel.set(3, ocl::KernelArg::PtrReadOnly(mean)); + mean_square_kernel.set(4, ocl::KernelArg::PtrWriteOnly(tmp)); + ret = mean_square_kernel.run(2, global, NULL, false); + if (!ret) { + return false; + } + ret = ocl4dnn::ocl4dnnGEMV(ocl4dnn::CblasNoTrans, loops, norm_size, inv_norm_size, + tmp, 0, one, 0, 0.f, mean_square, 0); + if (!ret) { + return false; + } + // Calculate instance norm: output = scale * (x - mean) / sqrt(var + eps) + bias + String mvn_kernel_name = format("mvn%d", num_vector); + build_opt += " -DNORM_VARIANCE -DFUSE_BATCH_NORM -DKERNEL_MVN"; + ocl::Kernel mvn_kernel(mvn_kernel_name.c_str(), ocl::dnn::mvn_oclsrc, build_opt); + if (mvn_kernel.empty()) { + return false; + } + mvn_kernel.set(0, ocl::KernelArg::PtrReadOnly(input)); + mvn_kernel.set(1, (int)loops); + mvn_kernel.set(2, (int)norm_size); + mvn_kernel.set(3, (float)epsilon); + mvn_kernel.set(4, ocl::KernelArg::PtrReadOnly(mean)); + mvn_kernel.set(5, ocl::KernelArg::PtrReadOnly(mean_square)); + mvn_kernel.set(6, ocl::KernelArg::PtrReadOnly(scale)); + mvn_kernel.set(7, ocl::KernelArg::PtrReadOnly(bias)); + mvn_kernel.set(8, (int)C); + mvn_kernel.set(9, (float)0.f); + mvn_kernel.set(10, ocl::KernelArg::PtrWriteOnly(output)); + ret = mvn_kernel.run(2, global, NULL, false); + if (!ret) { + return false; + } + + return true; + } +#endif + +#ifdef HAVE_DNN_NGRAPH + virtual Ptr initNgraph(const std::vector >& inputs, + const std::vector >& nodes) CV_OVERRIDE { + // onnx to openvino convertion: https://github.com/openvinotoolkit/openvino/blob/2023.1.0/src/frontends/onnx/frontend/src/op/instance_norm.cpp + + auto ieInpNode = nodes[0].dynamicCast()->node; + const auto &input_shape = ieInpNode.get_shape(); + std::shared_ptr mvn, result; + + // mvn +#if INF_ENGINE_VER_MAJOR_LE(INF_ENGINE_RELEASE_2021_2) + // https://docs.openvino.ai/2021.4/api/ngraph_python_api/_autosummary/ngraph.opset3.mvn.html?highlight=mvn#ngraph.opset3.mvn + bool across_channels = false; + bool normalize_variance = true; + mvn = std::make_shared(ieInpNode, across_channels, normalize_variance, epsilon); +#else + // https://docs.openvino.ai/2023.1/openvino_docs_ops_normalization_MVN_6.html + std::vector axes_v(input_shape.size() - 2); + std::iota(axes_v.begin(), axes_v.end(), 2); // {2, 3, ...} for nd input tensor, n>=3 + auto axes = std::make_shared(ngraph::element::i64, ngraph::Shape{axes_v.size()}, axes_v.data()); + bool normalize_variance = true; + mvn = std::make_shared(ieInpNode, axes, normalize_variance, epsilon, ngraph::op::MVNEpsMode::INSIDE_SQRT); +#endif + + // instance norm = scale * mvn + bias + auto scale = nodes[1].dynamicCast()->node; + std::vector shared_shape_v(input_shape.size(), 1); + shared_shape_v[1] = -1; + auto shared_shape = std::make_shared(ngraph::element::i64, ngraph::Shape{shared_shape_v.size()}, shared_shape_v.data()); + scale = std::make_shared(scale, shared_shape, true); + result = std::make_shared(mvn, scale); + auto bias = nodes[2].dynamicCast()->node; + bias = std::make_shared(bias, shared_shape, true); + result = std::make_shared(result, bias); + + return Ptr(new InfEngineNgraphNode(result)); + } +#endif // HAVE_DNN_NGRAPH + +#ifdef HAVE_CUDA + Ptr initCUDA(void *context_, + const std::vector>& inputs, + const std::vector>& outputs) override { + auto context = reinterpret_cast(context_); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input_shape = input_wrapper->getShape(); + size_t loops = static_cast(total(input_shape, 0, 2)); + + return make_cuda_node(preferableTarget, std::move(context->stream), epsilon, loops); + } +#endif // HAVE_CUDA + +}; + +Ptr InstanceNormLayer::create(const LayerParams ¶ms) { + return Ptr(new InstanceNormLayerImpl(params)); +} + +}} // cv::dnn diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 4813d118c4..28dd4d9b77 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -1844,44 +1844,43 @@ void ONNXImporter::parseLRN(LayerParams& layerParams, const opencv_onnx::NodePro addLayer(layerParams, node_proto); } -void ONNXImporter::parseInstanceNormalization(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_) -{ - opencv_onnx::NodeProto node_proto = node_proto_; - if (node_proto.input_size() != 3) - CV_Error(Error::StsNotImplemented, - "Expected input, scale, bias"); +void ONNXImporter::parseInstanceNormalization(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) { + int num_inputs = node_proto.input_size(); + CV_CheckEQ(num_inputs, 3, "DNN/ONNXImporter - InstanceNorm: three inputs are required"); - layerParams.blobs.resize(4); - layerParams.blobs[2] = getBlob(node_proto, 1); // weightData - layerParams.blobs[3] = getBlob(node_proto, 2); // biasData - layerParams.set("has_bias", true); - layerParams.set("has_weight", true); + bool found_input = constBlobs.find(node_proto.input(0)) != constBlobs.end(); + bool found_scale = constBlobs.find(node_proto.input(1)) != constBlobs.end(); + bool found_bias = constBlobs.find(node_proto.input(2)) != constBlobs.end(); - // Get number of channels in input - int size = layerParams.blobs[2].total(); - layerParams.blobs[0] = Mat::zeros(size, 1, CV_32F); // mean - layerParams.blobs[1] = Mat::ones(size, 1, CV_32F); // std + if (found_input && found_scale && found_bias) { + std::vector inputs, output; - LayerParams mvnParams; - mvnParams.name = layerParams.name + "/MVN"; - mvnParams.type = "MVN"; - mvnParams.set("eps", layerParams.get("epsilon")); - layerParams.erase("epsilon"); + Mat input = getBlob(node_proto, 0); + Mat scale = getBlob(node_proto, 1); + Mat bias = getBlob(node_proto, 2); + inputs.push_back(input); + inputs.push_back(scale); + inputs.push_back(bias); - //Create MVN layer - int id = dstNet.addLayer(mvnParams.name, mvnParams.type, mvnParams); - //Connect to input - IterLayerId_t layerId = layer_id.find(node_proto.input(0)); - CV_Assert(layerId != layer_id.end()); - dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0); - //Add shape - layer_id.insert(std::make_pair(mvnParams.name, LayerInfo(id, 0))); - outShapes[mvnParams.name] = outShapes[node_proto.input(0)]; + runLayer(layerParams, inputs, output); + addConstant(node_proto.output(0), output[0]); + } else { + auto add_const_node = [&] (int i) { + LayerParams const_params; + const_params.name = node_proto.input(i); + const_params.type = "Const"; + Mat blob = getBlob(node_proto, i); + const_params.blobs.push_back(blob); - //Replace Batch Norm's input to MVN - node_proto.set_input(0, mvnParams.name); - layerParams.type = "BatchNorm"; - addLayer(layerParams, node_proto); + opencv_onnx::NodeProto proto; + proto.add_output(const_params.name); + addLayer(const_params, proto); + }; + if (found_input && layer_id.find(node_proto.input(0)) == layer_id.end()) { add_const_node(0); } + if (found_scale && layer_id.find(node_proto.input(1)) == layer_id.end()) { add_const_node(1); } + if (found_bias && layer_id.find(node_proto.input(2)) == layer_id.end()) { add_const_node(2); } + addLayer(layerParams, node_proto); + } } void ONNXImporter::parseBatchNormalization(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) diff --git a/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp index 8c461b699f..be60c38b86 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp @@ -159,8 +159,6 @@ "test_if", "test_if_opt", "test_if_seq", -"test_instancenorm_epsilon", -"test_instancenorm_example", "test_isinf", "test_isinf_negative", "test_isinf_positive",