mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 20:42:53 +08:00
Merge pull request #11840 from dkurt:dnn_tf_nchw
This commit is contained in:
commit
ba1a6ad4cc
@ -246,16 +246,41 @@ const tensorflow::AttrValue& getLayerAttr(const tensorflow::NodeDef &layer, cons
|
|||||||
return layer.attr().at(name);
|
return layer.attr().at(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int getDataLayout(const tensorflow::NodeDef& layer)
|
||||||
|
{
|
||||||
|
if (hasLayerAttr(layer, "data_format"))
|
||||||
|
{
|
||||||
|
std::string format = getLayerAttr(layer, "data_format").s();
|
||||||
|
if (format == "NHWC" || format == "channels_last")
|
||||||
|
return DATA_LAYOUT_NHWC;
|
||||||
|
else if (format == "NCHW" || format == "channels_first")
|
||||||
|
return DATA_LAYOUT_NCHW;
|
||||||
|
else
|
||||||
|
CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
|
||||||
|
}
|
||||||
|
return DATA_LAYOUT_UNKNOWN;
|
||||||
|
}
|
||||||
|
|
||||||
void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer)
|
void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer)
|
||||||
{
|
{
|
||||||
if (hasLayerAttr(layer, "strides"))
|
if (hasLayerAttr(layer, "strides"))
|
||||||
{
|
{
|
||||||
const tensorflow::AttrValue& val = getLayerAttr(layer, "strides");
|
const tensorflow::AttrValue& val = getLayerAttr(layer, "strides");
|
||||||
|
int dimX, dimY, dimC;
|
||||||
|
int layout = getDataLayout(layer);
|
||||||
|
if (layout == DATA_LAYOUT_NCHW)
|
||||||
|
{
|
||||||
|
dimC = 1; dimY = 2; dimX = 3;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
dimY = 1; dimX = 2; dimC = 3;
|
||||||
|
}
|
||||||
if (val.list().i_size() != 4 ||
|
if (val.list().i_size() != 4 ||
|
||||||
val.list().i(0) != 1 || val.list().i(3) != 1)
|
val.list().i(0) != 1 || val.list().i(dimC) != 1)
|
||||||
CV_Error(Error::StsError, "Unsupported strides");
|
CV_Error(Error::StsError, "Unsupported strides");
|
||||||
layerParams.set("stride_h", static_cast<int>(val.list().i(1)));
|
layerParams.set("stride_h", static_cast<int>(val.list().i(dimY)));
|
||||||
layerParams.set("stride_w", static_cast<int>(val.list().i(2)));
|
layerParams.set("stride_w", static_cast<int>(val.list().i(dimX)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -278,11 +303,21 @@ void setKSize(LayerParams &layerParams, const tensorflow::NodeDef &layer)
|
|||||||
if (hasLayerAttr(layer, "ksize"))
|
if (hasLayerAttr(layer, "ksize"))
|
||||||
{
|
{
|
||||||
const tensorflow::AttrValue& val = getLayerAttr(layer, "ksize");
|
const tensorflow::AttrValue& val = getLayerAttr(layer, "ksize");
|
||||||
|
int dimX, dimY, dimC;
|
||||||
|
int layout = getDataLayout(layer);
|
||||||
|
if (layout == DATA_LAYOUT_NCHW)
|
||||||
|
{
|
||||||
|
dimC = 1; dimY = 2; dimX = 3;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
dimY = 1; dimX = 2; dimC = 3;
|
||||||
|
}
|
||||||
if (val.list().i_size() != 4 ||
|
if (val.list().i_size() != 4 ||
|
||||||
val.list().i(0) != 1 || val.list().i(3) != 1)
|
val.list().i(0) != 1 || val.list().i(dimC) != 1)
|
||||||
CV_Error(Error::StsError, "Unsupported ksize");
|
CV_Error(Error::StsError, "Unsupported ksize");
|
||||||
layerParams.set("kernel_h", static_cast<int>(val.list().i(1)));
|
layerParams.set("kernel_h", static_cast<int>(val.list().i(dimY)));
|
||||||
layerParams.set("kernel_w", static_cast<int>(val.list().i(2)));
|
layerParams.set("kernel_w", static_cast<int>(val.list().i(dimX)));
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -568,21 +603,6 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static int getDataLayout(const tensorflow::NodeDef& layer)
|
|
||||||
{
|
|
||||||
if (hasLayerAttr(layer, "data_format"))
|
|
||||||
{
|
|
||||||
std::string format = getLayerAttr(layer, "data_format").s();
|
|
||||||
if (format == "NHWC" || format == "channels_last")
|
|
||||||
return DATA_LAYOUT_NHWC;
|
|
||||||
else if (format == "NCHW" || format == "channels_first")
|
|
||||||
return DATA_LAYOUT_NCHW;
|
|
||||||
else
|
|
||||||
CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
|
|
||||||
}
|
|
||||||
return DATA_LAYOUT_UNKNOWN;
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline std::string getNodeName(const std::string& tensorName)
|
static inline std::string getNodeName(const std::string& tensorName)
|
||||||
{
|
{
|
||||||
return tensorName.substr(0, tensorName.rfind(':'));
|
return tensorName.substr(0, tensorName.rfind(':'));
|
||||||
|
@ -127,6 +127,7 @@ TEST_P(Test_TensorFlow_layers, conv)
|
|||||||
runTensorFlowNet("atrous_conv2d_same", targetId);
|
runTensorFlowNet("atrous_conv2d_same", targetId);
|
||||||
runTensorFlowNet("depthwise_conv2d", targetId);
|
runTensorFlowNet("depthwise_conv2d", targetId);
|
||||||
runTensorFlowNet("keras_atrous_conv2d_same", targetId);
|
runTensorFlowNet("keras_atrous_conv2d_same", targetId);
|
||||||
|
runTensorFlowNet("conv_pool_nchw", targetId);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(Test_TensorFlow_layers, padding)
|
TEST_P(Test_TensorFlow_layers, padding)
|
||||||
|
Loading…
Reference in New Issue
Block a user