mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 22:44:02 +08:00
OpenCL GPU target for Inference Engine deep learning backend
Enable FP16 GPU target for DL Inference Engine backend.
This commit is contained in:
parent
72cb06abf0
commit
709cf5d038
@ -80,7 +80,8 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
enum Target
|
||||
{
|
||||
DNN_TARGET_CPU,
|
||||
DNN_TARGET_OPENCL
|
||||
DNN_TARGET_OPENCL,
|
||||
DNN_TARGET_OPENCL_FP16
|
||||
};
|
||||
|
||||
/** @brief This class provides all data needed to initialize layer.
|
||||
|
@ -13,7 +13,7 @@
|
||||
namespace opencv_test {
|
||||
|
||||
CV_ENUM(DNNBackend, DNN_BACKEND_DEFAULT, DNN_BACKEND_HALIDE, DNN_BACKEND_INFERENCE_ENGINE)
|
||||
CV_ENUM(DNNTarget, DNN_TARGET_CPU, DNN_TARGET_OPENCL)
|
||||
CV_ENUM(DNNTarget, DNN_TARGET_CPU, DNN_TARGET_OPENCL, DNN_TARGET_OPENCL_FP16)
|
||||
|
||||
class DNNTestNetwork : public ::perf::TestBaseWithParam< tuple<DNNBackend, DNNTarget> >
|
||||
{
|
||||
@ -41,8 +41,6 @@ public:
|
||||
throw cvtest::SkipTestException("OpenCL is not available/disabled in OpenCV");
|
||||
}
|
||||
}
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL)
|
||||
throw SkipTestException("Skip OpenCL target of Inference Engine backend");
|
||||
|
||||
randu(input, 0.0f, 1.0f);
|
||||
|
||||
@ -89,24 +87,32 @@ public:
|
||||
|
||||
PERF_TEST_P_(DNNTestNetwork, AlexNet)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/bvlc_alexnet.caffemodel", "dnn/bvlc_alexnet.prototxt",
|
||||
"alexnet.yml", Mat(cv::Size(227, 227), CV_32FC3));
|
||||
}
|
||||
|
||||
PERF_TEST_P_(DNNTestNetwork, GoogLeNet)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/bvlc_googlenet.caffemodel", "dnn/bvlc_googlenet.prototxt",
|
||||
"", Mat(cv::Size(224, 224), CV_32FC3));
|
||||
}
|
||||
|
||||
PERF_TEST_P_(DNNTestNetwork, ResNet_50)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/ResNet-50-model.caffemodel", "dnn/ResNet-50-deploy.prototxt",
|
||||
"resnet_50.yml", Mat(cv::Size(224, 224), CV_32FC3));
|
||||
}
|
||||
|
||||
PERF_TEST_P_(DNNTestNetwork, SqueezeNet_v1_1)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/squeezenet_v1.1.caffemodel", "dnn/squeezenet_v1.1.prototxt",
|
||||
"squeezenet_v1_1.yml", Mat(cv::Size(227, 227), CV_32FC3));
|
||||
}
|
||||
@ -135,14 +141,18 @@ PERF_TEST_P_(DNNTestNetwork, SSD)
|
||||
|
||||
PERF_TEST_P_(DNNTestNetwork, OpenFace)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
|
||||
if (backend == DNN_BACKEND_HALIDE ||
|
||||
backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/openface_nn4.small2.v1.t7", "", "",
|
||||
Mat(cv::Size(96, 96), CV_32FC3));
|
||||
}
|
||||
|
||||
PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_Caffe)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
|
||||
if (backend == DNN_BACKEND_HALIDE ||
|
||||
backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/MobileNetSSD_deploy.caffemodel", "dnn/MobileNetSSD_deploy.prototxt", "",
|
||||
Mat(cv::Size(300, 300), CV_32FC3));
|
||||
}
|
||||
@ -150,7 +160,8 @@ PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_Caffe)
|
||||
PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_TensorFlow)
|
||||
{
|
||||
if (backend == DNN_BACKEND_DEFAULT && target == DNN_TARGET_OPENCL ||
|
||||
backend == DNN_BACKEND_HALIDE)
|
||||
backend == DNN_BACKEND_HALIDE ||
|
||||
backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/ssd_mobilenet_v1_coco.pb", "ssd_mobilenet_v1_coco.pbtxt", "",
|
||||
Mat(cv::Size(300, 300), CV_32FC3));
|
||||
@ -158,7 +169,9 @@ PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_TensorFlow)
|
||||
|
||||
PERF_TEST_P_(DNNTestNetwork, DenseNet_121)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
|
||||
if (backend == DNN_BACKEND_HALIDE ||
|
||||
backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/DenseNet_121.caffemodel", "dnn/DenseNet_121.prototxt", "",
|
||||
Mat(cv::Size(224, 224), CV_32FC3));
|
||||
}
|
||||
@ -189,7 +202,7 @@ PERF_TEST_P_(DNNTestNetwork, OpenPose_pose_mpi_faster_4_stages)
|
||||
PERF_TEST_P_(DNNTestNetwork, opencv_face_detector)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE ||
|
||||
backend == DNN_BACKEND_DEFAULT && target == DNN_TARGET_OPENCL)
|
||||
backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/opencv_face_detector.caffemodel", "dnn/opencv_face_detector.prototxt", "",
|
||||
Mat(cv::Size(300, 300), CV_32FC3));
|
||||
@ -197,7 +210,9 @@ PERF_TEST_P_(DNNTestNetwork, opencv_face_detector)
|
||||
|
||||
PERF_TEST_P_(DNNTestNetwork, Inception_v2_SSD_TensorFlow)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
|
||||
if (backend == DNN_BACKEND_HALIDE ||
|
||||
backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/ssd_inception_v2_coco_2017_11_17.pb", "ssd_inception_v2_coco_2017_11_17.pbtxt", "",
|
||||
Mat(cv::Size(300, 300), CV_32FC3));
|
||||
}
|
||||
@ -209,6 +224,8 @@ const tuple<DNNBackend, DNNTarget> testCases[] = {
|
||||
#endif
|
||||
#ifdef HAVE_INF_ENGINE
|
||||
tuple<DNNBackend, DNNTarget>(DNN_BACKEND_INFERENCE_ENGINE, DNN_TARGET_CPU),
|
||||
tuple<DNNBackend, DNNTarget>(DNN_BACKEND_INFERENCE_ENGINE, DNN_TARGET_OPENCL),
|
||||
tuple<DNNBackend, DNNTarget>(DNN_BACKEND_INFERENCE_ENGINE, DNN_TARGET_OPENCL_FP16),
|
||||
#endif
|
||||
tuple<DNNBackend, DNNTarget>(DNN_BACKEND_DEFAULT, DNN_TARGET_CPU),
|
||||
tuple<DNNBackend, DNNTarget>(DNN_BACKEND_DEFAULT, DNN_TARGET_OPENCL)
|
||||
|
@ -1154,7 +1154,7 @@ struct Net::Impl
|
||||
ld.skip = true;
|
||||
}
|
||||
layers[lastLayerId].skip = false;
|
||||
ieNode->net->init();
|
||||
ieNode->net->init(preferableTarget);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1167,17 +1167,17 @@ struct Net::Impl
|
||||
for (it = layers.begin(); it != layers.end(); ++it)
|
||||
{
|
||||
LayerData &ld = it->second;
|
||||
ld.skip = true; // Initially skip all Inference Engine supported layers.
|
||||
Ptr<Layer> layer = ld.layerInstance;
|
||||
bool fused = ld.skip && ld.id != 0;
|
||||
|
||||
Ptr<Layer> layer = ld.layerInstance;
|
||||
if (!layer->supportBackend(preferableBackend))
|
||||
{
|
||||
addInfEngineNetOutputs(ld);
|
||||
ld.skip = false;
|
||||
net = Ptr<InfEngineBackendNet>();
|
||||
netBlobsWrappers.clear();
|
||||
continue;
|
||||
}
|
||||
ld.skip = true; // Initially skip all Inference Engine supported layers.
|
||||
|
||||
// Create a new network if one of inputs from different Inference Engine graph.
|
||||
for (int i = 0; i < ld.inputBlobsId.size(); ++i)
|
||||
@ -1217,19 +1217,16 @@ struct Net::Impl
|
||||
}
|
||||
netBlobsWrappers[ld.id] = ld.outputBlobsWrappers[0];
|
||||
|
||||
bool fused = false;
|
||||
Ptr<BackendNode> node;
|
||||
if (!net.empty())
|
||||
{
|
||||
// Try to fuse.
|
||||
bool inPlace = ld.inputBlobsId.size() == 1 && ld.outputBlobs.size() == 1 &&
|
||||
ld.inputBlobs[0]->data == ld.outputBlobs[0].data;
|
||||
if (inPlace)
|
||||
if (fused)
|
||||
{
|
||||
node = layer->tryAttach(layers[ld.inputBlobsId[0].lid].backendNodes[preferableBackend]);
|
||||
fused = !node.empty();
|
||||
if (fused)
|
||||
ld.inputBlobsWrappers = layers[ld.inputBlobsId[0].lid].inputBlobsWrappers;
|
||||
bool inPlace = ld.inputBlobsId.size() == 1 && ld.outputBlobs.size() == 1 &&
|
||||
ld.inputBlobs[0]->data == ld.outputBlobs[0].data;
|
||||
CV_Assert(inPlace);
|
||||
node = layers[ld.inputBlobsId[0].lid].backendNodes[preferableBackend];
|
||||
ld.inputBlobsWrappers = layers[ld.inputBlobsId[0].lid].inputBlobsWrappers;
|
||||
}
|
||||
}
|
||||
else
|
||||
@ -1247,6 +1244,19 @@ struct Net::Impl
|
||||
CV_Assert(!ieNode.empty());
|
||||
ieNode->net = net;
|
||||
|
||||
if (preferableTarget == DNN_TARGET_OPENCL_FP16 && !fused)
|
||||
{
|
||||
ieNode->layer->precision = InferenceEngine::Precision::FP16;
|
||||
auto weightableLayer = std::dynamic_pointer_cast<InferenceEngine::WeightableLayer>(ieNode->layer);
|
||||
if (weightableLayer)
|
||||
{
|
||||
if (weightableLayer->_weights)
|
||||
weightableLayer->_weights = convertFp16(weightableLayer->_weights);
|
||||
if (weightableLayer->_biases)
|
||||
weightableLayer->_biases = convertFp16(weightableLayer->_biases);
|
||||
}
|
||||
}
|
||||
|
||||
ieNode->connect(ld.inputBlobsWrappers, ld.outputBlobsWrappers);
|
||||
net->addBlobs(ld.inputBlobsWrappers);
|
||||
net->addBlobs(ld.outputBlobsWrappers);
|
||||
@ -1276,7 +1286,7 @@ struct Net::Impl
|
||||
|
||||
if (!ieNode->net->isInitialized())
|
||||
{
|
||||
ieNode->net->init();
|
||||
ieNode->net->init(preferableTarget);
|
||||
ld.skip = false;
|
||||
}
|
||||
}
|
||||
@ -1380,7 +1390,8 @@ struct Net::Impl
|
||||
|
||||
void fuseLayers(const std::vector<LayerPin>& blobsToKeep_)
|
||||
{
|
||||
if( !fusion || preferableBackend != DNN_BACKEND_DEFAULT)
|
||||
if( !fusion || preferableBackend != DNN_BACKEND_DEFAULT &&
|
||||
preferableBackend != DNN_BACKEND_INFERENCE_ENGINE)
|
||||
return;
|
||||
|
||||
CV_TRACE_FUNCTION();
|
||||
@ -1407,7 +1418,7 @@ struct Net::Impl
|
||||
// some other layers.
|
||||
|
||||
// TODO: OpenCL target support more fusion styles.
|
||||
if ( preferableTarget == DNN_TARGET_OPENCL &&
|
||||
if ( preferableBackend == DNN_BACKEND_DEFAULT && preferableTarget == DNN_TARGET_OPENCL &&
|
||||
(!cv::ocl::useOpenCL() || (ld.layerInstance->type != "Convolution" &&
|
||||
ld.layerInstance->type != "MVN")) )
|
||||
continue;
|
||||
@ -1442,6 +1453,9 @@ struct Net::Impl
|
||||
break;
|
||||
}
|
||||
|
||||
if (preferableBackend != DNN_BACKEND_DEFAULT)
|
||||
continue; // Go to the next layer.
|
||||
|
||||
// For now, OpenCL target support fusion with activation of ReLU/ChannelsPReLU/Power/Tanh
|
||||
if ( preferableTarget != DNN_TARGET_OPENCL ||
|
||||
(preferableTarget == DNN_TARGET_OPENCL &&
|
||||
@ -1583,6 +1597,9 @@ struct Net::Impl
|
||||
}
|
||||
}
|
||||
|
||||
if (preferableBackend != DNN_BACKEND_DEFAULT)
|
||||
continue; // Go to the next layer.
|
||||
|
||||
// the optimization #2. if there is no layer that takes max pooling layer's computed
|
||||
// max indices (and only some semantical segmentation networks might need this;
|
||||
// many others only take the maximum values), then we switch the max pooling
|
||||
|
@ -234,19 +234,6 @@ public:
|
||||
#endif // HAVE_HALIDE
|
||||
break;
|
||||
}
|
||||
case DNN_BACKEND_INFERENCE_ENGINE:
|
||||
{
|
||||
#ifdef HAVE_INF_ENGINE
|
||||
auto base = node.dynamicCast<InfEngineBackendNode>();
|
||||
auto conv = std::dynamic_pointer_cast<InferenceEngine::ConvolutionLayer>(base->layer);
|
||||
if (conv)
|
||||
{
|
||||
fuseConvWeights(conv, weights_, bias_);
|
||||
return base;
|
||||
}
|
||||
#endif // HAVE_INF_ENGINE
|
||||
break;
|
||||
}
|
||||
}
|
||||
return Ptr<BackendNode>();
|
||||
}
|
||||
@ -287,8 +274,9 @@ public:
|
||||
lp.precision = InferenceEngine::Precision::FP32;
|
||||
std::shared_ptr<InferenceEngine::ScaleShiftLayer> ieLayer(new InferenceEngine::ScaleShiftLayer(lp));
|
||||
|
||||
ieLayer->_weights = wrapToInfEngineBlob(weights_);
|
||||
ieLayer->_biases = wrapToInfEngineBlob(bias_);
|
||||
const int numChannels = weights_.total();
|
||||
ieLayer->_weights = wrapToInfEngineBlob(weights_, {numChannels}, InferenceEngine::Layout::C);
|
||||
ieLayer->_biases = wrapToInfEngineBlob(bias_, {numChannels}, InferenceEngine::Layout::C);
|
||||
|
||||
return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));
|
||||
#endif // HAVE_INF_ENGINE
|
||||
|
@ -173,21 +173,21 @@ public:
|
||||
std::vector<float> biasvec;
|
||||
std::vector<float> reluslope;
|
||||
Ptr<ActivationLayer> activ;
|
||||
bool newWeightAndBias;
|
||||
bool fusedBias;
|
||||
|
||||
#ifdef HAVE_OPENCL
|
||||
Ptr<OCL4DNNConvSpatial<float> > convolutionOp;
|
||||
std::vector<UMat> umat_blobs;
|
||||
bool fusedBias;
|
||||
bool newWeightAndBias;
|
||||
bool newActiv;
|
||||
ocl4dnnFusedActiv_t activType;
|
||||
float power;
|
||||
#endif
|
||||
ConvolutionLayerImpl(const LayerParams ¶ms) : BaseConvolutionLayerImpl(params)
|
||||
{
|
||||
#ifdef HAVE_OPENCL
|
||||
fusedBias = false;
|
||||
newWeightAndBias = false;
|
||||
fusedBias = false;
|
||||
#ifdef HAVE_OPENCL
|
||||
newActiv = false;
|
||||
activType = OCL4DNN_CONV_FUSED_ACTIV_NONE;
|
||||
power = 0.f;
|
||||
@ -350,10 +350,8 @@ public:
|
||||
biasvec[i] += b.at<float>(i);
|
||||
}
|
||||
|
||||
#ifdef HAVE_OPENCL
|
||||
newWeightAndBias = !w.empty() || !b.empty();
|
||||
fusedBias = hasBias() || !b.empty();
|
||||
#endif
|
||||
biasvec[outCn] = biasvec[outCn+1] = biasvec[outCn-1];
|
||||
}
|
||||
|
||||
@ -433,9 +431,31 @@ public:
|
||||
ieLayer->_dilation_y = dilation.height;
|
||||
ieLayer->_group = group;
|
||||
|
||||
ieLayer->_weights = wrapToInfEngineBlob(blobs[0]);
|
||||
if (hasBias())
|
||||
ieLayer->_biases = wrapToInfEngineBlob(blobs[1]);
|
||||
ieLayer->_weights = wrapToInfEngineBlob(blobs[0], InferenceEngine::Layout::OIHW);
|
||||
if (newWeightAndBias)
|
||||
{
|
||||
if (weightsMat.isContinuous())
|
||||
{
|
||||
Mat fusedWeights = weightsMat.reshape(1, blobs[0].dims, blobs[0].size);
|
||||
ieLayer->_weights = wrapToInfEngineBlob(fusedWeights, InferenceEngine::Layout::OIHW);
|
||||
}
|
||||
else
|
||||
{
|
||||
ieLayer->_weights = InferenceEngine::make_shared_blob<float>(
|
||||
InferenceEngine::Precision::FP32, InferenceEngine::Layout::OIHW,
|
||||
ieLayer->_weights->dims());
|
||||
ieLayer->_weights->allocate();
|
||||
|
||||
Mat newWeights = infEngineBlobToMat(ieLayer->_weights).reshape(1, outCn);
|
||||
Mat fusedWeights = weightsMat.colRange(0, newWeights.cols);
|
||||
fusedWeights.copyTo(newWeights);
|
||||
}
|
||||
}
|
||||
if (hasBias() || fusedBias)
|
||||
{
|
||||
Mat biasesMat({outCn}, CV_32F, &biasvec[0]);
|
||||
ieLayer->_biases = wrapToInfEngineBlob(biasesMat, {outCn}, InferenceEngine::Layout::C);
|
||||
}
|
||||
return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));
|
||||
#endif // HAVE_INF_ENGINE
|
||||
return Ptr<BackendNode>();
|
||||
|
@ -412,9 +412,9 @@ public:
|
||||
std::shared_ptr<InferenceEngine::FullyConnectedLayer> ieLayer(new InferenceEngine::FullyConnectedLayer(lp));
|
||||
|
||||
ieLayer->_out_num = blobs[0].size[0];
|
||||
ieLayer->_weights = wrapToInfEngineBlob(blobs[0]);
|
||||
ieLayer->_weights = wrapToInfEngineBlob(blobs[0], {blobs[0].size[0], blobs[0].size[1], 1, 1}, InferenceEngine::Layout::OIHW);
|
||||
if (blobs.size() > 1)
|
||||
ieLayer->_biases = wrapToInfEngineBlob(blobs[1]);
|
||||
ieLayer->_biases = wrapToInfEngineBlob(blobs[1], {ieLayer->_out_num}, InferenceEngine::Layout::C);
|
||||
return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));
|
||||
#endif // HAVE_INF_ENGINE
|
||||
return Ptr<BackendNode>();
|
||||
|
@ -132,20 +132,6 @@ public:
|
||||
#endif // HAVE_HALIDE
|
||||
break;
|
||||
}
|
||||
case DNN_BACKEND_INFERENCE_ENGINE:
|
||||
{
|
||||
#ifdef HAVE_INF_ENGINE
|
||||
auto base = node.dynamicCast<InfEngineBackendNode>();
|
||||
auto conv = std::dynamic_pointer_cast<InferenceEngine::ConvolutionLayer>(base->layer);
|
||||
if (conv)
|
||||
{
|
||||
Mat bias = hasBias ? blobs[1] : Mat();
|
||||
fuseConvWeights(conv, blobs[0], bias);
|
||||
return base;
|
||||
}
|
||||
#endif // HAVE_INF_ENGINE
|
||||
break;
|
||||
}
|
||||
}
|
||||
return Ptr<BackendNode>();
|
||||
}
|
||||
@ -192,9 +178,10 @@ public:
|
||||
lp.precision = InferenceEngine::Precision::FP32;
|
||||
std::shared_ptr<InferenceEngine::ScaleShiftLayer> ieLayer(new InferenceEngine::ScaleShiftLayer(lp));
|
||||
|
||||
ieLayer->_weights = wrapToInfEngineBlob(blobs[0]);
|
||||
const int numChannels = blobs[0].total();
|
||||
ieLayer->_weights = wrapToInfEngineBlob(blobs[0], {numChannels}, InferenceEngine::Layout::C);
|
||||
if (hasBias)
|
||||
ieLayer->_biases = wrapToInfEngineBlob(blobs[1]);
|
||||
ieLayer->_biases = wrapToInfEngineBlob(blobs[1], {numChannels}, InferenceEngine::Layout::C);
|
||||
|
||||
return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));
|
||||
#endif // HAVE_INF_ENGINE
|
||||
|
@ -90,27 +90,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
virtual Ptr<BackendNode> tryAttach(const Ptr<BackendNode>& node) CV_OVERRIDE
|
||||
{
|
||||
switch (node->backendId)
|
||||
{
|
||||
case DNN_BACKEND_INFERENCE_ENGINE:
|
||||
{
|
||||
#ifdef HAVE_INF_ENGINE
|
||||
auto base = node.dynamicCast<InfEngineBackendNode>();
|
||||
auto conv = std::dynamic_pointer_cast<InferenceEngine::ConvolutionLayer>(base->layer);
|
||||
if (conv)
|
||||
{
|
||||
fuseConvWeights(conv, Mat(), blobs[0]);
|
||||
return base;
|
||||
}
|
||||
#endif // HAVE_INF_ENGINE
|
||||
break;
|
||||
}
|
||||
}
|
||||
return Ptr<BackendNode>();
|
||||
}
|
||||
|
||||
virtual Ptr<BackendNode> initInfEngine(const std::vector<Ptr<BackendWrapper> >&) CV_OVERRIDE
|
||||
{
|
||||
#ifdef HAVE_INF_ENGINE
|
||||
|
@ -59,22 +59,22 @@ static InferenceEngine::DataPtr wrapToInfEngineDataNode(const Mat& m, const std:
|
||||
std::vector<size_t> reversedShape(&m.size[0], &m.size[0] + m.dims);
|
||||
std::reverse(reversedShape.begin(), reversedShape.end());
|
||||
return InferenceEngine::DataPtr(
|
||||
new InferenceEngine::Data(name, reversedShape, InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Layout::ANY)
|
||||
new InferenceEngine::Data(name, reversedShape, InferenceEngine::Precision::FP32)
|
||||
);
|
||||
}
|
||||
|
||||
InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m, const std::vector<size_t>& shape)
|
||||
InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m, const std::vector<size_t>& shape,
|
||||
InferenceEngine::Layout layout)
|
||||
{
|
||||
return InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32,
|
||||
shape, (float*)m.data);
|
||||
layout, shape, (float*)m.data);
|
||||
}
|
||||
|
||||
InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m)
|
||||
InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m, InferenceEngine::Layout layout)
|
||||
{
|
||||
std::vector<size_t> reversedShape(&m.size[0], &m.size[0] + m.dims);
|
||||
std::reverse(reversedShape.begin(), reversedShape.end());
|
||||
return wrapToInfEngineBlob(m, reversedShape);
|
||||
return wrapToInfEngineBlob(m, reversedShape, layout);
|
||||
}
|
||||
|
||||
InferenceEngine::DataPtr infEngineDataNode(const Ptr<BackendWrapper>& ptr)
|
||||
@ -109,10 +109,14 @@ void InfEngineBackendWrapper::setHostDirty()
|
||||
|
||||
InfEngineBackendNet::InfEngineBackendNet()
|
||||
{
|
||||
targetDevice = InferenceEngine::TargetDevice::eCPU;
|
||||
precision = InferenceEngine::Precision::FP32;
|
||||
}
|
||||
|
||||
InfEngineBackendNet::InfEngineBackendNet(InferenceEngine::CNNNetwork& net)
|
||||
{
|
||||
targetDevice = InferenceEngine::TargetDevice::eCPU;
|
||||
precision = InferenceEngine::Precision::FP32;
|
||||
inputs = net.getInputsInfo();
|
||||
outputs = net.getOutputsInfo();
|
||||
layers.resize(net.layerCount()); // A hack to execute InfEngineBackendNet::layerCount correctly.
|
||||
@ -126,9 +130,14 @@ void InfEngineBackendNet::Release() noexcept
|
||||
outputs.clear();
|
||||
}
|
||||
|
||||
void InfEngineBackendNet::setPrecision(InferenceEngine::Precision p) noexcept
|
||||
{
|
||||
precision = p;
|
||||
}
|
||||
|
||||
InferenceEngine::Precision InfEngineBackendNet::getPrecision() noexcept
|
||||
{
|
||||
return InferenceEngine::Precision::FP32;
|
||||
return precision;
|
||||
}
|
||||
|
||||
// Assume that outputs of network is unconnected blobs.
|
||||
@ -161,9 +170,8 @@ InferenceEngine::InputInfo::Ptr InfEngineBackendNet::getInput(const std::string
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void InfEngineBackendNet::getName(char *pName, size_t len) noexcept
|
||||
void InfEngineBackendNet::getName(char*, size_t) noexcept
|
||||
{
|
||||
CV_Error(Error::StsNotImplemented, "");
|
||||
}
|
||||
|
||||
size_t InfEngineBackendNet::layerCount() noexcept
|
||||
@ -213,13 +221,15 @@ InfEngineBackendNet::getLayerByName(const char *layerName, InferenceEngine::CNNL
|
||||
|
||||
void InfEngineBackendNet::setTargetDevice(InferenceEngine::TargetDevice device) noexcept
|
||||
{
|
||||
if (device != InferenceEngine::TargetDevice::eCPU)
|
||||
if (device != InferenceEngine::TargetDevice::eCPU &&
|
||||
device != InferenceEngine::TargetDevice::eGPU)
|
||||
CV_Error(Error::StsNotImplemented, "");
|
||||
targetDevice = device;
|
||||
}
|
||||
|
||||
InferenceEngine::TargetDevice InfEngineBackendNet::getTargetDevice() noexcept
|
||||
{
|
||||
return InferenceEngine::TargetDevice::eCPU;
|
||||
return targetDevice;
|
||||
}
|
||||
|
||||
InferenceEngine::StatusCode InfEngineBackendNet::setBatchSize(const size_t size) noexcept
|
||||
@ -234,7 +244,7 @@ size_t InfEngineBackendNet::getBatchSize() const noexcept
|
||||
return 0;
|
||||
}
|
||||
|
||||
void InfEngineBackendNet::init()
|
||||
void InfEngineBackendNet::init(int targetId)
|
||||
{
|
||||
if (inputs.empty())
|
||||
{
|
||||
@ -307,6 +317,15 @@ void InfEngineBackendNet::init()
|
||||
outBlobs[it.first] = allBlobs[it.first];
|
||||
}
|
||||
|
||||
switch (targetId)
|
||||
{
|
||||
case DNN_TARGET_CPU: setTargetDevice(InferenceEngine::TargetDevice::eCPU); break;
|
||||
case DNN_TARGET_OPENCL_FP16: setPrecision(InferenceEngine::Precision::FP16); // Fallback to the next.
|
||||
case DNN_TARGET_OPENCL: setTargetDevice(InferenceEngine::TargetDevice::eGPU); break;
|
||||
default:
|
||||
CV_Error(Error::StsError, format("Unknown target identifier: %d", targetId));
|
||||
}
|
||||
|
||||
if (!isInitialized())
|
||||
initPlugin(*this);
|
||||
}
|
||||
@ -319,7 +338,7 @@ void InfEngineBackendNet::initPlugin(InferenceEngine::ICNNNetwork& net)
|
||||
InferenceEngine::ResponseDesc resp;
|
||||
const InferenceEngine::Version* v = InferenceEngine::GetInferenceEngineVersion();
|
||||
|
||||
plugin = InferenceEngine::PluginDispatcher({""}).getSuitablePlugin(InferenceEngine::TargetDevice::eCPU);
|
||||
plugin = InferenceEngine::PluginDispatcher({""}).getSuitablePlugin(targetDevice);
|
||||
if (std::atoi(v->buildNumber) > 5855)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
@ -360,7 +379,7 @@ void InfEngineBackendNet::forward()
|
||||
CV_Error(Error::StsAssert, resp.msg);
|
||||
}
|
||||
|
||||
static inline Mat infEngineBlobToMat(const InferenceEngine::Blob::Ptr& blob)
|
||||
Mat infEngineBlobToMat(const InferenceEngine::Blob::Ptr& blob)
|
||||
{
|
||||
// NOTE: Inference Engine sizes are reversed.
|
||||
std::vector<size_t> dims = blob->dims();
|
||||
@ -369,56 +388,6 @@ static inline Mat infEngineBlobToMat(const InferenceEngine::Blob::Ptr& blob)
|
||||
return Mat(size, CV_32F, (void*)blob->buffer());
|
||||
}
|
||||
|
||||
void fuseConvWeights(const std::shared_ptr<InferenceEngine::ConvolutionLayer>& conv,
|
||||
const Mat& w, const Mat& b)
|
||||
{
|
||||
CV_Assert(!w.empty() || !b.empty());
|
||||
if (!w.empty())
|
||||
{
|
||||
// Get convolution's weights. Clone the data because Inference Engine can host it
|
||||
// and conv->_weights->allocate() below will deallocate it.
|
||||
Mat originWeights = infEngineBlobToMat(conv->_weights).clone();
|
||||
|
||||
// Create new weights blob.
|
||||
conv->_weights = InferenceEngine::make_shared_blob<float>(
|
||||
InferenceEngine::Precision::FP32, conv->_weights->dims());
|
||||
conv->_weights->allocate();
|
||||
|
||||
// Convolution weights have OIHW data layout.
|
||||
// (conv(I) + b1 ) * w + b2
|
||||
// w*conv(I) + b1 * w + b2
|
||||
Mat fusedWeights = infEngineBlobToMat(conv->_weights);
|
||||
|
||||
const int numChannels = fusedWeights.size[0];
|
||||
// Mat weights = blobs[0].reshape(1, 1);
|
||||
// Mat bias = hasBias ? blobs[1].reshape(1, 1) : Mat();
|
||||
CV_Assert(numChannels == w.total());
|
||||
CV_Assert(b.empty() || numChannels == b.total());
|
||||
for (int i = 0; i < numChannels; ++i)
|
||||
{
|
||||
cv::multiply(slice(originWeights, i), w.at<float>(i), slice(fusedWeights, i));
|
||||
}
|
||||
}
|
||||
if (conv->_biases)
|
||||
{
|
||||
// The same for biases.
|
||||
Mat originBiases = infEngineBlobToMat(conv->_biases).clone();
|
||||
|
||||
conv->_biases = InferenceEngine::make_shared_blob<float>(
|
||||
InferenceEngine::Precision::FP32, conv->_biases->dims());
|
||||
conv->_biases->allocate();
|
||||
Mat fusedBiases = infEngineBlobToMat(conv->_biases);
|
||||
originBiases.copyTo(fusedBiases);
|
||||
|
||||
if (!w.empty())
|
||||
cv::multiply(w.reshape(1, fusedBiases.dims, &fusedBiases.size[0]), fusedBiases, fusedBiases);
|
||||
if (!b.empty())
|
||||
cv::add(fusedBiases, b.reshape(1, fusedBiases.dims, &fusedBiases.size[0]), fusedBiases);
|
||||
}
|
||||
else
|
||||
conv->_biases = wrapToInfEngineBlob(b);
|
||||
}
|
||||
|
||||
InfEngineBackendLayer::InfEngineBackendLayer(const InferenceEngine::DataPtr& output_)
|
||||
{
|
||||
output = output_;
|
||||
@ -454,6 +423,16 @@ void InfEngineBackendLayer::forward(InputArrayOfArrays inputs, OutputArrayOfArra
|
||||
CV_Error(Error::StsInternal, "Choose Inference Engine as a preferable backend.");
|
||||
}
|
||||
|
||||
InferenceEngine::TBlob<int16_t>::Ptr convertFp16(const InferenceEngine::Blob::Ptr& blob)
|
||||
{
|
||||
auto halfs = InferenceEngine::make_shared_blob<int16_t>(InferenceEngine::Precision::FP16, blob->layout(), blob->dims());
|
||||
halfs->allocate();
|
||||
Mat floatsData(1, blob->size(), CV_32F, blob->buffer());
|
||||
Mat halfsData(1, blob->size(), CV_16SC1, halfs->buffer());
|
||||
convertFp16(floatsData, halfsData);
|
||||
return halfs;
|
||||
}
|
||||
|
||||
#endif // HAVE_INF_ENGINE
|
||||
|
||||
bool haveInfEngine()
|
||||
|
@ -32,6 +32,8 @@ public:
|
||||
|
||||
virtual void Release() noexcept CV_OVERRIDE;
|
||||
|
||||
void setPrecision(InferenceEngine::Precision p) noexcept;
|
||||
|
||||
virtual InferenceEngine::Precision getPrecision() noexcept CV_OVERRIDE;
|
||||
|
||||
virtual void getOutputsInfo(InferenceEngine::OutputsDataMap &out) noexcept /*CV_OVERRIDE*/;
|
||||
@ -68,7 +70,7 @@ public:
|
||||
|
||||
virtual size_t getBatchSize() const noexcept CV_OVERRIDE;
|
||||
|
||||
void init();
|
||||
void init(int targetId);
|
||||
|
||||
void addBlobs(const std::vector<Ptr<BackendWrapper> >& wrappers);
|
||||
|
||||
@ -83,6 +85,8 @@ private:
|
||||
InferenceEngine::BlobMap inpBlobs;
|
||||
InferenceEngine::BlobMap outBlobs;
|
||||
InferenceEngine::BlobMap allBlobs;
|
||||
InferenceEngine::TargetDevice targetDevice;
|
||||
InferenceEngine::Precision precision;
|
||||
InferenceEngine::InferenceEnginePluginPtr plugin;
|
||||
|
||||
void initPlugin(InferenceEngine::ICNNNetwork& net);
|
||||
@ -116,15 +120,17 @@ public:
|
||||
InferenceEngine::TBlob<float>::Ptr blob;
|
||||
};
|
||||
|
||||
InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m);
|
||||
InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m, InferenceEngine::Layout layout = InferenceEngine::Layout::ANY);
|
||||
|
||||
InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m, const std::vector<size_t>& shape);
|
||||
InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m, const std::vector<size_t>& shape, InferenceEngine::Layout layout);
|
||||
|
||||
InferenceEngine::DataPtr infEngineDataNode(const Ptr<BackendWrapper>& ptr);
|
||||
|
||||
// Fuses convolution weights and biases with channel-wise scales and shifts.
|
||||
void fuseConvWeights(const std::shared_ptr<InferenceEngine::ConvolutionLayer>& conv,
|
||||
const Mat& w, const Mat& b = Mat());
|
||||
Mat infEngineBlobToMat(const InferenceEngine::Blob::Ptr& blob);
|
||||
|
||||
// Convert Inference Engine blob with FP32 precision to FP16 precision.
|
||||
// Allocates memory for a new blob.
|
||||
InferenceEngine::TBlob<int16_t>::Ptr convertFp16(const InferenceEngine::Blob::Ptr& blob);
|
||||
|
||||
// This is a fake class to run networks from Model Optimizer. Objects of that
|
||||
// class simulate responses of layers are imported by OpenCV and supported by
|
||||
@ -151,7 +157,6 @@ private:
|
||||
InferenceEngine::DataPtr output;
|
||||
};
|
||||
|
||||
|
||||
#endif // HAVE_INF_ENGINE
|
||||
|
||||
bool haveInfEngine();
|
||||
|
@ -100,6 +100,8 @@ public:
|
||||
|
||||
TEST_P(DNNTestNetwork, AlexNet)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/bvlc_alexnet.caffemodel", "dnn/bvlc_alexnet.prototxt",
|
||||
Size(227, 227), "prob",
|
||||
target == DNN_TARGET_OPENCL ? "dnn/halide_scheduler_opencl_alexnet.yml" :
|
||||
@ -108,6 +110,8 @@ TEST_P(DNNTestNetwork, AlexNet)
|
||||
|
||||
TEST_P(DNNTestNetwork, ResNet_50)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/ResNet-50-model.caffemodel", "dnn/ResNet-50-deploy.prototxt",
|
||||
Size(224, 224), "prob",
|
||||
target == DNN_TARGET_OPENCL ? "dnn/halide_scheduler_opencl_resnet_50.yml" :
|
||||
@ -116,6 +120,8 @@ TEST_P(DNNTestNetwork, ResNet_50)
|
||||
|
||||
TEST_P(DNNTestNetwork, SqueezeNet_v1_1)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/squeezenet_v1.1.caffemodel", "dnn/squeezenet_v1.1.prototxt",
|
||||
Size(227, 227), "prob",
|
||||
target == DNN_TARGET_OPENCL ? "dnn/halide_scheduler_opencl_squeezenet_v1_1.yml" :
|
||||
@ -124,6 +130,8 @@ TEST_P(DNNTestNetwork, SqueezeNet_v1_1)
|
||||
|
||||
TEST_P(DNNTestNetwork, GoogLeNet)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/bvlc_googlenet.caffemodel", "dnn/bvlc_googlenet.prototxt",
|
||||
Size(224, 224), "prob");
|
||||
}
|
||||
@ -147,7 +155,9 @@ TEST_P(DNNTestNetwork, ENet)
|
||||
|
||||
TEST_P(DNNTestNetwork, MobileNet_SSD_Caffe)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
|
||||
if (backend == DNN_BACKEND_HALIDE ||
|
||||
backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
|
||||
throw SkipTestException("");
|
||||
Mat sample = imread(findDataFile("dnn/street.png", false));
|
||||
Mat inp = blobFromImage(sample, 1.0f / 127.5, Size(300, 300), Scalar(127.5, 127.5, 127.5), false);
|
||||
|
||||
@ -157,7 +167,9 @@ TEST_P(DNNTestNetwork, MobileNet_SSD_Caffe)
|
||||
|
||||
TEST_P(DNNTestNetwork, MobileNet_SSD_TensorFlow)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
|
||||
if (backend == DNN_BACKEND_HALIDE ||
|
||||
backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
|
||||
throw SkipTestException("");
|
||||
Mat sample = imread(findDataFile("dnn/street.png", false));
|
||||
Mat inp = blobFromImage(sample, 1.0f / 127.5, Size(300, 300), Scalar(127.5, 127.5, 127.5), false);
|
||||
processNet("dnn/ssd_mobilenet_v1_coco.pb", "dnn/ssd_mobilenet_v1_coco.pbtxt",
|
||||
@ -177,35 +189,45 @@ TEST_P(DNNTestNetwork, SSD_VGG16)
|
||||
TEST_P(DNNTestNetwork, OpenPose_pose_coco)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
|
||||
double l1 = target == DNN_TARGET_OPENCL_FP16 ? 3e-5 : 1e-5;
|
||||
double lInf = target == DNN_TARGET_OPENCL_FP16 ? 3e-3 : 1e-4;
|
||||
processNet("dnn/openpose_pose_coco.caffemodel", "dnn/openpose_pose_coco.prototxt",
|
||||
Size(368, 368), "");
|
||||
Size(368, 368), "", "", l1, lInf);
|
||||
}
|
||||
|
||||
TEST_P(DNNTestNetwork, OpenPose_pose_mpi)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
|
||||
double l1 = target == DNN_TARGET_OPENCL_FP16 ? 4e-5 : 1e-5;
|
||||
double lInf = target == DNN_TARGET_OPENCL_FP16 ? 7e-3 : 1e-4;
|
||||
processNet("dnn/openpose_pose_mpi.caffemodel", "dnn/openpose_pose_mpi.prototxt",
|
||||
Size(368, 368), "");
|
||||
Size(368, 368), "", "", l1, lInf);
|
||||
}
|
||||
|
||||
TEST_P(DNNTestNetwork, OpenPose_pose_mpi_faster_4_stages)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
|
||||
double l1 = target == DNN_TARGET_OPENCL_FP16 ? 5e-5 : 1e-5;
|
||||
double lInf = target == DNN_TARGET_OPENCL_FP16 ? 5e-3 : 1e-4;
|
||||
// The same .caffemodel but modified .prototxt
|
||||
// See https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/pose/poseParameters.cpp
|
||||
processNet("dnn/openpose_pose_mpi.caffemodel", "dnn/openpose_pose_mpi_faster_4_stages.prototxt",
|
||||
Size(368, 368), "");
|
||||
Size(368, 368), "", "", l1, lInf);
|
||||
}
|
||||
|
||||
TEST_P(DNNTestNetwork, OpenFace)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
|
||||
if (backend == DNN_BACKEND_HALIDE ||
|
||||
backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/openface_nn4.small2.v1.t7", "", Size(96, 96), "");
|
||||
}
|
||||
|
||||
TEST_P(DNNTestNetwork, opencv_face_detector)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
|
||||
if (backend == DNN_BACKEND_HALIDE ||
|
||||
backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
|
||||
throw SkipTestException("");
|
||||
Mat img = imread(findDataFile("gpu/lbpcascade/er.png", false));
|
||||
Mat inp = blobFromImage(img, 1.0, Size(), Scalar(104.0, 177.0, 123.0), false, false);
|
||||
processNet("dnn/opencv_face_detector.caffemodel", "dnn/opencv_face_detector.prototxt",
|
||||
@ -214,13 +236,23 @@ TEST_P(DNNTestNetwork, opencv_face_detector)
|
||||
|
||||
TEST_P(DNNTestNetwork, Inception_v2_SSD_TensorFlow)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
|
||||
if (backend == DNN_BACKEND_HALIDE ||
|
||||
backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
|
||||
throw SkipTestException("");
|
||||
Mat sample = imread(findDataFile("dnn/street.png", false));
|
||||
Mat inp = blobFromImage(sample, 1.0f / 127.5, Size(300, 300), Scalar(127.5, 127.5, 127.5), false);
|
||||
processNet("dnn/ssd_inception_v2_coco_2017_11_17.pb", "dnn/ssd_inception_v2_coco_2017_11_17.pbtxt",
|
||||
inp, "detection_out");
|
||||
}
|
||||
|
||||
TEST_P(DNNTestNetwork, DenseNet_121)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE ||
|
||||
backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/DenseNet_121.caffemodel", "dnn/DenseNet_121.prototxt", Size(224, 224), "", "caffe");
|
||||
}
|
||||
|
||||
const tuple<DNNBackend, DNNTarget> testCases[] = {
|
||||
#ifdef HAVE_HALIDE
|
||||
tuple<DNNBackend, DNNTarget>(DNN_BACKEND_HALIDE, DNN_TARGET_CPU),
|
||||
@ -228,6 +260,8 @@ const tuple<DNNBackend, DNNTarget> testCases[] = {
|
||||
#endif
|
||||
#ifdef HAVE_INF_ENGINE
|
||||
tuple<DNNBackend, DNNTarget>(DNN_BACKEND_INFERENCE_ENGINE, DNN_TARGET_CPU),
|
||||
tuple<DNNBackend, DNNTarget>(DNN_BACKEND_INFERENCE_ENGINE, DNN_TARGET_OPENCL),
|
||||
tuple<DNNBackend, DNNTarget>(DNN_BACKEND_INFERENCE_ENGINE, DNN_TARGET_OPENCL_FP16),
|
||||
#endif
|
||||
tuple<DNNBackend, DNNTarget>(DNN_BACKEND_DEFAULT, DNN_TARGET_OPENCL)
|
||||
};
|
||||
|
@ -53,7 +53,7 @@ namespace opencv_test {
|
||||
using namespace cv::dnn;
|
||||
|
||||
CV_ENUM(DNNBackend, DNN_BACKEND_DEFAULT, DNN_BACKEND_HALIDE, DNN_BACKEND_INFERENCE_ENGINE)
|
||||
CV_ENUM(DNNTarget, DNN_TARGET_CPU, DNN_TARGET_OPENCL)
|
||||
CV_ENUM(DNNTarget, DNN_TARGET_CPU, DNN_TARGET_OPENCL, DNN_TARGET_OPENCL_FP16)
|
||||
|
||||
static testing::internal::ParamGenerator<DNNTarget> availableDnnTargets()
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user