Merge pull request #24409 from fengyuentau:norm_kernel

dnn: add shared fastNorm kernel for mvn, instance norm and layer norm #24409

Relates https://github.com/opencv/opencv/pull/24378#issuecomment-1756906570

TODO:

- [x] add fastNorm
- [x] refactor layer norm with fastNorm
- [x] refactor mvn with fastNorm
- [ ] add onnx mvn in importer (in a new PR?)
- [ ] refactor instance norm with fastNorm (in another PR https://github.com/opencv/opencv/pull/24378, need to merge this one first though)

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
Yuantao Feng 2023-11-01 19:33:57 +08:00 committed by GitHub
parent e202116b56
commit c91af16fa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 220 additions and 166 deletions

View File

@ -1143,7 +1143,7 @@ CV__DNN_INLINE_NS_BEGIN
class CV_EXPORTS LayerNormLayer : public Layer
{
public:
bool hasBias;
CV_DEPRECATED_EXTERNAL bool hasBias; // Deprecated, preserve for compatibility
int axis;
float epsilon;

View File

@ -0,0 +1,160 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "../../precomp.hpp"
#include "fast_norm.hpp"
namespace cv { namespace dnn {
void fastNorm(const Mat &input, Mat &output, float epsilon, size_t normalized_axis, bool normalize_variance) {
const auto input_shape = shape(input);
CV_CheckLT(normalized_axis, input_shape.size(), "fastNorm: axis out of range");
size_t loops = static_cast<size_t>(total(input_shape, 0, static_cast<int>(normalized_axis))),
norm_size = static_cast<size_t>(total(input_shape, static_cast<int>(normalized_axis)));
float inv_norm_size = 1.0 / norm_size;
auto fn = [&](const Range &r) {
const auto *input_data = input.ptr<const float>();
auto *output_data = output.ptr<float>();
for (int i = r.start; i < r.end; i++) {
const auto *x = input_data + norm_size * i;
auto *y = output_data + norm_size * i;
float mean = 0.f, mean_square = 0.f;
for (int j = 0; j < norm_size; j++) {
float v = x[j];
mean += v;
mean_square += v * v;
}
mean *= inv_norm_size;
mean_square = std::sqrt(std::max(0.f, mean_square * inv_norm_size - mean * mean) + epsilon);
float inv_stdev = normalize_variance ? 1.f / mean_square : 1.f;
for (size_t j = 0; j < norm_size; j++) {
y[j] = (x[j] - mean) * inv_stdev;
}
}
};
double nstripes = loops * norm_size * (1 / 1024.0);
parallel_for_(Range(0, loops), fn, nstripes);
}
void fastNorm(const Mat &input, const Mat &scale, Mat &output, float epsilon, size_t normalized_axis) {
const auto input_shape = shape(input);
CV_CheckLT(normalized_axis, input_shape.size(), "fastNorm: axis out of range");
size_t loops = static_cast<size_t>(total(input_shape, 0, static_cast<int>(normalized_axis))),
norm_size = static_cast<size_t>(total(input_shape, static_cast<int>(normalized_axis)));
float inv_norm_size = 1.0 / norm_size;
auto fn = [&](const Range &r) {
const auto *input_data = input.ptr<const float>();
const auto *scale_data = scale.ptr<const float>();
auto *output_data = output.ptr<float>();
for (int i = r.start; i < r.end; i++) {
const auto *x = input_data + norm_size * i;
auto *y = output_data + norm_size * i;
float mean = 0.f, mean_square = 0.f;
for (int j = 0; j < norm_size; j++) {
float v = x[j];
mean += v;
mean_square += v * v;
}
mean *= inv_norm_size;
mean_square = std::sqrt(std::max(0.f, mean_square * inv_norm_size - mean * mean) + epsilon);
float inv_stdev = 1.f / mean_square;
for (size_t j = 0; j < norm_size; j++) {
y[j] = scale_data[j] * (x[j] - mean) * inv_stdev;
}
}
};
double nstripes = loops * norm_size * (1 / 1024.0);
parallel_for_(Range(0, loops), fn, nstripes);
}
void fastNorm(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, float epsilon, size_t normalized_axis) {
const auto input_shape = shape(input);
CV_CheckLT(normalized_axis, input_shape.size(), "fastNorm: axis out of range");
CV_CheckEQ(scale.total(), bias.total(), "fastNorm: scale and bias should have the same shape");
size_t loops = static_cast<size_t>(total(input_shape, 0, static_cast<int>(normalized_axis))),
norm_size = static_cast<size_t>(total(input_shape, static_cast<int>(normalized_axis)));
float inv_norm_size = 1.0 / norm_size;
auto fn = [&](const Range &r) {
const auto *input_data = input.ptr<const float>();
const auto *scale_data = scale.ptr<const float>();
const auto *bias_data = bias.ptr<const float>();
auto *output_data = output.ptr<float>();
for (int i = r.start; i < r.end; i++) {
const auto *x = input_data + norm_size * i;
auto *y = output_data + norm_size * i;
float mean = 0.f, mean_square = 0.f;
for (int j = 0; j < norm_size; j++) {
float v = x[j];
mean += v;
mean_square += v * v;
}
mean *= inv_norm_size;
mean_square = std::sqrt(std::max(0.f, mean_square * inv_norm_size - mean * mean) + epsilon);
float inv_stdev = 1.f / mean_square;
for (size_t j = 0; j < norm_size; j++) {
y[j] = scale_data[j] * (x[j] - mean) * inv_stdev + bias_data[j];
}
}
};
double nstripes = loops * norm_size * (1 / 1024.0);
parallel_for_(Range(0, loops), fn, nstripes);
}
void fastNormChannel(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, float epsilon) {
const auto input_shape = shape(input);
CV_CheckEQ(scale.total(), bias.total(), "fastNormChannel: scale and bias should have the same shape");
CV_CheckGE(input.dims, 3, "fastNormChannel: input dimension >= 3");
size_t N = input_shape[0], C = input_shape[1];
size_t loops = N * C,
norm_size = static_cast<size_t>(total(input_shape, 2));
float inv_norm_size = 1.0 / norm_size;
auto fn = [&](const Range &r) {
const auto *input_data = input.ptr<const float>();
const auto *scale_data = scale.ptr<const float>();
const auto *bias_data = bias.ptr<const float>();
auto *output_data = output.ptr<float>();
for (int i = r.start; i < r.end; i++) {
const auto *x = input_data + norm_size * i;
auto *y = output_data + norm_size * i;
float mean = 0.f, mean_square = 0.f;
for (int j = 0; j < norm_size; j++) {
float v = x[j];
mean += v;
mean_square += v * v;
}
mean *= inv_norm_size;
mean_square = std::sqrt(std::max(0.f, mean_square * inv_norm_size - mean * mean) + epsilon);
float inv_stdev = 1.f / mean_square;
size_t c = i % C;
float s = scale_data[c], b = bias_data[c];
for (size_t j = 0; j < norm_size; j++) {
y[j] = s * (x[j] - mean) * inv_stdev + b;
}
}
};
double nstripes = loops * norm_size * (1 / 1024.0);
parallel_for_(Range(0, loops), fn, nstripes);
}
}} // cv::dnn

View File

@ -0,0 +1,26 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_FAST_NORM_HPP
#define OPENCV_DNN_FAST_NORM_HPP
#include <opencv2/dnn/shape_utils.hpp>
namespace cv { namespace dnn {
// Normalization speedup by multi-threading, mainly for Caffe MVN layer which has normalize_variance parameter.
void fastNorm(const Mat &input, Mat &output, float epsilon, size_t normalized_axis = 0, bool normalize_variance = true);
// Normalization speedup by multi-threading with absent bias. Mainly for LayerNormalization.
void fastNorm(const Mat &input, const Mat &scale, Mat &output, float epsilon, size_t normalized_axis = 0);
// Normalization speedup by multi-threading with scale and bias. Mainly for LayerNormalization.
void fastNorm(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, float epsilon, size_t normalized_axis = 0);
// Channel-wise Normalization speedup by multi-threading. Scale and bias should have the same shape (C). Input should have dimension >= 3.
void fastNormChannel(const Mat &input, const Mat &scale, const Mat &bias, Mat &output, float epsilon);
}} // cv::dnn
#endif // OPENCV_DNN_FAST_NORM_HPP

View File

@ -4,6 +4,7 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "cpu_kernels/fast_norm.hpp"
namespace cv { namespace dnn {
@ -15,11 +16,8 @@ public:
setParamsFrom(params);
// standard attr
axis = params.get<int>("axis", 0);
axis = params.get<int>("axis", -1);
epsilon = params.get<float>("epsilon", 1e-5);
// opencv attr
hasBias = params.get<bool>("hasBias", false);
}
virtual bool supportBackend(int backendId) CV_OVERRIDE
@ -46,104 +44,25 @@ public:
CV_CheckEQ(x_ndims - axis, w_ndims, "LayerNorm: shape of weight does not match with given axis and shape of input");
for (int i = 0; i < w_ndims; ++i)
CV_CheckEQ(x_shape[axis+i], w_shape[i], "LayerNorm: weight dimensions does not match with input dimensions");
if (hasBias)
if (inputs.size() == static_cast<int>(3))
{
CV_CheckEQ(inputs.size(), (size_t)3, "");
auto b_shape = inputs[2];
CV_CheckEQ(w_shape.size(), b_shape.size(), "LayerNorm: shape of weight does not match with shape of bias");
for (size_t i = 0; i < w_shape.size(); ++i)
CV_CheckEQ(w_shape[i], b_shape[i], "LayerNorm: bias dimensions does not match with weight dimensions");
}
// only one output is needed; Mean & InvStdDev are not needed
// in inference and should beomitted in onnx importer
outputs.assign(1, inputs[0]);
return false;
}
template<bool hasBias>
class LayerNormInvoker : public ParallelLoopBody
{
public:
const Mat& src;
const float* scaleData;
const float* biasData;
Mat& dst;
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
std::vector<Mat> inputs;
inputs_arr.getMatVector(inputs);
float epsilon;
int total;
int normSize;
float invNormSize;
LayerNormInvoker(const Mat& src_, const Mat& scale, const Mat* b, Mat& dst_, int axis, float epsilon_)
: src(src_), scaleData(scale.ptr<float>()), biasData(nullptr), dst(dst_), epsilon(epsilon_)
{
if (hasBias)
{
CV_Assert(b != nullptr);
CV_Assert(b->isContinuous());
biasData = (const float*)b->ptr<float>();
}
auto dstShape = shape(dst);
total = std::accumulate(dstShape.begin(), dstShape.begin() + axis, 1, std::multiplies<int>());
normSize = std::accumulate(dstShape.begin() + axis, dstShape.end(), 1, std::multiplies<int>());
invNormSize = 1.0f / normSize;
}
static void run(const Mat& src, const Mat& scale, const Mat* b, Mat& dst, int axis, float epsilon)
{
CV_Assert(src.isContinuous());
CV_Assert(dst.isContinuous());
CV_CheckTypeEQ(src.type(), CV_32F, "DNN/LayerNorm: only support float32");
CV_CheckTypeEQ(src.type(), dst.type(), "");
CV_Assert(scale.isContinuous());
CV_CheckGE(epsilon, 0.0f, "");
LayerNormInvoker p(src, scale, b, dst, axis, epsilon);
double nstripes = ((size_t)p.total * p.normSize) * (1 / 1024.0);
// double nstripes = ((size_t)p.total) * (1 / 1024.0);
parallel_for_(Range(0, p.total), p, nstripes);
}
void operator()(const Range& r) const CV_OVERRIDE
{
int stripeStart = r.start;
int stripeEnd = r.end;
const float* srcData = src.ptr<float>();
float* dstData = dst.ptr<float>();
for (int ofs = stripeStart; ofs < stripeEnd; ++ofs)
{
const float* first = srcData + ofs * normSize;
float* dstFirst = dstData + ofs * normSize;
float mean = 0;
float meanSquare = 0;
for (int h = 0; h < normSize; ++h)
{
float v = first[h];
mean += v;
meanSquare += v * v;
}
mean *= invNormSize;
meanSquare = std::sqrt(std::max(0.f, meanSquare * invNormSize - mean * mean) + epsilon);
float invMeanSquare = 1.0f / meanSquare;
for (int h = 0; h < normSize; ++h)
{
float v = (first[h] - mean) * invMeanSquare * scaleData[h];
if (hasBias) {
v = v + biasData[h];
}
dstFirst[h] = v;
}
}
}
};
const auto input_shape = shape(inputs[0]);
axis = normalize_axis(axis, static_cast<int>(input_shape.size()));
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
@ -160,10 +79,15 @@ public:
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
if (hasBias) {
LayerNormInvoker<true>::run(inputs[0], inputs[1], &inputs[2], outputs[0], axis, epsilon);
const auto &input = inputs[0];
const auto &scale = inputs[1];
auto &output = outputs[0];
if (inputs.size() == 3) {
const auto &bias = inputs[2];
fastNorm(input, scale, bias, output, epsilon, static_cast<size_t>(axis));
} else {
LayerNormInvoker<false>::run(inputs[0], inputs[1], nullptr, outputs[0], axis, epsilon);
fastNorm(input, scale, output, epsilon, static_cast<size_t>(axis));
}
}
};

View File

@ -46,6 +46,8 @@
#include "../ie_ngraph.hpp"
#include "../op_cuda.hpp"
#include "./cpu_kernels/fast_norm.hpp"
#include <opencv2/dnn/shape_utils.hpp>
#ifdef HAVE_OPENCL
@ -69,9 +71,12 @@ public:
MVNLayerImpl(const LayerParams& params)
{
setParamsFrom(params);
// Caffe params
normVariance = params.get<bool>("normalize_variance", true);
acrossChannels = params.get<bool>("across_channels", false);
eps = params.get<double>("eps", 1e-9);
fuse_batch_norm = false;
fuse_relu = false;
relu_slope = 0.f;
@ -310,73 +315,18 @@ public:
return;
}
std::vector<Mat> inputs, outputs, internals;
inputs_arr.getMatVector(inputs);
std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs); // assume only one input
outputs_arr.getMatVector(outputs);
internals_arr.getMatVector(internals);
for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++)
{
Mat &inpBlob = inputs[inpIdx];
Mat &outBlob = outputs[inpIdx];
const auto &input = inputs[0];
int splitDim = (acrossChannels) ? 1 : 2;
int i, newRows = 1;
for( i = 0; i < splitDim; i++ )
newRows *= inpBlob.size[i];
Mat inpMat = inpBlob.reshape(1, newRows);
Mat outMat = outBlob.reshape(1, newRows);
if ( inpBlob.total() == newRows )
{
// MVN is applied to single values at an every row.
if (shift.empty())
{
outBlob.setTo(0);
}
else
{
for ( i = 0; i < newRows; i++ )
{
outMat.row(i).setTo(((float*)shift.data)[i]);
}
}
return;
}
Scalar mean, dev;
for ( i = 0; i < newRows; i++)
{
Mat inpRow = inpMat.row(i);
Mat outRow = outMat.row(i);
float weight = 1.f;
float bias = 0.f;
if (fuse_batch_norm)
{
weight = i < scale.cols ? ((float*)scale.data)[i] : weight;
bias = i < shift.cols ? ((float*)shift.data)[i] : bias;
}
cv::meanStdDev(inpRow, mean, (normVariance) ? dev : noArray());
double alpha = 1;
if (normVariance)
{
alpha = 1 / std::sqrt(eps + dev[0]*dev[0]);
}
double normalizationScale = 1.0;
double normalizationShift = 0.0;
if (fuse_batch_norm)
{
normalizationScale = alpha * weight;
normalizationShift = -mean[0] * normalizationScale + bias;
}
else
{
normalizationScale = alpha;
normalizationShift = -mean[0] * alpha;
}
inpRow.convertTo(outRow, outRow.type(), normalizationScale, normalizationShift);
}
if (fuse_batch_norm) { // channel-wise scale/bias of shape (C)
CV_CheckTrue(normVariance, "DNN/MVN: not supported");
fastNormChannel(input, scale, shift, outputs[0], eps);
} else {
size_t axis = acrossChannels ? 1 : 2;
fastNorm(input, outputs[0], eps, axis, normVariance);
}
}

View File

@ -3160,12 +3160,6 @@ void ONNXImporter::parseLayerNorm(LayerParams& layerParams, const opencv_onnx::N
axis = (axis + inputDims) % inputDims;
layerParams.set("axis", axis);
// check if bias existed
bool hasBias = false;
if (node_proto.input_size() > 2)
hasBias = true;
layerParams.set("hasBias", hasBias);
// constants as constant inputs
for (size_t i = 1; i < node_proto.input_size(); i++)
{

View File

@ -570,10 +570,10 @@ TEST_P(Test_Torch_nets, FastNeuralStyle_accuracy)
}
else if (target == DNN_TARGET_CPU_FP16)
{
normAssert(out, refBlob, "", 0.62, 25);
normAssert(out, refBlob, "", 0.64, 25);
}
else
normAssert(out, refBlob, "", 0.5, 1.11);
normAssert(out, refBlob, "", 0.5, 1.16);
}
}