mirror of
https://github.com/opencv/opencv.git
synced 2025-06-11 20:09:23 +08:00
Fix Mobilenet v2 from TensorFlow slim
This commit is contained in:
parent
9340fc0c50
commit
9cfd219d70
@ -630,6 +630,21 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class SoftMaxSlimSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
SoftMaxSlimSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int shape = addNodeToMatch("Const");
|
||||
int shapeOp = addNodeToMatch("Shape", input);
|
||||
int reshape = addNodeToMatch("Reshape", input, shape);
|
||||
int softmax = addNodeToMatch("Softmax", reshape);
|
||||
addNodeToMatch("Reshape", softmax, shapeOp);
|
||||
setFusedNode("Softmax", input);
|
||||
}
|
||||
};
|
||||
|
||||
void simplifySubgraphs(tensorflow::GraphDef& net)
|
||||
{
|
||||
std::vector<Ptr<Subgraph> > subgraphs;
|
||||
@ -646,6 +661,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
|
||||
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph()));
|
||||
|
||||
int numNodes = net.node_size();
|
||||
std::vector<int> matchedNodesIds;
|
||||
|
@ -661,7 +661,10 @@ void TFImporter::populateNet(Net dstNet)
|
||||
RemoveIdentityOps(netTxt);
|
||||
|
||||
if (!netTxt.ByteSize())
|
||||
{
|
||||
simplifySubgraphs(netBin);
|
||||
sortByExecutionOrder(netBin);
|
||||
}
|
||||
|
||||
std::set<String> layers_to_ignore;
|
||||
|
||||
|
@ -549,6 +549,7 @@ TEST_P(Test_TensorFlow_layers, slice)
|
||||
TEST_P(Test_TensorFlow_layers, softmax)
|
||||
{
|
||||
runTensorFlowNet("keras_softmax");
|
||||
runTensorFlowNet("slim_softmax");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, relu6)
|
||||
|
Loading…
Reference in New Issue
Block a user