mirror of
https://github.com/opencv/opencv.git
synced 2024-11-29 05:29:54 +08:00
add MVNOp
This commit is contained in:
parent
77b01deb80
commit
a3106d424b
145
modules/dnn/src/cuda/mvn.cu
Normal file
145
modules/dnn/src/cuda/mvn.cu
Normal file
@ -0,0 +1,145 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "math.hpp"
|
||||
#include "types.hpp"
|
||||
#include "atomics.hpp"
|
||||
#include "grid_stride_range.hpp"
|
||||
#include "execution.hpp"
|
||||
|
||||
#include "../cuda4dnn/csl/stream.hpp"
|
||||
#include "../cuda4dnn/csl/span.hpp"
|
||||
|
||||
#include <opencv2/core.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
using namespace cv::dnn::cuda4dnn::csl;
|
||||
using namespace cv::dnn::cuda4dnn::csl::device;
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
|
||||
namespace raw {
|
||||
template <class T>
|
||||
__global__ void reduce_mean(Span<float> means, View<T> input, size_type inner_size) {
|
||||
for (auto idx : grid_stride_range(input.size())) {
|
||||
const index_type outer_idx = idx / inner_size;
|
||||
atomicAdd(&means[outer_idx], static_cast<float>(input[idx]) / inner_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void reduce_mean_sqr_sum(Span<float> means, Span<float> sum_sqrs, View<T> input, size_type inner_size) {
|
||||
for (auto idx : grid_stride_range(input.size())) {
|
||||
const index_type outer_idx = idx / inner_size;
|
||||
auto x = static_cast<float>(input[idx]);
|
||||
atomicAdd(&means[outer_idx], x / inner_size);
|
||||
atomicAdd(&sum_sqrs[outer_idx], x * x);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_normalization_scale(Span<float> scale, View<float> means, View<float> sums_sqr, size_type inner_size, float eps) {
|
||||
for (auto idx : grid_stride_range(scale.size())) {
|
||||
auto mean = means[idx];
|
||||
auto var = sums_sqr[idx] / inner_size - mean * mean;
|
||||
using device::rsqrt;
|
||||
scale[idx] = rsqrt(eps + var);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void normalize_mean(Span<T> output, View<T> input, View<float> means, size_type inner_size) {
|
||||
for (auto idx : grid_stride_range(output.size())) {
|
||||
const index_type outer_idx = idx / inner_size;
|
||||
output[idx] = static_cast<float>(input[idx]) - means[outer_idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void normalize_mean_variance(Span<T> output, View<T> input, View<float> means, View<float> scale, size_type inner_size) {
|
||||
for (auto idx : grid_stride_range(output.size())) {
|
||||
const index_type outer_idx = idx / inner_size;
|
||||
output[idx] = (static_cast<float>(input[idx]) - means[outer_idx]) * scale[outer_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void reduce_mean(const Stream& stream, Span<float> means, View<T> input, std::size_t inner_size)
|
||||
{
|
||||
CV_Assert(input.size() / inner_size == means.size());
|
||||
|
||||
auto kernel = raw::reduce_mean<T>;
|
||||
auto policy = make_policy(kernel, input.size(), 0, stream);
|
||||
launch_kernel(kernel, policy, means, input, inner_size);
|
||||
}
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
|
||||
template void reduce_mean(const Stream&, Span<float>, View<__half>, std::size_t);
|
||||
#endif
|
||||
template void reduce_mean(const Stream&, Span<float>, View<float>, std::size_t);
|
||||
|
||||
template <class T>
|
||||
void reduce_mean_sqr_sum(const Stream& stream, Span<float> means, Span<float> sum_sqrs, View<T> input, std::size_t inner_size)
|
||||
{
|
||||
CV_Assert(input.size() / inner_size == means.size());
|
||||
CV_Assert(input.size() / inner_size == sum_sqrs.size());
|
||||
|
||||
auto kernel = raw::reduce_mean_sqr_sum<T>;
|
||||
auto policy = make_policy(kernel, input.size(), 0, stream);
|
||||
launch_kernel(kernel, policy, means, sum_sqrs, input, inner_size);
|
||||
}
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
|
||||
template void reduce_mean_sqr_sum(const Stream&, Span<float>, Span<float>, View<__half>, std::size_t);
|
||||
#endif
|
||||
template void reduce_mean_sqr_sum(const Stream&, Span<float>, Span<float>, View<float>, std::size_t);
|
||||
|
||||
void compute_normalization_scale(const Stream& stream, Span<float> scale, View<float> means, View<float> sum_sqrs, std::size_t inner_size, float eps)
|
||||
{
|
||||
CV_Assert(scale.size() == means.size());
|
||||
CV_Assert(scale.size() == sum_sqrs.size());
|
||||
|
||||
auto kernel = raw::compute_normalization_scale;
|
||||
auto policy = make_policy(kernel, scale.size(), 0, stream);
|
||||
launch_kernel(kernel, policy, scale, means, sum_sqrs, inner_size, eps);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void normalize_mean(const Stream& stream, Span<T> output, View<T> input, View<float> means, std::size_t inner_size)
|
||||
{
|
||||
CV_Assert(output.size() == input.size());
|
||||
CV_Assert(input.size() / inner_size == means.size());
|
||||
|
||||
auto kernel = raw::normalize_mean<T>;
|
||||
auto policy = make_policy(kernel, output.size(), 0, stream);
|
||||
launch_kernel(kernel, policy, output, input, means, inner_size);
|
||||
}
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
|
||||
template void normalize_mean(const Stream&, Span<__half>, View<__half>, View<float>, std::size_t);
|
||||
#endif
|
||||
template void normalize_mean(const Stream&, Span<float>, View<float>, View<float>, std::size_t);
|
||||
|
||||
template <class T>
|
||||
void normalize_mean_variance(const Stream& stream, Span<T> output, View<T> input, View<float> means, View<float> scale, std::size_t inner_size)
|
||||
{
|
||||
CV_Assert(input.size() == output.size());
|
||||
CV_Assert(input.size() / inner_size == means.size());
|
||||
CV_Assert(input.size() / inner_size == scale.size());
|
||||
|
||||
auto kernel = raw::normalize_mean_variance<T>;
|
||||
auto policy = make_policy(kernel, output.size(), 0, stream);
|
||||
launch_kernel(kernel, policy, output, input, means, scale, inner_size);
|
||||
}
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
|
||||
template void normalize_mean_variance(const Stream&, Span<__half>, View<__half>, View<float>, View<float>, std::size_t);
|
||||
#endif
|
||||
template void normalize_mean_variance(const Stream&, Span<float>, View<float>, View<float>, View<float>, std::size_t);
|
||||
|
||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
31
modules/dnn/src/cuda4dnn/kernels/mvn.hpp
Normal file
31
modules/dnn/src/cuda4dnn/kernels/mvn.hpp
Normal file
@ -0,0 +1,31 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_MVN_HPP
|
||||
#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_MVN_HPP
|
||||
|
||||
#include "../csl/stream.hpp"
|
||||
#include "../csl/span.hpp"
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
|
||||
template <class T>
|
||||
void reduce_mean(const csl::Stream& stream, csl::Span<float> means, csl::View<T> input, std::size_t inner_size);
|
||||
|
||||
template <class T>
|
||||
void reduce_mean_sqr_sum(const csl::Stream& stream, csl::Span<float> means, csl::Span<float> sum_sqrs, csl::View<T> input, std::size_t inner_size);
|
||||
|
||||
void compute_normalization_scale(const csl::Stream& stream, csl::Span<float> scale, csl::View<float> means, csl::View<float> sum_sqrs, std::size_t inner_size, float eps);
|
||||
|
||||
template <class T>
|
||||
void normalize_mean(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, csl::View<float> means, std::size_t inner_size);
|
||||
|
||||
template <class T>
|
||||
void normalize_mean_variance(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, csl::View<float> means, csl::View<float> scale, std::size_t inner_size);
|
||||
|
||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_MVN_HPP */
|
134
modules/dnn/src/cuda4dnn/primitives/mvn.hpp
Normal file
134
modules/dnn/src/cuda4dnn/primitives/mvn.hpp
Normal file
@ -0,0 +1,134 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MVN_HPP
|
||||
#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MVN_HPP
|
||||
|
||||
#include "../../op_cuda.hpp"
|
||||
|
||||
#include "../csl/stream.hpp"
|
||||
#include "../csl/span.hpp"
|
||||
#include "../csl/tensor.hpp"
|
||||
#include "../csl/workspace.hpp"
|
||||
|
||||
#include "../kernels/fill_copy.hpp"
|
||||
#include "../kernels/mvn.hpp"
|
||||
|
||||
#include <opencv2/core.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn {
|
||||
|
||||
struct MVNConfiguration {
|
||||
std::vector<std::vector<std::size_t>> input_shapes;
|
||||
|
||||
/*
|
||||
* [0, split_axis) = outer range
|
||||
* [split_axis, -1] = inner range
|
||||
*
|
||||
* for each location in the outer range, all the values in the inner range are normalized as a group
|
||||
*/
|
||||
std::size_t split_axis;
|
||||
|
||||
/* The group (described above) is centered always. The following parameter controls whether the variance
|
||||
* is also normalized.
|
||||
*/
|
||||
bool normalize_variance;
|
||||
float epsilon;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class MVNOp final : public CUDABackendNode {
|
||||
public:
|
||||
using wrapper_type = GetCUDABackendWrapperType<T>;
|
||||
|
||||
MVNOp(csl::Stream stream_, const MVNConfiguration& config)
|
||||
: stream(std::move(stream_))
|
||||
{
|
||||
split_axis = config.split_axis;
|
||||
normalize_variance = config.normalize_variance;
|
||||
epsilon = config.epsilon;
|
||||
|
||||
std::size_t max_outer_size = 0;
|
||||
const auto& input_shapes = config.input_shapes;
|
||||
for (int i = 0; i < input_shapes.size(); i++)
|
||||
{
|
||||
std::size_t outer_size = 1;
|
||||
for (int j = 0; j < split_axis; j++)
|
||||
outer_size *= input_shapes[i][j];
|
||||
max_outer_size = std::max(max_outer_size, outer_size);
|
||||
}
|
||||
|
||||
csl::WorkspaceBuilder builder;
|
||||
builder.require<float>(max_outer_size);
|
||||
if (normalize_variance)
|
||||
builder.require<float>(max_outer_size);
|
||||
scratch_mem_in_bytes = builder.required_workspace_size();
|
||||
}
|
||||
|
||||
void forward(
|
||||
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
|
||||
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
|
||||
csl::Workspace& workspace) override
|
||||
{
|
||||
CV_Assert(inputs.size() == outputs.size());
|
||||
|
||||
for (int i = 0; i < inputs.size(); i++)
|
||||
{
|
||||
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
|
||||
auto input = input_wrapper->getView();
|
||||
|
||||
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
|
||||
auto output = output_wrapper->getSpan();
|
||||
|
||||
auto outer_size = input.size_range(0, split_axis);
|
||||
auto inner_size = input.size_range(split_axis, input.rank());
|
||||
if (inner_size == 1)
|
||||
{
|
||||
kernels::fill<T>(stream, output, 0.0f);
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto ws_allocator = csl::WorkspaceAllocator(workspace);
|
||||
|
||||
auto means = ws_allocator.get_span<float>(outer_size);
|
||||
kernels::fill<float>(stream, means, 0);
|
||||
|
||||
if (normalize_variance)
|
||||
{
|
||||
auto scales = ws_allocator.get_span<float>(outer_size);
|
||||
kernels::fill<float>(stream, scales, 0);
|
||||
|
||||
kernels::reduce_mean_sqr_sum<T>(stream, means, scales, input, inner_size);
|
||||
kernels::compute_normalization_scale(stream, scales, means, scales, inner_size, epsilon);
|
||||
kernels::normalize_mean_variance<T>(stream, output, input, means, scales, inner_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
kernels::reduce_mean<T>(stream, means, input, inner_size);
|
||||
kernels::normalize_mean<T>(stream, output, input, means, inner_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; }
|
||||
|
||||
private:
|
||||
csl::Stream stream;
|
||||
|
||||
bool normalize_variance;
|
||||
float epsilon;
|
||||
std::size_t split_axis;
|
||||
|
||||
std::size_t scratch_mem_in_bytes;
|
||||
};
|
||||
|
||||
}}} /* namespace cv::dnn::cuda4dnn */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MVN_HPP */
|
@ -44,6 +44,7 @@
|
||||
#include "layers_common.hpp"
|
||||
#include "../op_inf_engine.hpp"
|
||||
#include "../ie_ngraph.hpp"
|
||||
#include "../op_cuda.hpp"
|
||||
|
||||
#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
@ -52,6 +53,11 @@
|
||||
#include "opencl_kernels_dnn.hpp"
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_CUDA
|
||||
#include "../cuda4dnn/primitives/mvn.hpp"
|
||||
using namespace cv::dnn::cuda4dnn;
|
||||
#endif
|
||||
|
||||
namespace cv
|
||||
{
|
||||
namespace dnn
|
||||
@ -127,7 +133,7 @@ public:
|
||||
return true;
|
||||
#endif
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV;
|
||||
return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_CUDA;
|
||||
}
|
||||
}
|
||||
|
||||
@ -399,6 +405,31 @@ public:
|
||||
}
|
||||
#endif // HAVE_DNN_NGRAPH
|
||||
|
||||
#ifdef HAVE_CUDA
|
||||
Ptr<BackendNode> initCUDA(
|
||||
void *context_,
|
||||
const std::vector<Ptr<BackendWrapper>>& inputs,
|
||||
const std::vector<Ptr<BackendWrapper>>& outputs
|
||||
) override
|
||||
{
|
||||
auto context = reinterpret_cast<csl::CSLContext*>(context_);
|
||||
|
||||
cuda4dnn::MVNConfiguration config;
|
||||
config.split_axis = acrossChannels ? 1 : 2;
|
||||
config.normalize_variance = normVariance;
|
||||
config.epsilon = eps;
|
||||
config.input_shapes.resize(inputs.size());
|
||||
for (int i = 0; i < inputs.size(); i++)
|
||||
{
|
||||
auto wrapper = inputs[i].dynamicCast<CUDABackendWrapper>();
|
||||
auto shape = wrapper->getShape();
|
||||
config.input_shapes[i].assign(std::begin(shape), std::end(shape));
|
||||
}
|
||||
|
||||
return make_cuda_node<cuda4dnn::MVNOp>(preferableTarget, std::move(context->stream), config);
|
||||
}
|
||||
#endif
|
||||
|
||||
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
|
||||
const std::vector<MatShape> &outputs) const CV_OVERRIDE
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user