mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 09:25:45 +08:00
Merge pull request #24610 from jimmylaw21:dnn-onnx-add-group-norm-layer
dnn onnx: add group norm layer #24610 dnn onnx: add group norm layer Todo: - [x] speed up by multi-threading - [x] add perf - [x] add backend: OpenVINO - [x] add backend: CUDA - [x] add backend: OpenCL (no fp16) - [ ] add backend: CANN ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [ ] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake Co-authored-by: fengyuentau <yuantao.feng@opencv.org.cn>
This commit is contained in:
parent
97c418ab86
commit
a7fa1e6f4b
@ -1183,6 +1183,11 @@ CV__DNN_INLINE_NS_BEGIN
|
|||||||
static Ptr<AttentionLayer> create(const LayerParams ¶ms);
|
static Ptr<AttentionLayer> create(const LayerParams ¶ms);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class CV_EXPORTS GroupNormLayer : public Layer {
|
||||||
|
public:
|
||||||
|
static Ptr<GroupNormLayer> create(const LayerParams ¶ms);
|
||||||
|
};
|
||||||
|
|
||||||
//! @}
|
//! @}
|
||||||
//! @}
|
//! @}
|
||||||
CV__DNN_INLINE_NS_END
|
CV__DNN_INLINE_NS_END
|
||||||
|
@ -795,6 +795,66 @@ PERF_TEST_P_(Layer_Attention, VisionTransformer) {
|
|||||||
test_layer({1, 197, 768}, {768, 768, 768}, 12);
|
test_layer({1, 197, 768}, {768, 768, 768}, 12);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Layer_GroupNorm : public TestBaseWithParam<tuple<Backend, Target> >
|
||||||
|
{
|
||||||
|
void test_layer(const std::vector<int>& x_shape, int num_groups)
|
||||||
|
{
|
||||||
|
int backendId = get<0>(GetParam());
|
||||||
|
int targetId = get<1>(GetParam());
|
||||||
|
|
||||||
|
Mat x(x_shape, CV_32FC1);
|
||||||
|
Mat scale(x_shape[1], 1, CV_32FC1);
|
||||||
|
Mat b(x_shape[1], 1, CV_32FC1);
|
||||||
|
|
||||||
|
randu(x, 0.f, 1.f);
|
||||||
|
randu(scale, 0.f, 1.f);
|
||||||
|
randu(b, 0.f, 1.f);
|
||||||
|
|
||||||
|
Net net;
|
||||||
|
LayerParams lp;
|
||||||
|
lp.type = "GroupNormalization";
|
||||||
|
lp.name = "testLayer";
|
||||||
|
lp.set("num_groups", num_groups);
|
||||||
|
|
||||||
|
int id = net.addLayerToPrev(lp.name, lp.type, lp);
|
||||||
|
net.connect(0, 0, id, 0);
|
||||||
|
net.connect(0, 1, id, 1);
|
||||||
|
net.connect(0, 2, id, 2);
|
||||||
|
|
||||||
|
// warmup
|
||||||
|
{
|
||||||
|
std::vector<String> inpNames{"x", "scale", "b"};
|
||||||
|
net.setInputsNames(inpNames);
|
||||||
|
net.setInput(x, inpNames[0]);
|
||||||
|
net.setInput(scale, inpNames[1]);
|
||||||
|
net.setInput(b, inpNames[2]);
|
||||||
|
|
||||||
|
net.setPreferableBackend(backendId);
|
||||||
|
net.setPreferableTarget(targetId);
|
||||||
|
Mat out = net.forward();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CYCLE()
|
||||||
|
{
|
||||||
|
Mat res = net.forward();
|
||||||
|
}
|
||||||
|
|
||||||
|
SANITY_CHECK_NOTHING();
|
||||||
|
}
|
||||||
|
|
||||||
|
int N = 2;
|
||||||
|
int C = 64;
|
||||||
|
int H = 180;
|
||||||
|
int W = 240;
|
||||||
|
int num_groups = 16;
|
||||||
|
};
|
||||||
|
|
||||||
|
PERF_TEST_P_(Layer_GroupNorm, GroupNorm)
|
||||||
|
{
|
||||||
|
test_layer({N, C, H, W}, num_groups);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(/**/, Layer_Slice, dnnBackendsAndTargets(false, false));
|
INSTANTIATE_TEST_CASE_P(/**/, Layer_Slice, dnnBackendsAndTargets(false, false));
|
||||||
INSTANTIATE_TEST_CASE_P(/**/, Layer_NaryEltwise, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
INSTANTIATE_TEST_CASE_P(/**/, Layer_NaryEltwise, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
||||||
#ifdef HAVE_CUDA
|
#ifdef HAVE_CUDA
|
||||||
@ -807,7 +867,7 @@ INSTANTIATE_TEST_CASE_P(/**/, Layer_LayerNormExpanded, testing::Values(std::make
|
|||||||
INSTANTIATE_TEST_CASE_P(/**/, Layer_GatherElements, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
INSTANTIATE_TEST_CASE_P(/**/, Layer_GatherElements, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
||||||
INSTANTIATE_TEST_CASE_P(/**/, Layer_InstanceNorm, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
INSTANTIATE_TEST_CASE_P(/**/, Layer_InstanceNorm, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
||||||
INSTANTIATE_TEST_CASE_P(/**/, Layer_Attention, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
INSTANTIATE_TEST_CASE_P(/**/, Layer_Attention, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
||||||
|
INSTANTIATE_TEST_CASE_P(/**/, Layer_GroupNorm, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
||||||
|
|
||||||
typedef TestBaseWithParam<tuple<Vec4i, int, bool, tuple<Backend, Target> > > Layer_FullyConnected;
|
typedef TestBaseWithParam<tuple<Vec4i, int, bool, tuple<Backend, Target> > > Layer_FullyConnected;
|
||||||
PERF_TEST_P_(Layer_FullyConnected, fc)
|
PERF_TEST_P_(Layer_FullyConnected, fc)
|
||||||
|
@ -78,6 +78,18 @@ namespace raw {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
__global__ void normalize_mean_variance_groupwise(Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> inv_stddev, size_type inner_size, size_type C, size_type num_groups, size_type group_size) {
|
||||||
|
for (auto idx : grid_stride_range(output.size())) {
|
||||||
|
const index_type outer_idx = idx / inner_size;
|
||||||
|
const index_type c = outer_idx % C;
|
||||||
|
const index_type group_idx = outer_idx / group_size;
|
||||||
|
auto s = static_cast<float>(scale[c]) * inv_stddev[group_idx];
|
||||||
|
auto b = static_cast<float>(bias[c]);
|
||||||
|
output[idx] = (static_cast<float>(input[idx]) - means[group_idx]) * s + b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
__global__ void normalize_mean_variance_layernorm(Span<T> output, View<T> input, View<T> scale, View<float> means, View<float> inv_stddev, size_type inner_size) {
|
__global__ void normalize_mean_variance_layernorm(Span<T> output, View<T> input, View<T> scale, View<float> means, View<float> inv_stddev, size_type inner_size) {
|
||||||
for (auto idx : grid_stride_range(output.size())) {
|
for (auto idx : grid_stride_range(output.size())) {
|
||||||
@ -191,6 +203,24 @@ template void normalize_mean_variance_channelwise(const Stream&, Span<__half> /*
|
|||||||
#endif
|
#endif
|
||||||
template void normalize_mean_variance_channelwise(const Stream&, Span<float> /*output*/, View<float> /*input*/, View<float> /*scale*/, View<float> /*bias*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t, std::size_t);
|
template void normalize_mean_variance_channelwise(const Stream&, Span<float> /*output*/, View<float> /*input*/, View<float> /*scale*/, View<float> /*bias*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t, std::size_t);
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void normalize_mean_variance_groupwise(const Stream& stream, Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> inv_stddev, std::size_t inner_size, std::size_t C, std::size_t num_groups, std::size_t group_size)
|
||||||
|
{
|
||||||
|
CV_Assert(input.size() == output.size());
|
||||||
|
CV_Assert(input.size() / inner_size == means.size() * group_size);
|
||||||
|
CV_Assert(means.size() == inv_stddev.size());
|
||||||
|
|
||||||
|
auto kernel = raw::normalize_mean_variance_groupwise<T>;
|
||||||
|
auto policy = make_policy(kernel, output.size(), 0, stream);
|
||||||
|
launch_kernel(kernel, policy, output, input, scale, bias, means, inv_stddev, inner_size, C, num_groups, group_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
|
||||||
|
template void normalize_mean_variance_groupwise(const Stream&, Span<__half> /*output*/, View<__half> /*input*/, View<__half> /*scale*/, View<__half> /*bias*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t, std::size_t, std::size_t, std::size_t);
|
||||||
|
#endif
|
||||||
|
template void normalize_mean_variance_groupwise(const Stream&, Span<float> /*output*/, View<float> /*input*/, View<float> /*scale*/, View<float> /*bias*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t, std::size_t, std::size_t, std::size_t);
|
||||||
|
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
void normalize_mean_variance_layernorm(const Stream& stream, Span<T> output, View<T> input, View<T> scale, View<float> means, View<float> inv_stddev, std::size_t inner_size)
|
void normalize_mean_variance_layernorm(const Stream& stream, Span<T> output, View<T> input, View<T> scale, View<float> means, View<float> inv_stddev, std::size_t inner_size)
|
||||||
{
|
{
|
||||||
|
@ -35,6 +35,10 @@ void normalize_mean_variance_layernorm(const csl::Stream &stream, csl::Span<T> o
|
|||||||
template <class T>
|
template <class T>
|
||||||
void normalize_mean_variance_layernorm(const csl::Stream &stream, csl::Span<T> output, csl::View<T> input, csl::View<T> scale, csl::View<T> bias, csl::View<float> means, csl::View<float> inv_stddev, std::size_t inner_size);
|
void normalize_mean_variance_layernorm(const csl::Stream &stream, csl::Span<T> output, csl::View<T> input, csl::View<T> scale, csl::View<T> bias, csl::View<float> means, csl::View<float> inv_stddev, std::size_t inner_size);
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void normalize_mean_variance_groupwise(const csl::Stream &stream, csl::Span<T> output, csl::View<T> input, csl::View<T> scale, csl::View<T> bias, csl::View<float> means, csl::View<float> inv_stddev, std::size_t inner_size, std::size_t C, std::size_t num_groups, std::size_t group_size);
|
||||||
|
|
||||||
|
|
||||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
||||||
|
|
||||||
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_MVN_HPP */
|
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_MVN_HPP */
|
||||||
|
87
modules/dnn/src/cuda4dnn/primitives/group_norm.hpp
Normal file
87
modules/dnn/src/cuda4dnn/primitives/group_norm.hpp
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
// This file is part of OpenCV project.
|
||||||
|
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||||
|
// of this distribution and at http://opencv.org/license.html.
|
||||||
|
|
||||||
|
#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_GROUP_NORM_HPP
|
||||||
|
#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_GROUP_NORM_HPP
|
||||||
|
|
||||||
|
#include "../../op_cuda.hpp"
|
||||||
|
|
||||||
|
#include "../csl/stream.hpp"
|
||||||
|
#include "../csl/span.hpp"
|
||||||
|
#include "../csl/tensor.hpp"
|
||||||
|
#include "../csl/workspace.hpp"
|
||||||
|
|
||||||
|
#include "../kernels/fill_copy.hpp"
|
||||||
|
#include "../kernels/mvn.hpp"
|
||||||
|
|
||||||
|
#include <opencv2/core.hpp>
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace cv { namespace dnn { namespace cuda4dnn {
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
class GroupNormOp final : public CUDABackendNode {
|
||||||
|
public:
|
||||||
|
using wrapper_type = GetCUDABackendWrapperType<T>;
|
||||||
|
|
||||||
|
GroupNormOp(csl::Stream stream_, float epsilon_, size_t loops, size_t num_groups)
|
||||||
|
: stream(std::move(stream_)), epsilon(epsilon_), num_groups(num_groups) {
|
||||||
|
csl::WorkspaceBuilder builder;
|
||||||
|
builder.require<float>(loops * num_groups); // mean and stdev for each group
|
||||||
|
builder.require<float>(loops * num_groups);
|
||||||
|
scratch_mem_in_bytes = builder.required_workspace_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void forward(const std::vector<cv::Ptr<BackendWrapper>>& inputs,
|
||||||
|
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
|
||||||
|
csl::Workspace& workspace) override {
|
||||||
|
auto input_wrapper = inputs[0].dynamicCast<wrapper_type>();
|
||||||
|
auto scale_wrapper = inputs[1].dynamicCast<wrapper_type>();
|
||||||
|
auto bias_wrapper = inputs[2].dynamicCast<wrapper_type>();
|
||||||
|
|
||||||
|
auto input = input_wrapper->getView();
|
||||||
|
auto scale = scale_wrapper->getView();
|
||||||
|
auto bias = bias_wrapper->getView();
|
||||||
|
|
||||||
|
auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
|
||||||
|
auto output = output_wrapper->getSpan();
|
||||||
|
|
||||||
|
auto C = input.get_axis_size(1);
|
||||||
|
auto loops = input.size_range(0, 2);
|
||||||
|
auto norm_size = input.size_range(2, input.rank());
|
||||||
|
auto num_groups = this->num_groups;
|
||||||
|
auto group_size = C / num_groups;
|
||||||
|
if (norm_size == 1) {
|
||||||
|
kernels::fill<T>(stream, output, 0.f);
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
auto ws_allocator = csl::WorkspaceAllocator(workspace);
|
||||||
|
|
||||||
|
auto mean = ws_allocator.get_span<float>(loops / group_size);
|
||||||
|
kernels::fill<float>(stream, mean, 0.f);
|
||||||
|
|
||||||
|
auto stdev = ws_allocator.get_span<float>(loops / group_size);
|
||||||
|
kernels::fill<float>(stream, stdev, 0.f);
|
||||||
|
|
||||||
|
kernels::reduce_mean_sqr_sum<T>(stream, mean, stdev, input, norm_size * group_size);
|
||||||
|
kernels::compute_normalization_scale(stream, stdev, mean, stdev, norm_size * group_size, epsilon);
|
||||||
|
kernels::normalize_mean_variance_groupwise<T>(stream, output, input, scale, bias, mean, stdev, norm_size, C, num_groups, group_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
csl::Stream stream;
|
||||||
|
float epsilon;
|
||||||
|
std::size_t num_groups;
|
||||||
|
std::size_t scratch_mem_in_bytes;
|
||||||
|
};
|
||||||
|
|
||||||
|
}}} // cv::dnn::cuda4dnn
|
||||||
|
|
||||||
|
#endif // OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_GROUP_NORM_HPP
|
@ -163,6 +163,7 @@ void initializeLayerFactory()
|
|||||||
CV_DNN_REGISTER_LAYER_CLASS(Expand, ExpandLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Expand, ExpandLayer);
|
||||||
CV_DNN_REGISTER_LAYER_CLASS(InstanceNormalization, InstanceNormLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(InstanceNormalization, InstanceNormLayer);
|
||||||
CV_DNN_REGISTER_LAYER_CLASS(Attention, AttentionLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Attention, AttentionLayer);
|
||||||
|
CV_DNN_REGISTER_LAYER_CLASS(GroupNormalization, GroupNormLayer);
|
||||||
|
|
||||||
CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer);
|
||||||
CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer);
|
||||||
|
@ -158,4 +158,51 @@ void fastNormChannel(const Mat &input, const Mat &scale, const Mat &bias, Mat &o
|
|||||||
parallel_for_(Range(0, loops), fn, nstripes);
|
parallel_for_(Range(0, loops), fn, nstripes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void fastNormGroup(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, float epsilon, size_t num_groups) {
|
||||||
|
const auto input_shape = shape(input);
|
||||||
|
size_t N = input_shape[0], C = input_shape[1];
|
||||||
|
CV_CheckEQ(scale.total(), bias.total(), "fastNormGroup: scale and bias should have the same shape");
|
||||||
|
CV_CheckEQ(scale.total(), C, "fastNormGroup: scale should be a 1d tensor and match the channel of input");
|
||||||
|
CV_CheckGE(input.dims, 3, "fastNormGroup: input dimension >= 3");
|
||||||
|
|
||||||
|
size_t channels_per_group = C / num_groups;
|
||||||
|
size_t loops = N * num_groups;
|
||||||
|
size_t norm_size = static_cast<size_t>(total(input_shape, 2) * channels_per_group);
|
||||||
|
size_t step = norm_size / channels_per_group;
|
||||||
|
float inv_norm_size = 1.0 / norm_size;
|
||||||
|
|
||||||
|
auto fn = [&](const Range &r) {
|
||||||
|
const auto *input_data = input.ptr<const float>();
|
||||||
|
const auto *scale_data = scale.ptr<const float>();
|
||||||
|
const auto *bias_data = bias.ptr<const float>();
|
||||||
|
auto *output_data = output.ptr<float>();
|
||||||
|
|
||||||
|
for (int i = r.start; i < r.end; i++) {
|
||||||
|
const auto *x = input_data + norm_size * i;
|
||||||
|
auto *y = output_data + norm_size * i;
|
||||||
|
|
||||||
|
float mean = 0.f, mean_square = 0.f;
|
||||||
|
for (int j = 0; j < norm_size; j++) {
|
||||||
|
float v = x[j];
|
||||||
|
mean += v;
|
||||||
|
mean_square += v * v;
|
||||||
|
}
|
||||||
|
|
||||||
|
mean *= inv_norm_size;
|
||||||
|
mean_square = std::sqrt(std::max(0.f, mean_square * inv_norm_size - mean * mean) + epsilon);
|
||||||
|
float inv_stdev = 1.f / mean_square;
|
||||||
|
|
||||||
|
size_t group_idx = i % num_groups * channels_per_group;
|
||||||
|
for (size_t j = 0; j < norm_size; j++) {
|
||||||
|
size_t c = group_idx + (j / step);
|
||||||
|
float s = scale_data[c] * inv_stdev, b = bias_data[c];
|
||||||
|
y[j] = s * (x[j] - mean) + b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
double nstripes = loops * norm_size * (1 / 1024.0);
|
||||||
|
parallel_for_(Range(0, loops), fn, nstripes);
|
||||||
|
}
|
||||||
|
|
||||||
}} // cv::dnn
|
}} // cv::dnn
|
||||||
|
@ -21,6 +21,9 @@ void fastNorm(const Mat &input, const Mat &scale, const Mat &bias, Mat &output,
|
|||||||
// Channel-wise Normalization speedup by multi-threading. Scale and bias should have the same shape (C). Input should have dimension >= 3.
|
// Channel-wise Normalization speedup by multi-threading. Scale and bias should have the same shape (C). Input should have dimension >= 3.
|
||||||
void fastNormChannel(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, float epsilon);
|
void fastNormChannel(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, float epsilon);
|
||||||
|
|
||||||
|
// Group-wise Normalization speedup by multi-threading. Scale and bias should have the same shape (C). Input should have dimension >= 3.
|
||||||
|
void fastNormGroup(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, float epsilon, size_t num_groups);
|
||||||
|
|
||||||
}} // cv::dnn
|
}} // cv::dnn
|
||||||
|
|
||||||
#endif // OPENCV_DNN_FAST_NORM_HPP
|
#endif // OPENCV_DNN_FAST_NORM_HPP
|
||||||
|
190
modules/dnn/src/layers/group_norm_layer.cpp
Normal file
190
modules/dnn/src/layers/group_norm_layer.cpp
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
// This file is part of OpenCV project.
|
||||||
|
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||||
|
// of this distribution and at http://opencv.org/license.html.
|
||||||
|
|
||||||
|
#include "../precomp.hpp"
|
||||||
|
#include <opencv2/dnn/shape_utils.hpp>
|
||||||
|
#include "./cpu_kernels/fast_norm.hpp"
|
||||||
|
|
||||||
|
// CUDA backend
|
||||||
|
#include "../op_cuda.hpp"
|
||||||
|
#ifdef HAVE_CUDA
|
||||||
|
#include "../cuda4dnn/primitives/group_norm.hpp"
|
||||||
|
using namespace cv::dnn::cuda4dnn;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// OpenCL backend
|
||||||
|
#ifdef HAVE_OPENCL
|
||||||
|
#include "../ocl4dnn/include/math_functions.hpp"
|
||||||
|
#include "opencl_kernels_dnn.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace cv {
|
||||||
|
namespace dnn {
|
||||||
|
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#GroupNormalization
|
||||||
|
class GroupNormLayerImpl CV_FINAL : public GroupNormLayer {
|
||||||
|
public:
|
||||||
|
GroupNormLayerImpl(const LayerParams ¶ms) {
|
||||||
|
setParamsFrom(params);
|
||||||
|
|
||||||
|
epsilon = params.get<float>("epsilon", 1e-5);
|
||||||
|
num_groups = params.get<int>("num_groups");
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual bool supportBackend(int backendId) CV_OVERRIDE {
|
||||||
|
return backendId == DNN_BACKEND_OPENCV ||
|
||||||
|
backendId == DNN_BACKEND_CUDA;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||||
|
const int requiredOutputs,
|
||||||
|
std::vector<MatShape> &outputs,
|
||||||
|
std::vector<MatShape> &internals) const CV_OVERRIDE {
|
||||||
|
const auto &input = inputs[0];
|
||||||
|
const auto &scale = inputs[1];
|
||||||
|
const auto &bias = inputs[2];
|
||||||
|
CV_CheckGE(input.size(), static_cast<size_t>(3), "DNN/GroupNorm: input dimension >= 3 is required");
|
||||||
|
|
||||||
|
int C = input[1];
|
||||||
|
int scale_dim = std::accumulate(scale.begin(), scale.end(), 1, std::multiplies<int>());
|
||||||
|
CV_CheckEQ(scale_dim, C, "DNN/InstanceNorm: scale must be a 1d tensor and match the channel of input");
|
||||||
|
int bias_dim = std::accumulate(bias.begin(), bias.end(), 1, std::multiplies<int>());
|
||||||
|
CV_CheckEQ(bias_dim, C, "DNN/InstanceNorm: bias must be a 1d tensor and match the channel of input");
|
||||||
|
|
||||||
|
outputs.assign(1, inputs[0]);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE {
|
||||||
|
CV_TRACE_FUNCTION();
|
||||||
|
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||||
|
|
||||||
|
if (inputs_arr.depth() == CV_16S) {
|
||||||
|
forward_fallback(inputs_arr, outputs_arr, internals_arr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Mat> inputs, outputs;
|
||||||
|
inputs_arr.getMatVector(inputs);
|
||||||
|
outputs_arr.getMatVector(outputs);
|
||||||
|
|
||||||
|
const auto& input = inputs[0];
|
||||||
|
const auto& scale = inputs[1];
|
||||||
|
const auto& bias = inputs[2];
|
||||||
|
|
||||||
|
fastNormGroup(input, scale, bias, outputs[0], epsilon, num_groups);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef HAVE_OPENCL
|
||||||
|
bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_) {
|
||||||
|
std::vector<UMat> inputs;
|
||||||
|
std::vector<UMat> outputs;
|
||||||
|
|
||||||
|
inputs_.getUMatVector(inputs);
|
||||||
|
outputs_.getUMatVector(outputs);
|
||||||
|
|
||||||
|
const auto &input = inputs[0], &scale = inputs[1], &bias = inputs[2];
|
||||||
|
auto &output = outputs[0];
|
||||||
|
|
||||||
|
const auto input_shape = shape(input);
|
||||||
|
size_t N = input_shape[0], C = input_shape[1];
|
||||||
|
size_t num_groups = this->num_groups;
|
||||||
|
size_t channels_per_group = C / num_groups;
|
||||||
|
size_t loops = N * num_groups, norm_size = static_cast<size_t>(total(input_shape, 2)) * channels_per_group;
|
||||||
|
float inv_norm_size = 1.f / norm_size;
|
||||||
|
|
||||||
|
// no fp16 support
|
||||||
|
if (input.depth() == CV_16S) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
String base_opts = format(" -DT=float -DT4=float4 -Dconvert_T=convert_float4");
|
||||||
|
|
||||||
|
// Calculate mean
|
||||||
|
UMat one = UMat::ones(norm_size, 1, CV_32F);
|
||||||
|
UMat mean = UMat(loops, 1, CV_32F);
|
||||||
|
UMat mean_square = UMat(loops, 1, CV_32F);
|
||||||
|
UMat tmp = UMat(loops, norm_size, CV_32F);
|
||||||
|
bool ret = ocl4dnn::ocl4dnnGEMV<float>(ocl4dnn::CblasNoTrans, loops, norm_size, inv_norm_size,
|
||||||
|
input, 0, one, 0, 0.f, mean, 0);
|
||||||
|
if (!ret) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Calculate mean_square
|
||||||
|
int num_vector = (norm_size % 8 == 0) ? 8 : ((norm_size % 4 == 0) ? 4 : 1);
|
||||||
|
size_t global[] = {loops, static_cast<size_t>(norm_size / num_vector)};
|
||||||
|
String build_opt = format(" -DNUM=%d", num_vector) + base_opts;
|
||||||
|
String mean_square_kernel_name = format("calc_mean%d", num_vector);
|
||||||
|
ocl::Kernel mean_square_kernel(mean_square_kernel_name.c_str(), ocl::dnn::mvn_oclsrc, build_opt + " -DKERNEL_MEAN");
|
||||||
|
if (mean_square_kernel.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
mean_square_kernel.set(0, ocl::KernelArg::PtrReadOnly(input));
|
||||||
|
mean_square_kernel.set(1, (int)loops);
|
||||||
|
mean_square_kernel.set(2, (int)norm_size);
|
||||||
|
mean_square_kernel.set(3, ocl::KernelArg::PtrReadOnly(mean));
|
||||||
|
mean_square_kernel.set(4, ocl::KernelArg::PtrWriteOnly(tmp));
|
||||||
|
ret = mean_square_kernel.run(2, global, NULL, false);
|
||||||
|
if (!ret) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ret = ocl4dnn::ocl4dnnGEMV<float>(ocl4dnn::CblasNoTrans, loops, norm_size, inv_norm_size,
|
||||||
|
tmp, 0, one, 0, 0.f, mean_square, 0);
|
||||||
|
if (!ret) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Calculate group norm: output = scale * (x - mean) / sqrt(var + eps) + bias
|
||||||
|
String mvn_group_kernel_name = format("mvn_group%d", num_vector);
|
||||||
|
build_opt += " -DNORM_VARIANCE -DKERNEL_MVN_GROUP";
|
||||||
|
ocl::Kernel mvn_group_kernel(mvn_group_kernel_name.c_str(), ocl::dnn::mvn_oclsrc, build_opt);
|
||||||
|
if (mvn_group_kernel.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
mvn_group_kernel.set(0, ocl::KernelArg::PtrReadOnly(input));
|
||||||
|
mvn_group_kernel.set(1, (int)loops);
|
||||||
|
mvn_group_kernel.set(2, (int)norm_size);
|
||||||
|
mvn_group_kernel.set(3, (float)epsilon);
|
||||||
|
mvn_group_kernel.set(4, ocl::KernelArg::PtrReadOnly(mean));
|
||||||
|
mvn_group_kernel.set(5, ocl::KernelArg::PtrReadOnly(mean_square));
|
||||||
|
mvn_group_kernel.set(6, ocl::KernelArg::PtrReadOnly(scale));
|
||||||
|
mvn_group_kernel.set(7, ocl::KernelArg::PtrReadOnly(bias));
|
||||||
|
mvn_group_kernel.set(8, (int)C);
|
||||||
|
mvn_group_kernel.set(9, (int)num_groups);
|
||||||
|
mvn_group_kernel.set(10, (float)0.f);
|
||||||
|
mvn_group_kernel.set(11, ocl::KernelArg::PtrWriteOnly(output));
|
||||||
|
ret = mvn_group_kernel.run(2, global, NULL, false);
|
||||||
|
if (!ret) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef HAVE_CUDA
|
||||||
|
Ptr<BackendNode> initCUDA(void *context_,
|
||||||
|
const std::vector<Ptr<BackendWrapper>>& inputs,
|
||||||
|
const std::vector<Ptr<BackendWrapper>>& outputs) override {
|
||||||
|
auto context = reinterpret_cast<csl::CSLContext*>(context_);
|
||||||
|
|
||||||
|
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
|
||||||
|
auto input_shape = input_wrapper->getShape();
|
||||||
|
size_t N = input_shape[0];
|
||||||
|
size_t num_groups = this->num_groups;
|
||||||
|
size_t loops = N * num_groups;
|
||||||
|
|
||||||
|
return make_cuda_node<cuda4dnn::GroupNormOp>(preferableTarget, std::move(context->stream), epsilon, loops, num_groups);
|
||||||
|
}
|
||||||
|
#endif // HAVE_CUDA
|
||||||
|
|
||||||
|
private:
|
||||||
|
float epsilon;
|
||||||
|
size_t num_groups;
|
||||||
|
};
|
||||||
|
|
||||||
|
Ptr<GroupNormLayer> GroupNormLayer::create(const LayerParams ¶ms) {
|
||||||
|
return Ptr<GroupNormLayer>(new GroupNormLayerImpl(params));
|
||||||
|
}
|
||||||
|
|
||||||
|
}} // cv::dnn
|
@ -4008,6 +4008,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
|
|||||||
dispatch["ScatterElements"] = dispatch["Scatter"] = dispatch["ScatterND"] = &ONNXImporter::parseScatter;
|
dispatch["ScatterElements"] = dispatch["Scatter"] = dispatch["ScatterND"] = &ONNXImporter::parseScatter;
|
||||||
dispatch["Tile"] = &ONNXImporter::parseTile;
|
dispatch["Tile"] = &ONNXImporter::parseTile;
|
||||||
dispatch["LayerNormalization"] = &ONNXImporter::parseLayerNorm;
|
dispatch["LayerNormalization"] = &ONNXImporter::parseLayerNorm;
|
||||||
|
dispatch["GroupNormalization"] = &ONNXImporter::parseInstanceNormalization;
|
||||||
|
|
||||||
dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = dispatch["Pow"] = dispatch["Add"] =
|
dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = dispatch["Pow"] = dispatch["Add"] =
|
||||||
dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = dispatch["GreaterOrEqual"] =
|
dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = dispatch["GreaterOrEqual"] =
|
||||||
|
@ -54,6 +54,7 @@
|
|||||||
#define vec_type Dtype8
|
#define vec_type Dtype8
|
||||||
#define CALC_MEAN calc_mean8
|
#define CALC_MEAN calc_mean8
|
||||||
#define MVN mvn8
|
#define MVN mvn8
|
||||||
|
#define MVN_GROUP mvn_group8
|
||||||
#define MEAN_FUSE mean_fuse8
|
#define MEAN_FUSE mean_fuse8
|
||||||
#define MVN_FUSE mvn_fuse8
|
#define MVN_FUSE mvn_fuse8
|
||||||
#elif NUM == 4
|
#elif NUM == 4
|
||||||
@ -62,6 +63,7 @@
|
|||||||
#define vec_type Dtype4
|
#define vec_type Dtype4
|
||||||
#define CALC_MEAN calc_mean4
|
#define CALC_MEAN calc_mean4
|
||||||
#define MVN mvn4
|
#define MVN mvn4
|
||||||
|
#define MVN_GROUP mvn_group4
|
||||||
#define MEAN_FUSE mean_fuse4
|
#define MEAN_FUSE mean_fuse4
|
||||||
#define MVN_FUSE mvn_fuse4
|
#define MVN_FUSE mvn_fuse4
|
||||||
#elif NUM == 1
|
#elif NUM == 1
|
||||||
@ -70,6 +72,7 @@
|
|||||||
#define vec_type Dtype
|
#define vec_type Dtype
|
||||||
#define CALC_MEAN calc_mean1
|
#define CALC_MEAN calc_mean1
|
||||||
#define MVN mvn1
|
#define MVN mvn1
|
||||||
|
#define MVN_GROUP mvn_group1
|
||||||
#define MEAN_FUSE mean_fuse1
|
#define MEAN_FUSE mean_fuse1
|
||||||
#define MVN_FUSE mvn_fuse1
|
#define MVN_FUSE mvn_fuse1
|
||||||
#endif
|
#endif
|
||||||
@ -150,6 +153,54 @@ __kernel void MVN(__global const Dtype* src,
|
|||||||
store(dst_vec, dst, index);
|
store(dst_vec, dst, index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#elif defined KERNEL_MVN_GROUP
|
||||||
|
|
||||||
|
__kernel void MVN_GROUP(__global const Dtype* src,
|
||||||
|
const int rows,
|
||||||
|
const int cols,
|
||||||
|
const Dtype eps,
|
||||||
|
__global const Dtype* mean,
|
||||||
|
__global const Dtype* dev,
|
||||||
|
__global const Dtype* weight,
|
||||||
|
__global const Dtype* bias,
|
||||||
|
const int channels,
|
||||||
|
const int num_groups,
|
||||||
|
const float relu_slope,
|
||||||
|
__global Dtype* dst)
|
||||||
|
{
|
||||||
|
int x = get_global_id(0);
|
||||||
|
int y = get_global_id(1) * NUM;
|
||||||
|
int index = x * cols + y;
|
||||||
|
|
||||||
|
if (x >= rows || y >= cols)
|
||||||
|
return;
|
||||||
|
|
||||||
|
int group_size = channels / num_groups;
|
||||||
|
int step = norm_size / group_size;
|
||||||
|
int channel_index = x % num_groups * group_size + y / step
|
||||||
|
Dtype mean_val = mean[x];
|
||||||
|
Dtype dev_val = dev[x];
|
||||||
|
Dtype alpha;
|
||||||
|
#ifdef NORM_VARIANCE
|
||||||
|
alpha = 1 / sqrt(eps + dev_val);
|
||||||
|
#else
|
||||||
|
alpha = 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Dtype w = weight[channel_index], b = bias[channel_index];
|
||||||
|
|
||||||
|
vec_type src_vec = load(src, index) - (vec_type)mean_val;
|
||||||
|
vec_type dst_vec = src_vec * alpha;
|
||||||
|
dst_vec = dst_vec * w + (vec_type)b;
|
||||||
|
|
||||||
|
#ifdef FUSE_RELU
|
||||||
|
vec_type new_val = dst_vec * relu_slope;
|
||||||
|
dst_vec = select(new_val, dst_vec, dst_vec > (vec_type)0.f);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
store(dst_vec, dst, index);
|
||||||
|
}
|
||||||
|
|
||||||
#elif defined KERNEL_MEAN_FUSE
|
#elif defined KERNEL_MEAN_FUSE
|
||||||
|
|
||||||
__kernel void MEAN_FUSE(__global const T * A,
|
__kernel void MEAN_FUSE(__global const T * A,
|
||||||
|
@ -311,6 +311,8 @@ static const TestCase testConformanceConfig[] = {
|
|||||||
{"test_gridsample_nearest", 2, 1},
|
{"test_gridsample_nearest", 2, 1},
|
||||||
{"test_gridsample_reflection_padding", 2, 1},
|
{"test_gridsample_reflection_padding", 2, 1},
|
||||||
{"test_gridsample_zeros_padding", 2, 1},
|
{"test_gridsample_zeros_padding", 2, 1},
|
||||||
|
{"test_group_normalization_epsilon", 3, 1},
|
||||||
|
{"test_group_normalization_example", 3, 1},
|
||||||
{"test_gru_batchwise", 3, 2},
|
{"test_gru_batchwise", 3, 2},
|
||||||
{"test_gru_defaults", 3, 1},
|
{"test_gru_defaults", 3, 1},
|
||||||
{"test_gru_seq_length", 4, 1},
|
{"test_gru_seq_length", 4, 1},
|
||||||
|
@ -736,6 +736,10 @@ CASE(test_gridsample_reflection_padding)
|
|||||||
// no filter
|
// no filter
|
||||||
CASE(test_gridsample_zeros_padding)
|
CASE(test_gridsample_zeros_padding)
|
||||||
// no filter
|
// no filter
|
||||||
|
CASE(test_group_normalization_epsilon)
|
||||||
|
// no filter
|
||||||
|
CASE(test_group_normalization_example)
|
||||||
|
// no filter
|
||||||
CASE(test_gru_batchwise)
|
CASE(test_gru_batchwise)
|
||||||
// no filter
|
// no filter
|
||||||
CASE(test_gru_defaults)
|
CASE(test_gru_defaults)
|
||||||
|
Loading…
Reference in New Issue
Block a user