mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 22:44:02 +08:00
Merge pull request #14860 from vonchenplus:ocv_maxpoolgrad
This commit is contained in:
commit
e00b0f6f47
@ -43,12 +43,18 @@ public:
|
||||
std::vector<MatShape> &outputs,
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(inputs.size() == 2);
|
||||
CV_Assert(inputs.size() == 2 || inputs.size() == 3);
|
||||
CV_Assert(total(inputs[0]) == total(inputs[1]));
|
||||
|
||||
MatShape outShape = inputs[0];
|
||||
outShape[2] = (outShape[2] - 1) * poolStride.height + poolKernel.height - 2 * poolPad.height;
|
||||
outShape[3] = (outShape[3] - 1) * poolStride.width + poolKernel.width - 2 * poolPad.width;
|
||||
MatShape outShape;
|
||||
if (inputs.size() == 2)
|
||||
{
|
||||
outShape = inputs[0];
|
||||
outShape[2] = (outShape[2] - 1) * poolStride.height + poolKernel.height - 2 * poolPad.height;
|
||||
outShape[3] = (outShape[3] - 1) * poolStride.width + poolKernel.width - 2 * poolPad.width;
|
||||
}
|
||||
else
|
||||
outShape = inputs[2];
|
||||
|
||||
outputs.clear();
|
||||
outputs.push_back(outShape);
|
||||
@ -71,7 +77,7 @@ public:
|
||||
inputs_arr.getMatVector(inputs);
|
||||
outputs_arr.getMatVector(outputs);
|
||||
|
||||
CV_Assert(inputs.size() == 2);
|
||||
CV_Assert(inputs.size() == 2 || inputs.size() == 3);
|
||||
Mat& input = inputs[0];
|
||||
Mat& indices = inputs[1];
|
||||
|
||||
|
@ -1370,6 +1370,24 @@ void TFImporter::populateNet(Net dstNet)
|
||||
|
||||
connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, layer.input_size());
|
||||
}
|
||||
else if (type == "MaxPoolGrad")
|
||||
{
|
||||
CV_Assert(layer.input_size() == 3);
|
||||
|
||||
layerParams.set("pool_k_h", 0);
|
||||
layerParams.set("pool_k_w", 0);
|
||||
layerParams.set("pool_stride_h", 0);
|
||||
layerParams.set("pool_stride_w", 0);
|
||||
layerParams.set("pool_pad_h", 0);
|
||||
layerParams.set("pool_pad_w", 0);
|
||||
|
||||
int id = dstNet.addLayer(name, "MaxUnpool", layerParams);
|
||||
layer_id[name] = id;
|
||||
|
||||
connect(layer_id, dstNet, parsePin(layer.input(2)), id, 0);
|
||||
connect(layer_id, dstNet, parsePin(layer.input(1) + ":1"), id, 1);
|
||||
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 2);
|
||||
}
|
||||
else if (type == "Placeholder")
|
||||
{
|
||||
if (!hasLayerAttr(layer, "dtype") ||
|
||||
|
@ -218,6 +218,13 @@ TEST_P(Test_TensorFlow_layers, pooling)
|
||||
runTensorFlowNet("reduce_mean"); // an average pooling over all spatial dimensions.
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, max_pool_grad)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE);
|
||||
runTensorFlowNet("max_pool_grad");
|
||||
}
|
||||
|
||||
// TODO: fix tests and replace to pooling
|
||||
TEST_P(Test_TensorFlow_layers, ave_pool_same)
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user