mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 14:13:15 +08:00
Merge pull request #13497 from dkurt:dnn_torch_bn_train
This commit is contained in:
commit
14633bc857
@ -46,9 +46,9 @@
|
||||
#include <opencv2/core.hpp>
|
||||
|
||||
#if !defined CV_DOXYGEN && !defined CV_DNN_DONT_ADD_EXPERIMENTAL_NS
|
||||
#define CV__DNN_EXPERIMENTAL_NS_BEGIN namespace experimental_dnn_34_v10 {
|
||||
#define CV__DNN_EXPERIMENTAL_NS_BEGIN namespace experimental_dnn_34_v11 {
|
||||
#define CV__DNN_EXPERIMENTAL_NS_END }
|
||||
namespace cv { namespace dnn { namespace experimental_dnn_34_v10 { } using namespace experimental_dnn_34_v10; }}
|
||||
namespace cv { namespace dnn { namespace experimental_dnn_34_v11 { } using namespace experimental_dnn_34_v11; }}
|
||||
#else
|
||||
#define CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
#define CV__DNN_EXPERIMENTAL_NS_END
|
||||
@ -754,6 +754,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
* @brief Reads a network model stored in <a href="http://torch.ch">Torch7</a> framework's format.
|
||||
* @param model path to the file, dumped from Torch by using torch.save() function.
|
||||
* @param isBinary specifies whether the network was serialized in ascii mode or binary.
|
||||
* @param evaluate specifies testing phase of network. If true, it's similar to evaluate() method in Torch.
|
||||
* @returns Net object.
|
||||
*
|
||||
* @note Ascii mode of Torch serializer is more preferable, because binary mode extensively use `long` type of C language,
|
||||
@ -775,7 +776,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
*
|
||||
* Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported.
|
||||
*/
|
||||
CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true);
|
||||
CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true, bool evaluate = true);
|
||||
|
||||
/**
|
||||
* @brief Read deep learning network represented in one of the supported formats.
|
||||
|
@ -129,13 +129,15 @@ struct TorchImporter
|
||||
Module *rootModule;
|
||||
Module *curModule;
|
||||
int moduleCounter;
|
||||
bool testPhase;
|
||||
|
||||
TorchImporter(String filename, bool isBinary)
|
||||
TorchImporter(String filename, bool isBinary, bool evaluate)
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
|
||||
rootModule = curModule = NULL;
|
||||
moduleCounter = 0;
|
||||
testPhase = evaluate;
|
||||
|
||||
file = cv::Ptr<THFile>(THDiskFile_new(filename, "r", 0), THFile_free);
|
||||
CV_Assert(file && THFile_isOpened(file));
|
||||
@ -680,7 +682,8 @@ struct TorchImporter
|
||||
layerParams.blobs.push_back(tensorParams["bias"].second);
|
||||
}
|
||||
|
||||
if (nnName == "InstanceNormalization")
|
||||
bool trainPhase = scalarParams.get<bool>("train", false);
|
||||
if (nnName == "InstanceNormalization" || (trainPhase && !testPhase))
|
||||
{
|
||||
cv::Ptr<Module> mvnModule(new Module(nnName));
|
||||
mvnModule->apiType = "MVN";
|
||||
@ -1243,18 +1246,18 @@ struct TorchImporter
|
||||
|
||||
Mat readTorchBlob(const String &filename, bool isBinary)
|
||||
{
|
||||
TorchImporter importer(filename, isBinary);
|
||||
TorchImporter importer(filename, isBinary, true);
|
||||
importer.readObject();
|
||||
CV_Assert(importer.tensors.size() == 1);
|
||||
|
||||
return importer.tensors.begin()->second;
|
||||
}
|
||||
|
||||
Net readNetFromTorch(const String &model, bool isBinary)
|
||||
Net readNetFromTorch(const String &model, bool isBinary, bool evaluate)
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
|
||||
TorchImporter importer(model, isBinary);
|
||||
TorchImporter importer(model, isBinary, evaluate);
|
||||
Net net;
|
||||
importer.populateNet(net);
|
||||
return net;
|
||||
|
@ -73,7 +73,7 @@ class Test_Torch_layers : public DNNTestLayer
|
||||
{
|
||||
public:
|
||||
void runTorchNet(const String& prefix, String outLayerName = "",
|
||||
bool check2ndBlob = false, bool isBinary = false,
|
||||
bool check2ndBlob = false, bool isBinary = false, bool evaluate = true,
|
||||
double l1 = 0.0, double lInf = 0.0)
|
||||
{
|
||||
String suffix = (isBinary) ? ".dat" : ".txt";
|
||||
@ -84,7 +84,7 @@ public:
|
||||
|
||||
checkBackend(backend, target, &inp, &outRef);
|
||||
|
||||
Net net = readNetFromTorch(_tf(prefix + "_net" + suffix), isBinary);
|
||||
Net net = readNetFromTorch(_tf(prefix + "_net" + suffix), isBinary, evaluate);
|
||||
ASSERT_FALSE(net.empty());
|
||||
|
||||
net.setPreferableBackend(backend);
|
||||
@ -114,7 +114,7 @@ TEST_P(Test_Torch_layers, run_convolution)
|
||||
// Output reference values are in range [23.4018, 72.0181]
|
||||
double l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.08 : default_l1;
|
||||
double lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.42 : default_lInf;
|
||||
runTorchNet("net_conv", "", false, true, l1, lInf);
|
||||
runTorchNet("net_conv", "", false, true, true, l1, lInf);
|
||||
}
|
||||
|
||||
TEST_P(Test_Torch_layers, run_pool_max)
|
||||
@ -147,7 +147,7 @@ TEST_P(Test_Torch_layers, run_reshape)
|
||||
TEST_P(Test_Torch_layers, run_reshape_single_sample)
|
||||
{
|
||||
// Reference output values in range [14.4586, 18.4492].
|
||||
runTorchNet("net_reshape_single_sample", "", false, false,
|
||||
runTorchNet("net_reshape_single_sample", "", false, false, true,
|
||||
(target == DNN_TARGET_MYRIAD || target == DNN_TARGET_OPENCL_FP16) ? 0.0073 : default_l1,
|
||||
(target == DNN_TARGET_MYRIAD || target == DNN_TARGET_OPENCL_FP16) ? 0.025 : default_lInf);
|
||||
}
|
||||
@ -166,7 +166,7 @@ TEST_P(Test_Torch_layers, run_concat)
|
||||
|
||||
TEST_P(Test_Torch_layers, run_depth_concat)
|
||||
{
|
||||
runTorchNet("net_depth_concat", "", false, true, 0.0,
|
||||
runTorchNet("net_depth_concat", "", false, true, true, 0.0,
|
||||
target == DNN_TARGET_OPENCL_FP16 ? 0.021 : 0.0);
|
||||
}
|
||||
|
||||
@ -182,6 +182,7 @@ TEST_P(Test_Torch_layers, run_deconv)
|
||||
TEST_P(Test_Torch_layers, run_batch_norm)
|
||||
{
|
||||
runTorchNet("net_batch_norm", "", false, true);
|
||||
runTorchNet("net_batch_norm_train", "", false, true, false);
|
||||
}
|
||||
|
||||
TEST_P(Test_Torch_layers, net_prelu)
|
||||
@ -216,7 +217,7 @@ TEST_P(Test_Torch_layers, net_conv_gemm_lrn)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_MYRIAD)
|
||||
throw SkipTestException("");
|
||||
runTorchNet("net_conv_gemm_lrn", "", false, true,
|
||||
runTorchNet("net_conv_gemm_lrn", "", false, true, true,
|
||||
target == DNN_TARGET_OPENCL_FP16 ? 0.046 : 0.0,
|
||||
target == DNN_TARGET_OPENCL_FP16 ? 0.023 : 0.0);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user