mirror of
https://github.com/opencv/opencv.git
synced 2025-06-13 04:52:53 +08:00
Fix Mobilenet v2 from TensorFlow slim
This commit is contained in:
parent
9340fc0c50
commit
9cfd219d70
@ -630,6 +630,21 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SoftMaxSlimSubgraph : public Subgraph
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
SoftMaxSlimSubgraph()
|
||||||
|
{
|
||||||
|
int input = addNodeToMatch("");
|
||||||
|
int shape = addNodeToMatch("Const");
|
||||||
|
int shapeOp = addNodeToMatch("Shape", input);
|
||||||
|
int reshape = addNodeToMatch("Reshape", input, shape);
|
||||||
|
int softmax = addNodeToMatch("Softmax", reshape);
|
||||||
|
addNodeToMatch("Reshape", softmax, shapeOp);
|
||||||
|
setFusedNode("Softmax", input);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void simplifySubgraphs(tensorflow::GraphDef& net)
|
void simplifySubgraphs(tensorflow::GraphDef& net)
|
||||||
{
|
{
|
||||||
std::vector<Ptr<Subgraph> > subgraphs;
|
std::vector<Ptr<Subgraph> > subgraphs;
|
||||||
@ -646,6 +661,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
|
|||||||
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
|
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
|
||||||
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
|
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
|
||||||
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
|
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
|
||||||
|
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph()));
|
||||||
|
|
||||||
int numNodes = net.node_size();
|
int numNodes = net.node_size();
|
||||||
std::vector<int> matchedNodesIds;
|
std::vector<int> matchedNodesIds;
|
||||||
|
@ -661,7 +661,10 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
RemoveIdentityOps(netTxt);
|
RemoveIdentityOps(netTxt);
|
||||||
|
|
||||||
if (!netTxt.ByteSize())
|
if (!netTxt.ByteSize())
|
||||||
|
{
|
||||||
simplifySubgraphs(netBin);
|
simplifySubgraphs(netBin);
|
||||||
|
sortByExecutionOrder(netBin);
|
||||||
|
}
|
||||||
|
|
||||||
std::set<String> layers_to_ignore;
|
std::set<String> layers_to_ignore;
|
||||||
|
|
||||||
|
@ -549,6 +549,7 @@ TEST_P(Test_TensorFlow_layers, slice)
|
|||||||
TEST_P(Test_TensorFlow_layers, softmax)
|
TEST_P(Test_TensorFlow_layers, softmax)
|
||||||
{
|
{
|
||||||
runTensorFlowNet("keras_softmax");
|
runTensorFlowNet("keras_softmax");
|
||||||
|
runTensorFlowNet("slim_softmax");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(Test_TensorFlow_layers, relu6)
|
TEST_P(Test_TensorFlow_layers, relu6)
|
||||||
|
Loading…
Reference in New Issue
Block a user