Merge pull request #14858 from dvd42:instancenorm_onnx

Instancenorm onnx (#14858)

* Onnx unsupported operation handling

* instance norm implementation

* Revert "Onnx unsupported operation handling"

* instance norm layer test

* onnx instancenorm layer
This commit is contained in:
Diego 2019-07-04 20:15:04 +02:00 committed by Alexander Alekhin
parent 3200fe0e53
commit 57fae4a6a1
2 changed files with 45 additions and 0 deletions

View File

@ -546,6 +546,43 @@ void ONNXImporter::populateNet(Net dstNet)
{
replaceLayerParam(layerParams, "size", "local_size");
}
else if (layer_type == "InstanceNormalization")
{
if (node_proto.input_size() != 3)
CV_Error(Error::StsNotImplemented,
"Expected input, scale, bias");
layerParams.blobs.resize(4);
layerParams.blobs[2] = getBlob(node_proto, constBlobs, 1); // weightData
layerParams.blobs[3] = getBlob(node_proto, constBlobs, 2); // biasData
layerParams.set("has_bias", true);
layerParams.set("has_weight", true);
// Get number of channels in input
int size = layerParams.blobs[2].total();
layerParams.blobs[0] = Mat::zeros(size, 1, CV_32F); // mean
layerParams.blobs[1] = Mat::ones(size, 1, CV_32F); // std
LayerParams mvnParams;
mvnParams.name = layerParams.name + "/MVN";
mvnParams.type = "MVN";
mvnParams.set("eps", layerParams.get<float>("epsilon"));
layerParams.erase("epsilon");
//Create MVN layer
int id = dstNet.addLayer(mvnParams.name, mvnParams.type, mvnParams);
//Connect to input
layerId = layer_id.find(node_proto.input(0));
CV_Assert(layerId != layer_id.end());
dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
//Add shape
layer_id.insert(std::make_pair(mvnParams.name, LayerInfo(id, 0)));
outShapes[mvnParams.name] = outShapes[node_proto.input(0)];
//Replace Batch Norm's input to MVN
node_proto.set_input(0, mvnParams.name);
layerParams.type = "BatchNorm";
}
else if (layer_type == "BatchNormalization")
{
if (node_proto.input_size() != 5)

View File

@ -76,6 +76,14 @@ public:
}
};
TEST_P(Test_ONNX_layers, InstanceNorm)
{
if (target == DNN_TARGET_MYRIAD)
testONNXModels("instancenorm", npy, 0, 0, false, false);
else
testONNXModels("instancenorm", npy);
}
TEST_P(Test_ONNX_layers, MaxPooling)
{
testONNXModels("maxpooling");