mirror of
https://github.com/opencv/opencv.git
synced 2025-06-27 23:11:57 +08:00
Support for MobileNetV3-SSD from TensorFlow
This commit is contained in:
parent
969cc3dd95
commit
b927ce18b2
@ -458,6 +458,7 @@ private:
|
||||
tensorflow::GraphDef netTxt;
|
||||
|
||||
std::vector<String> netInputsNames;
|
||||
std::vector<MatShape> netInputShapes;
|
||||
};
|
||||
|
||||
TFImporter::TFImporter(const char *model, const char *config)
|
||||
@ -1401,6 +1402,27 @@ void TFImporter::populateNet(Net dstNet)
|
||||
netInputsNames.push_back(name);
|
||||
layer_id[name] = 0;
|
||||
}
|
||||
if (hasLayerAttr(layer, "shape"))
|
||||
{
|
||||
const tensorflow::TensorShapeProto& shape = getLayerAttr(layer, "shape").shape();
|
||||
MatShape dims(shape.dim_size());
|
||||
for (int i = 0; i < dims.size(); ++i)
|
||||
dims[i] = shape.dim(i).size();
|
||||
if (dims.size() == 4 && predictedLayout == DATA_LAYOUT_NHWC)
|
||||
{
|
||||
std::swap(dims[1], dims[3]); // NHWC->NCWH
|
||||
std::swap(dims[2], dims[3]); // NCWH->NCHW
|
||||
if (dims[0] == -1) // It's OK to have undetermined batch size
|
||||
dims[0] = 1;
|
||||
}
|
||||
bool hasNeg = false;
|
||||
for (int i = 0; i < dims.size() && !hasNeg; ++i)
|
||||
{
|
||||
hasNeg = dims[i] < 0;
|
||||
}
|
||||
if (!hasNeg)
|
||||
netInputShapes.push_back(dims);
|
||||
}
|
||||
}
|
||||
else if (type == "Split") {
|
||||
// TODO: determining axis index remapping by input dimensions order of input blob
|
||||
@ -1579,9 +1601,42 @@ void TFImporter::populateNet(Net dstNet)
|
||||
connect(layer_id, dstNet, parsePin(layer.input(1)), id, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Check if all the inputs have the same shape.
|
||||
bool equalInpShapes = true;
|
||||
MatShape outShape0;
|
||||
for (int ii = 0; ii < layer.input_size() && !netInputShapes.empty(); ii++)
|
||||
{
|
||||
Pin pin = parsePin(layer.input(ii));
|
||||
int inpId = layer_id.find(pin.name)->second;
|
||||
|
||||
// Get input shape
|
||||
MatShape outShape;
|
||||
std::vector<MatShape> inpShapes, outShapes;
|
||||
dstNet.getLayerShapes(netInputShapes, inpId, inpShapes, outShapes);
|
||||
CV_CheckGT(static_cast<int>(outShapes.size()), pin.blobIndex, "");
|
||||
outShape = outShapes[pin.blobIndex];
|
||||
|
||||
if (ii == 0)
|
||||
{
|
||||
outShape0 = outShape;
|
||||
}
|
||||
else if (outShape != outShape0)
|
||||
{
|
||||
equalInpShapes = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
int id;
|
||||
if (equalInpShapes || netInputShapes.empty())
|
||||
{
|
||||
layerParams.set("operation", "prod");
|
||||
int id = dstNet.addLayer(name, "Eltwise", layerParams);
|
||||
id = dstNet.addLayer(name, "Eltwise", layerParams);
|
||||
}
|
||||
else
|
||||
id = dstNet.addLayer(name, "Scale", layerParams);
|
||||
|
||||
layer_id[name] = id;
|
||||
|
||||
for (int ii = 0; ii < layer.input_size(); ii++)
|
||||
|
@ -181,6 +181,13 @@ TEST_P(Test_TensorFlow_layers, eltwise)
|
||||
runTensorFlowNet("eltwise_sub");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, channel_broadcast)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
|
||||
runTensorFlowNet("channel_broadcast");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, pad_and_concat)
|
||||
{
|
||||
runTensorFlowNet("pad_and_concat");
|
||||
|
@ -64,7 +64,7 @@ def createSSDGraph(modelPath, configPath, outputPath):
|
||||
# Nodes that should be kept.
|
||||
keepOps = ['Conv2D', 'BiasAdd', 'Add', 'AddV2', 'Relu', 'Relu6', 'Placeholder', 'FusedBatchNorm',
|
||||
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity',
|
||||
'Sub', 'ResizeNearestNeighbor', 'Pad', 'FusedBatchNormV3']
|
||||
'Sub', 'ResizeNearestNeighbor', 'Pad', 'FusedBatchNormV3', 'Mean']
|
||||
|
||||
# Node with which prefixes should be removed
|
||||
prefixesToRemove = ('MultipleGridAnchorGenerator/', 'Concatenate/', 'Postprocessor/', 'Preprocessor/map')
|
||||
@ -235,7 +235,7 @@ def createSSDGraph(modelPath, configPath, outputPath):
|
||||
# Connect input node to the first layer
|
||||
assert(graph_def.node[0].op == 'Placeholder')
|
||||
# assert(graph_def.node[1].op == 'Conv2D')
|
||||
weights = graph_def.node[1].input[0]
|
||||
weights = graph_def.node[1].input[-1]
|
||||
for i in range(len(graph_def.node[1].input)):
|
||||
graph_def.node[1].input.pop()
|
||||
graph_def.node[1].input.append(graph_def.node[0].name)
|
||||
|
Loading…
Reference in New Issue
Block a user