dnn: fix MVN layer (issue 14683)

This commit is contained in:
Dmitry Kurtaev 2019-06-06 19:08:45 +03:00
parent eba696a41e
commit dfdc91f8c9
4 changed files with 77 additions and 26 deletions

View File

@ -118,7 +118,7 @@ public:
{
#ifdef HAVE_INF_ENGINE
if (backendId == DNN_BACKEND_INFERENCE_ENGINE)
return !zeroDev && eps <= 1e-7f;
return !zeroDev && (preferableTarget != DNN_TARGET_MYRIAD || eps <= 1e-7f);
else
#endif // HAVE_INF_ENGINE
return backendId == DNN_BACKEND_OPENCV;
@ -347,7 +347,11 @@ public:
bias = i < shift.cols ? ((float*)shift.data)[i] : bias;
}
cv::meanStdDev(inpRow, mean, (normVariance) ? dev : noArray());
double alpha = (normVariance) ? 1/(eps + dev[0]) : 1;
double alpha = 1;
if (normVariance)
{
alpha = 1 / std::sqrt(eps + dev[0]*dev[0]);
}
double normalizationScale = 1.0;
double normalizationShift = 0.0;
if (fuse_batch_norm)

View File

@ -118,10 +118,10 @@ __kernel void MVN(__global const Dtype* src,
return;
Dtype mean_val = mean[x];
Dtype dev_val = sqrt(dev[x]);
Dtype dev_val = dev[x];
Dtype alpha;
#ifdef NORM_VARIANCE
alpha = 1 / (eps + dev_val);
alpha = 1 / sqrt(eps + dev_val);
#else
alpha = 1;
#endif
@ -275,7 +275,7 @@ __kernel void MVN_FUSE(__global const Dtype * tmp,
barrier(CLK_LOCAL_MEM_FENCE);
Dtype4 mean_val = convert_float4(mean[row_gid]);
Dtype4 dev_val = sqrt(work[0] * alpha_val) + (Dtype4)eps;
Dtype4 dev_val = sqrt(work[0] * alpha_val + (Dtype4)eps);
Dtype4 alpha = (Dtype4)1.f / dev_val;
Dtype4 w = (Dtype4)1.f;

View File

@ -70,13 +70,6 @@ public:
{
fusedNodeInputs = inputs_;
fusedNodeOp = op;
nodesToFuse.clear();
for (int i = 0; i < nodes.size(); ++i)
{
if (std::find(fusedNodeInputs.begin(), fusedNodeInputs.end(), i) == fusedNodeInputs.end() &&
nodes[i] != "Const")
nodesToFuse.push_back(i);
}
}
static int getInputNodeId(const tensorflow::GraphDef& net,
@ -99,15 +92,17 @@ public:
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector<int>& matchedNodesIds)
virtual bool match(const tensorflow::GraphDef& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds)
{
matchedNodesIds.clear();
matchedNodesIds.reserve(nodesToFuse.size());
targetNodesIds.clear();
std::queue<int> nodesToMatch;
std::queue<int> targetNodes;
nodesToMatch.push(nodeId);
targetNodes.push(nodesToFuse.back());
targetNodes.push(nodes.size() - 1);
while (!nodesToMatch.empty())
{
int nodeToMatch = nodesToMatch.front();
@ -142,13 +137,25 @@ public:
return false;
}
matchedNodesIds.push_back(nodeToMatch);
targetNodesIds.push_back(targetNodeId);
}
const int n = matchedNodesIds.size();
std::vector<std::pair<int, int> > elements(n);
for (int i = 0; i < n; ++i)
elements[i] = std::make_pair(matchedNodesIds[i], targetNodesIds[i]);
std::sort(elements.begin(), elements.end());
for (int i = 0; i < n; ++i)
{
matchedNodesIds[i] = elements[i].first;
targetNodesIds[i] = elements[i].second;
}
std::sort(matchedNodesIds.begin(), matchedNodesIds.end());
return true;
}
// Fuse matched subgraph.
void replace(tensorflow::GraphDef& net, const std::vector<int>& matchedNodesIds)
void replace(tensorflow::GraphDef& net, const std::vector<int>& matchedNodesIds,
const std::vector<int>& targetNodesIds)
{
// Extract names of input nodes.
std::vector<std::string> inputsNames(fusedNodeInputs.size());
@ -159,7 +166,7 @@ public:
for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j)
{
const tensorflow::NodeDef &node = net.node(matchedNodesIds[j]);
std::vector<int>& inpIndices = inputs[nodesToFuse[j]];
std::vector<int>& inpIndices = inputs[targetNodesIds[j]];
CV_Assert(node.input_size() == inpIndices.size());
for (int k = 0; k < inpIndices.size(); ++k)
@ -204,7 +211,6 @@ private:
std::vector<std::vector<int> > inputs; // Connections of an every node to it's inputs.
std::string fusedNodeOp; // Operation name of resulting fused node.
std::vector<int> nodesToFuse; // Set of nodes to be fused.
std::vector<int> fusedNodeInputs; // Inputs of fused node.
};
@ -360,9 +366,11 @@ public:
setFusedNode("Relu6", input);
}
virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector<int>& matchedNodesIds) CV_OVERRIDE
virtual bool match(const tensorflow::GraphDef& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE
{
if (!Subgraph::match(net, nodeId, matchedNodesIds))
if (!Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
return false;
Mat maxValue = getTensorContent(net.node(matchedNodesIds.front() + 1).attr().at("value").tensor());
return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at<float>(0) == 6;
@ -394,14 +402,16 @@ public:
setFusedNode("Reshape", ids);
}
virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector<int>& matchedNodesIds) CV_OVERRIDE
virtual bool match(const tensorflow::GraphDef& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE
{
const tensorflow::NodeDef& node = net.node(nodeId);
if (node.input_size() == 0)
return false;
inpName = node.input(0);
return Subgraph::match(net, nodeId, matchedNodesIds);
return Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds);
}
@ -693,6 +703,40 @@ public:
}
};
class KerasMVNSubgraph : public Subgraph
{
public:
KerasMVNSubgraph()
{
int input = addNodeToMatch("");
int mean = addNodeToMatch("Mean", input, addNodeToMatch("Const"));
int grad = addNodeToMatch("StopGradient", mean);
int diff = addNodeToMatch("SquaredDifference", input, grad);
int var = addNodeToMatch("Mean", diff, addNodeToMatch("Const"));
int sub = addNodeToMatch("Sub", input, mean);
int add_y = addNodeToMatch("Const");
int add = addNodeToMatch("Add", var, add_y);
int pow_y = addNodeToMatch("Const");
int powNode = addNodeToMatch("Pow", add, pow_y);
addNodeToMatch("RealDiv", sub, powNode);
setFusedNode("MVN", input, add_y);
}
virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
{
tensorflow::AttrValue eps;
Mat epsMat = getTensorContent(inputNodes[1]->attr().at("value").tensor());
CV_CheckEQ(epsMat.total(), (size_t)1, "");
CV_CheckTypeEQ(epsMat.type(), CV_32FC1, "");
eps.set_f(epsMat.at<float>(0));
fusedNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("eps", eps));
fusedNode->mutable_input()->RemoveLast();
}
};
void simplifySubgraphs(tensorflow::GraphDef& net)
{
std::vector<Ptr<Subgraph> > subgraphs;
@ -712,16 +756,17 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimV2Subgraph()));
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new KerasMVNSubgraph()));
int numNodes = net.node_size();
std::vector<int> matchedNodesIds;
std::vector<int> matchedNodesIds, targetNodesIds;
for (int i = 0; i < numNodes; ++i)
{
for (int j = 0; j < subgraphs.size(); ++j)
{
if (subgraphs[j]->match(net, i, matchedNodesIds))
if (subgraphs[j]->match(net, i, matchedNodesIds, targetNodesIds))
{
subgraphs[j]->replace(net, matchedNodesIds);
subgraphs[j]->replace(net, matchedNodesIds, targetNodesIds);
numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added.
break;
}

View File

@ -190,6 +190,7 @@ TEST_P(Test_TensorFlow_layers, batch_norm)
runTensorFlowNet("mvn_batch_norm");
runTensorFlowNet("mvn_batch_norm_1x1");
runTensorFlowNet("switch_identity");
runTensorFlowNet("keras_batch_norm_training");
}
TEST_P(Test_TensorFlow_layers, batch_norm3D)
@ -259,6 +260,7 @@ TEST_P(Test_TensorFlow_layers, deconvolution)
runTensorFlowNet("deconvolution_adj_pad_same");
runTensorFlowNet("keras_deconv_valid");
runTensorFlowNet("keras_deconv_same");
runTensorFlowNet("keras_deconv_same_v2");
}
TEST_P(Test_TensorFlow_layers, matmul)