mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 06:26:29 +08:00
Fix Identity Switch from Keras
This commit is contained in:
parent
e07ffe902e
commit
081d9bc73f
@ -601,7 +601,7 @@ public:
|
||||
class UpsamplingKerasSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
UpsamplingKerasSubgraph()
|
||||
UpsamplingKerasSubgraph(const std::string& type)
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int shape = addNodeToMatch("Shape", input);
|
||||
@ -611,8 +611,8 @@ public:
|
||||
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||
int factors = addNodeToMatch("Const");
|
||||
int mul = addNodeToMatch("Mul", strided_slice, factors);
|
||||
addNodeToMatch("ResizeNearestNeighbor", input, mul);
|
||||
setFusedNode("ResizeNearestNeighbor", input, factors);
|
||||
addNodeToMatch(type, input, mul);
|
||||
setFusedNode(type, input, factors);
|
||||
}
|
||||
|
||||
virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
|
||||
@ -707,7 +707,8 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
|
||||
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph("ResizeNearestNeighbor")));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph("ResizeBilinear")));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimV2Subgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
|
||||
@ -752,6 +753,8 @@ void RemoveIdentityOps(tensorflow::GraphDef& net)
|
||||
tensorflow::NodeDef* layer = net.mutable_node(li);
|
||||
for (int input_id = 0; input_id < layer->input_size(); input_id++) {
|
||||
String input_op_name = layer->input(input_id);
|
||||
input_op_name = input_op_name.substr(input_op_name.find('^') + 1,
|
||||
input_op_name.rfind(':'));
|
||||
IdentityOpsMap::iterator it = identity_ops.find(input_op_name);
|
||||
|
||||
if (it != identity_ops.end()) {
|
||||
|
@ -186,6 +186,7 @@ TEST_P(Test_TensorFlow_layers, batch_norm)
|
||||
runTensorFlowNet("unfused_batch_norm_no_gamma");
|
||||
runTensorFlowNet("mvn_batch_norm");
|
||||
runTensorFlowNet("mvn_batch_norm_1x1");
|
||||
runTensorFlowNet("switch_identity");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, batch_norm3D)
|
||||
|
Loading…
Reference in New Issue
Block a user