mirror of
https://github.com/opencv/opencv.git
synced 2024-11-28 05:06:29 +08:00
Fuse deconvolution layer subgraphs from Keras
This commit is contained in:
parent
7ea5029ae5
commit
d959d7b9f0
@ -419,6 +419,125 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class DeconvolutionValidKerasSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
DeconvolutionValidKerasSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int shape = addNodeToMatch("Shape", input);
|
||||
int kernel = addNodeToMatch("Const");
|
||||
|
||||
int stack = addNodeToMatch("Const");
|
||||
int stack_1 = addNodeToMatch("Const");
|
||||
int stack_2 = addNodeToMatch("Const");
|
||||
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||
|
||||
stack = addNodeToMatch("Const");
|
||||
stack_1 = addNodeToMatch("Const");
|
||||
stack_2 = addNodeToMatch("Const");
|
||||
int strided_slice_1 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||
|
||||
stack = addNodeToMatch("Const");
|
||||
stack_1 = addNodeToMatch("Const");
|
||||
stack_2 = addNodeToMatch("Const");
|
||||
int strided_slice_2 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||
|
||||
int mul = addNodeToMatch("Mul", strided_slice_1, addNodeToMatch("Const"));
|
||||
int add = addNodeToMatch("Add", mul, addNodeToMatch("Const"));
|
||||
|
||||
int mul_1 = addNodeToMatch("Mul", strided_slice_2, addNodeToMatch("Const"));
|
||||
int add_1 = addNodeToMatch("Add", mul_1, addNodeToMatch("Const"));
|
||||
int pack = addNodeToMatch("Pack", strided_slice, add, add_1, addNodeToMatch("Const"));
|
||||
addNodeToMatch("Conv2DBackpropInput", pack, kernel, input);
|
||||
// Put any unused Const op to the first input.
|
||||
setFusedNode("Conv2DBackpropInput", stack, kernel, input);
|
||||
}
|
||||
|
||||
virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
|
||||
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
|
||||
{
|
||||
// Disable adjusted paddings (see Conv2DBackpropInput layer at tf_importer.cpp)
|
||||
// adj_w = (outW - (pad == "SAME") ? 1 : kernelW) % strideX;
|
||||
// adj_h = (outH - (pad == "SAME") ? 1 : kernelH) % strideY;
|
||||
// Where outH and outW are 1st and 2nd dimensions (NHWC) or 2nd and third (NCHW).
|
||||
std::string padMode = fusedNode->attr().at("padding").s();
|
||||
CV_Assert(padMode == "VALID");
|
||||
|
||||
const tensorflow::TensorShapeProto& kernelShape =
|
||||
inputNodes[1]->mutable_attr()->at("value").tensor().tensor_shape();
|
||||
|
||||
CV_Assert(kernelShape.dim_size() == 4);
|
||||
const int kernelHeight = kernelShape.dim(0).size();
|
||||
const int kernelWidth = kernelShape.dim(1).size();
|
||||
|
||||
tensorflow::TensorProto* outShape = inputNodes[0]->mutable_attr()->at("value").mutable_tensor();
|
||||
outShape->clear_int_val();
|
||||
outShape->add_int_val(-1);
|
||||
outShape->add_int_val(kernelHeight);
|
||||
outShape->add_int_val(kernelWidth);
|
||||
outShape->add_int_val(-1);
|
||||
}
|
||||
};
|
||||
|
||||
class DeconvolutionSameKerasSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
DeconvolutionSameKerasSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int shape = addNodeToMatch("Shape", input);
|
||||
int kernel = addNodeToMatch("Const");
|
||||
|
||||
int stack = addNodeToMatch("Const");
|
||||
int stack_1 = addNodeToMatch("Const");
|
||||
int stack_2 = addNodeToMatch("Const");
|
||||
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||
|
||||
stack = addNodeToMatch("Const");
|
||||
stack_1 = addNodeToMatch("Const");
|
||||
stack_2 = addNodeToMatch("Const");
|
||||
int strided_slice_1 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||
|
||||
stack = addNodeToMatch("Const");
|
||||
stack_1 = addNodeToMatch("Const");
|
||||
stack_2 = addNodeToMatch("Const");
|
||||
int strided_slice_2 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||
|
||||
int mul = addNodeToMatch("Mul", strided_slice_1, addNodeToMatch("Const"));
|
||||
|
||||
int mul_1 = addNodeToMatch("Mul", strided_slice_2, addNodeToMatch("Const"));
|
||||
int pack = addNodeToMatch("Pack", strided_slice, mul, mul_1, addNodeToMatch("Const"));
|
||||
addNodeToMatch("Conv2DBackpropInput", pack, kernel, input);
|
||||
// Put any unused Const op to the first input.
|
||||
setFusedNode("Conv2DBackpropInput", stack, kernel, input);
|
||||
}
|
||||
|
||||
virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
|
||||
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
|
||||
{
|
||||
// Disable adjusted paddings (see Conv2DBackpropInput layer at tf_importer.cpp)
|
||||
// adj_w = (outW - (pad == "SAME") ? 1 : kernelW) % strideX;
|
||||
// adj_h = (outH - (pad == "SAME") ? 1 : kernelH) % strideY;
|
||||
// Where outH and outW are 1st and 2nd dimensions (NHWC) or 2nd and third (NCHW).
|
||||
std::string padMode = fusedNode->attr().at("padding").s();
|
||||
CV_Assert(padMode == "SAME");
|
||||
|
||||
const tensorflow::AttrValue_ListValue& strides = fusedNode->attr().at("strides").list();
|
||||
CV_Assert(strides.i_size() == 4);
|
||||
|
||||
const int strideY = strides.i(1);
|
||||
const int strideX = strides.i(2);
|
||||
|
||||
tensorflow::TensorProto* outShape = inputNodes[0]->mutable_attr()->at("value").mutable_tensor();
|
||||
outShape->clear_int_val();
|
||||
outShape->add_int_val(-1);
|
||||
outShape->add_int_val(strideY);
|
||||
outShape->add_int_val(strideX);
|
||||
outShape->add_int_val(-1);
|
||||
}
|
||||
};
|
||||
|
||||
void simplifySubgraphs(tensorflow::GraphDef& net)
|
||||
{
|
||||
std::vector<Ptr<Subgraph> > subgraphs;
|
||||
@ -430,6 +549,8 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
|
||||
subgraphs.push_back(Ptr<Subgraph>(new ReLU6KerasSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new ReshapeKerasSubgraph(3)));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new L2NormalizeSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph()));
|
||||
|
||||
int numNodes = net.node_size();
|
||||
std::vector<int> matchedNodesIds;
|
||||
|
@ -1303,8 +1303,8 @@ void TFImporter::populateNet(Net dstNet)
|
||||
const int strideY = layerParams.get<int>("stride_h");
|
||||
const int strideX = layerParams.get<int>("stride_w");
|
||||
Mat outShape = getTensorContent(getConstBlob(layer, value_id, 0));
|
||||
const int outH = outShape.at<int>(2);
|
||||
const int outW = outShape.at<int>(1);
|
||||
const int outH = outShape.at<int>(1);
|
||||
const int outW = outShape.at<int>(2);
|
||||
if (layerParams.get<String>("pad_mode") == "SAME")
|
||||
{
|
||||
layerParams.set("adj_w", (outW - 1) % strideX);
|
||||
|
@ -173,6 +173,8 @@ TEST_P(Test_TensorFlow_layers, deconvolution)
|
||||
runTensorFlowNet("deconvolution_stride_2_same", targetId);
|
||||
runTensorFlowNet("deconvolution_adj_pad_valid", targetId);
|
||||
runTensorFlowNet("deconvolution_adj_pad_same", targetId);
|
||||
runTensorFlowNet("keras_deconv_valid", targetId);
|
||||
runTensorFlowNet("keras_deconv_same", targetId);
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, matmul)
|
||||
|
Loading…
Reference in New Issue
Block a user