mirror of
https://github.com/opencv/opencv.git
synced 2025-06-11 20:09:23 +08:00
nn.BatchNormalization and nn.Dropout layers from Torch
This commit is contained in:
parent
fc9e031454
commit
bbbec300a6
@ -119,8 +119,9 @@ public:
|
||||
CV_Assert(inputs.size() == 1);
|
||||
|
||||
Mat &inpBlob = *inputs[0];
|
||||
int rows = inpBlob.size[2];
|
||||
int cols = inpBlob.size[3];
|
||||
CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4);
|
||||
int rows = inpBlob.dims > 2 ? inpBlob.size[2] : 1;
|
||||
int cols = inpBlob.dims > 2 ? inpBlob.size[3] : 1;
|
||||
|
||||
for (size_t ii = 0; ii < outputs.size(); ii++)
|
||||
{
|
||||
|
@ -617,7 +617,8 @@ struct TorchImporter : public ::cv::dnn::Importer
|
||||
curModule->modules.push_back(cv::Ptr<Module>(new Module(nnName, "Sigmoid")));
|
||||
readObject();
|
||||
}
|
||||
else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization")
|
||||
else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization" ||
|
||||
nnName == "BatchNormalization")
|
||||
{
|
||||
newModule->apiType = "BatchNorm";
|
||||
readTorchTable(scalarParams, tensorParams);
|
||||
@ -700,17 +701,24 @@ struct TorchImporter : public ::cv::dnn::Importer
|
||||
|
||||
curModule->modules.push_back(newModule);
|
||||
}
|
||||
else if (nnName == "SpatialDropout")
|
||||
else if (nnName == "SpatialDropout" || nnName == "Dropout")
|
||||
{
|
||||
readTorchTable(scalarParams, tensorParams);
|
||||
CV_Assert(scalarParams.has("p"));
|
||||
|
||||
float scale = 1 - scalarParams.get<double>("p");
|
||||
if (scalarParams.has("v2") && scalarParams.get<bool>("v2"))
|
||||
{
|
||||
newModule->apiType = "Identity";
|
||||
}
|
||||
else
|
||||
{
|
||||
float scale = 1 - scalarParams.get<double>("p");
|
||||
|
||||
CV_Assert(scale > 0);
|
||||
CV_Assert(scale > 0);
|
||||
|
||||
newModule->apiType = "Power";
|
||||
layerParams.set("scale", scale);
|
||||
newModule->apiType = "Power";
|
||||
layerParams.set("scale", scale);
|
||||
}
|
||||
curModule->modules.push_back(newModule);
|
||||
}
|
||||
// TotalVariation layer is from fast-neural-style project: https://github.com/jcjohnson/fast-neural-style
|
||||
|
@ -234,6 +234,11 @@ TEST(Torch_Importer, net_padding)
|
||||
runTorchNet("net_spatial_reflection_padding", DNN_TARGET_CPU, "", false, true);
|
||||
}
|
||||
|
||||
TEST(Torch_Importer, net_non_spatial)
|
||||
{
|
||||
runTorchNet("net_non_spatial", DNN_TARGET_CPU, "", false, true);
|
||||
}
|
||||
|
||||
TEST(Torch_Importer, ENet_accuracy)
|
||||
{
|
||||
Net net;
|
||||
|
Loading…
Reference in New Issue
Block a user