mirror of
https://github.com/opencv/opencv.git
synced 2024-11-29 13:47:32 +08:00
Merge pull request #14858 from dvd42:instancenorm_onnx
Instancenorm onnx (#14858) * Onnx unsupported operation handling * instance norm implementation * Revert "Onnx unsupported operation handling" * instance norm layer test * onnx instancenorm layer
This commit is contained in:
parent
3200fe0e53
commit
57fae4a6a1
@ -546,6 +546,43 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
{
|
||||
replaceLayerParam(layerParams, "size", "local_size");
|
||||
}
|
||||
else if (layer_type == "InstanceNormalization")
|
||||
{
|
||||
if (node_proto.input_size() != 3)
|
||||
CV_Error(Error::StsNotImplemented,
|
||||
"Expected input, scale, bias");
|
||||
|
||||
layerParams.blobs.resize(4);
|
||||
layerParams.blobs[2] = getBlob(node_proto, constBlobs, 1); // weightData
|
||||
layerParams.blobs[3] = getBlob(node_proto, constBlobs, 2); // biasData
|
||||
layerParams.set("has_bias", true);
|
||||
layerParams.set("has_weight", true);
|
||||
|
||||
// Get number of channels in input
|
||||
int size = layerParams.blobs[2].total();
|
||||
layerParams.blobs[0] = Mat::zeros(size, 1, CV_32F); // mean
|
||||
layerParams.blobs[1] = Mat::ones(size, 1, CV_32F); // std
|
||||
|
||||
LayerParams mvnParams;
|
||||
mvnParams.name = layerParams.name + "/MVN";
|
||||
mvnParams.type = "MVN";
|
||||
mvnParams.set("eps", layerParams.get<float>("epsilon"));
|
||||
layerParams.erase("epsilon");
|
||||
|
||||
//Create MVN layer
|
||||
int id = dstNet.addLayer(mvnParams.name, mvnParams.type, mvnParams);
|
||||
//Connect to input
|
||||
layerId = layer_id.find(node_proto.input(0));
|
||||
CV_Assert(layerId != layer_id.end());
|
||||
dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
|
||||
//Add shape
|
||||
layer_id.insert(std::make_pair(mvnParams.name, LayerInfo(id, 0)));
|
||||
outShapes[mvnParams.name] = outShapes[node_proto.input(0)];
|
||||
|
||||
//Replace Batch Norm's input to MVN
|
||||
node_proto.set_input(0, mvnParams.name);
|
||||
layerParams.type = "BatchNorm";
|
||||
}
|
||||
else if (layer_type == "BatchNormalization")
|
||||
{
|
||||
if (node_proto.input_size() != 5)
|
||||
|
@ -76,6 +76,14 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(Test_ONNX_layers, InstanceNorm)
|
||||
{
|
||||
if (target == DNN_TARGET_MYRIAD)
|
||||
testONNXModels("instancenorm", npy, 0, 0, false, false);
|
||||
else
|
||||
testONNXModels("instancenorm", npy);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, MaxPooling)
|
||||
{
|
||||
testONNXModels("maxpooling");
|
||||
|
Loading…
Reference in New Issue
Block a user