mirror of
https://github.com/opencv/opencv.git
synced 2025-08-05 22:19:14 +08:00
Merge pull request #24552 from fengyuentau:layernorm_backends
dnn: add openvino, opencl and cuda backends for layer normalization layer #24552 Merge after https://github.com/opencv/opencv/pull/24544. Todo: - [x] openvino - [x] opencl - [x] cuda ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
fba3c947ef
commit
d05fb709f9
@ -68,15 +68,36 @@ namespace raw {
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void normalize_mean_variance_channelwise(Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> stdev, size_type inner_size, size_type C) {
|
||||
__global__ void normalize_mean_variance_channelwise(Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> inv_stddev, size_type inner_size, size_type C) {
|
||||
for (auto idx : grid_stride_range(output.size())) {
|
||||
const index_type outer_idx = idx / inner_size;
|
||||
const index_type c = outer_idx % C;
|
||||
auto s = static_cast<float>(scale[c]) * stdev[outer_idx];
|
||||
auto s = static_cast<float>(scale[c]) * inv_stddev[outer_idx];
|
||||
auto b = static_cast<float>(bias[c]);
|
||||
output[idx] = (static_cast<float>(input[idx]) - means[outer_idx]) * s + b;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void normalize_mean_variance_layernorm(Span<T> output, View<T> input, View<T> scale, View<float> means, View<float> inv_stddev, size_type inner_size) {
|
||||
for (auto idx : grid_stride_range(output.size())) {
|
||||
const index_type outer_idx = idx / inner_size;
|
||||
const index_type inner_idx = idx % inner_size;
|
||||
auto s = static_cast<float>(scale[inner_idx]) * inv_stddev[outer_idx];
|
||||
output[idx] = (static_cast<float>(input[idx]) - means[outer_idx]) * s;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void normalize_mean_variance_layernorm_with_bias(Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> inv_stddev, size_type inner_size) {
|
||||
for (auto idx : grid_stride_range(output.size())) {
|
||||
const index_type outer_idx = idx / inner_size;
|
||||
const index_type inner_idx = idx % inner_size;
|
||||
auto s = static_cast<float>(scale[inner_idx]) * inv_stddev[outer_idx];
|
||||
auto b = static_cast<float>(bias[inner_idx]);
|
||||
output[idx] = (static_cast<float>(input[idx]) - means[outer_idx]) * s + b;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
@ -154,20 +175,54 @@ template void normalize_mean_variance(const Stream&, Span<__half>, View<__half>,
|
||||
template void normalize_mean_variance(const Stream&, Span<float>, View<float>, View<float>, View<float>, std::size_t);
|
||||
|
||||
template <class T>
|
||||
void normalize_mean_variance_channelwise(const Stream& stream, Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> stdev, std::size_t inner_size, std::size_t C)
|
||||
void normalize_mean_variance_channelwise(const Stream& stream, Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> inv_stddev, std::size_t inner_size, std::size_t C)
|
||||
{
|
||||
CV_Assert(input.size() == output.size());
|
||||
CV_Assert(input.size() / inner_size == means.size());
|
||||
CV_Assert(means.size() == stdev.size());
|
||||
CV_Assert(means.size() == inv_stddev.size());
|
||||
|
||||
auto kernel = raw::normalize_mean_variance_channelwise<T>;
|
||||
auto policy = make_policy(kernel, output.size(), 0, stream);
|
||||
launch_kernel(kernel, policy, output, input, scale, bias, means, stdev, inner_size, C);
|
||||
launch_kernel(kernel, policy, output, input, scale, bias, means, inv_stddev, inner_size, C);
|
||||
}
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
|
||||
template void normalize_mean_variance_channelwise(const Stream&, Span<__half> /*output*/, View<__half> /*input*/, View<__half> /*scale*/, View<__half> /*bias*/, View<float> /*means*/, View<float> /*stdev*/, std::size_t, std::size_t);
|
||||
template void normalize_mean_variance_channelwise(const Stream&, Span<__half> /*output*/, View<__half> /*input*/, View<__half> /*scale*/, View<__half> /*bias*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t, std::size_t);
|
||||
#endif
|
||||
template void normalize_mean_variance_channelwise(const Stream&, Span<float> /*output*/, View<float> /*input*/, View<float> /*scale*/, View<float> /*bias*/, View<float> /*means*/, View<float> /*stdev*/, std::size_t, std::size_t);
|
||||
template void normalize_mean_variance_channelwise(const Stream&, Span<float> /*output*/, View<float> /*input*/, View<float> /*scale*/, View<float> /*bias*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t, std::size_t);
|
||||
|
||||
template <class T>
|
||||
void normalize_mean_variance_layernorm(const Stream& stream, Span<T> output, View<T> input, View<T> scale, View<float> means, View<float> inv_stddev, std::size_t inner_size)
|
||||
{
|
||||
CV_Assert(input.size() == output.size());
|
||||
CV_Assert(input.size() / inner_size == means.size());
|
||||
CV_Assert(means.size() == inv_stddev.size());
|
||||
|
||||
auto kernel = raw::normalize_mean_variance_layernorm<T>;
|
||||
auto policy = make_policy(kernel, output.size(), 0, stream);
|
||||
launch_kernel(kernel, policy, output, input, scale, means, inv_stddev, inner_size);
|
||||
}
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
|
||||
template void normalize_mean_variance_layernorm(const Stream&, Span<__half> /*output*/, View<__half> /*input*/, View<__half> /*scale*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t);
|
||||
#endif
|
||||
template void normalize_mean_variance_layernorm(const Stream&, Span<float> /*output*/, View<float> /*input*/, View<float> /*scale*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t);
|
||||
|
||||
template <class T>
|
||||
void normalize_mean_variance_layernorm(const Stream& stream, Span<T> output, View<T> input, View<T> scale, View<T> bias, View<float> means, View<float> inv_stddev, std::size_t inner_size)
|
||||
{
|
||||
CV_Assert(input.size() == output.size());
|
||||
CV_Assert(input.size() / inner_size == means.size());
|
||||
CV_Assert(means.size() == inv_stddev.size());
|
||||
|
||||
auto kernel = raw::normalize_mean_variance_layernorm_with_bias<T>;
|
||||
auto policy = make_policy(kernel, output.size(), 0, stream);
|
||||
launch_kernel(kernel, policy, output, input, scale, bias, means, inv_stddev, inner_size);
|
||||
}
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
|
||||
template void normalize_mean_variance_layernorm(const Stream&, Span<__half> /*output*/, View<__half> /*input*/, View<__half> /*scale*/, View<__half> /*bias*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t);
|
||||
#endif
|
||||
template void normalize_mean_variance_layernorm(const Stream&, Span<float> /*output*/, View<float> /*input*/, View<float> /*scale*/, View<float> /*bias*/, View<float> /*means*/, View<float> /*inv_stddev*/, std::size_t);
|
||||
|
||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
||||
|
@ -27,7 +27,13 @@ template <class T>
|
||||
void normalize_mean_variance(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, csl::View<float> means, csl::View<float> scale, std::size_t inner_size);
|
||||
|
||||
template <class T>
|
||||
void normalize_mean_variance_channelwise(const csl::Stream &stream, csl::Span<T> output, csl::View<T> input, csl::View<T> scale, csl::View<T> bias, csl::View<float> means, csl::View<float> stdev, std::size_t inner_size, std::size_t C);
|
||||
void normalize_mean_variance_channelwise(const csl::Stream &stream, csl::Span<T> output, csl::View<T> input, csl::View<T> scale, csl::View<T> bias, csl::View<float> means, csl::View<float> inv_stddev, std::size_t inner_size, std::size_t C);
|
||||
|
||||
template <class T>
|
||||
void normalize_mean_variance_layernorm(const csl::Stream &stream, csl::Span<T> output, csl::View<T> input, csl::View<T> scale, csl::View<float> means, csl::View<float> inv_stddev, std::size_t inner_size);
|
||||
|
||||
template <class T>
|
||||
void normalize_mean_variance_layernorm(const csl::Stream &stream, csl::Span<T> output, csl::View<T> input, csl::View<T> scale, csl::View<T> bias, csl::View<float> means, csl::View<float> inv_stddev, std::size_t inner_size);
|
||||
|
||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
||||
|
||||
|
93
modules/dnn/src/cuda4dnn/primitives/layer_norm.hpp
Normal file
93
modules/dnn/src/cuda4dnn/primitives/layer_norm.hpp
Normal file
@ -0,0 +1,93 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_LAYER_NORM_HPP
|
||||
#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_LAYER_NORM_HPP
|
||||
|
||||
#include "../../op_cuda.hpp"
|
||||
|
||||
#include "../csl/stream.hpp"
|
||||
#include "../csl/span.hpp"
|
||||
#include "../csl/tensor.hpp"
|
||||
#include "../csl/workspace.hpp"
|
||||
|
||||
#include "../kernels/fill_copy.hpp"
|
||||
#include "../kernels/mvn.hpp"
|
||||
|
||||
#include <opencv2/core.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn {
|
||||
|
||||
template <class T>
|
||||
class LayerNormOp final : public CUDABackendNode {
|
||||
public:
|
||||
using wrapper_type = GetCUDABackendWrapperType<T>;
|
||||
|
||||
LayerNormOp(csl::Stream stream_, int normalized_axis, float epsilon_, size_t loops)
|
||||
: stream(std::move(stream_)), epsilon(epsilon_) {
|
||||
CV_CheckGE(normalized_axis, 0, "LayerNorm/CUDA: axis needs to be normalized");
|
||||
axis = static_cast<size_t>(normalized_axis);
|
||||
|
||||
csl::WorkspaceBuilder builder;
|
||||
builder.require<float>(loops);
|
||||
builder.require<float>(loops);
|
||||
scratch_mem_in_bytes = builder.required_workspace_size();
|
||||
}
|
||||
|
||||
void forward(const std::vector<cv::Ptr<BackendWrapper>>& inputs,
|
||||
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
|
||||
csl::Workspace& workspace) override {
|
||||
auto input_wrapper = inputs[0].dynamicCast<wrapper_type>();
|
||||
auto scale_wrapper = inputs[1].dynamicCast<wrapper_type>();
|
||||
|
||||
auto input = input_wrapper->getView();
|
||||
auto scale = scale_wrapper->getView();
|
||||
|
||||
auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
|
||||
auto output = output_wrapper->getSpan();
|
||||
|
||||
auto loops = input.size_range(0, axis);
|
||||
auto norm_size = input.size_range(axis, input.rank());
|
||||
if (norm_size == 1) {
|
||||
kernels::fill<T>(stream, output, 0.f);
|
||||
return;
|
||||
} else {
|
||||
auto ws_allocator = csl::WorkspaceAllocator(workspace);
|
||||
|
||||
auto mean = ws_allocator.get_span<float>(loops);
|
||||
kernels::fill<float>(stream, mean, 0.f);
|
||||
|
||||
auto inv_stddev = ws_allocator.get_span<float>(loops);
|
||||
kernels::fill<float>(stream, inv_stddev, 0.f);
|
||||
|
||||
kernels::reduce_mean_sqr_sum<T>(stream, mean, inv_stddev, input, norm_size);
|
||||
kernels::compute_normalization_scale(stream, inv_stddev, mean, inv_stddev, norm_size, epsilon);
|
||||
if (inputs.size() == 3) {
|
||||
auto bias_wrapper = inputs[2].dynamicCast<wrapper_type>();
|
||||
auto bias = bias_wrapper->getView();
|
||||
kernels::normalize_mean_variance_layernorm<T>(stream, output, input, scale, bias, mean, inv_stddev, norm_size);
|
||||
} else {
|
||||
kernels::normalize_mean_variance_layernorm<T>(stream, output, input, scale, mean, inv_stddev, norm_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; }
|
||||
|
||||
private:
|
||||
csl::Stream stream;
|
||||
|
||||
float epsilon;
|
||||
size_t axis;
|
||||
|
||||
std::size_t scratch_mem_in_bytes;
|
||||
};
|
||||
|
||||
}}} // cv::dnn::cuda4dnn
|
||||
|
||||
#endif // OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_LAYER_NORM_HPP
|
@ -9,8 +9,26 @@
|
||||
// CANN backend
|
||||
#include "../op_cann.hpp"
|
||||
|
||||
// OpenVINO backend
|
||||
#include "../op_inf_engine.hpp"
|
||||
#include "../ie_ngraph.hpp"
|
||||
|
||||
// CUDA backend
|
||||
#include "../op_cuda.hpp"
|
||||
#ifdef HAVE_CUDA
|
||||
#include "../cuda4dnn/primitives/layer_norm.hpp"
|
||||
using namespace cv::dnn::cuda4dnn;
|
||||
#endif
|
||||
|
||||
// OpenCL backend
|
||||
#ifdef HAVE_OPENCL
|
||||
#include "../ocl4dnn/include/math_functions.hpp"
|
||||
#include "opencl_kernels_dnn.hpp"
|
||||
#endif
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#LayerNormalization
|
||||
class LayerNormLayerImpl CV_FINAL : public LayerNormLayer
|
||||
{
|
||||
public:
|
||||
@ -25,7 +43,12 @@ public:
|
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||
{
|
||||
#ifdef HAVE_INF_ENGINE
|
||||
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
|
||||
return true;
|
||||
#endif
|
||||
return backendId == DNN_BACKEND_OPENCV ||
|
||||
backendId == DNN_BACKEND_CUDA ||
|
||||
(backendId == DNN_BACKEND_CANN && axis != -1); // axis=-1 not supported due to 1d mat shape problem
|
||||
}
|
||||
|
||||
@ -73,6 +96,9 @@ public:
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||
|
||||
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
|
||||
forward_ocl(inputs_arr, outputs_arr, internals_arr))
|
||||
|
||||
if (inputs_arr.depth() == CV_16S)
|
||||
{
|
||||
forward_fallback(inputs_arr, outputs_arr, internals_arr);
|
||||
@ -95,6 +121,91 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_OPENCL
|
||||
bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_) {
|
||||
std::vector<UMat> inputs;
|
||||
std::vector<UMat> outputs;
|
||||
|
||||
inputs_.getUMatVector(inputs);
|
||||
outputs_.getUMatVector(outputs);
|
||||
|
||||
const auto &input = inputs[0], &scale = inputs[1]; // &bias = inputs[2]; // bias is optional
|
||||
auto &output = outputs[0];
|
||||
|
||||
const auto input_shape = shape(input);
|
||||
size_t loops = static_cast<size_t>(total(input_shape, 0, axis)),
|
||||
norm_size = static_cast<size_t>(total(input_shape, axis));
|
||||
float inv_norm_size = 1.f / norm_size;
|
||||
|
||||
const auto &bias = inputs.size() == 3 ? inputs[2] : UMat::zeros(norm_size, 1, CV_32F);
|
||||
|
||||
// no fp16 support
|
||||
if (input.depth() == CV_16S) {
|
||||
return false;
|
||||
}
|
||||
|
||||
String base_opts = format(" -DT=float -DT4=float4 -Dconvert_T=convert_float4");
|
||||
|
||||
// Calculate mean
|
||||
UMat one = UMat::ones(norm_size, 1, CV_32F);
|
||||
UMat mean = UMat(loops, 1, CV_32F);
|
||||
UMat mean_square = UMat(loops, 1, CV_32F);
|
||||
UMat tmp = UMat(loops, norm_size, CV_32F);
|
||||
bool ret = ocl4dnn::ocl4dnnGEMV<float>(ocl4dnn::CblasNoTrans, loops, norm_size, inv_norm_size,
|
||||
input, 0, one, 0, 0.f, mean, 0);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
// Calculate mean_square
|
||||
int num_vector = (norm_size % 8 == 0) ? 8 : ((norm_size % 4 == 0) ? 4 : 1);
|
||||
size_t global[] = {loops, static_cast<size_t>(norm_size / num_vector)};
|
||||
String build_opt = format(" -DNUM=%d", num_vector) + base_opts;
|
||||
String mean_square_kernel_name = format("calc_mean%d", num_vector);
|
||||
ocl::Kernel mean_square_kernel(mean_square_kernel_name.c_str(), ocl::dnn::mvn_oclsrc, build_opt + " -DKERNEL_MEAN");
|
||||
if (mean_square_kernel.empty()) {
|
||||
return false;
|
||||
}
|
||||
mean_square_kernel.set(0, ocl::KernelArg::PtrReadOnly(input));
|
||||
mean_square_kernel.set(1, (int)loops);
|
||||
mean_square_kernel.set(2, (int)norm_size);
|
||||
mean_square_kernel.set(3, ocl::KernelArg::PtrReadOnly(mean));
|
||||
mean_square_kernel.set(4, ocl::KernelArg::PtrWriteOnly(tmp));
|
||||
ret = mean_square_kernel.run(2, global, NULL, false);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
ret = ocl4dnn::ocl4dnnGEMV<float>(ocl4dnn::CblasNoTrans, loops, norm_size, inv_norm_size,
|
||||
tmp, 0, one, 0, 0.f, mean_square, 0);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
// Calculate instance norm: output = scale * (x - mean) / sqrt(var + eps) + bias
|
||||
String mvn_kernel_name = format("mvn%d", num_vector);
|
||||
build_opt += " -DNORM_VARIANCE -DLAYER_NORM -DKERNEL_MVN";
|
||||
ocl::Kernel mvn_kernel(mvn_kernel_name.c_str(), ocl::dnn::mvn_oclsrc, build_opt);
|
||||
if (mvn_kernel.empty()) {
|
||||
return false;
|
||||
}
|
||||
mvn_kernel.set(0, ocl::KernelArg::PtrReadOnly(input));
|
||||
mvn_kernel.set(1, (int)loops);
|
||||
mvn_kernel.set(2, (int)norm_size);
|
||||
mvn_kernel.set(3, (float)epsilon);
|
||||
mvn_kernel.set(4, ocl::KernelArg::PtrReadOnly(mean));
|
||||
mvn_kernel.set(5, ocl::KernelArg::PtrReadOnly(mean_square));
|
||||
mvn_kernel.set(6, ocl::KernelArg::PtrReadOnly(scale));
|
||||
mvn_kernel.set(7, ocl::KernelArg::PtrReadOnly(bias));
|
||||
mvn_kernel.set(8, (int)1);
|
||||
mvn_kernel.set(9, (float)0.f);
|
||||
mvn_kernel.set(10, ocl::KernelArg::PtrWriteOnly(output));
|
||||
ret = mvn_kernel.run(2, global, NULL, false);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_CANN
|
||||
virtual Ptr<BackendNode> initCann(const std::vector<Ptr<BackendWrapper> > &inputs,
|
||||
const std::vector<Ptr<BackendWrapper> > &outputs,
|
||||
@ -147,6 +258,67 @@ public:
|
||||
}
|
||||
#endif // HAVE_CANN
|
||||
|
||||
#ifdef HAVE_DNN_NGRAPH
|
||||
virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inputs,
|
||||
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE {
|
||||
auto ieInpNode = nodes[0].dynamicCast<InfEngineNgraphNode>()->node;
|
||||
const auto &input_shape = ieInpNode.get_shape();
|
||||
std::shared_ptr<ngraph::Node> mvn, result;
|
||||
|
||||
// mvn
|
||||
#if INF_ENGINE_VER_MAJOR_LE(INF_ENGINE_RELEASE_2021_2)
|
||||
// https://docs.openvino.ai/2021.4/api/ngraph_python_api/_autosummary/ngraph.opset3.mvn.html?highlight=mvn#ngraph.opset3.mvn
|
||||
bool across_channels = false;
|
||||
bool normalize_variance = true;
|
||||
mvn = std::make_shared<ngraph::op::MVN>(ieInpNode, across_channels, normalize_variance, epsilon);
|
||||
#else
|
||||
// https://docs.openvino.ai/2023.1/openvino_docs_ops_normalization_MVN_6.html
|
||||
std::vector<int64_t> axes_v(input_shape.size() - axis);
|
||||
std::iota(axes_v.begin(), axes_v.end(), axis);
|
||||
auto axes = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{axes_v.size()}, axes_v.data());
|
||||
bool normalize_variance = true;
|
||||
mvn = std::make_shared<ngraph::op::v6::MVN>(ieInpNode, axes, normalize_variance, epsilon, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||
#endif
|
||||
|
||||
// layer norm = scale * mvn + bias
|
||||
auto scale = nodes[1].dynamicCast<InfEngineNgraphNode>()->node;
|
||||
ngraph::Output<ngraph::Node> bias;
|
||||
if (nodes.size() == 3) {
|
||||
bias = nodes[2].dynamicCast<InfEngineNgraphNode>()->node;
|
||||
}
|
||||
if (axis == -1 || axis == input_shape.size() - 1) { // special case for 1D tensor (2D mat)
|
||||
std::vector<int64_t> shared_shape_v(input_shape.size(), 1);
|
||||
shared_shape_v.back() = -1;
|
||||
auto shared_shape = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{shared_shape_v.size()}, shared_shape_v.data());
|
||||
scale = std::make_shared<ngraph::op::v1::Reshape>(scale, shared_shape, true);
|
||||
if (nodes.size() == 3) {
|
||||
bias = std::make_shared<ngraph::op::v1::Reshape>(bias, shared_shape, true);
|
||||
}
|
||||
}
|
||||
|
||||
result = std::make_shared<ngraph::op::v1::Multiply>(mvn, scale);
|
||||
if (nodes.size() == 3) {
|
||||
result = std::make_shared<ngraph::op::v1::Add>(result, bias);
|
||||
}
|
||||
|
||||
return Ptr<BackendNode>(new InfEngineNgraphNode(result));
|
||||
}
|
||||
#endif // HAVE_DNN_NGRAPH
|
||||
|
||||
#ifdef HAVE_CUDA
|
||||
Ptr<BackendNode> initCUDA(void *context_,
|
||||
const std::vector<Ptr<BackendWrapper>>& inputs,
|
||||
const std::vector<Ptr<BackendWrapper>>& outputs) override {
|
||||
auto context = reinterpret_cast<csl::CSLContext*>(context_);
|
||||
|
||||
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
|
||||
auto input_shape = input_wrapper->getShape();
|
||||
size_t loops = static_cast<size_t>(total(input_shape, 0, axis));
|
||||
|
||||
return make_cuda_node<cuda4dnn::LayerNormOp>(preferableTarget, std::move(context->stream), axis, epsilon, loops);
|
||||
}
|
||||
#endif // HAVE_CUDA
|
||||
|
||||
};
|
||||
|
||||
Ptr<LayerNormLayer> LayerNormLayer::create(const LayerParams& params)
|
||||
|
@ -126,12 +126,18 @@ __kernel void MVN(__global const Dtype* src,
|
||||
alpha = 1;
|
||||
#endif
|
||||
|
||||
#ifdef LAYER_NORM
|
||||
vec_type w = load(bnorm_weight, y), b = load(bnorm_bias, y);
|
||||
#else
|
||||
|
||||
Dtype w = 1.f, b = 0.f;
|
||||
#ifdef FUSE_BATCH_NORM
|
||||
w = bnorm_weight[x % channels];
|
||||
b = bnorm_bias[x % channels];
|
||||
#endif
|
||||
|
||||
#endif // LAYER_NORM
|
||||
|
||||
vec_type src_vec = load(src, index) - (vec_type)mean_val;
|
||||
vec_type dst_vec = src_vec * alpha;
|
||||
dst_vec = dst_vec * w + (vec_type)b;
|
||||
|
@ -793,81 +793,43 @@ CASE(test_isinf_positive)
|
||||
CASE(test_isnan)
|
||||
// no filter
|
||||
CASE(test_layer_normalization_2d_axis0)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_2d_axis1)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_2d_axis_negative_1)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_2d_axis_negative_2)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_3d_axis0_epsilon)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_3d_axis1_epsilon)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_3d_axis2_epsilon)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_3d_axis_negative_1_epsilon)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_3d_axis_negative_2_epsilon)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_3d_axis_negative_3_epsilon)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_4d_axis0)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_4d_axis1)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_4d_axis2)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_4d_axis3)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_4d_axis_negative_1)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_4d_axis_negative_2)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_4d_axis_negative_3)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_4d_axis_negative_4)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_layer_normalization_default_axis)
|
||||
#if SKIP_SET_1
|
||||
SKIP_NON_CPU;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_leakyrelu)
|
||||
// no filter
|
||||
CASE(test_leakyrelu_default)
|
||||
|
Loading…
Reference in New Issue
Block a user