Merge pull request #17288 from dkurt:dnn_tf_resize_down

This commit is contained in:
Alexander Alekhin 2020-05-14 07:08:41 +00:00
commit c8689d9d0a
4 changed files with 57 additions and 9 deletions

View File

@ -19,8 +19,8 @@ namespace cv { namespace dnn {
class ResizeLayerImpl : public ResizeLayer class ResizeLayerImpl : public ResizeLayer
{ {
public: public:
ResizeLayerImpl(const LayerParams& params) : zoomFactorWidth(params.get<int>("zoom_factor_x", params.get<int>("zoom_factor", 0))), ResizeLayerImpl(const LayerParams& params) : zoomFactorWidth(params.get<float>("zoom_factor_x", params.get<float>("zoom_factor", 0))),
zoomFactorHeight(params.get<int>("zoom_factor_y", params.get<int>("zoom_factor", 0))), zoomFactorHeight(params.get<float>("zoom_factor_y", params.get<float>("zoom_factor", 0))),
scaleWidth(0), scaleHeight(0) scaleWidth(0), scaleHeight(0)
{ {
setParamsFrom(params); setParamsFrom(params);
@ -223,7 +223,7 @@ public:
protected: protected:
int outWidth, outHeight; int outWidth, outHeight;
const int zoomFactorWidth, zoomFactorHeight; const float zoomFactorWidth, zoomFactorHeight;
String interpolation; String interpolation;
float scaleWidth, scaleHeight; float scaleWidth, scaleHeight;
bool alignCorners; bool alignCorners;

View File

@ -495,8 +495,9 @@ public:
ResizeBilinearSubgraph() ResizeBilinearSubgraph()
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int shapeSource = addNodeToMatch("");
int shape = addNodeToMatch("Shape", input); int shape = addNodeToMatch("Shape", shapeSource);
int stack = addNodeToMatch("Const"); int stack = addNodeToMatch("Const");
int stack_1 = addNodeToMatch("Const"); int stack_1 = addNodeToMatch("Const");
int stack_2 = addNodeToMatch("Const"); int stack_2 = addNodeToMatch("Const");
@ -504,7 +505,7 @@ public:
int factorY = addNodeToMatch("Const"); int factorY = addNodeToMatch("Const");
int mul = addNodeToMatch("Mul", strided_slice, factorY); int mul = addNodeToMatch("Mul", strided_slice, factorY);
shape = addNodeToMatch("Shape", input); shape = addNodeToMatch("Shape", shapeSource);
stack = addNodeToMatch("Const"); stack = addNodeToMatch("Const");
stack_1 = addNodeToMatch("Const"); stack_1 = addNodeToMatch("Const");
stack_2 = addNodeToMatch("Const"); stack_2 = addNodeToMatch("Const");
@ -519,6 +520,51 @@ public:
} }
}; };
// In case of resizing by factor.
class ResizeBilinearSubgraphDown : public TFSubgraph
{
public:
ResizeBilinearSubgraphDown()
{
int input = addNodeToMatch("");
int shapeSource = addNodeToMatch("");
int shape = addNodeToMatch("Shape", shapeSource);
int stack = addNodeToMatch("Const");
int stack_1 = addNodeToMatch("Const");
int stack_2 = addNodeToMatch("Const");
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
int factorY = addNodeToMatch("Const");
int div = addNodeToMatch("RealDiv", addNodeToMatch("Cast", strided_slice), factorY);
int cast = addNodeToMatch("Cast", div);
shape = addNodeToMatch("Shape", shapeSource);
stack = addNodeToMatch("Const");
stack_1 = addNodeToMatch("Const");
stack_2 = addNodeToMatch("Const");
strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
int factorX = addNodeToMatch("Const");
int div_1 = addNodeToMatch("RealDiv", addNodeToMatch("Cast", strided_slice), factorX);
int cast_1 = addNodeToMatch("Cast", div_1);
int pack = addNodeToMatch("Pack", cast, cast_1);
addNodeToMatch("ResizeBilinear", input, pack);
setFusedNode("ResizeBilinear", input, factorY, factorX);
}
virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
{
for (int i = 1; i < 3; ++i)
{
tensorflow::TensorProto* factor = inputNodes[i]->mutable_attr()->at("value").mutable_tensor();
factor->set_double_val(0, 1.0 / factor->double_val(0));
}
}
};
// In case of resizing by factor. // In case of resizing by factor.
class UpsamplingKerasSubgraph : public TFSubgraph class UpsamplingKerasSubgraph : public TFSubgraph
{ {
@ -702,6 +748,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(true))); subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(true)));
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(false))); subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(false)));
subgraphs.push_back(Ptr<Subgraph>(new FlattenProdSubgraph())); subgraphs.push_back(Ptr<Subgraph>(new FlattenProdSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraphDown()));
for (int i = 0; i < net.node_size(); ++i) for (int i = 0; i < net.node_size(); ++i)
{ {

View File

@ -1932,10 +1932,10 @@ void TFImporter::populateNet(Net dstNet)
{ {
Mat factorHeight = getTensorContent(getConstBlob(layer, value_id, 1)); Mat factorHeight = getTensorContent(getConstBlob(layer, value_id, 1));
Mat factorWidth = getTensorContent(getConstBlob(layer, value_id, 2)); Mat factorWidth = getTensorContent(getConstBlob(layer, value_id, 2));
CV_CheckTypeEQ(factorHeight.type(), CV_32SC1, ""); CV_CheckEQ(factorHeight.total(), (size_t)1, ""); factorHeight.convertTo(factorHeight, CV_32F);
CV_CheckTypeEQ(factorWidth.type(), CV_32SC1, ""); CV_CheckEQ(factorWidth.total(), (size_t)1, ""); factorWidth.convertTo(factorWidth, CV_32F);
layerParams.set("zoom_factor_x", factorWidth.at<int>(0)); layerParams.set("zoom_factor_x", factorWidth.at<float>(0));
layerParams.set("zoom_factor_y", factorHeight.at<int>(0)); layerParams.set("zoom_factor_y", factorHeight.at<float>(0));
} }
else else
CV_Assert(layer.input_size() == 2 || layer.input_size() == 3); CV_Assert(layer.input_size() == 2 || layer.input_size() == 3);

View File

@ -969,6 +969,7 @@ TEST_P(Test_TensorFlow_layers, resize_bilinear)
{ {
runTensorFlowNet("resize_bilinear"); runTensorFlowNet("resize_bilinear");
runTensorFlowNet("resize_bilinear_factor"); runTensorFlowNet("resize_bilinear_factor");
runTensorFlowNet("resize_bilinear_down");
} }
TEST_P(Test_TensorFlow_layers, tf2_dense) TEST_P(Test_TensorFlow_layers, tf2_dense)