mirror of
https://github.com/opencv/opencv.git
synced 2025-06-11 20:09:23 +08:00
Merge pull request #11738 from dkurt:dnn_batch_norm_fusion_base
This commit is contained in:
commit
8c4e0dfd13
@ -96,6 +96,46 @@ public:
|
|||||||
shift = bias_;
|
shift = bias_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual bool tryFuse(Ptr<Layer>& top) CV_OVERRIDE
|
||||||
|
{
|
||||||
|
Mat w, b;
|
||||||
|
top->getScaleShift(w, b);
|
||||||
|
if (w.empty() && b.empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
const int numChannels = weights_.total();
|
||||||
|
const int numFusedWeights = w.total();
|
||||||
|
const int numFusedBias = b.total();
|
||||||
|
|
||||||
|
if ((numFusedWeights != numChannels && numFusedWeights != 1 && !w.empty()) ||
|
||||||
|
(numFusedBias != numChannels && numFusedBias != 1 && !b.empty()))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (!w.empty())
|
||||||
|
{
|
||||||
|
w = w.reshape(1, 1);
|
||||||
|
if (numFusedWeights == 1)
|
||||||
|
{
|
||||||
|
multiply(weights_, w.at<float>(0), weights_);
|
||||||
|
multiply(bias_, w.at<float>(0), bias_);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
multiply(weights_, w, weights_);
|
||||||
|
multiply(bias_, w, bias_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!b.empty())
|
||||||
|
{
|
||||||
|
b = b.reshape(1, 1);
|
||||||
|
if (numFusedBias == 1)
|
||||||
|
add(bias_, b.at<float>(0), bias_);
|
||||||
|
else
|
||||||
|
add(bias_, b.reshape(1, 1), bias_);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||||
const int requiredOutputs,
|
const int requiredOutputs,
|
||||||
std::vector<MatShape> &outputs,
|
std::vector<MatShape> &outputs,
|
||||||
|
Loading…
Reference in New Issue
Block a user