diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 41fe0df70f..3301f20fde 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -1183,6 +1183,11 @@ CV__DNN_INLINE_NS_BEGIN static Ptr create(const LayerParams ¶ms); }; + class CV_EXPORTS GroupNormLayer : public Layer { + public: + static Ptr create(const LayerParams ¶ms); + }; + //! @} //! @} CV__DNN_INLINE_NS_END diff --git a/modules/dnn/perf/perf_layer.cpp b/modules/dnn/perf/perf_layer.cpp index 66b5ad62c2..946e29ccb4 100644 --- a/modules/dnn/perf/perf_layer.cpp +++ b/modules/dnn/perf/perf_layer.cpp @@ -795,6 +795,66 @@ PERF_TEST_P_(Layer_Attention, VisionTransformer) { test_layer({1, 197, 768}, {768, 768, 768}, 12); } +struct Layer_GroupNorm : public TestBaseWithParam > +{ + void test_layer(const std::vector& x_shape, int num_groups) + { + int backendId = get<0>(GetParam()); + int targetId = get<1>(GetParam()); + + Mat x(x_shape, CV_32FC1); + Mat scale(x_shape[1], 1, CV_32FC1); + Mat b(x_shape[1], 1, CV_32FC1); + + randu(x, 0.f, 1.f); + randu(scale, 0.f, 1.f); + randu(b, 0.f, 1.f); + + Net net; + LayerParams lp; + lp.type = "GroupNormalization"; + lp.name = "testLayer"; + lp.set("num_groups", num_groups); + + int id = net.addLayerToPrev(lp.name, lp.type, lp); + net.connect(0, 0, id, 0); + net.connect(0, 1, id, 1); + net.connect(0, 2, id, 2); + + // warmup + { + std::vector inpNames{"x", "scale", "b"}; + net.setInputsNames(inpNames); + net.setInput(x, inpNames[0]); + net.setInput(scale, inpNames[1]); + net.setInput(b, inpNames[2]); + + net.setPreferableBackend(backendId); + net.setPreferableTarget(targetId); + Mat out = net.forward(); + } + + TEST_CYCLE() + { + Mat res = net.forward(); + } + + SANITY_CHECK_NOTHING(); + } + + int N = 2; + int C = 64; + int H = 180; + int W = 240; + int num_groups = 16; +}; + +PERF_TEST_P_(Layer_GroupNorm, GroupNorm) +{ + test_layer({N, C, H, W}, num_groups); +} + + INSTANTIATE_TEST_CASE_P(/**/, Layer_Slice, dnnBackendsAndTargets(false, false)); INSTANTIATE_TEST_CASE_P(/**/, Layer_NaryEltwise, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); #ifdef HAVE_CUDA @@ -807,7 +867,7 @@ INSTANTIATE_TEST_CASE_P(/**/, Layer_LayerNormExpanded, testing::Values(std::make INSTANTIATE_TEST_CASE_P(/**/, Layer_GatherElements, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); INSTANTIATE_TEST_CASE_P(/**/, Layer_InstanceNorm, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); INSTANTIATE_TEST_CASE_P(/**/, Layer_Attention, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); - +INSTANTIATE_TEST_CASE_P(/**/, Layer_GroupNorm, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); typedef TestBaseWithParam > > Layer_FullyConnected; PERF_TEST_P_(Layer_FullyConnected, fc) diff --git a/modules/dnn/src/cuda/mvn.cu b/modules/dnn/src/cuda/mvn.cu index 0accc499a2..d6db7c4fb4 100644 --- a/modules/dnn/src/cuda/mvn.cu +++ b/modules/dnn/src/cuda/mvn.cu @@ -78,6 +78,18 @@ namespace raw { } } + template + __global__ void normalize_mean_variance_groupwise(Span output, View input, View scale, View bias, View means, View inv_stddev, size_type inner_size, size_type C, size_type num_groups, size_type group_size) { + for (auto idx : grid_stride_range(output.size())) { + const index_type outer_idx = idx / inner_size; + const index_type c = outer_idx % C; + const index_type group_idx = outer_idx / group_size; + auto s = static_cast(scale[c]) * inv_stddev[group_idx]; + auto b = static_cast(bias[c]); + output[idx] = (static_cast(input[idx]) - means[group_idx]) * s + b; + } + } + template __global__ void normalize_mean_variance_layernorm(Span output, View input, View scale, View means, View inv_stddev, size_type inner_size) { for (auto idx : grid_stride_range(output.size())) { @@ -191,6 +203,24 @@ template void normalize_mean_variance_channelwise(const Stream&, Span<__half> /* #endif template void normalize_mean_variance_channelwise(const Stream&, Span /*output*/, View /*input*/, View /*scale*/, View /*bias*/, View /*means*/, View /*inv_stddev*/, std::size_t, std::size_t); +template +void normalize_mean_variance_groupwise(const Stream& stream, Span output, View input, View scale, View bias, View means, View inv_stddev, std::size_t inner_size, std::size_t C, std::size_t num_groups, std::size_t group_size) +{ + CV_Assert(input.size() == output.size()); + CV_Assert(input.size() / inner_size == means.size() * group_size); + CV_Assert(means.size() == inv_stddev.size()); + + auto kernel = raw::normalize_mean_variance_groupwise; + auto policy = make_policy(kernel, output.size(), 0, stream); + launch_kernel(kernel, policy, output, input, scale, bias, means, inv_stddev, inner_size, C, num_groups, group_size); +} + +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) +template void normalize_mean_variance_groupwise(const Stream&, Span<__half> /*output*/, View<__half> /*input*/, View<__half> /*scale*/, View<__half> /*bias*/, View /*means*/, View /*inv_stddev*/, std::size_t, std::size_t, std::size_t, std::size_t); +#endif +template void normalize_mean_variance_groupwise(const Stream&, Span /*output*/, View /*input*/, View /*scale*/, View /*bias*/, View /*means*/, View /*inv_stddev*/, std::size_t, std::size_t, std::size_t, std::size_t); + + template void normalize_mean_variance_layernorm(const Stream& stream, Span output, View input, View scale, View means, View inv_stddev, std::size_t inner_size) { diff --git a/modules/dnn/src/cuda4dnn/kernels/mvn.hpp b/modules/dnn/src/cuda4dnn/kernels/mvn.hpp index 6cddeb22bb..a09dafb76d 100644 --- a/modules/dnn/src/cuda4dnn/kernels/mvn.hpp +++ b/modules/dnn/src/cuda4dnn/kernels/mvn.hpp @@ -35,6 +35,10 @@ void normalize_mean_variance_layernorm(const csl::Stream &stream, csl::Span o template void normalize_mean_variance_layernorm(const csl::Stream &stream, csl::Span output, csl::View input, csl::View scale, csl::View bias, csl::View means, csl::View inv_stddev, std::size_t inner_size); +template +void normalize_mean_variance_groupwise(const csl::Stream &stream, csl::Span output, csl::View input, csl::View scale, csl::View bias, csl::View means, csl::View inv_stddev, std::size_t inner_size, std::size_t C, std::size_t num_groups, std::size_t group_size); + + }}}} /* namespace cv::dnn::cuda4dnn::kernels */ #endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_MVN_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/group_norm.hpp b/modules/dnn/src/cuda4dnn/primitives/group_norm.hpp new file mode 100644 index 0000000000..bb3e162a33 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/group_norm.hpp @@ -0,0 +1,87 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_GROUP_NORM_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_GROUP_NORM_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/span.hpp" +#include "../csl/tensor.hpp" +#include "../csl/workspace.hpp" + +#include "../kernels/fill_copy.hpp" +#include "../kernels/mvn.hpp" + +#include + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class GroupNormOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + GroupNormOp(csl::Stream stream_, float epsilon_, size_t loops, size_t num_groups) + : stream(std::move(stream_)), epsilon(epsilon_), num_groups(num_groups) { + csl::WorkspaceBuilder builder; + builder.require(loops * num_groups); // mean and stdev for each group + builder.require(loops * num_groups); + scratch_mem_in_bytes = builder.required_workspace_size(); + } + + void forward(const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override { + auto input_wrapper = inputs[0].dynamicCast(); + auto scale_wrapper = inputs[1].dynamicCast(); + auto bias_wrapper = inputs[2].dynamicCast(); + + auto input = input_wrapper->getView(); + auto scale = scale_wrapper->getView(); + auto bias = bias_wrapper->getView(); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + auto C = input.get_axis_size(1); + auto loops = input.size_range(0, 2); + auto norm_size = input.size_range(2, input.rank()); + auto num_groups = this->num_groups; + auto group_size = C / num_groups; + if (norm_size == 1) { + kernels::fill(stream, output, 0.f); + return; + } else { + auto ws_allocator = csl::WorkspaceAllocator(workspace); + + auto mean = ws_allocator.get_span(loops / group_size); + kernels::fill(stream, mean, 0.f); + + auto stdev = ws_allocator.get_span(loops / group_size); + kernels::fill(stream, stdev, 0.f); + + kernels::reduce_mean_sqr_sum(stream, mean, stdev, input, norm_size * group_size); + kernels::compute_normalization_scale(stream, stdev, mean, stdev, norm_size * group_size, epsilon); + kernels::normalize_mean_variance_groupwise(stream, output, input, scale, bias, mean, stdev, norm_size, C, num_groups, group_size); + } + } + + std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; } + + private: + csl::Stream stream; + float epsilon; + std::size_t num_groups; + std::size_t scratch_mem_in_bytes; + }; + +}}} // cv::dnn::cuda4dnn + +#endif // OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_GROUP_NORM_HPP diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index 9b433dac50..2170aafc4b 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -163,6 +163,7 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(Expand, ExpandLayer); CV_DNN_REGISTER_LAYER_CLASS(InstanceNormalization, InstanceNormLayer); CV_DNN_REGISTER_LAYER_CLASS(Attention, AttentionLayer); + CV_DNN_REGISTER_LAYER_CLASS(GroupNormalization, GroupNormLayer); CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer); CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer); diff --git a/modules/dnn/src/layers/cpu_kernels/fast_norm.cpp b/modules/dnn/src/layers/cpu_kernels/fast_norm.cpp index ab9d8ee0af..35f354ed29 100644 --- a/modules/dnn/src/layers/cpu_kernels/fast_norm.cpp +++ b/modules/dnn/src/layers/cpu_kernels/fast_norm.cpp @@ -158,4 +158,51 @@ void fastNormChannel(const Mat &input, const Mat &scale, const Mat &bias, Mat &o parallel_for_(Range(0, loops), fn, nstripes); } +void fastNormGroup(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, float epsilon, size_t num_groups) { + const auto input_shape = shape(input); + size_t N = input_shape[0], C = input_shape[1]; + CV_CheckEQ(scale.total(), bias.total(), "fastNormGroup: scale and bias should have the same shape"); + CV_CheckEQ(scale.total(), C, "fastNormGroup: scale should be a 1d tensor and match the channel of input"); + CV_CheckGE(input.dims, 3, "fastNormGroup: input dimension >= 3"); + + size_t channels_per_group = C / num_groups; + size_t loops = N * num_groups; + size_t norm_size = static_cast(total(input_shape, 2) * channels_per_group); + size_t step = norm_size / channels_per_group; + float inv_norm_size = 1.0 / norm_size; + + auto fn = [&](const Range &r) { + const auto *input_data = input.ptr(); + const auto *scale_data = scale.ptr(); + const auto *bias_data = bias.ptr(); + auto *output_data = output.ptr(); + + for (int i = r.start; i < r.end; i++) { + const auto *x = input_data + norm_size * i; + auto *y = output_data + norm_size * i; + + float mean = 0.f, mean_square = 0.f; + for (int j = 0; j < norm_size; j++) { + float v = x[j]; + mean += v; + mean_square += v * v; + } + + mean *= inv_norm_size; + mean_square = std::sqrt(std::max(0.f, mean_square * inv_norm_size - mean * mean) + epsilon); + float inv_stdev = 1.f / mean_square; + + size_t group_idx = i % num_groups * channels_per_group; + for (size_t j = 0; j < norm_size; j++) { + size_t c = group_idx + (j / step); + float s = scale_data[c] * inv_stdev, b = bias_data[c]; + y[j] = s * (x[j] - mean) + b; + } + } + }; + + double nstripes = loops * norm_size * (1 / 1024.0); + parallel_for_(Range(0, loops), fn, nstripes); +} + }} // cv::dnn diff --git a/modules/dnn/src/layers/cpu_kernels/fast_norm.hpp b/modules/dnn/src/layers/cpu_kernels/fast_norm.hpp index 61316542d3..72cbdad0a7 100644 --- a/modules/dnn/src/layers/cpu_kernels/fast_norm.hpp +++ b/modules/dnn/src/layers/cpu_kernels/fast_norm.hpp @@ -21,6 +21,9 @@ void fastNorm(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, // Channel-wise Normalization speedup by multi-threading. Scale and bias should have the same shape (C). Input should have dimension >= 3. void fastNormChannel(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, float epsilon); +// Group-wise Normalization speedup by multi-threading. Scale and bias should have the same shape (C). Input should have dimension >= 3. +void fastNormGroup(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, float epsilon, size_t num_groups); + }} // cv::dnn #endif // OPENCV_DNN_FAST_NORM_HPP diff --git a/modules/dnn/src/layers/group_norm_layer.cpp b/modules/dnn/src/layers/group_norm_layer.cpp new file mode 100644 index 0000000000..006e8fe7f8 --- /dev/null +++ b/modules/dnn/src/layers/group_norm_layer.cpp @@ -0,0 +1,190 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "../precomp.hpp" +#include +#include "./cpu_kernels/fast_norm.hpp" + +// CUDA backend +#include "../op_cuda.hpp" +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/group_norm.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + +// OpenCL backend +#ifdef HAVE_OPENCL +#include "../ocl4dnn/include/math_functions.hpp" +#include "opencl_kernels_dnn.hpp" +#endif + +namespace cv { +namespace dnn { + +// https://github.com/onnx/onnx/blob/main/docs/Operators.md#GroupNormalization +class GroupNormLayerImpl CV_FINAL : public GroupNormLayer { +public: + GroupNormLayerImpl(const LayerParams ¶ms) { + setParamsFrom(params); + + epsilon = params.get("epsilon", 1e-5); + num_groups = params.get("num_groups"); + } + + virtual bool supportBackend(int backendId) CV_OVERRIDE { + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA; + } + + bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE { + const auto &input = inputs[0]; + const auto &scale = inputs[1]; + const auto &bias = inputs[2]; + CV_CheckGE(input.size(), static_cast(3), "DNN/GroupNorm: input dimension >= 3 is required"); + + int C = input[1]; + int scale_dim = std::accumulate(scale.begin(), scale.end(), 1, std::multiplies()); + CV_CheckEQ(scale_dim, C, "DNN/InstanceNorm: scale must be a 1d tensor and match the channel of input"); + int bias_dim = std::accumulate(bias.begin(), bias.end(), 1, std::multiplies()); + CV_CheckEQ(bias_dim, C, "DNN/InstanceNorm: bias must be a 1d tensor and match the channel of input"); + + outputs.assign(1, inputs[0]); + return false; + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + if (inputs_arr.depth() == CV_16S) { + forward_fallback(inputs_arr, outputs_arr, internals_arr); + return; + } + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + const auto& input = inputs[0]; + const auto& scale = inputs[1]; + const auto& bias = inputs[2]; + + fastNormGroup(input, scale, bias, outputs[0], epsilon, num_groups); + } + +#ifdef HAVE_OPENCL + bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_) { + std::vector inputs; + std::vector outputs; + + inputs_.getUMatVector(inputs); + outputs_.getUMatVector(outputs); + + const auto &input = inputs[0], &scale = inputs[1], &bias = inputs[2]; + auto &output = outputs[0]; + + const auto input_shape = shape(input); + size_t N = input_shape[0], C = input_shape[1]; + size_t num_groups = this->num_groups; + size_t channels_per_group = C / num_groups; + size_t loops = N * num_groups, norm_size = static_cast(total(input_shape, 2)) * channels_per_group; + float inv_norm_size = 1.f / norm_size; + + // no fp16 support + if (input.depth() == CV_16S) { + return false; + } + + String base_opts = format(" -DT=float -DT4=float4 -Dconvert_T=convert_float4"); + + // Calculate mean + UMat one = UMat::ones(norm_size, 1, CV_32F); + UMat mean = UMat(loops, 1, CV_32F); + UMat mean_square = UMat(loops, 1, CV_32F); + UMat tmp = UMat(loops, norm_size, CV_32F); + bool ret = ocl4dnn::ocl4dnnGEMV(ocl4dnn::CblasNoTrans, loops, norm_size, inv_norm_size, + input, 0, one, 0, 0.f, mean, 0); + if (!ret) { + return false; + } + // Calculate mean_square + int num_vector = (norm_size % 8 == 0) ? 8 : ((norm_size % 4 == 0) ? 4 : 1); + size_t global[] = {loops, static_cast(norm_size / num_vector)}; + String build_opt = format(" -DNUM=%d", num_vector) + base_opts; + String mean_square_kernel_name = format("calc_mean%d", num_vector); + ocl::Kernel mean_square_kernel(mean_square_kernel_name.c_str(), ocl::dnn::mvn_oclsrc, build_opt + " -DKERNEL_MEAN"); + if (mean_square_kernel.empty()) { + return false; + } + mean_square_kernel.set(0, ocl::KernelArg::PtrReadOnly(input)); + mean_square_kernel.set(1, (int)loops); + mean_square_kernel.set(2, (int)norm_size); + mean_square_kernel.set(3, ocl::KernelArg::PtrReadOnly(mean)); + mean_square_kernel.set(4, ocl::KernelArg::PtrWriteOnly(tmp)); + ret = mean_square_kernel.run(2, global, NULL, false); + if (!ret) { + return false; + } + ret = ocl4dnn::ocl4dnnGEMV(ocl4dnn::CblasNoTrans, loops, norm_size, inv_norm_size, + tmp, 0, one, 0, 0.f, mean_square, 0); + if (!ret) { + return false; + } + // Calculate group norm: output = scale * (x - mean) / sqrt(var + eps) + bias + String mvn_group_kernel_name = format("mvn_group%d", num_vector); + build_opt += " -DNORM_VARIANCE -DKERNEL_MVN_GROUP"; + ocl::Kernel mvn_group_kernel(mvn_group_kernel_name.c_str(), ocl::dnn::mvn_oclsrc, build_opt); + if (mvn_group_kernel.empty()) { + return false; + } + mvn_group_kernel.set(0, ocl::KernelArg::PtrReadOnly(input)); + mvn_group_kernel.set(1, (int)loops); + mvn_group_kernel.set(2, (int)norm_size); + mvn_group_kernel.set(3, (float)epsilon); + mvn_group_kernel.set(4, ocl::KernelArg::PtrReadOnly(mean)); + mvn_group_kernel.set(5, ocl::KernelArg::PtrReadOnly(mean_square)); + mvn_group_kernel.set(6, ocl::KernelArg::PtrReadOnly(scale)); + mvn_group_kernel.set(7, ocl::KernelArg::PtrReadOnly(bias)); + mvn_group_kernel.set(8, (int)C); + mvn_group_kernel.set(9, (int)num_groups); + mvn_group_kernel.set(10, (float)0.f); + mvn_group_kernel.set(11, ocl::KernelArg::PtrWriteOnly(output)); + ret = mvn_group_kernel.run(2, global, NULL, false); + if (!ret) { + return false; + } + + return true; + } +#endif + +#ifdef HAVE_CUDA + Ptr initCUDA(void *context_, + const std::vector>& inputs, + const std::vector>& outputs) override { + auto context = reinterpret_cast(context_); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input_shape = input_wrapper->getShape(); + size_t N = input_shape[0]; + size_t num_groups = this->num_groups; + size_t loops = N * num_groups; + + return make_cuda_node(preferableTarget, std::move(context->stream), epsilon, loops, num_groups); +} +#endif // HAVE_CUDA + +private: + float epsilon; + size_t num_groups; +}; + +Ptr GroupNormLayer::create(const LayerParams ¶ms) { + return Ptr(new GroupNormLayerImpl(params)); +} + +}} // cv::dnn diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index f0b33d111b..a6acc6e800 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -4008,6 +4008,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) dispatch["ScatterElements"] = dispatch["Scatter"] = dispatch["ScatterND"] = &ONNXImporter::parseScatter; dispatch["Tile"] = &ONNXImporter::parseTile; dispatch["LayerNormalization"] = &ONNXImporter::parseLayerNorm; + dispatch["GroupNormalization"] = &ONNXImporter::parseInstanceNormalization; dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = dispatch["Pow"] = dispatch["Add"] = dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = dispatch["GreaterOrEqual"] = diff --git a/modules/dnn/src/opencl/mvn.cl b/modules/dnn/src/opencl/mvn.cl index 7353ed8b82..053749b483 100644 --- a/modules/dnn/src/opencl/mvn.cl +++ b/modules/dnn/src/opencl/mvn.cl @@ -54,6 +54,7 @@ #define vec_type Dtype8 #define CALC_MEAN calc_mean8 #define MVN mvn8 + #define MVN_GROUP mvn_group8 #define MEAN_FUSE mean_fuse8 #define MVN_FUSE mvn_fuse8 #elif NUM == 4 @@ -62,6 +63,7 @@ #define vec_type Dtype4 #define CALC_MEAN calc_mean4 #define MVN mvn4 + #define MVN_GROUP mvn_group4 #define MEAN_FUSE mean_fuse4 #define MVN_FUSE mvn_fuse4 #elif NUM == 1 @@ -70,6 +72,7 @@ #define vec_type Dtype #define CALC_MEAN calc_mean1 #define MVN mvn1 + #define MVN_GROUP mvn_group1 #define MEAN_FUSE mean_fuse1 #define MVN_FUSE mvn_fuse1 #endif @@ -150,6 +153,54 @@ __kernel void MVN(__global const Dtype* src, store(dst_vec, dst, index); } +#elif defined KERNEL_MVN_GROUP + +__kernel void MVN_GROUP(__global const Dtype* src, + const int rows, + const int cols, + const Dtype eps, + __global const Dtype* mean, + __global const Dtype* dev, + __global const Dtype* weight, + __global const Dtype* bias, + const int channels, + const int num_groups, + const float relu_slope, + __global Dtype* dst) +{ + int x = get_global_id(0); + int y = get_global_id(1) * NUM; + int index = x * cols + y; + + if (x >= rows || y >= cols) + return; + + int group_size = channels / num_groups; + int step = norm_size / group_size; + int channel_index = x % num_groups * group_size + y / step + Dtype mean_val = mean[x]; + Dtype dev_val = dev[x]; + Dtype alpha; +#ifdef NORM_VARIANCE + alpha = 1 / sqrt(eps + dev_val); +#else + alpha = 1; +#endif + + Dtype w = weight[channel_index], b = bias[channel_index]; + + vec_type src_vec = load(src, index) - (vec_type)mean_val; + vec_type dst_vec = src_vec * alpha; + dst_vec = dst_vec * w + (vec_type)b; + +#ifdef FUSE_RELU + vec_type new_val = dst_vec * relu_slope; + dst_vec = select(new_val, dst_vec, dst_vec > (vec_type)0.f); +#endif + + store(dst_vec, dst, index); +} + #elif defined KERNEL_MEAN_FUSE __kernel void MEAN_FUSE(__global const T * A, diff --git a/modules/dnn/test/test_onnx_conformance.cpp b/modules/dnn/test/test_onnx_conformance.cpp index 5b783722c4..1ca3f2f75b 100644 --- a/modules/dnn/test/test_onnx_conformance.cpp +++ b/modules/dnn/test/test_onnx_conformance.cpp @@ -311,6 +311,8 @@ static const TestCase testConformanceConfig[] = { {"test_gridsample_nearest", 2, 1}, {"test_gridsample_reflection_padding", 2, 1}, {"test_gridsample_zeros_padding", 2, 1}, + {"test_group_normalization_epsilon", 3, 1}, + {"test_group_normalization_example", 3, 1}, {"test_gru_batchwise", 3, 2}, {"test_gru_defaults", 3, 1}, {"test_gru_seq_length", 4, 1}, diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp index 199bfdcd18..291ea30e92 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp @@ -736,6 +736,10 @@ CASE(test_gridsample_reflection_padding) // no filter CASE(test_gridsample_zeros_padding) // no filter +CASE(test_group_normalization_epsilon) + // no filter +CASE(test_group_normalization_example) + // no filter CASE(test_gru_batchwise) // no filter CASE(test_gru_defaults)