mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #11890 from dkurt:keras_resize_nearest
This commit is contained in:
commit
ccd2370bb7
@ -571,6 +571,50 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// In case of resizing by factor.
|
||||
class UpsamplingKerasSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
UpsamplingKerasSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int shape = addNodeToMatch("Shape", input);
|
||||
int stack = addNodeToMatch("Const");
|
||||
int stack_1 = addNodeToMatch("Const");
|
||||
int stack_2 = addNodeToMatch("Const");
|
||||
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||
int factors = addNodeToMatch("Const");
|
||||
int mul = addNodeToMatch("Mul", strided_slice, factors);
|
||||
addNodeToMatch("ResizeNearestNeighbor", input, mul);
|
||||
setFusedNode("ResizeNearestNeighbor", input, factors);
|
||||
}
|
||||
|
||||
virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
|
||||
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
|
||||
{
|
||||
Mat factorsMat = getTensorContent(inputNodes[1]->attr().at("value").tensor());
|
||||
CV_Assert(factorsMat.total() == 2, factorsMat.type() == CV_32SC1);
|
||||
|
||||
// Height scale factor
|
||||
tensorflow::TensorProto* factorY = inputNodes[1]->mutable_attr()->at("value").mutable_tensor();
|
||||
factorY->clear_int_val();
|
||||
factorY->clear_tensor_content();
|
||||
factorY->add_int_val(factorsMat.at<int>(0, 0));
|
||||
|
||||
// Width scale factor.
|
||||
tensorflow::NodeDef* factorXNode = net.add_node();
|
||||
factorXNode->set_op("Const");
|
||||
factorXNode->set_name(fusedNode->name() + "/factor_y");
|
||||
|
||||
tensorflow::AttrValue factorX;
|
||||
factorX.mutable_tensor()->set_dtype(tensorflow::DT_INT32);
|
||||
factorX.mutable_tensor()->add_int_val(factorsMat.at<int>(0, 1));
|
||||
factorXNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("value", factorX));
|
||||
|
||||
fusedNode->add_input(factorXNode->name());
|
||||
}
|
||||
};
|
||||
|
||||
void simplifySubgraphs(tensorflow::GraphDef& net)
|
||||
{
|
||||
std::vector<Ptr<Subgraph> > subgraphs;
|
||||
@ -585,6 +629,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
|
||||
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
|
||||
|
||||
int numNodes = net.node_size();
|
||||
std::vector<int> matchedNodesIds;
|
||||
|
@ -403,6 +403,7 @@ TEST(Test_TensorFlow, split)
|
||||
TEST(Test_TensorFlow, resize_nearest_neighbor)
|
||||
{
|
||||
runTensorFlowNet("resize_nearest_neighbor");
|
||||
runTensorFlowNet("keras_upsampling2d");
|
||||
}
|
||||
|
||||
TEST(Test_TensorFlow, slice)
|
||||
|
Loading…
Reference in New Issue
Block a user