mirror of
https://github.com/opencv/opencv.git
synced 2025-06-17 15:20:51 +08:00
Add Reshape layer tests
This commit is contained in:
parent
f73eff7517
commit
7ed5d85f25
@ -82,17 +82,26 @@ static void computeShapeByReshapeMask(const MatShape &srcShape,
|
|||||||
{
|
{
|
||||||
if (matched)
|
if (matched)
|
||||||
{
|
{
|
||||||
if (i == 0 || total(srcShape, i, srcRange.end) != maskTotal)
|
if (total(srcShape, i, srcRange.end) != maskTotal)
|
||||||
{
|
{
|
||||||
srcRange.start = i + 1;
|
srcRange.start = i + 1;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
else if (i == 0)
|
||||||
|
{
|
||||||
|
srcRange.start = 0;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
matched = total(srcShape, i, srcRange.end) == maskTotal;
|
matched = total(srcShape, i, srcRange.end) == maskTotal;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
while (total(srcShape, srcRange.start, srcRange.end) != maskTotal && srcRange.start > 0)
|
||||||
|
{
|
||||||
|
srcRange.start -= 1;
|
||||||
|
}
|
||||||
CV_Assert(total(srcShape, srcRange.start, srcRange.end) == maskTotal);
|
CV_Assert(total(srcShape, srcRange.start, srcRange.end) == maskTotal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -262,6 +262,18 @@ static int getDataLayout(const tensorflow::NodeDef& layer)
|
|||||||
return DATA_LAYOUT_UNKNOWN;
|
return DATA_LAYOUT_UNKNOWN;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline std::string getNodeName(const std::string& tensorName)
|
||||||
|
{
|
||||||
|
return tensorName.substr(0, tensorName.rfind(':'));
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline int getDataLayout(const std::string& layerName,
|
||||||
|
const std::map<String, int>& data_layouts)
|
||||||
|
{
|
||||||
|
std::map<String, int>::const_iterator it = data_layouts.find(getNodeName(layerName));
|
||||||
|
return it != data_layouts.end() ? it->second : DATA_LAYOUT_UNKNOWN;
|
||||||
|
}
|
||||||
|
|
||||||
void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer)
|
void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer)
|
||||||
{
|
{
|
||||||
if (hasLayerAttr(layer, "strides"))
|
if (hasLayerAttr(layer, "strides"))
|
||||||
@ -604,11 +616,6 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline std::string getNodeName(const std::string& tensorName)
|
|
||||||
{
|
|
||||||
return tensorName.substr(0, tensorName.rfind(':'));
|
|
||||||
}
|
|
||||||
|
|
||||||
// If all inputs of specific layer have the same data layout we can say that
|
// If all inputs of specific layer have the same data layout we can say that
|
||||||
// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
|
// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
|
||||||
static int predictOutputDataLayout(const tensorflow::GraphDef& net,
|
static int predictOutputDataLayout(const tensorflow::GraphDef& net,
|
||||||
@ -830,7 +837,8 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
// one input only
|
// one input only
|
||||||
connect(layer_id, dstNet, parsePin(input), id, 0);
|
connect(layer_id, dstNet, parsePin(input), id, 0);
|
||||||
|
|
||||||
if (data_layouts[name] == DATA_LAYOUT_UNKNOWN)
|
|
||||||
|
if (getDataLayout(name, data_layouts) == DATA_LAYOUT_UNKNOWN)
|
||||||
data_layouts[name] = DATA_LAYOUT_NHWC;
|
data_layouts[name] = DATA_LAYOUT_NHWC;
|
||||||
}
|
}
|
||||||
else if (type == "BiasAdd" || type == "Add")
|
else if (type == "BiasAdd" || type == "Add")
|
||||||
@ -956,7 +964,8 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
Pin inpId = parsePin(layer.input(0));
|
Pin inpId = parsePin(layer.input(0));
|
||||||
Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1));
|
Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1));
|
||||||
|
|
||||||
if (newShape.total() != 4 && data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
|
int inpLayout = getDataLayout(layer.input(0), data_layouts);
|
||||||
|
if (newShape.total() != 4 && inpLayout == DATA_LAYOUT_NHWC)
|
||||||
{
|
{
|
||||||
LayerParams permLP;
|
LayerParams permLP;
|
||||||
int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC.
|
int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC.
|
||||||
@ -969,7 +978,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
connect(layer_id, dstNet, inpId, permId, 0);
|
connect(layer_id, dstNet, inpId, permId, 0);
|
||||||
inpId = Pin(permName);
|
inpId = Pin(permName);
|
||||||
}
|
}
|
||||||
else if (newShape.total() == 4 && data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
|
else if (newShape.total() == 4 && inpLayout == DATA_LAYOUT_NHWC)
|
||||||
{
|
{
|
||||||
// NHWC->NCHW
|
// NHWC->NCHW
|
||||||
std::swap(*newShape.ptr<int32_t>(0, 2), *newShape.ptr<int32_t>(0, 3));
|
std::swap(*newShape.ptr<int32_t>(0, 2), *newShape.ptr<int32_t>(0, 3));
|
||||||
@ -987,7 +996,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
else if (type == "Flatten" || type == "Squeeze")
|
else if (type == "Flatten" || type == "Squeeze")
|
||||||
{
|
{
|
||||||
Pin inpId = parsePin(layer.input(0));
|
Pin inpId = parsePin(layer.input(0));
|
||||||
int inpLayout = data_layouts[layer.input(0)];
|
int inpLayout = getDataLayout(layer.input(0), data_layouts);
|
||||||
if (type == "Squeeze")
|
if (type == "Squeeze")
|
||||||
{
|
{
|
||||||
CV_Assert(hasLayerAttr(layer, "squeeze_dims"));
|
CV_Assert(hasLayerAttr(layer, "squeeze_dims"));
|
||||||
@ -1032,7 +1041,8 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
{
|
{
|
||||||
// Only NHWC <-> NCHW permutations are allowed. OpenCV is always
|
// Only NHWC <-> NCHW permutations are allowed. OpenCV is always
|
||||||
// keep NCHW layout this way.
|
// keep NCHW layout this way.
|
||||||
if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
|
int inpLayout = getDataLayout(layer.input(0), data_layouts);
|
||||||
|
if (inpLayout == DATA_LAYOUT_NHWC)
|
||||||
{
|
{
|
||||||
if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2)
|
if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2)
|
||||||
{
|
{
|
||||||
@ -1049,7 +1059,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
else
|
else
|
||||||
CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
|
CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
|
||||||
}
|
}
|
||||||
else if (data_layouts[layer.input(0)] == DATA_LAYOUT_NCHW)
|
else if (inpLayout == DATA_LAYOUT_NCHW)
|
||||||
{
|
{
|
||||||
if (permData[0] == 0 && permData[1] == 2 && permData[2] == 3 && permData[3] == 1)
|
if (permData[0] == 0 && permData[1] == 2 && permData[2] == 3 && permData[3] == 1)
|
||||||
{
|
{
|
||||||
@ -1112,7 +1122,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
int axisId = (type == "Concat" ? 0 : layer.input_size() - 1);
|
int axisId = (type == "Concat" ? 0 : layer.input_size() - 1);
|
||||||
int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0);
|
int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0);
|
||||||
|
|
||||||
if (data_layouts[name] == DATA_LAYOUT_NHWC)
|
if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC)
|
||||||
axis = toNCHW(axis);
|
axis = toNCHW(axis);
|
||||||
layerParams.set("axis", axis);
|
layerParams.set("axis", axis);
|
||||||
|
|
||||||
@ -1197,7 +1207,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
CV_Assert(!begins.empty(), !sizes.empty(), begins.type() == CV_32SC1,
|
CV_Assert(!begins.empty(), !sizes.empty(), begins.type() == CV_32SC1,
|
||||||
sizes.type() == CV_32SC1);
|
sizes.type() == CV_32SC1);
|
||||||
|
|
||||||
if (begins.total() == 4 && data_layouts[name] == DATA_LAYOUT_NHWC)
|
if (begins.total() == 4 && getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC)
|
||||||
{
|
{
|
||||||
// Swap NHWC parameters' order to NCHW.
|
// Swap NHWC parameters' order to NCHW.
|
||||||
std::swap(*begins.ptr<int32_t>(0, 2), *begins.ptr<int32_t>(0, 3));
|
std::swap(*begins.ptr<int32_t>(0, 2), *begins.ptr<int32_t>(0, 3));
|
||||||
@ -1597,7 +1607,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
CV_Assert(reductionIndices.type() == CV_32SC1);
|
CV_Assert(reductionIndices.type() == CV_32SC1);
|
||||||
|
|
||||||
const int numAxes = reductionIndices.total();
|
const int numAxes = reductionIndices.total();
|
||||||
if (data_layouts[name] == DATA_LAYOUT_NHWC)
|
if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC)
|
||||||
for (int i = 0; i < numAxes; ++i)
|
for (int i = 0; i < numAxes; ++i)
|
||||||
reductionIndices.at<int>(i) = toNCHW(reductionIndices.at<int>(i));
|
reductionIndices.at<int>(i) = toNCHW(reductionIndices.at<int>(i));
|
||||||
|
|
||||||
|
@ -592,8 +592,8 @@ struct TorchImporter
|
|||||||
DictValue dimParam = scalarParams.get("size");
|
DictValue dimParam = scalarParams.get("size");
|
||||||
layerParams.set("dim", dimParam);
|
layerParams.set("dim", dimParam);
|
||||||
|
|
||||||
if (scalarParams.has("batchMode") && scalarParams.get<bool>("batchMode"))
|
int axis = (int)scalarParams.get<bool>("batchMode", true);
|
||||||
layerParams.set("axis", 1);
|
layerParams.set("axis", axis);
|
||||||
|
|
||||||
curModule->modules.push_back(newModule);
|
curModule->modules.push_back(newModule);
|
||||||
}
|
}
|
||||||
|
@ -201,6 +201,13 @@ TEST(Layer_Test_Reshape, Accuracy)
|
|||||||
testReshape(MatShape(inp, inp + 4), MatShape(out, out + 2), 0, -1,
|
testReshape(MatShape(inp, inp + 4), MatShape(out, out + 2), 0, -1,
|
||||||
MatShape(mask, mask + 2));
|
MatShape(mask, mask + 2));
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
int inp[] = {1, 2, 3};
|
||||||
|
int out[] = {3, 1, 2};
|
||||||
|
int mask[] = {3, 1, 2};
|
||||||
|
testReshape(MatShape(inp, inp + 3), MatShape(out, out + 3), 0, -1,
|
||||||
|
MatShape(mask, mask + 3));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Layer_Test_BatchNorm, Accuracy)
|
TEST(Layer_Test_BatchNorm, Accuracy)
|
||||||
|
@ -198,6 +198,7 @@ TEST_P(Test_TensorFlow_layers, reshape)
|
|||||||
{
|
{
|
||||||
int targetId = GetParam();
|
int targetId = GetParam();
|
||||||
runTensorFlowNet("shift_reshape_no_reorder", targetId);
|
runTensorFlowNet("shift_reshape_no_reorder", targetId);
|
||||||
|
runTensorFlowNet("reshape_no_reorder", targetId);
|
||||||
runTensorFlowNet("reshape_reduce", targetId);
|
runTensorFlowNet("reshape_reduce", targetId);
|
||||||
runTensorFlowNet("flatten", targetId, true);
|
runTensorFlowNet("flatten", targetId, true);
|
||||||
runTensorFlowNet("unfused_flatten", targetId);
|
runTensorFlowNet("unfused_flatten", targetId);
|
||||||
|
Loading…
Reference in New Issue
Block a user