Merge pull request #20453 from rogday:onnx_importer_fix

Split layer dispatch into functions in ONNXImporter

* split layer dispatch into functions

* fixes

* identation and comment fixes

* fix constness
This commit is contained in:
rogday 2021-07-28 18:06:24 +03:00 committed by GitHub
parent d83901e665
commit cff0168f3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 1724 additions and 1517 deletions

View File

@ -14,6 +14,7 @@ Mutex& getInitializationMutex();
void initializeLayerFactory();
namespace detail {
#define CALL_MEMBER_FN(object, ptrToMemFn) ((object).*(ptrToMemFn))
struct NetImplBase
{

View File

@ -62,7 +62,7 @@ class ONNXImporter
public:
ONNXImporter(Net& net, const char *onnxFile)
: dstNet(net)
: dstNet(net), dispatch(buildDispatchMap())
{
hasDynamicShapes = false;
CV_Assert(onnxFile);
@ -83,7 +83,7 @@ public:
}
ONNXImporter(Net& net, const char* buffer, size_t sizeBuffer)
: dstNet(net)
: dstNet(net), dispatch(buildDispatchMap())
{
hasDynamicShapes = false;
CV_LOG_DEBUG(NULL, "DNN/ONNX: processing in-memory ONNX model (" << sizeBuffer << " bytes)");
@ -124,6 +124,57 @@ protected:
typedef std::map<std::string, LayerInfo>::iterator IterLayerId_t;
void handleNode(const opencv_onnx::NodeProto& node_proto);
private:
typedef void (ONNXImporter::*ONNXImporterNodeParser)(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
typedef std::map<std::string, ONNXImporterNodeParser> DispatchMap;
void parseMaxPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseAveragePool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseReduce (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseSlice (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseSplit (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseBias (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parsePow (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseMax (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseNeg (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseConstant (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseLSTM (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseImageScaler (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseClip (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseLeakyRelu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseRelu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseElu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseTanh (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parsePRelu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseLRN (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseInstanceNormalization(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseBatchNormalization (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseGemm (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseMatMul (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseMul (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseConv (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseConvTranspose (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseTranspose (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseSqueeze (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseFlatten (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseUnsqueeze (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseExpand (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseReshape (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parsePad (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseShape (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseCast (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseConstantFill (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseGather (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseConcat (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseResize (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseUpsample (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseSoftMax (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseDetectionOutput (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseCustom (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
const DispatchMap dispatch;
static const DispatchMap buildDispatchMap();
};
inline void replaceLayerParam(LayerParams& layerParams, const String& oldKey, const String& newKey)
@ -448,13 +499,11 @@ void ONNXImporter::populateNet()
CV_LOG_DEBUG(NULL, "DNN/ONNX: import completed!");
}
void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
{
opencv_onnx::NodeProto node_proto = node_proto_; // TODO FIXIT
CV_Assert(node_proto.output_size() >= 1);
std::string name = node_proto.output(0);
std::string layer_type = node_proto.op_type();
const std::string& layer_type = node_proto.op_type();
CV_LOG_DEBUG(NULL, "DNN/ONNX: processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
);
@ -468,22 +517,56 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.type = layer_type;
layerParams.set("has_dynamic_shapes", hasDynamicShapes);
if (layer_type == "MaxPool")
DispatchMap::const_iterator iter = dispatch.find(layer_type);
if (iter != dispatch.end())
{
CALL_MEMBER_FN(*this, iter->second)(layerParams, node_proto);
}
else
{
parseCustom(layerParams, node_proto);
}
}
catch (const cv::Exception& e)
{
CV_LOG_ERROR(NULL, "DNN/ONNX: ERROR during processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
);
for (int i = 0; i < node_proto.input_size(); i++)
{
CV_LOG_INFO(NULL, " Input[" << i << "] = '" << node_proto.input(i) << "'");
}
for (int i = 0; i < node_proto.output_size(); i++)
{
CV_LOG_INFO(NULL, " Output[" << i << "] = '" << node_proto.output(i) << "'");
}
CV_Error(Error::StsError, cv::format("Node [%s]:(%s) parse error: %s", layer_type.c_str(), name.c_str(), e.what()));
}
}
void ONNXImporter::parseMaxPool(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "Pooling";
layerParams.set("pool", "MAX");
layerParams.set("ceil_mode", layerParams.has("pad_mode"));
}
else if (layer_type == "AveragePool")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseAveragePool(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "Pooling";
layerParams.set("pool", "AVE");
layerParams.set("ceil_mode", layerParams.has("pad_mode"));
layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
}
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" ||
layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax")
{
addLayer(layerParams, node_proto);
}
// "GlobalAveragePool" "GlobalMaxPool" "ReduceMean" "ReduceSum" "ReduceMax"
void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
const std::string& layer_type = node_proto.op_type();
CV_Assert(node_proto.input_size() == 1);
layerParams.type = "Pooling";
String pool;
@ -635,9 +718,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
node_proto.set_input(0, node_proto.output(0));
node_proto.set_output(0, layerParams.name);
}
}
else if (layer_type == "Slice")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
int axis = 0;
std::vector<int> begin;
std::vector<int> end;
@ -744,9 +829,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
addConstant(layerParams.name, sliced[0]);
return;
}
}
else if (layer_type == "Split")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
if (layerParams.has("split"))
{
DictValue splits = layerParams.get("split");
@ -765,9 +852,14 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.set("num_split", node_proto.output_size());
}
layerParams.type = "Slice";
}
else if (layer_type == "Add" || layer_type == "Sum" || layer_type == "Sub")
{
addLayer(layerParams, node_proto);
}
// "Add" "Sum" "Sub"
void ONNXImporter::parseBias(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
const std::string& layer_type = node_proto.op_type();
bool isSub = layer_type == "Sub";
CV_CheckEQ(node_proto.input_size(), 2, "");
bool is_const_0 = layer_id.find(node_proto.input(0)) == layer_id.end();
@ -859,9 +951,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.type = "Scale";
layerParams.set("bias_term", true);
}
}
else if (layer_type == "Pow")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parsePow(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
if (layer_id.find(node_proto.input(1)) != layer_id.end())
CV_Error(Error::StsNotImplemented, "Unsupported Pow op with variable power");
@ -872,26 +966,33 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
blob.convertTo(blob, CV_32F);
layerParams.type = "Power";
layerParams.set("power", blob.ptr<float>()[0]);
}
else if (layer_type == "Max")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseMax(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "Eltwise";
layerParams.set("operation", "max");
}
else if (layer_type == "Neg")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseNeg(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "Power";
layerParams.set("scale", -1);
}
else if (layer_type == "Constant")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseConstant(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() == 0);
CV_Assert(layerParams.blobs.size() == 1);
addConstant(layerParams.name, layerParams.blobs[0]);
return;
}
else if (layer_type == "LSTM")
{
}
void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
LayerParams lstmParams = layerParams;
lstmParams.name += "/lstm";
@ -958,9 +1059,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.set("dim", DictValue::arrayInt(&lstmShape[0], lstmShape.size()));
node_proto.set_input(0, lstmParams.name); // redirect input to LSTM
node_proto.set_output(0, layerParams.name); // keep origin LSTM's name
}
else if (layer_type == "ImageScaler")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseImageScaler(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f;
layerParams.erase("scale");
@ -982,42 +1085,58 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.set("scale", scale);
layerParams.type = "Power";
}
}
else if (layer_type == "Clip")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseClip(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "ReLU6";
replaceLayerParam(layerParams, "min", "min_value");
replaceLayerParam(layerParams, "max", "max_value");
addLayer(layerParams, node_proto);
}
}
else if (layer_type == "LeakyRelu")
{
void ONNXImporter::parseLeakyRelu(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "ReLU";
replaceLayerParam(layerParams, "alpha", "negative_slope");
}
else if (layer_type == "Relu")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseRelu(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "ReLU";
}
else if (layer_type == "Elu")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseElu(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "ELU";
}
else if (layer_type == "Tanh")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseTanh(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "TanH";
}
else if (layer_type == "PRelu")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parsePRelu(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "PReLU";
layerParams.blobs.push_back(getBlob(node_proto, 1));
}
else if (layer_type == "LRN")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseLRN(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
replaceLayerParam(layerParams, "size", "local_size");
}
else if (layer_type == "InstanceNormalization")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseInstanceNormalization(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
if (node_proto.input_size() != 3)
CV_Error(Error::StsNotImplemented,
"Expected input, scale, bias");
@ -1052,9 +1171,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
//Replace Batch Norm's input to MVN
node_proto.set_input(0, mvnParams.name);
layerParams.type = "BatchNorm";
}
else if (layer_type == "BatchNormalization")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseBatchNormalization(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
if (node_proto.input_size() != 5)
CV_Error(Error::StsNotImplemented,
"Expected input, scale, bias, mean and var");
@ -1082,9 +1203,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
} else {
layerParams.set("has_bias", false);
}
}
else if (layer_type == "Gemm")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() >= 2);
layerParams.type = "InnerProduct";
Mat weights = getBlob(node_proto, 1);
@ -1115,9 +1238,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]);
layerParams.set("bias_term", node_proto.input_size() == 3);
}
else if (layer_type == "MatMul")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() == 2);
layerParams.type = "InnerProduct";
layerParams.set("bias_term", false);
@ -1135,9 +1260,14 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
secondInpDims = outShapes[node_proto.input(1)].size();
}
layerParams.set("axis", firstInpDims - secondInpDims + 1);
}
else if (layer_type == "Mul" || layer_type == "Div")
{
addLayer(layerParams, node_proto);
}
// "Mul" "Div"
void ONNXImporter::parseMul(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
const std::string& layer_type = node_proto.op_type();
CV_Assert(node_proto.input_size() == 2);
bool isDiv = layer_type == "Div";
@ -1255,9 +1385,12 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
}
layerParams.type = "Scale";
}
}
else if (layer_type == "Conv")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseConv(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
CV_Assert(node_proto.input_size() >= 2);
layerParams.type = "Convolution";
for (int j = 1; j < node_proto.input_size(); j++) {
@ -1307,9 +1440,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
node_proto.set_input(0, padLp.name);
}
}
}
else if (layer_type == "ConvTranspose")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseConvTranspose(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() >= 2);
layerParams.type = "Deconvolution";
for (int j = 1; j < node_proto.input_size(); j++) {
@ -1350,9 +1485,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
{
replaceLayerParam(layerParams, "output_padding", "adj");
}
}
else if (layer_type == "Transpose")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseTranspose(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "Permute";
replaceLayerParam(layerParams, "perm", "order");
@ -1365,9 +1502,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
addConstant(layerParams.name, transposed[0]);
return;
}
}
else if (layer_type == "Squeeze")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseSqueeze(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
DictValue axes_dict = layerParams.get("axes");
MatShape inpShape = outShapes[node_proto.input(0)];
@ -1415,9 +1554,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
addConstant(layerParams.name, out);
return;
}
}
else if (layer_type == "Flatten")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_CheckEQ(node_proto.input_size(), 1, "");
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
@ -1430,9 +1571,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
addConstant(layerParams.name, output);
return;
}
}
else if (layer_type == "Unsqueeze")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseUnsqueeze(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() == 1);
DictValue axes = layerParams.get("axes");
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
@ -1478,9 +1621,12 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.set("dynamic_axes", DictValue::arrayInt(dynamicAxes.data(), dynamicAxes.size()));
layerParams.set("input_indices", DictValue::arrayInt(inputIndices.data(), inputIndices.size()));
}
}
else if (layer_type == "Expand")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
CV_CheckEQ(node_proto.input_size(), 2, "");
const std::string& input0 = node_proto.input(0);
const std::string& input1 = node_proto.input(1);
@ -1602,9 +1748,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
}
else
CV_Error(Error::StsNotImplemented, "Unsupported Expand op");
}
else if (layer_type == "Reshape")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseReshape(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() == 2 || layerParams.has("shape"));
if (node_proto.input_size() == 2) {
@ -1636,9 +1784,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
}
replaceLayerParam(layerParams, "shape", "dim");
}
}
else if (layer_type == "Pad")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parsePad(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "Padding";
replaceLayerParam(layerParams, "mode", "type");
if (node_proto.input_size() == 3 || node_proto.input_size() == 2)
@ -1655,9 +1805,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.set("value", value.ptr<float>()[0]);
}
}
}
else if (layer_type == "Shape")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseShape(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() == 1);
IterShape_t shapeIt = outShapes.find(node_proto.input(0));
CV_Assert(shapeIt != outShapes.end());
@ -1669,10 +1821,10 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
shapeMat.dims = 1;
addConstant(layerParams.name, shapeMat);
return;
}
else if (layer_type == "Cast")
{
}
void ONNXImporter::parseCast(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
Mat blob = getBlob(node_proto, 0);
@ -1697,9 +1849,12 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
}
else
layerParams.type = "Identity";
}
else if (layer_type == "ConstantOfShape" || layer_type == "ConstantFill")
{
addLayer(layerParams, node_proto);
}
// "ConstantOfShape" "ConstantFill"
void ONNXImporter::parseConstantFill(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
int depth = CV_32F;
float fill_value;
if (!layerParams.blobs.empty())
@ -1718,10 +1873,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
CV_CheckGT(inpShape[i], 0, "");
Mat tensor(inpShape.size(), &inpShape[0], depth, Scalar(fill_value));
addConstant(layerParams.name, tensor);
return;
}
else if (layer_type == "Gather")
{
}
void ONNXImporter::parseGather(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
CV_Assert(node_proto.input_size() == 2);
Mat indexMat = getBlob(node_proto, 1);
CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
@ -1796,9 +1952,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams = sliceLp;
}
}
}
else if (layer_type == "Concat")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseConcat(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
bool hasVariableInps = false;
for (int i = 0; i < node_proto.input_size(); ++i)
{
@ -1856,9 +2014,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
}
}
}
}
else if (layer_type == "Resize")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseResize(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
for (int i = 1; i < node_proto.input_size(); i++)
CV_Assert(layer_id.find(node_proto.input(i)) == layer_id.end());
@ -1903,9 +2063,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
}
}
replaceLayerParam(layerParams, "mode", "interpolation");
}
else if (layer_type == "Upsample")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseUpsample(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
//fused from Resize Subgraph
if (layerParams.has("coordinate_transformation_mode"))
{
@ -1950,14 +2112,21 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
}
}
replaceLayerParam(layerParams, "mode", "interpolation");
}
else if (layer_type == "SoftMax" || layer_type == "LogSoftmax")
{
addLayer(layerParams, node_proto);
}
// "SoftMax" "LogSoftmax"
void ONNXImporter::parseSoftMax(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
const std::string& layer_type = node_proto.op_type();
layerParams.type = "Softmax";
layerParams.set("log_softmax", layer_type == "LogSoftmax");
}
else if (layer_type == "DetectionOutput")
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseDetectionOutput(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
CV_CheckEQ(node_proto.input_size(), 3, "");
if (constBlobs.find(node_proto.input(2)) != constBlobs.end())
{
@ -1974,31 +2143,68 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
node_proto.set_input(2, constParams.name);
}
}
else
{
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseCustom(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
for (int j = 0; j < node_proto.input_size(); j++) {
if (layer_id.find(node_proto.input(j)) == layer_id.end())
layerParams.blobs.push_back(getBlob(node_proto, j));
}
}
addLayer(layerParams, node_proto);
}
catch (const cv::Exception& e)
{
CV_LOG_ERROR(NULL, "DNN/ONNX: ERROR during processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
);
for (int i = 0; i < node_proto.input_size(); i++)
{
CV_LOG_INFO(NULL, " Input[" << i << "] = '" << node_proto.input(i) << "'");
}
for (int i = 0; i < node_proto.output_size(); i++)
{
CV_LOG_INFO(NULL, " Output[" << i << "] = '" << node_proto.output(i) << "'");
}
CV_Error(Error::StsError, cv::format("Node [%s]:(%s) parse error: %s", layer_type.c_str(), name.c_str(), e.what()));
}
}
const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap()
{
DispatchMap dispatch;
dispatch["MaxPool"] = &ONNXImporter::parseMaxPool;
dispatch["AveragePool"] = &ONNXImporter::parseAveragePool;
dispatch["GlobalAveragePool"] = dispatch["GlobalMaxPool"] = dispatch["ReduceMean"] = dispatch["ReduceSum"] =
dispatch["ReduceMax"] = &ONNXImporter::parseReduce;
dispatch["Slice"] = &ONNXImporter::parseSlice;
dispatch["Split"] = &ONNXImporter::parseSplit;
dispatch["Add"] = dispatch["Sum"] = dispatch["Sub"] = &ONNXImporter::parseBias;
dispatch["Pow"] = &ONNXImporter::parsePow;
dispatch["Max"] = &ONNXImporter::parseMax;
dispatch["Neg"] = &ONNXImporter::parseNeg;
dispatch["Constant"] = &ONNXImporter::parseConstant;
dispatch["LSTM"] = &ONNXImporter::parseLSTM;
dispatch["ImageScaler"] = &ONNXImporter::parseImageScaler;
dispatch["Clip"] = &ONNXImporter::parseClip;
dispatch["LeakyRelu"] = &ONNXImporter::parseLeakyRelu;
dispatch["Relu"] = &ONNXImporter::parseRelu;
dispatch["Elu"] = &ONNXImporter::parseElu;
dispatch["Tanh"] = &ONNXImporter::parseTanh;
dispatch["PRelu"] = &ONNXImporter::parsePRelu;
dispatch["LRN"] = &ONNXImporter::parseLRN;
dispatch["InstanceNormalization"] = &ONNXImporter::parseInstanceNormalization;
dispatch["BatchNormalization"] = &ONNXImporter::parseBatchNormalization;
dispatch["Gemm"] = &ONNXImporter::parseGemm;
dispatch["MatMul"] = &ONNXImporter::parseMatMul;
dispatch["Mul"] = dispatch["Div"] = &ONNXImporter::parseMul;
dispatch["Conv"] = &ONNXImporter::parseConv;
dispatch["ConvTranspose"] = &ONNXImporter::parseConvTranspose;
dispatch["Transpose"] = &ONNXImporter::parseTranspose;
dispatch["Squeeze"] = &ONNXImporter::parseSqueeze;
dispatch["Flatten"] = &ONNXImporter::parseFlatten;
dispatch["Unsqueeze"] = &ONNXImporter::parseUnsqueeze;
dispatch["Expand"] = &ONNXImporter::parseExpand;
dispatch["Reshape"] = &ONNXImporter::parseReshape;
dispatch["Pad"] = &ONNXImporter::parsePad;
dispatch["Shape"] = &ONNXImporter::parseShape;
dispatch["Cast"] = &ONNXImporter::parseCast;
dispatch["ConstantFill"] = dispatch["ConstantOfShape"] = &ONNXImporter::parseConstantFill;
dispatch["Gather"] = &ONNXImporter::parseGather;
dispatch["Concat"] = &ONNXImporter::parseConcat;
dispatch["Resize"] = &ONNXImporter::parseResize;
dispatch["Upsample"] = &ONNXImporter::parseUpsample;
dispatch["SoftMax"] = dispatch["LogSoftmax"] = &ONNXImporter::parseSoftMax;
dispatch["DetectionOutput"] = &ONNXImporter::parseDetectionOutput;
dispatch["Custom"] = &ONNXImporter::parseCustom;
return dispatch;
}
Net readNetFromONNX(const String& onnxFile)

View File

@ -2869,7 +2869,7 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer)
DispatchMap::const_iterator iter = dispatch.find(type);
if (iter != dispatch.end())
{
((*this).*(iter->second))(net, layer, layerParams);
CALL_MEMBER_FN(*this, iter->second)(net, layer, layerParams);
}
else
{