mirror of
https://github.com/opencv/opencv.git
synced 2024-11-25 11:40:44 +08:00
Merge pull request #17288 from dkurt:dnn_tf_resize_down
This commit is contained in:
commit
c8689d9d0a
@ -19,8 +19,8 @@ namespace cv { namespace dnn {
|
||||
class ResizeLayerImpl : public ResizeLayer
|
||||
{
|
||||
public:
|
||||
ResizeLayerImpl(const LayerParams& params) : zoomFactorWidth(params.get<int>("zoom_factor_x", params.get<int>("zoom_factor", 0))),
|
||||
zoomFactorHeight(params.get<int>("zoom_factor_y", params.get<int>("zoom_factor", 0))),
|
||||
ResizeLayerImpl(const LayerParams& params) : zoomFactorWidth(params.get<float>("zoom_factor_x", params.get<float>("zoom_factor", 0))),
|
||||
zoomFactorHeight(params.get<float>("zoom_factor_y", params.get<float>("zoom_factor", 0))),
|
||||
scaleWidth(0), scaleHeight(0)
|
||||
{
|
||||
setParamsFrom(params);
|
||||
@ -223,7 +223,7 @@ public:
|
||||
|
||||
protected:
|
||||
int outWidth, outHeight;
|
||||
const int zoomFactorWidth, zoomFactorHeight;
|
||||
const float zoomFactorWidth, zoomFactorHeight;
|
||||
String interpolation;
|
||||
float scaleWidth, scaleHeight;
|
||||
bool alignCorners;
|
||||
|
@ -495,8 +495,9 @@ public:
|
||||
ResizeBilinearSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int shapeSource = addNodeToMatch("");
|
||||
|
||||
int shape = addNodeToMatch("Shape", input);
|
||||
int shape = addNodeToMatch("Shape", shapeSource);
|
||||
int stack = addNodeToMatch("Const");
|
||||
int stack_1 = addNodeToMatch("Const");
|
||||
int stack_2 = addNodeToMatch("Const");
|
||||
@ -504,7 +505,7 @@ public:
|
||||
int factorY = addNodeToMatch("Const");
|
||||
int mul = addNodeToMatch("Mul", strided_slice, factorY);
|
||||
|
||||
shape = addNodeToMatch("Shape", input);
|
||||
shape = addNodeToMatch("Shape", shapeSource);
|
||||
stack = addNodeToMatch("Const");
|
||||
stack_1 = addNodeToMatch("Const");
|
||||
stack_2 = addNodeToMatch("Const");
|
||||
@ -519,6 +520,51 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// In case of resizing by factor.
|
||||
class ResizeBilinearSubgraphDown : public TFSubgraph
|
||||
{
|
||||
public:
|
||||
ResizeBilinearSubgraphDown()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int shapeSource = addNodeToMatch("");
|
||||
|
||||
int shape = addNodeToMatch("Shape", shapeSource);
|
||||
int stack = addNodeToMatch("Const");
|
||||
int stack_1 = addNodeToMatch("Const");
|
||||
int stack_2 = addNodeToMatch("Const");
|
||||
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||
int factorY = addNodeToMatch("Const");
|
||||
int div = addNodeToMatch("RealDiv", addNodeToMatch("Cast", strided_slice), factorY);
|
||||
int cast = addNodeToMatch("Cast", div);
|
||||
|
||||
shape = addNodeToMatch("Shape", shapeSource);
|
||||
stack = addNodeToMatch("Const");
|
||||
stack_1 = addNodeToMatch("Const");
|
||||
stack_2 = addNodeToMatch("Const");
|
||||
strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||
int factorX = addNodeToMatch("Const");
|
||||
int div_1 = addNodeToMatch("RealDiv", addNodeToMatch("Cast", strided_slice), factorX);
|
||||
int cast_1 = addNodeToMatch("Cast", div_1);
|
||||
|
||||
int pack = addNodeToMatch("Pack", cast, cast_1);
|
||||
|
||||
addNodeToMatch("ResizeBilinear", input, pack);
|
||||
setFusedNode("ResizeBilinear", input, factorY, factorX);
|
||||
}
|
||||
|
||||
virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
|
||||
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
|
||||
{
|
||||
|
||||
for (int i = 1; i < 3; ++i)
|
||||
{
|
||||
tensorflow::TensorProto* factor = inputNodes[i]->mutable_attr()->at("value").mutable_tensor();
|
||||
factor->set_double_val(0, 1.0 / factor->double_val(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// In case of resizing by factor.
|
||||
class UpsamplingKerasSubgraph : public TFSubgraph
|
||||
{
|
||||
@ -702,6 +748,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
|
||||
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(true)));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(false)));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new FlattenProdSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraphDown()));
|
||||
|
||||
for (int i = 0; i < net.node_size(); ++i)
|
||||
{
|
||||
|
@ -1932,10 +1932,10 @@ void TFImporter::populateNet(Net dstNet)
|
||||
{
|
||||
Mat factorHeight = getTensorContent(getConstBlob(layer, value_id, 1));
|
||||
Mat factorWidth = getTensorContent(getConstBlob(layer, value_id, 2));
|
||||
CV_CheckTypeEQ(factorHeight.type(), CV_32SC1, ""); CV_CheckEQ(factorHeight.total(), (size_t)1, "");
|
||||
CV_CheckTypeEQ(factorWidth.type(), CV_32SC1, ""); CV_CheckEQ(factorWidth.total(), (size_t)1, "");
|
||||
layerParams.set("zoom_factor_x", factorWidth.at<int>(0));
|
||||
layerParams.set("zoom_factor_y", factorHeight.at<int>(0));
|
||||
factorHeight.convertTo(factorHeight, CV_32F);
|
||||
factorWidth.convertTo(factorWidth, CV_32F);
|
||||
layerParams.set("zoom_factor_x", factorWidth.at<float>(0));
|
||||
layerParams.set("zoom_factor_y", factorHeight.at<float>(0));
|
||||
}
|
||||
else
|
||||
CV_Assert(layer.input_size() == 2 || layer.input_size() == 3);
|
||||
|
@ -969,6 +969,7 @@ TEST_P(Test_TensorFlow_layers, resize_bilinear)
|
||||
{
|
||||
runTensorFlowNet("resize_bilinear");
|
||||
runTensorFlowNet("resize_bilinear_factor");
|
||||
runTensorFlowNet("resize_bilinear_down");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, tf2_dense)
|
||||
|
Loading…
Reference in New Issue
Block a user