mirror of
https://github.com/opencv/opencv.git
synced 2024-11-29 05:29:54 +08:00
Merge pull request #17288 from dkurt:dnn_tf_resize_down
This commit is contained in:
commit
c8689d9d0a
@ -19,8 +19,8 @@ namespace cv { namespace dnn {
|
|||||||
class ResizeLayerImpl : public ResizeLayer
|
class ResizeLayerImpl : public ResizeLayer
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
ResizeLayerImpl(const LayerParams& params) : zoomFactorWidth(params.get<int>("zoom_factor_x", params.get<int>("zoom_factor", 0))),
|
ResizeLayerImpl(const LayerParams& params) : zoomFactorWidth(params.get<float>("zoom_factor_x", params.get<float>("zoom_factor", 0))),
|
||||||
zoomFactorHeight(params.get<int>("zoom_factor_y", params.get<int>("zoom_factor", 0))),
|
zoomFactorHeight(params.get<float>("zoom_factor_y", params.get<float>("zoom_factor", 0))),
|
||||||
scaleWidth(0), scaleHeight(0)
|
scaleWidth(0), scaleHeight(0)
|
||||||
{
|
{
|
||||||
setParamsFrom(params);
|
setParamsFrom(params);
|
||||||
@ -223,7 +223,7 @@ public:
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
int outWidth, outHeight;
|
int outWidth, outHeight;
|
||||||
const int zoomFactorWidth, zoomFactorHeight;
|
const float zoomFactorWidth, zoomFactorHeight;
|
||||||
String interpolation;
|
String interpolation;
|
||||||
float scaleWidth, scaleHeight;
|
float scaleWidth, scaleHeight;
|
||||||
bool alignCorners;
|
bool alignCorners;
|
||||||
|
@ -495,8 +495,9 @@ public:
|
|||||||
ResizeBilinearSubgraph()
|
ResizeBilinearSubgraph()
|
||||||
{
|
{
|
||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
|
int shapeSource = addNodeToMatch("");
|
||||||
|
|
||||||
int shape = addNodeToMatch("Shape", input);
|
int shape = addNodeToMatch("Shape", shapeSource);
|
||||||
int stack = addNodeToMatch("Const");
|
int stack = addNodeToMatch("Const");
|
||||||
int stack_1 = addNodeToMatch("Const");
|
int stack_1 = addNodeToMatch("Const");
|
||||||
int stack_2 = addNodeToMatch("Const");
|
int stack_2 = addNodeToMatch("Const");
|
||||||
@ -504,7 +505,7 @@ public:
|
|||||||
int factorY = addNodeToMatch("Const");
|
int factorY = addNodeToMatch("Const");
|
||||||
int mul = addNodeToMatch("Mul", strided_slice, factorY);
|
int mul = addNodeToMatch("Mul", strided_slice, factorY);
|
||||||
|
|
||||||
shape = addNodeToMatch("Shape", input);
|
shape = addNodeToMatch("Shape", shapeSource);
|
||||||
stack = addNodeToMatch("Const");
|
stack = addNodeToMatch("Const");
|
||||||
stack_1 = addNodeToMatch("Const");
|
stack_1 = addNodeToMatch("Const");
|
||||||
stack_2 = addNodeToMatch("Const");
|
stack_2 = addNodeToMatch("Const");
|
||||||
@ -519,6 +520,51 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// In case of resizing by factor.
|
||||||
|
class ResizeBilinearSubgraphDown : public TFSubgraph
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ResizeBilinearSubgraphDown()
|
||||||
|
{
|
||||||
|
int input = addNodeToMatch("");
|
||||||
|
int shapeSource = addNodeToMatch("");
|
||||||
|
|
||||||
|
int shape = addNodeToMatch("Shape", shapeSource);
|
||||||
|
int stack = addNodeToMatch("Const");
|
||||||
|
int stack_1 = addNodeToMatch("Const");
|
||||||
|
int stack_2 = addNodeToMatch("Const");
|
||||||
|
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||||
|
int factorY = addNodeToMatch("Const");
|
||||||
|
int div = addNodeToMatch("RealDiv", addNodeToMatch("Cast", strided_slice), factorY);
|
||||||
|
int cast = addNodeToMatch("Cast", div);
|
||||||
|
|
||||||
|
shape = addNodeToMatch("Shape", shapeSource);
|
||||||
|
stack = addNodeToMatch("Const");
|
||||||
|
stack_1 = addNodeToMatch("Const");
|
||||||
|
stack_2 = addNodeToMatch("Const");
|
||||||
|
strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||||
|
int factorX = addNodeToMatch("Const");
|
||||||
|
int div_1 = addNodeToMatch("RealDiv", addNodeToMatch("Cast", strided_slice), factorX);
|
||||||
|
int cast_1 = addNodeToMatch("Cast", div_1);
|
||||||
|
|
||||||
|
int pack = addNodeToMatch("Pack", cast, cast_1);
|
||||||
|
|
||||||
|
addNodeToMatch("ResizeBilinear", input, pack);
|
||||||
|
setFusedNode("ResizeBilinear", input, factorY, factorX);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
|
||||||
|
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
|
||||||
|
{
|
||||||
|
|
||||||
|
for (int i = 1; i < 3; ++i)
|
||||||
|
{
|
||||||
|
tensorflow::TensorProto* factor = inputNodes[i]->mutable_attr()->at("value").mutable_tensor();
|
||||||
|
factor->set_double_val(0, 1.0 / factor->double_val(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// In case of resizing by factor.
|
// In case of resizing by factor.
|
||||||
class UpsamplingKerasSubgraph : public TFSubgraph
|
class UpsamplingKerasSubgraph : public TFSubgraph
|
||||||
{
|
{
|
||||||
@ -702,6 +748,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
|
|||||||
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(true)));
|
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(true)));
|
||||||
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(false)));
|
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(false)));
|
||||||
subgraphs.push_back(Ptr<Subgraph>(new FlattenProdSubgraph()));
|
subgraphs.push_back(Ptr<Subgraph>(new FlattenProdSubgraph()));
|
||||||
|
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraphDown()));
|
||||||
|
|
||||||
for (int i = 0; i < net.node_size(); ++i)
|
for (int i = 0; i < net.node_size(); ++i)
|
||||||
{
|
{
|
||||||
|
@ -1932,10 +1932,10 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
{
|
{
|
||||||
Mat factorHeight = getTensorContent(getConstBlob(layer, value_id, 1));
|
Mat factorHeight = getTensorContent(getConstBlob(layer, value_id, 1));
|
||||||
Mat factorWidth = getTensorContent(getConstBlob(layer, value_id, 2));
|
Mat factorWidth = getTensorContent(getConstBlob(layer, value_id, 2));
|
||||||
CV_CheckTypeEQ(factorHeight.type(), CV_32SC1, ""); CV_CheckEQ(factorHeight.total(), (size_t)1, "");
|
factorHeight.convertTo(factorHeight, CV_32F);
|
||||||
CV_CheckTypeEQ(factorWidth.type(), CV_32SC1, ""); CV_CheckEQ(factorWidth.total(), (size_t)1, "");
|
factorWidth.convertTo(factorWidth, CV_32F);
|
||||||
layerParams.set("zoom_factor_x", factorWidth.at<int>(0));
|
layerParams.set("zoom_factor_x", factorWidth.at<float>(0));
|
||||||
layerParams.set("zoom_factor_y", factorHeight.at<int>(0));
|
layerParams.set("zoom_factor_y", factorHeight.at<float>(0));
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
CV_Assert(layer.input_size() == 2 || layer.input_size() == 3);
|
CV_Assert(layer.input_size() == 2 || layer.input_size() == 3);
|
||||||
|
@ -969,6 +969,7 @@ TEST_P(Test_TensorFlow_layers, resize_bilinear)
|
|||||||
{
|
{
|
||||||
runTensorFlowNet("resize_bilinear");
|
runTensorFlowNet("resize_bilinear");
|
||||||
runTensorFlowNet("resize_bilinear_factor");
|
runTensorFlowNet("resize_bilinear_factor");
|
||||||
|
runTensorFlowNet("resize_bilinear_down");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(Test_TensorFlow_layers, tf2_dense)
|
TEST_P(Test_TensorFlow_layers, tf2_dense)
|
||||||
|
Loading…
Reference in New Issue
Block a user