mirror of
https://github.com/opencv/opencv.git
synced 2024-11-28 21:20:18 +08:00
Merge pull request #11840 from dkurt:dnn_tf_nchw
This commit is contained in:
commit
ba1a6ad4cc
@ -246,16 +246,41 @@ const tensorflow::AttrValue& getLayerAttr(const tensorflow::NodeDef &layer, cons
|
||||
return layer.attr().at(name);
|
||||
}
|
||||
|
||||
static int getDataLayout(const tensorflow::NodeDef& layer)
|
||||
{
|
||||
if (hasLayerAttr(layer, "data_format"))
|
||||
{
|
||||
std::string format = getLayerAttr(layer, "data_format").s();
|
||||
if (format == "NHWC" || format == "channels_last")
|
||||
return DATA_LAYOUT_NHWC;
|
||||
else if (format == "NCHW" || format == "channels_first")
|
||||
return DATA_LAYOUT_NCHW;
|
||||
else
|
||||
CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
|
||||
}
|
||||
return DATA_LAYOUT_UNKNOWN;
|
||||
}
|
||||
|
||||
void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer)
|
||||
{
|
||||
if (hasLayerAttr(layer, "strides"))
|
||||
{
|
||||
const tensorflow::AttrValue& val = getLayerAttr(layer, "strides");
|
||||
int dimX, dimY, dimC;
|
||||
int layout = getDataLayout(layer);
|
||||
if (layout == DATA_LAYOUT_NCHW)
|
||||
{
|
||||
dimC = 1; dimY = 2; dimX = 3;
|
||||
}
|
||||
else
|
||||
{
|
||||
dimY = 1; dimX = 2; dimC = 3;
|
||||
}
|
||||
if (val.list().i_size() != 4 ||
|
||||
val.list().i(0) != 1 || val.list().i(3) != 1)
|
||||
val.list().i(0) != 1 || val.list().i(dimC) != 1)
|
||||
CV_Error(Error::StsError, "Unsupported strides");
|
||||
layerParams.set("stride_h", static_cast<int>(val.list().i(1)));
|
||||
layerParams.set("stride_w", static_cast<int>(val.list().i(2)));
|
||||
layerParams.set("stride_h", static_cast<int>(val.list().i(dimY)));
|
||||
layerParams.set("stride_w", static_cast<int>(val.list().i(dimX)));
|
||||
}
|
||||
}
|
||||
|
||||
@ -278,11 +303,21 @@ void setKSize(LayerParams &layerParams, const tensorflow::NodeDef &layer)
|
||||
if (hasLayerAttr(layer, "ksize"))
|
||||
{
|
||||
const tensorflow::AttrValue& val = getLayerAttr(layer, "ksize");
|
||||
int dimX, dimY, dimC;
|
||||
int layout = getDataLayout(layer);
|
||||
if (layout == DATA_LAYOUT_NCHW)
|
||||
{
|
||||
dimC = 1; dimY = 2; dimX = 3;
|
||||
}
|
||||
else
|
||||
{
|
||||
dimY = 1; dimX = 2; dimC = 3;
|
||||
}
|
||||
if (val.list().i_size() != 4 ||
|
||||
val.list().i(0) != 1 || val.list().i(3) != 1)
|
||||
val.list().i(0) != 1 || val.list().i(dimC) != 1)
|
||||
CV_Error(Error::StsError, "Unsupported ksize");
|
||||
layerParams.set("kernel_h", static_cast<int>(val.list().i(1)));
|
||||
layerParams.set("kernel_w", static_cast<int>(val.list().i(2)));
|
||||
layerParams.set("kernel_h", static_cast<int>(val.list().i(dimY)));
|
||||
layerParams.set("kernel_w", static_cast<int>(val.list().i(dimX)));
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -568,21 +603,6 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons
|
||||
}
|
||||
}
|
||||
|
||||
static int getDataLayout(const tensorflow::NodeDef& layer)
|
||||
{
|
||||
if (hasLayerAttr(layer, "data_format"))
|
||||
{
|
||||
std::string format = getLayerAttr(layer, "data_format").s();
|
||||
if (format == "NHWC" || format == "channels_last")
|
||||
return DATA_LAYOUT_NHWC;
|
||||
else if (format == "NCHW" || format == "channels_first")
|
||||
return DATA_LAYOUT_NCHW;
|
||||
else
|
||||
CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
|
||||
}
|
||||
return DATA_LAYOUT_UNKNOWN;
|
||||
}
|
||||
|
||||
static inline std::string getNodeName(const std::string& tensorName)
|
||||
{
|
||||
return tensorName.substr(0, tensorName.rfind(':'));
|
||||
|
@ -127,6 +127,7 @@ TEST_P(Test_TensorFlow_layers, conv)
|
||||
runTensorFlowNet("atrous_conv2d_same", targetId);
|
||||
runTensorFlowNet("depthwise_conv2d", targetId);
|
||||
runTensorFlowNet("keras_atrous_conv2d_same", targetId);
|
||||
runTensorFlowNet("conv_pool_nchw", targetId);
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, padding)
|
||||
|
Loading…
Reference in New Issue
Block a user