diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index b9fb0be624..5164fbeccb 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -546,6 +546,43 @@ void ONNXImporter::populateNet(Net dstNet) { replaceLayerParam(layerParams, "size", "local_size"); } + else if (layer_type == "InstanceNormalization") + { + if (node_proto.input_size() != 3) + CV_Error(Error::StsNotImplemented, + "Expected input, scale, bias"); + + layerParams.blobs.resize(4); + layerParams.blobs[2] = getBlob(node_proto, constBlobs, 1); // weightData + layerParams.blobs[3] = getBlob(node_proto, constBlobs, 2); // biasData + layerParams.set("has_bias", true); + layerParams.set("has_weight", true); + + // Get number of channels in input + int size = layerParams.blobs[2].total(); + layerParams.blobs[0] = Mat::zeros(size, 1, CV_32F); // mean + layerParams.blobs[1] = Mat::ones(size, 1, CV_32F); // std + + LayerParams mvnParams; + mvnParams.name = layerParams.name + "/MVN"; + mvnParams.type = "MVN"; + mvnParams.set("eps", layerParams.get("epsilon")); + layerParams.erase("epsilon"); + + //Create MVN layer + int id = dstNet.addLayer(mvnParams.name, mvnParams.type, mvnParams); + //Connect to input + layerId = layer_id.find(node_proto.input(0)); + CV_Assert(layerId != layer_id.end()); + dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0); + //Add shape + layer_id.insert(std::make_pair(mvnParams.name, LayerInfo(id, 0))); + outShapes[mvnParams.name] = outShapes[node_proto.input(0)]; + + //Replace Batch Norm's input to MVN + node_proto.set_input(0, mvnParams.name); + layerParams.type = "BatchNorm"; + } else if (layer_type == "BatchNormalization") { if (node_proto.input_size() != 5) diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 186239494f..fc38a2378c 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -76,6 +76,14 @@ public: } }; +TEST_P(Test_ONNX_layers, InstanceNorm) +{ + if (target == DNN_TARGET_MYRIAD) + testONNXModels("instancenorm", npy, 0, 0, false, false); + else + testONNXModels("instancenorm", npy); +} + TEST_P(Test_ONNX_layers, MaxPooling) { testONNXModels("maxpooling");