mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #11890 from dkurt:keras_resize_nearest
This commit is contained in:
commit
ccd2370bb7
@ -571,6 +571,50 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// In case of resizing by factor.
|
||||||
|
class UpsamplingKerasSubgraph : public Subgraph
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
UpsamplingKerasSubgraph()
|
||||||
|
{
|
||||||
|
int input = addNodeToMatch("");
|
||||||
|
int shape = addNodeToMatch("Shape", input);
|
||||||
|
int stack = addNodeToMatch("Const");
|
||||||
|
int stack_1 = addNodeToMatch("Const");
|
||||||
|
int stack_2 = addNodeToMatch("Const");
|
||||||
|
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||||
|
int factors = addNodeToMatch("Const");
|
||||||
|
int mul = addNodeToMatch("Mul", strided_slice, factors);
|
||||||
|
addNodeToMatch("ResizeNearestNeighbor", input, mul);
|
||||||
|
setFusedNode("ResizeNearestNeighbor", input, factors);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
|
||||||
|
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
|
||||||
|
{
|
||||||
|
Mat factorsMat = getTensorContent(inputNodes[1]->attr().at("value").tensor());
|
||||||
|
CV_Assert(factorsMat.total() == 2, factorsMat.type() == CV_32SC1);
|
||||||
|
|
||||||
|
// Height scale factor
|
||||||
|
tensorflow::TensorProto* factorY = inputNodes[1]->mutable_attr()->at("value").mutable_tensor();
|
||||||
|
factorY->clear_int_val();
|
||||||
|
factorY->clear_tensor_content();
|
||||||
|
factorY->add_int_val(factorsMat.at<int>(0, 0));
|
||||||
|
|
||||||
|
// Width scale factor.
|
||||||
|
tensorflow::NodeDef* factorXNode = net.add_node();
|
||||||
|
factorXNode->set_op("Const");
|
||||||
|
factorXNode->set_name(fusedNode->name() + "/factor_y");
|
||||||
|
|
||||||
|
tensorflow::AttrValue factorX;
|
||||||
|
factorX.mutable_tensor()->set_dtype(tensorflow::DT_INT32);
|
||||||
|
factorX.mutable_tensor()->add_int_val(factorsMat.at<int>(0, 1));
|
||||||
|
factorXNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("value", factorX));
|
||||||
|
|
||||||
|
fusedNode->add_input(factorXNode->name());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void simplifySubgraphs(tensorflow::GraphDef& net)
|
void simplifySubgraphs(tensorflow::GraphDef& net)
|
||||||
{
|
{
|
||||||
std::vector<Ptr<Subgraph> > subgraphs;
|
std::vector<Ptr<Subgraph> > subgraphs;
|
||||||
@ -585,6 +629,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
|
|||||||
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph()));
|
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph()));
|
||||||
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph()));
|
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph()));
|
||||||
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
|
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
|
||||||
|
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
|
||||||
|
|
||||||
int numNodes = net.node_size();
|
int numNodes = net.node_size();
|
||||||
std::vector<int> matchedNodesIds;
|
std::vector<int> matchedNodesIds;
|
||||||
|
@ -403,6 +403,7 @@ TEST(Test_TensorFlow, split)
|
|||||||
TEST(Test_TensorFlow, resize_nearest_neighbor)
|
TEST(Test_TensorFlow, resize_nearest_neighbor)
|
||||||
{
|
{
|
||||||
runTensorFlowNet("resize_nearest_neighbor");
|
runTensorFlowNet("resize_nearest_neighbor");
|
||||||
|
runTensorFlowNet("keras_upsampling2d");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Test_TensorFlow, slice)
|
TEST(Test_TensorFlow, slice)
|
||||||
|
Loading…
Reference in New Issue
Block a user