Merge pull request #11840 from dkurt:dnn_tf_nchw

This commit is contained in:
Vadim Pisarevsky 2018-06-26 19:27:25 +00:00
commit ba1a6ad4cc
2 changed files with 42 additions and 21 deletions

View File

@ -246,16 +246,41 @@ const tensorflow::AttrValue& getLayerAttr(const tensorflow::NodeDef &layer, cons
return layer.attr().at(name);
}
static int getDataLayout(const tensorflow::NodeDef& layer)
{
if (hasLayerAttr(layer, "data_format"))
{
std::string format = getLayerAttr(layer, "data_format").s();
if (format == "NHWC" || format == "channels_last")
return DATA_LAYOUT_NHWC;
else if (format == "NCHW" || format == "channels_first")
return DATA_LAYOUT_NCHW;
else
CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
}
return DATA_LAYOUT_UNKNOWN;
}
void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer)
{
if (hasLayerAttr(layer, "strides"))
{
const tensorflow::AttrValue& val = getLayerAttr(layer, "strides");
int dimX, dimY, dimC;
int layout = getDataLayout(layer);
if (layout == DATA_LAYOUT_NCHW)
{
dimC = 1; dimY = 2; dimX = 3;
}
else
{
dimY = 1; dimX = 2; dimC = 3;
}
if (val.list().i_size() != 4 ||
val.list().i(0) != 1 || val.list().i(3) != 1)
val.list().i(0) != 1 || val.list().i(dimC) != 1)
CV_Error(Error::StsError, "Unsupported strides");
layerParams.set("stride_h", static_cast<int>(val.list().i(1)));
layerParams.set("stride_w", static_cast<int>(val.list().i(2)));
layerParams.set("stride_h", static_cast<int>(val.list().i(dimY)));
layerParams.set("stride_w", static_cast<int>(val.list().i(dimX)));
}
}
@ -278,11 +303,21 @@ void setKSize(LayerParams &layerParams, const tensorflow::NodeDef &layer)
if (hasLayerAttr(layer, "ksize"))
{
const tensorflow::AttrValue& val = getLayerAttr(layer, "ksize");
int dimX, dimY, dimC;
int layout = getDataLayout(layer);
if (layout == DATA_LAYOUT_NCHW)
{
dimC = 1; dimY = 2; dimX = 3;
}
else
{
dimY = 1; dimX = 2; dimC = 3;
}
if (val.list().i_size() != 4 ||
val.list().i(0) != 1 || val.list().i(3) != 1)
val.list().i(0) != 1 || val.list().i(dimC) != 1)
CV_Error(Error::StsError, "Unsupported ksize");
layerParams.set("kernel_h", static_cast<int>(val.list().i(1)));
layerParams.set("kernel_w", static_cast<int>(val.list().i(2)));
layerParams.set("kernel_h", static_cast<int>(val.list().i(dimY)));
layerParams.set("kernel_w", static_cast<int>(val.list().i(dimX)));
}
else
{
@ -568,21 +603,6 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons
}
}
static int getDataLayout(const tensorflow::NodeDef& layer)
{
if (hasLayerAttr(layer, "data_format"))
{
std::string format = getLayerAttr(layer, "data_format").s();
if (format == "NHWC" || format == "channels_last")
return DATA_LAYOUT_NHWC;
else if (format == "NCHW" || format == "channels_first")
return DATA_LAYOUT_NCHW;
else
CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
}
return DATA_LAYOUT_UNKNOWN;
}
static inline std::string getNodeName(const std::string& tensorName)
{
return tensorName.substr(0, tensorName.rfind(':'));

View File

@ -127,6 +127,7 @@ TEST_P(Test_TensorFlow_layers, conv)
runTensorFlowNet("atrous_conv2d_same", targetId);
runTensorFlowNet("depthwise_conv2d", targetId);
runTensorFlowNet("keras_atrous_conv2d_same", targetId);
runTensorFlowNet("conv_pool_nchw", targetId);
}
TEST_P(Test_TensorFlow_layers, padding)