Merge pull request #9787 from dkurt:feature_dnn_resize_nearest_neighbor

This commit is contained in:
Vadim Pisarevsky 2017-10-06 13:46:50 +00:00
commit b969d86415
5 changed files with 104 additions and 0 deletions

View File

@ -539,6 +539,17 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
static Ptr<NormalizeBBoxLayer> create(const LayerParams& params);
};
/**
* @brief Resize input 4-dimensional blob by nearest neghbor strategy.
*
* Layer is used to support TensorFlow's resize_nearest_neighbor op.
*/
class CV_EXPORTS ResizeNearestNeighborLayer : public Layer
{
public:
static Ptr<ResizeNearestNeighborLayer> create(const LayerParams& params);
};
//! @}
//! @}
CV__DNN_EXPERIMENTAL_NS_END

View File

@ -83,6 +83,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(Concat, ConcatLayer);
CV_DNN_REGISTER_LAYER_CLASS(Reshape, ReshapeLayer);
CV_DNN_REGISTER_LAYER_CLASS(Flatten, FlattenLayer);
CV_DNN_REGISTER_LAYER_CLASS(ResizeNearestNeighbor, ResizeNearestNeighborLayer);
CV_DNN_REGISTER_LAYER_CLASS(Convolution, ConvolutionLayer);
CV_DNN_REGISTER_LAYER_CLASS(Deconvolution, DeconvolutionLayer);

View File

@ -0,0 +1,71 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2017, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#include "../precomp.hpp"
#include "layers_common.hpp"
#include <opencv2/imgproc.hpp>
namespace cv { namespace dnn {
class ResizeNearestNeighborLayerImpl : public ResizeNearestNeighborLayer
{
public:
ResizeNearestNeighborLayerImpl(const LayerParams& params)
{
setParamsFrom(params);
CV_Assert(params.has("width"), params.has("height"));
outWidth = params.get<float>("width");
outHeight = params.get<float>("height");
alignCorners = params.get<bool>("align_corners", false);
if (alignCorners)
CV_Error(Error::StsNotImplemented, "Nearest neighborhood resize with align_corners=true is not implemented");
}
bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
std::vector<MatShape> &internals) const
{
CV_Assert(inputs.size() == 1, inputs[0].size() == 4);
outputs.resize(1, inputs[0]);
outputs[0][2] = outHeight;
outputs[0][3] = outWidth;
// We can work in-place (do nothing) if input shape == output shape.
return (outputs[0][2] == inputs[0][2]) && (outputs[0][3] == inputs[0][3]);
}
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
{
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
if (outHeight == inputs[0]->size[2] && outWidth == inputs[0]->size[3])
return;
Mat& inp = *inputs[0];
Mat& out = outputs[0];
for (size_t n = 0; n < inputs[0]->size[0]; ++n)
{
for (size_t ch = 0; ch < inputs[0]->size[1]; ++ch)
{
resize(getPlane(inp, n, ch), getPlane(out, n, ch),
Size(outWidth, outHeight), 0, 0, INTER_NEAREST);
}
}
}
private:
int outWidth, outHeight;
bool alignCorners;
};
Ptr<ResizeNearestNeighborLayer> ResizeNearestNeighborLayer::create(const LayerParams& params)
{
return Ptr<ResizeNearestNeighborLayer>(new ResizeNearestNeighborLayerImpl(params));
}
} // namespace dnn
} // namespace cv

View File

@ -1132,6 +1132,22 @@ void TFImporter::populateNet(Net dstNet)
// one input only
connect(layer_id, dstNet, parsePin(layer.input(1)), id, 0);
}
else if (type == "ResizeNearestNeighbor")
{
Mat outSize = getTensorContent(getConstBlob(layer, value_id, 1));
CV_Assert(outSize.type() == CV_32SC1, outSize.total() == 2);
layerParams.set("height", outSize.at<int>(0, 0));
layerParams.set("width", outSize.at<int>(0, 1));
if (hasLayerAttr(layer, "align_corners"))
layerParams.set("align_corners", getLayerAttr(layer, "align_corners").b());
int id = dstNet.addLayer(name, "ResizeNearestNeighbor", layerParams);
layer_id[name] = id;
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
}
else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
type == "Relu" || type == "Elu" || type == "Softmax" ||
type == "Identity" || type == "Relu6")

View File

@ -175,4 +175,9 @@ TEST(Test_TensorFlow, split)
runTensorFlowNet("split_equals");
}
TEST(Test_TensorFlow, resize_nearest_neighbor)
{
runTensorFlowNet("resize_nearest_neighbor");
}
}