mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Nearest neighbor resize layer
This commit is contained in:
parent
5f6ce6f4b0
commit
b9f94c9315
@ -539,6 +539,17 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
|||||||
static Ptr<NormalizeBBoxLayer> create(const LayerParams& params);
|
static Ptr<NormalizeBBoxLayer> create(const LayerParams& params);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Resize input 4-dimensional blob by nearest neghbor strategy.
|
||||||
|
*
|
||||||
|
* Layer is used to support TensorFlow's resize_nearest_neighbor op.
|
||||||
|
*/
|
||||||
|
class CV_EXPORTS ResizeNearestNeighborLayer : public Layer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static Ptr<ResizeNearestNeighborLayer> create(const LayerParams& params);
|
||||||
|
};
|
||||||
|
|
||||||
//! @}
|
//! @}
|
||||||
//! @}
|
//! @}
|
||||||
CV__DNN_EXPERIMENTAL_NS_END
|
CV__DNN_EXPERIMENTAL_NS_END
|
||||||
|
@ -83,6 +83,7 @@ void initializeLayerFactory()
|
|||||||
CV_DNN_REGISTER_LAYER_CLASS(Concat, ConcatLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Concat, ConcatLayer);
|
||||||
CV_DNN_REGISTER_LAYER_CLASS(Reshape, ReshapeLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Reshape, ReshapeLayer);
|
||||||
CV_DNN_REGISTER_LAYER_CLASS(Flatten, FlattenLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Flatten, FlattenLayer);
|
||||||
|
CV_DNN_REGISTER_LAYER_CLASS(ResizeNearestNeighbor, ResizeNearestNeighborLayer);
|
||||||
|
|
||||||
CV_DNN_REGISTER_LAYER_CLASS(Convolution, ConvolutionLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Convolution, ConvolutionLayer);
|
||||||
CV_DNN_REGISTER_LAYER_CLASS(Deconvolution, DeconvolutionLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Deconvolution, DeconvolutionLayer);
|
||||||
|
71
modules/dnn/src/layers/resize_nearest_neighbor_layer.cpp
Normal file
71
modules/dnn/src/layers/resize_nearest_neighbor_layer.cpp
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
// This file is part of OpenCV project.
|
||||||
|
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||||
|
// of this distribution and at http://opencv.org/license.html.
|
||||||
|
|
||||||
|
// Copyright (C) 2017, Intel Corporation, all rights reserved.
|
||||||
|
// Third party copyrights are property of their respective owners.
|
||||||
|
#include "../precomp.hpp"
|
||||||
|
#include "layers_common.hpp"
|
||||||
|
#include <opencv2/imgproc.hpp>
|
||||||
|
|
||||||
|
namespace cv { namespace dnn {
|
||||||
|
|
||||||
|
class ResizeNearestNeighborLayerImpl : public ResizeNearestNeighborLayer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ResizeNearestNeighborLayerImpl(const LayerParams& params)
|
||||||
|
{
|
||||||
|
setParamsFrom(params);
|
||||||
|
CV_Assert(params.has("width"), params.has("height"));
|
||||||
|
outWidth = params.get<float>("width");
|
||||||
|
outHeight = params.get<float>("height");
|
||||||
|
alignCorners = params.get<bool>("align_corners", false);
|
||||||
|
if (alignCorners)
|
||||||
|
CV_Error(Error::StsNotImplemented, "Nearest neighborhood resize with align_corners=true is not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||||
|
const int requiredOutputs,
|
||||||
|
std::vector<MatShape> &outputs,
|
||||||
|
std::vector<MatShape> &internals) const
|
||||||
|
{
|
||||||
|
CV_Assert(inputs.size() == 1, inputs[0].size() == 4);
|
||||||
|
outputs.resize(1, inputs[0]);
|
||||||
|
outputs[0][2] = outHeight;
|
||||||
|
outputs[0][3] = outWidth;
|
||||||
|
// We can work in-place (do nothing) if input shape == output shape.
|
||||||
|
return (outputs[0][2] == inputs[0][2]) && (outputs[0][3] == inputs[0][3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
|
||||||
|
{
|
||||||
|
CV_TRACE_FUNCTION();
|
||||||
|
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||||
|
|
||||||
|
if (outHeight == inputs[0]->size[2] && outWidth == inputs[0]->size[3])
|
||||||
|
return;
|
||||||
|
|
||||||
|
Mat& inp = *inputs[0];
|
||||||
|
Mat& out = outputs[0];
|
||||||
|
for (size_t n = 0; n < inputs[0]->size[0]; ++n)
|
||||||
|
{
|
||||||
|
for (size_t ch = 0; ch < inputs[0]->size[1]; ++ch)
|
||||||
|
{
|
||||||
|
resize(getPlane(inp, n, ch), getPlane(out, n, ch),
|
||||||
|
Size(outWidth, outHeight), 0, 0, INTER_NEAREST);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
int outWidth, outHeight;
|
||||||
|
bool alignCorners;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
Ptr<ResizeNearestNeighborLayer> ResizeNearestNeighborLayer::create(const LayerParams& params)
|
||||||
|
{
|
||||||
|
return Ptr<ResizeNearestNeighborLayer>(new ResizeNearestNeighborLayerImpl(params));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dnn
|
||||||
|
} // namespace cv
|
@ -1132,6 +1132,22 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
// one input only
|
// one input only
|
||||||
connect(layer_id, dstNet, parsePin(layer.input(1)), id, 0);
|
connect(layer_id, dstNet, parsePin(layer.input(1)), id, 0);
|
||||||
}
|
}
|
||||||
|
else if (type == "ResizeNearestNeighbor")
|
||||||
|
{
|
||||||
|
Mat outSize = getTensorContent(getConstBlob(layer, value_id, 1));
|
||||||
|
CV_Assert(outSize.type() == CV_32SC1, outSize.total() == 2);
|
||||||
|
|
||||||
|
layerParams.set("height", outSize.at<int>(0, 0));
|
||||||
|
layerParams.set("width", outSize.at<int>(0, 1));
|
||||||
|
|
||||||
|
if (hasLayerAttr(layer, "align_corners"))
|
||||||
|
layerParams.set("align_corners", getLayerAttr(layer, "align_corners").b());
|
||||||
|
|
||||||
|
int id = dstNet.addLayer(name, "ResizeNearestNeighbor", layerParams);
|
||||||
|
layer_id[name] = id;
|
||||||
|
|
||||||
|
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
||||||
|
}
|
||||||
else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
|
else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
|
||||||
type == "Relu" || type == "Elu" || type == "Softmax" ||
|
type == "Relu" || type == "Elu" || type == "Softmax" ||
|
||||||
type == "Identity" || type == "Relu6")
|
type == "Identity" || type == "Relu6")
|
||||||
|
@ -175,4 +175,9 @@ TEST(Test_TensorFlow, split)
|
|||||||
runTensorFlowNet("split_equals");
|
runTensorFlowNet("split_equals");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Test_TensorFlow, resize_nearest_neighbor)
|
||||||
|
{
|
||||||
|
runTensorFlowNet("resize_nearest_neighbor");
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user