mirror of
https://github.com/opencv/opencv.git
synced 2024-11-25 19:50:38 +08:00
Merge pull request #16983 from dkurt:dnn_tf_prelu
This commit is contained in:
commit
5da4bb7e88
@ -223,6 +223,26 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class FlattenProdSubgraph : public Subgraph
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
FlattenProdSubgraph()
|
||||||
|
{
|
||||||
|
int input = addNodeToMatch("");
|
||||||
|
int shape = addNodeToMatch("Shape", input);
|
||||||
|
int stack = addNodeToMatch("Const");
|
||||||
|
int stack_1 = addNodeToMatch("Const");
|
||||||
|
int stack_2 = addNodeToMatch("Const");
|
||||||
|
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||||
|
int prod = addNodeToMatch("Prod", strided_slice, addNodeToMatch("Const"));
|
||||||
|
int shape_pack = addNodeToMatch("Const");
|
||||||
|
int pack = addNodeToMatch("Pack", shape_pack, prod);
|
||||||
|
addNodeToMatch("Reshape", input, pack);
|
||||||
|
|
||||||
|
setFusedNode("Flatten", input);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// K.layers.Softmax
|
// K.layers.Softmax
|
||||||
class SoftMaxKerasSubgraph : public Subgraph
|
class SoftMaxKerasSubgraph : public Subgraph
|
||||||
{
|
{
|
||||||
@ -629,6 +649,36 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class PReLUSubgraph : public TFSubgraph
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
PReLUSubgraph(bool negativeScales_) : negativeScales(negativeScales_)
|
||||||
|
{
|
||||||
|
int input = addNodeToMatch("");
|
||||||
|
int scales = addNodeToMatch("Const");
|
||||||
|
int neg = addNodeToMatch("Neg", input);
|
||||||
|
int relu_neg = addNodeToMatch("Relu", neg);
|
||||||
|
int finalScales = negativeScales ? addNodeToMatch("Neg", scales) : scales;
|
||||||
|
int mul = addNodeToMatch("Mul", finalScales, relu_neg);
|
||||||
|
int relu_pos = addNodeToMatch("Relu", input);
|
||||||
|
addNodeToMatch("Add", relu_pos, mul);
|
||||||
|
setFusedNode("PReLU", input, scales);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
|
||||||
|
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
|
||||||
|
{
|
||||||
|
if (!negativeScales)
|
||||||
|
{
|
||||||
|
Mat scales = getTensorContent(inputNodes[1]->attr().at("value").tensor(), /*copy*/false);
|
||||||
|
scales *= -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool negativeScales;
|
||||||
|
};
|
||||||
|
|
||||||
void simplifySubgraphs(tensorflow::GraphDef& net)
|
void simplifySubgraphs(tensorflow::GraphDef& net)
|
||||||
{
|
{
|
||||||
std::vector<Ptr<Subgraph> > subgraphs;
|
std::vector<Ptr<Subgraph> > subgraphs;
|
||||||
@ -649,6 +699,16 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
|
|||||||
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimV2Subgraph()));
|
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimV2Subgraph()));
|
||||||
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
|
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
|
||||||
subgraphs.push_back(Ptr<Subgraph>(new KerasMVNSubgraph()));
|
subgraphs.push_back(Ptr<Subgraph>(new KerasMVNSubgraph()));
|
||||||
|
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(true)));
|
||||||
|
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(false)));
|
||||||
|
subgraphs.push_back(Ptr<Subgraph>(new FlattenProdSubgraph()));
|
||||||
|
|
||||||
|
for (int i = 0; i < net.node_size(); ++i)
|
||||||
|
{
|
||||||
|
tensorflow::NodeDef* layer = net.mutable_node(i);
|
||||||
|
if (layer->op() == "AddV2")
|
||||||
|
layer->set_op("Add");
|
||||||
|
}
|
||||||
|
|
||||||
simplifySubgraphs(Ptr<ImportGraphWrapper>(new TFGraphWrapper(net)), subgraphs);
|
simplifySubgraphs(Ptr<ImportGraphWrapper>(new TFGraphWrapper(net)), subgraphs);
|
||||||
}
|
}
|
||||||
|
@ -1231,6 +1231,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
// Only NHWC <-> NCHW permutations are allowed. OpenCV is always
|
// Only NHWC <-> NCHW permutations are allowed. OpenCV is always
|
||||||
// keep NCHW layout this way.
|
// keep NCHW layout this way.
|
||||||
int inpLayout = getDataLayout(layer.input(0), data_layouts);
|
int inpLayout = getDataLayout(layer.input(0), data_layouts);
|
||||||
|
std::string type = "Identity";
|
||||||
if (inpLayout == DATA_LAYOUT_NHWC)
|
if (inpLayout == DATA_LAYOUT_NHWC)
|
||||||
{
|
{
|
||||||
if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2)
|
if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2)
|
||||||
@ -1245,6 +1246,15 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
// in OpenCV: NCHW->NCHW
|
// in OpenCV: NCHW->NCHW
|
||||||
data_layouts[name] = DATA_LAYOUT_NHWC;
|
data_layouts[name] = DATA_LAYOUT_NHWC;
|
||||||
}
|
}
|
||||||
|
else if (permData[0] == 0 && permData[1] == 3 && permData[2] == 2 && permData[3] == 1)
|
||||||
|
{
|
||||||
|
// in TensorFlow: NHWC->NCWH
|
||||||
|
// in OpenCV: NCHW->NCWH
|
||||||
|
int permData[] = {0, 1, 3, 2};
|
||||||
|
layerParams.set("order", DictValue::arrayInt<int*>(permData, perm.total()));
|
||||||
|
data_layouts[name] = DATA_LAYOUT_NCHW; // we keep track NCHW because channels position only matters
|
||||||
|
type = "Permute";
|
||||||
|
}
|
||||||
else
|
else
|
||||||
CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
|
CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
|
||||||
}
|
}
|
||||||
@ -1265,7 +1275,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
else
|
else
|
||||||
CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
|
CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
|
||||||
}
|
}
|
||||||
int id = dstNet.addLayer(name, "Identity", layerParams);
|
int id = dstNet.addLayer(name, type, layerParams);
|
||||||
layer_id[name] = id;
|
layer_id[name] = id;
|
||||||
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
||||||
}
|
}
|
||||||
|
@ -956,11 +956,25 @@ TEST_P(Test_TensorFlow_layers, resize_bilinear)
|
|||||||
runTensorFlowNet("resize_bilinear_factor");
|
runTensorFlowNet("resize_bilinear_factor");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(Test_TensorFlow_layers, tf2_keras)
|
TEST_P(Test_TensorFlow_layers, tf2_dense)
|
||||||
{
|
{
|
||||||
runTensorFlowNet("tf2_dense");
|
runTensorFlowNet("tf2_dense");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_TensorFlow_layers, tf2_prelu)
|
||||||
|
{
|
||||||
|
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
||||||
|
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
|
||||||
|
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
|
||||||
|
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
|
||||||
|
runTensorFlowNet("tf2_prelu");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_TensorFlow_layers, tf2_permute_nhwc_ncwh)
|
||||||
|
{
|
||||||
|
runTensorFlowNet("tf2_permute_nhwc_ncwh");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(Test_TensorFlow_layers, squeeze)
|
TEST_P(Test_TensorFlow_layers, squeeze)
|
||||||
{
|
{
|
||||||
#if defined(INF_ENGINE_RELEASE)
|
#if defined(INF_ENGINE_RELEASE)
|
||||||
|
Loading…
Reference in New Issue
Block a user