mirror of
https://github.com/opencv/opencv.git
synced 2024-11-25 19:50:38 +08:00
Merge pull request #14858 from dvd42:instancenorm_onnx
Instancenorm onnx (#14858) * Onnx unsupported operation handling * instance norm implementation * Revert "Onnx unsupported operation handling" * instance norm layer test * onnx instancenorm layer
This commit is contained in:
parent
3200fe0e53
commit
57fae4a6a1
@ -546,6 +546,43 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
{
|
{
|
||||||
replaceLayerParam(layerParams, "size", "local_size");
|
replaceLayerParam(layerParams, "size", "local_size");
|
||||||
}
|
}
|
||||||
|
else if (layer_type == "InstanceNormalization")
|
||||||
|
{
|
||||||
|
if (node_proto.input_size() != 3)
|
||||||
|
CV_Error(Error::StsNotImplemented,
|
||||||
|
"Expected input, scale, bias");
|
||||||
|
|
||||||
|
layerParams.blobs.resize(4);
|
||||||
|
layerParams.blobs[2] = getBlob(node_proto, constBlobs, 1); // weightData
|
||||||
|
layerParams.blobs[3] = getBlob(node_proto, constBlobs, 2); // biasData
|
||||||
|
layerParams.set("has_bias", true);
|
||||||
|
layerParams.set("has_weight", true);
|
||||||
|
|
||||||
|
// Get number of channels in input
|
||||||
|
int size = layerParams.blobs[2].total();
|
||||||
|
layerParams.blobs[0] = Mat::zeros(size, 1, CV_32F); // mean
|
||||||
|
layerParams.blobs[1] = Mat::ones(size, 1, CV_32F); // std
|
||||||
|
|
||||||
|
LayerParams mvnParams;
|
||||||
|
mvnParams.name = layerParams.name + "/MVN";
|
||||||
|
mvnParams.type = "MVN";
|
||||||
|
mvnParams.set("eps", layerParams.get<float>("epsilon"));
|
||||||
|
layerParams.erase("epsilon");
|
||||||
|
|
||||||
|
//Create MVN layer
|
||||||
|
int id = dstNet.addLayer(mvnParams.name, mvnParams.type, mvnParams);
|
||||||
|
//Connect to input
|
||||||
|
layerId = layer_id.find(node_proto.input(0));
|
||||||
|
CV_Assert(layerId != layer_id.end());
|
||||||
|
dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
|
||||||
|
//Add shape
|
||||||
|
layer_id.insert(std::make_pair(mvnParams.name, LayerInfo(id, 0)));
|
||||||
|
outShapes[mvnParams.name] = outShapes[node_proto.input(0)];
|
||||||
|
|
||||||
|
//Replace Batch Norm's input to MVN
|
||||||
|
node_proto.set_input(0, mvnParams.name);
|
||||||
|
layerParams.type = "BatchNorm";
|
||||||
|
}
|
||||||
else if (layer_type == "BatchNormalization")
|
else if (layer_type == "BatchNormalization")
|
||||||
{
|
{
|
||||||
if (node_proto.input_size() != 5)
|
if (node_proto.input_size() != 5)
|
||||||
|
@ -76,6 +76,14 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
TEST_P(Test_ONNX_layers, InstanceNorm)
|
||||||
|
{
|
||||||
|
if (target == DNN_TARGET_MYRIAD)
|
||||||
|
testONNXModels("instancenorm", npy, 0, 0, false, false);
|
||||||
|
else
|
||||||
|
testONNXModels("instancenorm", npy);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(Test_ONNX_layers, MaxPooling)
|
TEST_P(Test_ONNX_layers, MaxPooling)
|
||||||
{
|
{
|
||||||
testONNXModels("maxpooling");
|
testONNXModels("maxpooling");
|
||||||
|
Loading…
Reference in New Issue
Block a user