mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 09:25:45 +08:00
Update tutorials. A new cv::dnn::readNet function
This commit is contained in:
parent
8e4fe30db6
commit
f2440ceae6
@ -13,50 +13,53 @@ We will demonstrate results of this example on the following picture.
|
||||
Source Code
|
||||
-----------
|
||||
|
||||
We will be using snippets from the example application, that can be downloaded [here](https://github.com/opencv/opencv/blob/master/samples/dnn/caffe_googlenet.cpp).
|
||||
We will be using snippets from the example application, that can be downloaded [here](https://github.com/opencv/opencv/blob/master/samples/dnn/classification.cpp).
|
||||
|
||||
@include dnn/caffe_googlenet.cpp
|
||||
@include dnn/classification.cpp
|
||||
|
||||
Explanation
|
||||
-----------
|
||||
|
||||
-# Firstly, download GoogLeNet model files:
|
||||
[bvlc_googlenet.prototxt ](https://raw.githubusercontent.com/opencv/opencv/master/samples/data/dnn/bvlc_googlenet.prototxt) and
|
||||
[bvlc_googlenet.prototxt ](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/bvlc_googlenet.prototxt) and
|
||||
[bvlc_googlenet.caffemodel](http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel)
|
||||
|
||||
Also you need file with names of [ILSVRC2012](http://image-net.org/challenges/LSVRC/2012/browse-synsets) classes:
|
||||
[synset_words.txt](https://raw.githubusercontent.com/opencv/opencv/master/samples/data/dnn/synset_words.txt).
|
||||
[classification_classes_ILSVRC2012.txt](https://github.com/opencv/opencv/tree/master/samples/dnn/classification_classes_ILSVRC2012.txt).
|
||||
|
||||
Put these files into working dir of this program example.
|
||||
|
||||
-# Read and initialize network using path to .prototxt and .caffemodel files
|
||||
@snippet dnn/caffe_googlenet.cpp Read and initialize network
|
||||
@snippet dnn/classification.cpp Read and initialize network
|
||||
|
||||
-# Check that network was read successfully
|
||||
@snippet dnn/caffe_googlenet.cpp Check that network was read successfully
|
||||
You can skip an argument `framework` if one of the files `model` or `config` has an
|
||||
extension `.caffemodel` or `.prototxt`.
|
||||
This way function cv::dnn::readNet can automatically detects a model's format.
|
||||
|
||||
-# Read input image and convert to the blob, acceptable by GoogleNet
|
||||
@snippet dnn/caffe_googlenet.cpp Prepare blob
|
||||
We convert the image to a 4-dimensional blob (so-called batch) with 1x3x224x224 shape after applying necessary pre-processing like resizing and mean subtraction using cv::dnn::blobFromImage constructor.
|
||||
@snippet dnn/classification.cpp Open a video file or an image file or a camera stream
|
||||
|
||||
cv::VideoCapture can load both images and videos.
|
||||
|
||||
@snippet dnn/classification.cpp Create a 4D blob from a frame
|
||||
We convert the image to a 4-dimensional blob (so-called batch) with `1x3x224x224` shape
|
||||
after applying necessary pre-processing like resizing and mean subtraction
|
||||
`(-104, -117, -123)` for each blue, green and red channels correspondingly using cv::dnn::blobFromImage function.
|
||||
|
||||
-# Pass the blob to the network
|
||||
@snippet dnn/caffe_googlenet.cpp Set input blob
|
||||
In bvlc_googlenet.prototxt the network input blob named as "data", therefore this blob labeled as ".data" in opencv_dnn API.
|
||||
|
||||
Other blobs labeled as "name_of_layer.name_of_layer_output".
|
||||
@snippet dnn/classification.cpp Set input blob
|
||||
|
||||
-# Make forward pass
|
||||
@snippet dnn/caffe_googlenet.cpp Make forward pass
|
||||
During the forward pass output of each network layer is computed, but in this example we need output from "prob" layer only.
|
||||
@snippet dnn/classification.cpp Make forward pass
|
||||
During the forward pass output of each network layer is computed, but in this example we need output from the last layer only.
|
||||
|
||||
-# Determine the best class
|
||||
@snippet dnn/caffe_googlenet.cpp Gather output
|
||||
We put the output of "prob" layer, which contain probabilities for each of 1000 ILSVRC2012 image classes, to the `prob` blob.
|
||||
And find the index of element with maximal value in this one. This index correspond to the class of the image.
|
||||
@snippet dnn/classification.cpp Get a class with a highest score
|
||||
We put the output of network, which contain probabilities for each of 1000 ILSVRC2012 image classes, to the `prob` blob.
|
||||
And find the index of element with maximal value in this one. This index corresponds to the class of the image.
|
||||
|
||||
-# Print results
|
||||
@snippet dnn/caffe_googlenet.cpp Print results
|
||||
For our image we get:
|
||||
> Best class: #812 'space shuttle'
|
||||
>
|
||||
> Probability: 99.6378%
|
||||
-# Run an example from command line
|
||||
@code
|
||||
./example_dnn_classification --model=bvlc_googlenet.caffemodel --config=bvlc_googlenet.prototxt --width=224 --height=224 --classes=classification_classes_ILSVRC2012.txt --input=space_shuttle.jpg --mean="104 117 123"
|
||||
@endcode
|
||||
For our image we get prediction of class `space shuttle` with more than 99% sureness.
|
||||
|
@ -74,46 +74,7 @@ When you build OpenCV add the following configuration flags:
|
||||
|
||||
- `HALIDE_ROOT_DIR` - path to Halide build directory
|
||||
|
||||
## Sample
|
||||
|
||||
@include dnn/squeezenet_halide.cpp
|
||||
|
||||
## Explanation
|
||||
Download Caffe model from SqueezeNet repository: [train_val.prototxt](https://github.com/DeepScale/SqueezeNet/blob/master/SqueezeNet_v1.1/train_val.prototxt) and [squeezenet_v1.1.caffemodel](https://github.com/DeepScale/SqueezeNet/blob/master/SqueezeNet_v1.1/squeezenet_v1.1.caffemodel).
|
||||
|
||||
Also you need file with names of [ILSVRC2012](http://image-net.org/challenges/LSVRC/2012/browse-synsets) classes:
|
||||
[synset_words.txt](https://raw.githubusercontent.com/opencv/opencv/master/samples/data/dnn/synset_words.txt).
|
||||
|
||||
Put these files into working dir of this program example.
|
||||
|
||||
-# Read and initialize network using path to .prototxt and .caffemodel files
|
||||
@snippet dnn/squeezenet_halide.cpp Read and initialize network
|
||||
|
||||
-# Check that network was read successfully
|
||||
@snippet dnn/squeezenet_halide.cpp Check that network was read successfully
|
||||
|
||||
-# Read input image and convert to the 4-dimensional blob, acceptable by SqueezeNet v1.1
|
||||
@snippet dnn/squeezenet_halide.cpp Prepare blob
|
||||
|
||||
-# Pass the blob to the network
|
||||
@snippet dnn/squeezenet_halide.cpp Set input blob
|
||||
|
||||
-# Enable Halide backend for layers where it is implemented
|
||||
@snippet dnn/squeezenet_halide.cpp Enable Halide backend
|
||||
|
||||
-# Make forward pass
|
||||
@snippet dnn/squeezenet_halide.cpp Make forward pass
|
||||
Remember that the first forward pass after initialization require quite more
|
||||
time that the next ones. It's because of runtime compilation of Halide pipelines
|
||||
at the first invocation.
|
||||
|
||||
-# Determine the best class
|
||||
@snippet dnn/squeezenet_halide.cpp Determine the best class
|
||||
|
||||
-# Print results
|
||||
@snippet dnn/squeezenet_halide.cpp Print results
|
||||
For our image we get:
|
||||
|
||||
> Best class: #812 'space shuttle'
|
||||
>
|
||||
> Probability: 97.9812%
|
||||
## Set Halide as a preferable backend
|
||||
@code
|
||||
net.setPreferableBackend(DNN_BACKEND_HALIDE);
|
||||
@endcode
|
||||
|
@ -683,6 +683,29 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
*/
|
||||
CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true);
|
||||
|
||||
/**
|
||||
* @brief Read deep learning network represented in one of the supported formats.
|
||||
* @param[in] model Binary file contains trained weights. The following file
|
||||
* extensions are expected for models from different frameworks:
|
||||
* * `*.caffemodel` (Caffe, http://caffe.berkeleyvision.org/)
|
||||
* * `*.pb` (TensorFlow, https://www.tensorflow.org/)
|
||||
* * `*.t7` | `*.net` (Torch, http://torch.ch/)
|
||||
* * `*.weights` (Darknet, https://pjreddie.com/darknet/)
|
||||
* @param[in] config Text file contains network configuration. It could be a
|
||||
* file with the following extensions:
|
||||
* * `*.prototxt` (Caffe, http://caffe.berkeleyvision.org/)
|
||||
* * `*.pbtxt` (TensorFlow, https://www.tensorflow.org/)
|
||||
* * `*.cfg` (Darknet, https://pjreddie.com/darknet/)
|
||||
* @param[in] framework Explicit framework name tag to determine a format.
|
||||
* @returns Net object.
|
||||
*
|
||||
* This function automatically detects an origin framework of trained model
|
||||
* and calls an appropriate function such @ref readNetFromCaffe, @ref readNetFromTensorflow,
|
||||
* @ref readNetFromTorch or @ref readNetFromDarknet. An order of @p model and @p config
|
||||
* arguments does not matter.
|
||||
*/
|
||||
CV_EXPORTS_W Net readNet(String model, String config = "", String framework = "");
|
||||
|
||||
/** @brief Loads blob which was serialized as torch.Tensor object of Torch7 framework.
|
||||
* @warning This function has the same limitations as readNetFromTorch().
|
||||
*/
|
||||
|
@ -2805,5 +2805,41 @@ BackendWrapper::BackendWrapper(const Ptr<BackendWrapper>& base, const MatShape&
|
||||
|
||||
BackendWrapper::~BackendWrapper() {}
|
||||
|
||||
Net readNet(String model, String config, String framework)
|
||||
{
|
||||
framework = framework.toLowerCase();
|
||||
const std::string modelExt = model.substr(model.rfind('.') + 1);
|
||||
const std::string configExt = config.substr(config.rfind('.') + 1);
|
||||
if (framework == "caffe" || modelExt == "caffemodel" || configExt == "caffemodel" ||
|
||||
modelExt == "prototxt" || configExt == "prototxt")
|
||||
{
|
||||
if (modelExt == "prototxt" || configExt == "caffemodel")
|
||||
std::swap(model, config);
|
||||
return readNetFromCaffe(config, model);
|
||||
}
|
||||
if (framework == "tensorflow" || modelExt == "pb" || configExt == "pb" ||
|
||||
modelExt == "pbtxt" || configExt == "pbtxt")
|
||||
{
|
||||
if (modelExt == "pbtxt" || configExt == "pb")
|
||||
std::swap(model, config);
|
||||
return readNetFromTensorflow(model, config);
|
||||
}
|
||||
if (framework == "torch" || modelExt == "t7" || modelExt == "net" ||
|
||||
configExt == "t7" || configExt == "net")
|
||||
{
|
||||
return readNetFromTorch(model.empty() ? config : model);
|
||||
}
|
||||
if (framework == "darknet" || modelExt == "weights" || configExt == "weights" ||
|
||||
modelExt == "cfg" || configExt == "cfg")
|
||||
{
|
||||
if (modelExt == "cfg" || configExt == "weights")
|
||||
std::swap(model, config);
|
||||
return readNetFromDarknet(config, model);
|
||||
}
|
||||
CV_Error(Error::StsError, "Cannot determine an origin framework of files: " +
|
||||
model + (config.empty() ? "" : ", " + config));
|
||||
return Net();
|
||||
}
|
||||
|
||||
CV__DNN_EXPERIMENTAL_NS_END
|
||||
}} // namespace
|
||||
|
@ -57,4 +57,22 @@ TEST(imagesFromBlob, Regression)
|
||||
}
|
||||
}
|
||||
|
||||
TEST(readNet, Regression)
|
||||
{
|
||||
Net net = readNet(findDataFile("dnn/squeezenet_v1.1.prototxt", false),
|
||||
findDataFile("dnn/squeezenet_v1.1.caffemodel", false));
|
||||
EXPECT_FALSE(net.empty());
|
||||
net = readNet(findDataFile("dnn/opencv_face_detector.caffemodel", false),
|
||||
findDataFile("dnn/opencv_face_detector.prototxt", false));
|
||||
EXPECT_FALSE(net.empty());
|
||||
net = readNet(findDataFile("dnn/openface_nn4.small2.v1.t7", false));
|
||||
EXPECT_FALSE(net.empty());
|
||||
net = readNet(findDataFile("dnn/tiny-yolo-voc.cfg", false),
|
||||
findDataFile("dnn/tiny-yolo-voc.weights", false));
|
||||
EXPECT_FALSE(net.empty());
|
||||
net = readNet(findDataFile("dnn/ssd_mobilenet_v1_coco.pbtxt", false),
|
||||
findDataFile("dnn/ssd_mobilenet_v1_coco.pb", false));
|
||||
EXPECT_FALSE(net.empty());
|
||||
}
|
||||
|
||||
}} // namespace
|
||||
|
@ -14,14 +14,12 @@
|
||||
| [Faster-RCNN](https://github.com/rbgirshick/py-faster-rcnn) | `1.0` | `800x600` | `102.9801, 115.9465, 122.7717` | BGR |
|
||||
| [R-FCN](https://github.com/YuwenXiong/py-R-FCN) | `1.0` | `800x600` | `102.9801 115.9465 122.7717` | BGR |
|
||||
|
||||
|
||||
### Classification
|
||||
| Model | Scale | Size WxH| Mean subtraction | Channels order |
|
||||
|---------------|-------|-----------|--------------------|-------|
|
||||
| GoogLeNet | `1.0` | `224x224` | `104 117 123` | BGR |
|
||||
| [SqueezeNet](https://github.com/DeepScale/SqueezeNet) | `1.0` | `227x227` | `0 0 0` | BGR |
|
||||
|
||||
|
||||
## References
|
||||
* [Models downloading script](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/download_models.py)
|
||||
* [Configuration files adopted for OpenCV](https://github.com/opencv/opencv_extra/tree/master/testdata/dnn)
|
||||
|
@ -2,8 +2,9 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include <opencv2/dnn.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <opencv2/highgui.hpp>
|
||||
|
||||
const char* keys =
|
||||
"{ help h | | Print help message. }"
|
||||
@ -33,8 +34,6 @@ using namespace dnn;
|
||||
|
||||
std::vector<std::string> classes;
|
||||
|
||||
Net readNet(const std::string& model, const std::string& config = "", const std::string& framework = "");
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
CommandLineParser parser(argc, argv, keys);
|
||||
@ -49,6 +48,11 @@ int main(int argc, char** argv)
|
||||
bool swapRB = parser.get<bool>("rgb");
|
||||
int inpWidth = parser.get<int>("width");
|
||||
int inpHeight = parser.get<int>("height");
|
||||
String model = parser.get<String>("model");
|
||||
String config = parser.get<String>("config");
|
||||
String framework = parser.get<String>("framework");
|
||||
int backendId = parser.get<int>("backend");
|
||||
int targetId = parser.get<int>("target");
|
||||
|
||||
// Parse mean values.
|
||||
Scalar mean;
|
||||
@ -77,22 +81,24 @@ int main(int argc, char** argv)
|
||||
}
|
||||
}
|
||||
|
||||
// Load a model.
|
||||
CV_Assert(parser.has("model"));
|
||||
Net net = readNet(parser.get<String>("model"), parser.get<String>("config"), parser.get<String>("framework"));
|
||||
net.setPreferableBackend(parser.get<int>("backend"));
|
||||
net.setPreferableTarget(parser.get<int>("target"));
|
||||
//! [Read and initialize network]
|
||||
Net net = readNet(model, config, framework);
|
||||
net.setPreferableBackend(backendId);
|
||||
net.setPreferableTarget(targetId);
|
||||
//! [Read and initialize network]
|
||||
|
||||
// Create a window
|
||||
static const std::string kWinName = "Deep learning image classification in OpenCV";
|
||||
namedWindow(kWinName, WINDOW_NORMAL);
|
||||
|
||||
// Open a video file or an image file or a camera stream.
|
||||
//! [Open a video file or an image file or a camera stream]
|
||||
VideoCapture cap;
|
||||
if (parser.has("input"))
|
||||
cap.open(parser.get<String>("input"));
|
||||
else
|
||||
cap.open(0);
|
||||
//! [Open a video file or an image file or a camera stream]
|
||||
|
||||
// Process frames.
|
||||
Mat frame, blob;
|
||||
@ -105,24 +111,29 @@ int main(int argc, char** argv)
|
||||
break;
|
||||
}
|
||||
|
||||
// Create a 4D blob from a frame.
|
||||
//! [Create a 4D blob from a frame]
|
||||
blobFromImage(frame, blob, scale, Size(inpWidth, inpHeight), mean, swapRB, false);
|
||||
//! [Create a 4D blob from a frame]
|
||||
|
||||
// Run a model.
|
||||
//! [Set input blob]
|
||||
net.setInput(blob);
|
||||
Mat out = net.forward();
|
||||
out = out.reshape(1, 1);
|
||||
//! [Set input blob]
|
||||
//! [Make forward pass]
|
||||
Mat prob = net.forward();
|
||||
//! [Make forward pass]
|
||||
|
||||
// Get a class with a highest score.
|
||||
//! [Get a class with a highest score]
|
||||
Point classIdPoint;
|
||||
double confidence;
|
||||
minMaxLoc(out, 0, &confidence, 0, &classIdPoint);
|
||||
minMaxLoc(prob.reshape(1, 1), 0, &confidence, 0, &classIdPoint);
|
||||
int classId = classIdPoint.x;
|
||||
//! [Get a class with a highest score]
|
||||
|
||||
// Put efficiency information.
|
||||
std::vector<double> layersTimes;
|
||||
double t = net.getPerfProfile(layersTimes);
|
||||
std::string label = format("Inference time: %.2f", t * 1000 / getTickFrequency());
|
||||
double freq = getTickFrequency() / 1000;
|
||||
double t = net.getPerfProfile(layersTimes) / freq;
|
||||
std::string label = format("Inference time: %.2f ms", t);
|
||||
putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
|
||||
|
||||
// Print predicted class.
|
||||
@ -135,19 +146,3 @@ int main(int argc, char** argv)
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
Net readNet(const std::string& model, const std::string& config, const std::string& framework)
|
||||
{
|
||||
std::string modelExt = model.substr(model.rfind('.'));
|
||||
if (framework == "caffe" || modelExt == ".caffemodel")
|
||||
return readNetFromCaffe(config, model);
|
||||
else if (framework == "tensorflow" || modelExt == ".pb")
|
||||
return readNetFromTensorflow(model, config);
|
||||
else if (framework == "torch" || modelExt == ".t7" || modelExt == ".net")
|
||||
return readNetFromTorch(model);
|
||||
else if (framework == "darknet" || modelExt == ".weights")
|
||||
return readNetFromDarknet(config, model);
|
||||
else
|
||||
CV_Error(Error::StsError, "Cannot determine an origin framework of model from file " + model);
|
||||
return Net();
|
||||
}
|
||||
|
@ -48,19 +48,7 @@ if args.classes:
|
||||
classes = f.read().rstrip('\n').split('\n')
|
||||
|
||||
# Load a network
|
||||
modelExt = args.model[args.model.rfind('.'):]
|
||||
if args.framework == 'caffe' or modelExt == '.caffemodel':
|
||||
net = cv.dnn.readNetFromCaffe(args.config, args.model)
|
||||
elif args.framework == 'tensorflow' or modelExt == '.pb':
|
||||
net = cv.dnn.readNetFromTensorflow(args.model, args.config)
|
||||
elif args.framework == 'torch' or modelExt in ['.t7', '.net']:
|
||||
net = cv.dnn.readNetFromTorch(args.model)
|
||||
elif args.framework == 'darknet' or modelExt == '.weights':
|
||||
net = cv.dnn.readNetFromDarknet(args.config, args.model)
|
||||
else:
|
||||
print('Cannot determine an origin framework of model from file %s' % args.model)
|
||||
sys.exit(0)
|
||||
|
||||
net = cv.dnn.readNet(args.model, args.config, args.framework)
|
||||
net.setPreferableBackend(args.backend)
|
||||
net.setPreferableTarget(args.target)
|
||||
|
||||
|
@ -2,8 +2,9 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include <opencv2/dnn.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <opencv2/highgui.hpp>
|
||||
|
||||
const char* keys =
|
||||
"{ help h | | Print help message. }"
|
||||
@ -35,8 +36,6 @@ using namespace dnn;
|
||||
float confThreshold;
|
||||
std::vector<std::string> classes;
|
||||
|
||||
Net readNet(const std::string& model, const std::string& config = "", const std::string& framework = "");
|
||||
|
||||
void postprocess(Mat& frame, const Mat& out, Net& net);
|
||||
|
||||
void drawPred(int classId, float conf, int left, int top, int right, int bottom, Mat& frame);
|
||||
@ -95,7 +94,7 @@ int main(int argc, char** argv)
|
||||
// Create a window
|
||||
static const std::string kWinName = "Deep learning object detection in OpenCV";
|
||||
namedWindow(kWinName, WINDOW_NORMAL);
|
||||
int initialConf = confThreshold * 100;
|
||||
int initialConf = (int)(confThreshold * 100);
|
||||
createTrackbar("Confidence threshold, %", kWinName, &initialConf, 99, callback);
|
||||
|
||||
// Open a video file or an image file or a camera stream.
|
||||
@ -135,8 +134,9 @@ int main(int argc, char** argv)
|
||||
|
||||
// Put efficiency information.
|
||||
std::vector<double> layersTimes;
|
||||
double t = net.getPerfProfile(layersTimes);
|
||||
std::string label = format("Inference time: %.2f", t * 1000 / getTickFrequency());
|
||||
double freq = getTickFrequency() / 1000;
|
||||
double t = net.getPerfProfile(layersTimes) / freq;
|
||||
std::string label = format("Inference time: %.2f ms", t);
|
||||
putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
|
||||
|
||||
imshow(kWinName, frame);
|
||||
@ -160,10 +160,10 @@ void postprocess(Mat& frame, const Mat& out, Net& net)
|
||||
float confidence = data[i + 2];
|
||||
if (confidence > confThreshold)
|
||||
{
|
||||
int left = data[i + 3];
|
||||
int top = data[i + 4];
|
||||
int right = data[i + 5];
|
||||
int bottom = data[i + 6];
|
||||
int left = (int)data[i + 3];
|
||||
int top = (int)data[i + 4];
|
||||
int right = (int)data[i + 5];
|
||||
int bottom = (int)data[i + 6];
|
||||
int classId = (int)(data[i + 1]) - 1; // Skip 0th background class id.
|
||||
drawPred(classId, confidence, left, top, right, bottom, frame);
|
||||
}
|
||||
@ -208,7 +208,7 @@ void postprocess(Mat& frame, const Mat& out, Net& net)
|
||||
int height = (int)(data[3] * frame.rows);
|
||||
int left = centerX - width / 2;
|
||||
int top = centerY - height / 2;
|
||||
drawPred(classId, confidence, left, top, left + width, top + height, frame);
|
||||
drawPred(classId, (float)confidence, left, top, left + width, top + height, frame);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -238,21 +238,5 @@ void drawPred(int classId, float conf, int left, int top, int right, int bottom,
|
||||
|
||||
void callback(int pos, void*)
|
||||
{
|
||||
confThreshold = pos * 0.01;
|
||||
}
|
||||
|
||||
Net readNet(const std::string& model, const std::string& config, const std::string& framework)
|
||||
{
|
||||
std::string modelExt = model.substr(model.rfind('.'));
|
||||
if (framework == "caffe" || modelExt == ".caffemodel")
|
||||
return readNetFromCaffe(config, model);
|
||||
else if (framework == "tensorflow" || modelExt == ".pb")
|
||||
return readNetFromTensorflow(model, config);
|
||||
else if (framework == "torch" || modelExt == ".t7" || modelExt == ".net")
|
||||
return readNetFromTorch(model);
|
||||
else if (framework == "darknet" || modelExt == ".weights")
|
||||
return readNetFromDarknet(config, model);
|
||||
else
|
||||
CV_Error(Error::StsError, "Cannot determine an origin framework of model from file " + model);
|
||||
return Net();
|
||||
confThreshold = pos * 0.01f;
|
||||
}
|
||||
|
@ -49,19 +49,7 @@ if args.classes:
|
||||
classes = f.read().rstrip('\n').split('\n')
|
||||
|
||||
# Load a network
|
||||
modelExt = args.model[args.model.rfind('.'):]
|
||||
if args.framework == 'caffe' or modelExt == '.caffemodel':
|
||||
net = cv.dnn.readNetFromCaffe(args.config, args.model)
|
||||
elif args.framework == 'tensorflow' or modelExt == '.pb':
|
||||
net = cv.dnn.readNetFromTensorflow(args.model, args.config)
|
||||
elif args.framework == 'torch' or modelExt in ['.t7', '.net']:
|
||||
net = cv.dnn.readNetFromTorch(args.model)
|
||||
elif args.framework == 'darknet' or modelExt == '.weights':
|
||||
net = cv.dnn.readNetFromDarknet(args.config, args.model)
|
||||
else:
|
||||
print('Cannot determine an origin framework of model from file %s' % args.model)
|
||||
sys.exit(0)
|
||||
|
||||
net = cv.dnn.readNet(args.model, args.config, args.framework)
|
||||
net.setPreferableBackend(args.backend)
|
||||
net.setPreferableTarget(args.target)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user