Merge pull request #20283 from SamFC10:fix-batchnorm

This commit is contained in:
Alexander Alekhin 2021-06-21 11:27:12 +00:00
commit 9d584475f6

View File

@ -29,6 +29,7 @@ namespace dnn
class BatchNormLayerImpl CV_FINAL : public BatchNormLayer class BatchNormLayerImpl CV_FINAL : public BatchNormLayer
{ {
public: public:
Mat origin_weights, origin_bias;
Mat weights_, bias_; Mat weights_, bias_;
UMat umat_weight, umat_bias; UMat umat_weight, umat_bias;
mutable int dims; mutable int dims;
@ -82,11 +83,11 @@ public:
const float* weightsData = hasWeights ? blobs[weightsBlobIndex].ptr<float>() : 0; const float* weightsData = hasWeights ? blobs[weightsBlobIndex].ptr<float>() : 0;
const float* biasData = hasBias ? blobs[biasBlobIndex].ptr<float>() : 0; const float* biasData = hasBias ? blobs[biasBlobIndex].ptr<float>() : 0;
weights_.create(1, (int)n, CV_32F); origin_weights.create(1, (int)n, CV_32F);
bias_.create(1, (int)n, CV_32F); origin_bias.create(1, (int)n, CV_32F);
float* dstWeightsData = weights_.ptr<float>(); float* dstWeightsData = origin_weights.ptr<float>();
float* dstBiasData = bias_.ptr<float>(); float* dstBiasData = origin_bias.ptr<float>();
for (size_t i = 0; i < n; ++i) for (size_t i = 0; i < n; ++i)
{ {
@ -94,15 +95,12 @@ public:
dstWeightsData[i] = w; dstWeightsData[i] = w;
dstBiasData[i] = (hasBias ? biasData[i] : 0.0f) - w * meanData[i] * varMeanScale; dstBiasData[i] = (hasBias ? biasData[i] : 0.0f) - w * meanData[i] * varMeanScale;
} }
// We will use blobs to store origin weights and bias to restore them in case of reinitialization.
weights_.copyTo(blobs[0].reshape(1, 1));
bias_.copyTo(blobs[1].reshape(1, 1));
} }
virtual void finalize(InputArrayOfArrays, OutputArrayOfArrays) CV_OVERRIDE virtual void finalize(InputArrayOfArrays, OutputArrayOfArrays) CV_OVERRIDE
{ {
blobs[0].reshape(1, 1).copyTo(weights_); origin_weights.reshape(1, 1).copyTo(weights_);
blobs[1].reshape(1, 1).copyTo(bias_); origin_bias.reshape(1, 1).copyTo(bias_);
} }
void getScaleShift(Mat& scale, Mat& shift) const CV_OVERRIDE void getScaleShift(Mat& scale, Mat& shift) const CV_OVERRIDE