mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 09:25:45 +08:00
Add high level API (Merge pull request #14780)
* Add high level API * Fix Model * Add DetectionModel * Add ClassificationModel * Fix classify * Add python test * Fix pytest * Fix comments to review * Fix detect * Fix docs * Modify DetectionOutput postprocessing * Fix test * Extract ref boxes * Fix draw rect * fix test * Add rect wrap * Fix wrap * Fix detect * Fix Rect wrap * Fix OCL_FP16 * Fix MyriadX * Fix nms * Fix NMS * Fix coords
This commit is contained in:
parent
f482050f9a
commit
778f42ad34
@ -992,6 +992,155 @@ CV__DNN_INLINE_NS_BEGIN
|
||||
CV_OUT std::vector<int>& indices,
|
||||
const float eta = 1.f, const int top_k = 0);
|
||||
|
||||
|
||||
/** @brief This class is presented high-level API for neural networks.
|
||||
*
|
||||
* Model allows to set params for preprocessing input image.
|
||||
* Model creates net from file with trained weights and config,
|
||||
* sets preprocessing input and runs forward pass.
|
||||
*/
|
||||
class CV_EXPORTS_W Model : public Net
|
||||
{
|
||||
public:
|
||||
/**
|
||||
* @brief Create model from deep learning network represented in one of the supported formats.
|
||||
* An order of @p model and @p config arguments does not matter.
|
||||
* @param[in] model Binary file contains trained weights.
|
||||
* @param[in] config Text file contains network configuration.
|
||||
*/
|
||||
CV_WRAP Model(const String& model, const String& config = "");
|
||||
|
||||
/**
|
||||
* @brief Create model from deep learning network.
|
||||
* @param[in] network Net object.
|
||||
*/
|
||||
CV_WRAP Model(const Net& network);
|
||||
|
||||
/** @brief Set input size for frame.
|
||||
* @param[in] size New input size.
|
||||
* @note If shape of the new blob less than 0, then frame size not change.
|
||||
*/
|
||||
Model& setInputSize(const Size& size);
|
||||
|
||||
/** @brief Set input size for frame.
|
||||
* @param[in] width New input width.
|
||||
* @param[in] height New input height.
|
||||
* @note If shape of the new blob less than 0,
|
||||
* then frame size not change.
|
||||
*/
|
||||
Model& setInputSize(int width, int height);
|
||||
|
||||
/** @brief Set mean value for frame.
|
||||
* @param[in] mean Scalar with mean values which are subtracted from channels.
|
||||
*/
|
||||
Model& setInputMean(const Scalar& mean);
|
||||
|
||||
/** @brief Set scalefactor value for frame.
|
||||
* @param[in] scale Multiplier for frame values.
|
||||
*/
|
||||
Model& setInputScale(double scale);
|
||||
|
||||
/** @brief Set flag crop for frame.
|
||||
* @param[in] crop Flag which indicates whether image will be cropped after resize or not.
|
||||
*/
|
||||
Model& setInputCrop(bool crop);
|
||||
|
||||
/** @brief Set flag swapRB for frame.
|
||||
* @param[in] swapRB Flag which indicates that swap first and last channels.
|
||||
*/
|
||||
Model& setInputSwapRB(bool swapRB);
|
||||
|
||||
/** @brief Set preprocessing parameters for frame.
|
||||
* @param[in] size New input size.
|
||||
* @param[in] mean Scalar with mean values which are subtracted from channels.
|
||||
* @param[in] scale Multiplier for frame values.
|
||||
* @param[in] swapRB Flag which indicates that swap first and last channels.
|
||||
* @param[in] crop Flag which indicates whether image will be cropped after resize or not.
|
||||
* blob(n, c, y, x) = scale * resize( frame(y, x, c) ) - mean(c) )
|
||||
*/
|
||||
CV_WRAP void setInputParams(double scale = 1.0, const Size& size = Size(),
|
||||
const Scalar& mean = Scalar(), bool swapRB = false, bool crop = false);
|
||||
|
||||
/** @brief Given the @p input frame, create input blob, run net and return the output @p blobs.
|
||||
* @param[in] frame The input image.
|
||||
* @param[out] outs Allocated output blobs, which will store results of the computation.
|
||||
*/
|
||||
CV_WRAP void predict(InputArray frame, OutputArrayOfArrays outs);
|
||||
|
||||
protected:
|
||||
struct Impl;
|
||||
Ptr<Impl> impl;
|
||||
};
|
||||
|
||||
/** @brief This class represents high-level API for classification models.
|
||||
*
|
||||
* ClassificationModel allows to set params for preprocessing input image.
|
||||
* ClassificationModel creates net from file with trained weights and config,
|
||||
* sets preprocessing input, runs forward pass and return top-1 prediction.
|
||||
*/
|
||||
class CV_EXPORTS_W ClassificationModel : public Model
|
||||
{
|
||||
public:
|
||||
/**
|
||||
* @brief Create classification model from network represented in one of the supported formats.
|
||||
* An order of @p model and @p config arguments does not matter.
|
||||
* @param[in] model Binary file contains trained weights.
|
||||
* @param[in] config Text file contains network configuration.
|
||||
*/
|
||||
CV_WRAP ClassificationModel(const String& model, const String& config = "");
|
||||
|
||||
/**
|
||||
* @brief Create model from deep learning network.
|
||||
* @param[in] network Net object.
|
||||
*/
|
||||
CV_WRAP ClassificationModel(const Net& network);
|
||||
|
||||
/** @brief Given the @p input frame, create input blob, run net and return top-1 prediction.
|
||||
* @param[in] frame The input image.
|
||||
*/
|
||||
std::pair<int, float> classify(InputArray frame);
|
||||
|
||||
/** @overload */
|
||||
CV_WRAP void classify(InputArray frame, CV_OUT int& classId, CV_OUT float& conf);
|
||||
};
|
||||
|
||||
/** @brief This class represents high-level API for object detection networks.
|
||||
*
|
||||
* DetectionModel allows to set params for preprocessing input image.
|
||||
* DetectionModel creates net from file with trained weights and config,
|
||||
* sets preprocessing input, runs forward pass and return result detections.
|
||||
* For DetectionModel SSD, Faster R-CNN, YOLO topologies are supported.
|
||||
*/
|
||||
class CV_EXPORTS_W DetectionModel : public Model
|
||||
{
|
||||
public:
|
||||
/**
|
||||
* @brief Create detection model from network represented in one of the supported formats.
|
||||
* An order of @p model and @p config arguments does not matter.
|
||||
* @param[in] model Binary file contains trained weights.
|
||||
* @param[in] config Text file contains network configuration.
|
||||
*/
|
||||
CV_WRAP DetectionModel(const String& model, const String& config = "");
|
||||
|
||||
/**
|
||||
* @brief Create model from deep learning network.
|
||||
* @param[in] network Net object.
|
||||
*/
|
||||
CV_WRAP DetectionModel(const Net& network);
|
||||
|
||||
/** @brief Given the @p input frame, create input blob, run net and return result detections.
|
||||
* @param[in] frame The input image.
|
||||
* @param[out] classIds Class indexes in result detection.
|
||||
* @param[out] confidences A set of corresponding confidences.
|
||||
* @param[out] boxes A set of bounding boxes.
|
||||
* @param[in] confThreshold A threshold used to filter boxes by confidences.
|
||||
* @param[in] nmsThreshold A threshold used in non maximum suppression.
|
||||
*/
|
||||
CV_WRAP void detect(InputArray frame, CV_OUT std::vector<int>& classIds,
|
||||
CV_OUT std::vector<float>& confidences, CV_OUT std::vector<Rect>& boxes,
|
||||
float confThreshold = 0.5f, float nmsThreshold = 0.0f);
|
||||
};
|
||||
|
||||
//! @}
|
||||
CV__DNN_INLINE_NS_END
|
||||
}
|
||||
|
@ -21,15 +21,11 @@ def box2str(box):
|
||||
width, height = box[2] - left, box[3] - top
|
||||
return '[%f x %f from (%f, %f)]' % (width, height, left, top)
|
||||
|
||||
def normAssertDetections(test, ref, out, confThreshold=0.0, scores_diff=1e-5, boxes_iou_diff=1e-4):
|
||||
ref = np.array(ref, np.float32)
|
||||
refClassIds, testClassIds = ref[:, 1], out[:, 1]
|
||||
refScores, testScores = ref[:, 2], out[:, 2]
|
||||
refBoxes, testBoxes = ref[:, 3:], out[:, 3:]
|
||||
|
||||
def normAssertDetections(test, refClassIds, refScores, refBoxes, testClassIds, testScores, testBoxes,
|
||||
confThreshold=0.0, scores_diff=1e-5, boxes_iou_diff=1e-4):
|
||||
matchedRefBoxes = [False] * len(refBoxes)
|
||||
errMsg = ''
|
||||
for i in range(len(refBoxes)):
|
||||
for i in range(len(testBoxes)):
|
||||
testScore = testScores[i]
|
||||
if testScore < confThreshold:
|
||||
continue
|
||||
@ -136,6 +132,38 @@ class dnn_test(NewOpenCVTests):
|
||||
normAssert(self, blob, target)
|
||||
|
||||
|
||||
def test_model(self):
|
||||
img_path = self.find_dnn_file("dnn/street.png")
|
||||
weights = self.find_dnn_file("dnn/MobileNetSSD_deploy.caffemodel")
|
||||
config = self.find_dnn_file("dnn/MobileNetSSD_deploy.prototxt")
|
||||
frame = cv.imread(img_path)
|
||||
model = cv.dnn_DetectionModel(weights, config)
|
||||
size = (300, 300)
|
||||
mean = (127.5, 127.5, 127.5)
|
||||
scale = 1.0 / 127.5
|
||||
model.setInputParams(size=size, mean=mean, scale=scale)
|
||||
|
||||
iouDiff = 0.05
|
||||
confThreshold = 0.0001
|
||||
nmsThreshold = 0
|
||||
scoreDiff = 1e-3
|
||||
|
||||
classIds, confidences, boxes = model.detect(frame, confThreshold, nmsThreshold)
|
||||
|
||||
refClassIds = (7, 15)
|
||||
refConfidences = (0.9998, 0.8793)
|
||||
refBoxes = ((328, 238, 85, 102), (101, 188, 34, 138))
|
||||
|
||||
normAssertDetections(self, refClassIds, refConfidences, refBoxes,
|
||||
classIds, confidences, boxes,confThreshold, scoreDiff, iouDiff)
|
||||
|
||||
for box in boxes:
|
||||
cv.rectangle(frame, box, (0, 255, 0))
|
||||
cv.rectangle(frame, np.array(box), (0, 255, 0))
|
||||
cv.rectangle(frame, tuple(box), (0, 255, 0))
|
||||
cv.rectangle(frame, list(box), (0, 255, 0))
|
||||
|
||||
|
||||
def test_face_detection(self):
|
||||
testdata_required = bool(os.environ.get('OPENCV_DNN_TEST_REQUIRE_TESTDATA', False))
|
||||
proto = self.find_dnn_file('dnn/opencv_face_detector.prototxt', required=testdata_required)
|
||||
@ -166,7 +194,13 @@ class dnn_test(NewOpenCVTests):
|
||||
scoresDiff = 4e-3 if target in [cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD] else 1e-5
|
||||
iouDiff = 2e-2 if target in [cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD] else 1e-4
|
||||
|
||||
normAssertDetections(self, ref, out, 0.5, scoresDiff, iouDiff)
|
||||
ref = np.array(ref, np.float32)
|
||||
refClassIds, testClassIds = ref[:, 1], out[:, 1]
|
||||
refScores, testScores = ref[:, 2], out[:, 2]
|
||||
refBoxes, testBoxes = ref[:, 3:], out[:, 3:]
|
||||
|
||||
normAssertDetections(self, refClassIds, refScores, refBoxes, testClassIds,
|
||||
testScores, testBoxes, 0.5, scoresDiff, iouDiff)
|
||||
|
||||
def test_async(self):
|
||||
timeout = 500*10**6 # in nanoseconds (500ms)
|
||||
|
268
modules/dnn/src/model.cpp
Normal file
268
modules/dnn/src/model.cpp
Normal file
@ -0,0 +1,268 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "precomp.hpp"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
#include <iterator>
|
||||
|
||||
#include <opencv2/imgproc.hpp>
|
||||
|
||||
namespace cv {
|
||||
namespace dnn {
|
||||
|
||||
struct Model::Impl
|
||||
{
|
||||
Size size;
|
||||
Scalar mean;
|
||||
double scale = 1.0;
|
||||
bool swapRB = false;
|
||||
bool crop = false;
|
||||
Mat blob;
|
||||
std::vector<String> outNames;
|
||||
|
||||
void predict(Net& net, const Mat& frame, std::vector<Mat>& outs)
|
||||
{
|
||||
if (size.empty())
|
||||
CV_Error(Error::StsBadSize, "Input size not specified");
|
||||
|
||||
blob = blobFromImage(frame, scale, size, mean, swapRB, crop);
|
||||
net.setInput(blob);
|
||||
|
||||
// Faster-RCNN or R-FCN
|
||||
if (net.getLayer(0)->outputNameToIndex("im_info") != -1)
|
||||
{
|
||||
Mat imInfo = (Mat_<float>(1, 3) << size.height, size.width, 1.6f);
|
||||
net.setInput(imInfo, "im_info");
|
||||
}
|
||||
net.forward(outs, outNames);
|
||||
}
|
||||
};
|
||||
|
||||
Model::Model(const String& model, const String& config)
|
||||
: Net(readNet(model, config)), impl(new Impl)
|
||||
{
|
||||
impl->outNames = getUnconnectedOutLayersNames();
|
||||
};
|
||||
|
||||
Model::Model(const Net& network) : Net(network), impl(new Impl)
|
||||
{
|
||||
impl->outNames = getUnconnectedOutLayersNames();
|
||||
};
|
||||
|
||||
Model& Model::setInputSize(const Size& size)
|
||||
{
|
||||
impl->size = size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Model& Model::setInputSize(int width, int height)
|
||||
{
|
||||
impl->size = Size(width, height);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Model& Model::setInputMean(const Scalar& mean)
|
||||
{
|
||||
impl->mean = mean;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Model& Model::setInputScale(double scale)
|
||||
{
|
||||
impl->scale = scale;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Model& Model::setInputCrop(bool crop)
|
||||
{
|
||||
impl->crop = crop;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Model& Model::setInputSwapRB(bool swapRB)
|
||||
{
|
||||
impl->swapRB = swapRB;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Model::setInputParams(double scale, const Size& size, const Scalar& mean,
|
||||
bool swapRB, bool crop)
|
||||
{
|
||||
impl->size = size;
|
||||
impl->mean = mean;
|
||||
impl->scale = scale;
|
||||
impl->crop = crop;
|
||||
impl->swapRB = swapRB;
|
||||
}
|
||||
|
||||
void Model::predict(InputArray frame, OutputArrayOfArrays outs)
|
||||
{
|
||||
std::vector<Mat> outputs;
|
||||
outs.getMatVector(outputs);
|
||||
impl->predict(*this, frame.getMat(), outputs);
|
||||
}
|
||||
|
||||
ClassificationModel::ClassificationModel(const String& model, const String& config)
|
||||
: Model(model, config) {};
|
||||
|
||||
ClassificationModel::ClassificationModel(const Net& network) : Model(network) {};
|
||||
|
||||
std::pair<int, float> ClassificationModel::classify(InputArray frame)
|
||||
{
|
||||
std::vector<Mat> outs;
|
||||
impl->predict(*this, frame.getMat(), outs);
|
||||
CV_Assert(outs.size() == 1);
|
||||
|
||||
double conf;
|
||||
cv::Point maxLoc;
|
||||
minMaxLoc(outs[0].reshape(1, 1), nullptr, &conf, nullptr, &maxLoc);
|
||||
return {maxLoc.x, static_cast<float>(conf)};
|
||||
}
|
||||
|
||||
void ClassificationModel::classify(InputArray frame, int& classId, float& conf)
|
||||
{
|
||||
std::tie(classId, conf) = classify(frame);
|
||||
}
|
||||
|
||||
DetectionModel::DetectionModel(const String& model, const String& config)
|
||||
: Model(model, config) {};
|
||||
|
||||
DetectionModel::DetectionModel(const Net& network) : Model(network) {};
|
||||
|
||||
void DetectionModel::detect(InputArray frame, CV_OUT std::vector<int>& classIds,
|
||||
CV_OUT std::vector<float>& confidences, CV_OUT std::vector<Rect>& boxes,
|
||||
float confThreshold, float nmsThreshold)
|
||||
{
|
||||
std::vector<Mat> detections;
|
||||
impl->predict(*this, frame.getMat(), detections);
|
||||
|
||||
boxes.clear();
|
||||
confidences.clear();
|
||||
classIds.clear();
|
||||
|
||||
int frameWidth = frame.cols();
|
||||
int frameHeight = frame.rows();
|
||||
if (getLayer(0)->outputNameToIndex("im_info") != -1)
|
||||
{
|
||||
frameWidth = impl->size.width;
|
||||
frameHeight = impl->size.height;
|
||||
}
|
||||
|
||||
std::vector<String> layerNames = getLayerNames();
|
||||
int lastLayerId = getLayerId(layerNames.back());
|
||||
Ptr<Layer> lastLayer = getLayer(lastLayerId);
|
||||
|
||||
std::vector<int> predClassIds;
|
||||
std::vector<Rect> predBoxes;
|
||||
std::vector<float> predConf;
|
||||
if (lastLayer->type == "DetectionOutput")
|
||||
{
|
||||
// Network produces output blob with a shape 1x1xNx7 where N is a number of
|
||||
// detections and an every detection is a vector of values
|
||||
// [batchId, classId, confidence, left, top, right, bottom]
|
||||
for (int i = 0; i < detections.size(); ++i)
|
||||
{
|
||||
float* data = (float*)detections[i].data;
|
||||
for (int j = 0; j < detections[i].total(); j += 7)
|
||||
{
|
||||
float conf = data[j + 2];
|
||||
if (conf < confThreshold)
|
||||
continue;
|
||||
|
||||
int left = data[j + 3];
|
||||
int top = data[j + 4];
|
||||
int right = data[j + 5];
|
||||
int bottom = data[j + 6];
|
||||
int width = right - left + 1;
|
||||
int height = bottom - top + 1;
|
||||
|
||||
if (width * height <= 1)
|
||||
{
|
||||
left = data[j + 3] * frameWidth;
|
||||
top = data[j + 4] * frameHeight;
|
||||
right = data[j + 5] * frameWidth;
|
||||
bottom = data[j + 6] * frameHeight;
|
||||
width = right - left + 1;
|
||||
height = bottom - top + 1;
|
||||
}
|
||||
|
||||
left = std::max(0, std::min(left, frameWidth - 1));
|
||||
top = std::max(0, std::min(top, frameHeight - 1));
|
||||
width = std::max(1, std::min(width, frameWidth - left));
|
||||
height = std::max(1, std::min(height, frameHeight - top));
|
||||
predBoxes.emplace_back(left, top, width, height);
|
||||
|
||||
predClassIds.push_back(static_cast<int>(data[j + 1]));
|
||||
predConf.push_back(conf);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (lastLayer->type == "Region")
|
||||
{
|
||||
for (int i = 0; i < detections.size(); ++i)
|
||||
{
|
||||
// Network produces output blob with a shape NxC where N is a number of
|
||||
// detected objects and C is a number of classes + 4 where the first 4
|
||||
// numbers are [center_x, center_y, width, height]
|
||||
float* data = (float*)detections[i].data;
|
||||
for (int j = 0; j < detections[i].rows; ++j, data += detections[i].cols)
|
||||
{
|
||||
|
||||
Mat scores = detections[i].row(j).colRange(5, detections[i].cols);
|
||||
Point classIdPoint;
|
||||
double conf;
|
||||
minMaxLoc(scores, nullptr, &conf, nullptr, &classIdPoint);
|
||||
|
||||
if (static_cast<float>(conf) < confThreshold)
|
||||
continue;
|
||||
|
||||
int centerX = data[0] * frameWidth;
|
||||
int centerY = data[1] * frameHeight;
|
||||
int width = data[2] * frameWidth;
|
||||
int height = data[3] * frameHeight;
|
||||
|
||||
int left = std::max(0, std::min(centerX - width / 2, frameWidth - 1));
|
||||
int top = std::max(0, std::min(centerY - height / 2, frameHeight - 1));
|
||||
width = std::max(1, std::min(width, frameWidth - left));
|
||||
height = std::max(1, std::min(height, frameHeight - top));
|
||||
|
||||
predClassIds.push_back(classIdPoint.x);
|
||||
predConf.push_back(static_cast<float>(conf));
|
||||
predBoxes.emplace_back(left, top, width, height);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
CV_Error(Error::StsNotImplemented, "Unknown output layer type: \"" + lastLayer->type + "\"");
|
||||
|
||||
if (nmsThreshold)
|
||||
{
|
||||
std::vector<int> indices;
|
||||
NMSBoxes(predBoxes, predConf, confThreshold, nmsThreshold, indices);
|
||||
|
||||
boxes.reserve(indices.size());
|
||||
confidences.reserve(indices.size());
|
||||
classIds.reserve(indices.size());
|
||||
|
||||
for (int idx : indices)
|
||||
{
|
||||
boxes.push_back(predBoxes[idx]);
|
||||
confidences.push_back(predConf[idx]);
|
||||
classIds.push_back(predClassIds[idx]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
boxes = std::move(predBoxes);
|
||||
classIds = std::move(predClassIds);
|
||||
confidences = std::move(predConf);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
}} // namespace
|
207
modules/dnn/test/test_model.cpp
Normal file
207
modules/dnn/test/test_model.cpp
Normal file
@ -0,0 +1,207 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "test_precomp.hpp"
|
||||
#include <opencv2/dnn/shape_utils.hpp>
|
||||
#include "npy_blob.hpp"
|
||||
namespace opencv_test { namespace {
|
||||
|
||||
template<typename TString>
|
||||
static std::string _tf(TString filename)
|
||||
{
|
||||
String rootFolder = "dnn/";
|
||||
return findDataFile(rootFolder + filename);
|
||||
}
|
||||
|
||||
|
||||
class Test_Model : public DNNTestLayer
|
||||
{
|
||||
public:
|
||||
void testDetectModel(const std::string& weights, const std::string& cfg,
|
||||
const std::string& imgPath, const std::vector<int>& refClassIds,
|
||||
const std::vector<float>& refConfidences,
|
||||
const std::vector<Rect2d>& refBoxes,
|
||||
double scoreDiff, double iouDiff,
|
||||
double confThreshold = 0.24, double nmsThreshold = 0.0,
|
||||
const Size& size = {-1, -1}, Scalar mean = Scalar(),
|
||||
double scale = 1.0, bool swapRB = false, bool crop = false)
|
||||
{
|
||||
checkBackend();
|
||||
|
||||
Mat frame = imread(imgPath);
|
||||
DetectionModel model(weights, cfg);
|
||||
|
||||
model.setInputSize(size).setInputMean(mean).setInputScale(scale)
|
||||
.setInputSwapRB(swapRB).setInputCrop(crop);
|
||||
|
||||
model.setPreferableBackend(backend);
|
||||
model.setPreferableTarget(target);
|
||||
|
||||
std::vector<int> classIds;
|
||||
std::vector<float> confidences;
|
||||
std::vector<Rect> boxes;
|
||||
|
||||
model.detect(frame, classIds, confidences, boxes, confThreshold, nmsThreshold);
|
||||
|
||||
std::vector<Rect2d> boxesDouble(boxes.size());
|
||||
for (int i = 0; i < boxes.size(); i++) {
|
||||
boxesDouble[i] = boxes[i];
|
||||
}
|
||||
normAssertDetections(refClassIds, refConfidences, refBoxes, classIds,
|
||||
confidences, boxesDouble, "",
|
||||
confThreshold, scoreDiff, iouDiff);
|
||||
}
|
||||
|
||||
void testClassifyModel(const std::string& weights, const std::string& cfg,
|
||||
const std::string& imgPath, std::pair<int, float> ref, float norm,
|
||||
const Size& size = {-1, -1}, Scalar mean = Scalar(),
|
||||
double scale = 1.0, bool swapRB = false, bool crop = false)
|
||||
{
|
||||
checkBackend();
|
||||
|
||||
Mat frame = imread(imgPath);
|
||||
ClassificationModel model(weights, cfg);
|
||||
model.setInputSize(size).setInputMean(mean).setInputScale(scale)
|
||||
.setInputSwapRB(swapRB).setInputCrop(crop);
|
||||
|
||||
std::pair<int, float> prediction = model.classify(frame);
|
||||
EXPECT_EQ(prediction.first, ref.first);
|
||||
ASSERT_NEAR(prediction.second, ref.second, norm);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(Test_Model, Classify)
|
||||
{
|
||||
std::pair<int, float> ref(652, 0.641789);
|
||||
|
||||
std::string img_path = _tf("grace_hopper_227.png");
|
||||
std::string config_file = _tf("bvlc_alexnet.prototxt");
|
||||
std::string weights_file = _tf("bvlc_alexnet.caffemodel");
|
||||
|
||||
Size size{227, 227};
|
||||
float norm = 1e-4;
|
||||
|
||||
testClassifyModel(weights_file, config_file, img_path, ref, norm, size);
|
||||
}
|
||||
|
||||
|
||||
TEST_P(Test_Model, DetectRegion)
|
||||
{
|
||||
applyTestTag(CV_TEST_TAG_LONG, CV_TEST_TAG_MEMORY_1GB);
|
||||
|
||||
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_GE(2019010000)
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL_FP16);
|
||||
#endif
|
||||
|
||||
#if defined(INF_ENGINE_RELEASE)
|
||||
if (target == DNN_TARGET_MYRIAD
|
||||
&& getInferenceEngineVPUType() == CV_DNN_INFERENCE_ENGINE_VPU_TYPE_MYRIAD_X)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD_X);
|
||||
#endif
|
||||
|
||||
std::vector<int> refClassIds = {6, 1, 11};
|
||||
std::vector<float> refConfidences = {0.750469f, 0.780879f, 0.901615f};
|
||||
std::vector<Rect2d> refBoxes = {Rect2d(240, 53, 135, 72),
|
||||
Rect2d(112, 109, 192, 200),
|
||||
Rect2d(58, 141, 117, 249)};
|
||||
|
||||
std::string img_path = _tf("dog416.png");
|
||||
std::string weights_file = _tf("yolo-voc.weights");
|
||||
std::string config_file = _tf("yolo-voc.cfg");
|
||||
|
||||
double scale = 1.0 / 255.0;
|
||||
Size size{416, 416};
|
||||
bool swapRB = true;
|
||||
|
||||
double confThreshold = 0.24;
|
||||
double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 1e-2 : 8e-5;
|
||||
double iouDiff = (target == DNN_TARGET_MYRIAD || target == DNN_TARGET_OPENCL_FP16) ? 1.6e-2 : 1e-5;
|
||||
double nmsThreshold = (target == DNN_TARGET_MYRIAD) ? 0.397 : 0.4;
|
||||
|
||||
testDetectModel(weights_file, config_file, img_path, refClassIds, refConfidences,
|
||||
refBoxes, scoreDiff, iouDiff, confThreshold, nmsThreshold, size,
|
||||
Scalar(), scale, swapRB);
|
||||
}
|
||||
|
||||
TEST_P(Test_Model, DetectionOutput)
|
||||
{
|
||||
#if defined(INF_ENGINE_RELEASE)
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL_FP16);
|
||||
|
||||
if (target == DNN_TARGET_MYRIAD)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD);
|
||||
#endif
|
||||
|
||||
std::vector<int> refClassIds = {7, 12};
|
||||
std::vector<float> refConfidences = {0.991359f, 0.94786f};
|
||||
std::vector<Rect2d> refBoxes = {Rect2d(491, 81, 212, 98),
|
||||
Rect2d(132, 223, 207, 344)};
|
||||
|
||||
std::string img_path = _tf("dog416.png");
|
||||
std::string weights_file = _tf("resnet50_rfcn_final.caffemodel");
|
||||
std::string config_file = _tf("rfcn_pascal_voc_resnet50.prototxt");
|
||||
|
||||
Scalar mean = Scalar(102.9801, 115.9465, 122.7717);
|
||||
Size size{800, 600};
|
||||
|
||||
double scoreDiff = (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16) ?
|
||||
4e-3 : default_l1;
|
||||
double iouDiff = (target == DNN_TARGET_OPENCL_FP16) ? 1.8e-1 : 1e-5;
|
||||
float confThreshold = 0.8;
|
||||
double nmsThreshold = 0.0;
|
||||
|
||||
testDetectModel(weights_file, config_file, img_path, refClassIds, refConfidences, refBoxes,
|
||||
scoreDiff, iouDiff, confThreshold, nmsThreshold, size, mean);
|
||||
}
|
||||
|
||||
|
||||
TEST_P(Test_Model, DetectionMobilenetSSD)
|
||||
{
|
||||
Mat ref = blobFromNPY(_tf("mobilenet_ssd_caffe_out.npy"));
|
||||
ref = ref.reshape(1, ref.size[2]);
|
||||
|
||||
std::string img_path = _tf("street.png");
|
||||
Mat frame = imread(img_path);
|
||||
int frameWidth = frame.cols;
|
||||
int frameHeight = frame.rows;
|
||||
|
||||
std::vector<int> refClassIds;
|
||||
std::vector<float> refConfidences;
|
||||
std::vector<Rect2d> refBoxes;
|
||||
for (int i = 0; i < ref.rows; i++)
|
||||
{
|
||||
refClassIds.emplace_back(ref.at<float>(i, 1));
|
||||
refConfidences.emplace_back(ref.at<float>(i, 2));
|
||||
int left = ref.at<float>(i, 3) * frameWidth;
|
||||
int top = ref.at<float>(i, 4) * frameHeight;
|
||||
int right = ref.at<float>(i, 5) * frameWidth;
|
||||
int bottom = ref.at<float>(i, 6) * frameHeight;
|
||||
int width = right - left + 1;
|
||||
int height = bottom - top + 1;
|
||||
refBoxes.emplace_back(left, top, width, height);
|
||||
}
|
||||
|
||||
std::string weights_file = _tf("MobileNetSSD_deploy.caffemodel");
|
||||
std::string config_file = _tf("MobileNetSSD_deploy.prototxt");
|
||||
|
||||
Scalar mean = Scalar(127.5, 127.5, 127.5);
|
||||
double scale = 1.0 / 127.5;
|
||||
Size size{300, 300};
|
||||
|
||||
double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 1.7e-2 : 1e-5;
|
||||
double iouDiff = (target == DNN_TARGET_OPENCL_FP16 || (target == DNN_TARGET_MYRIAD &&
|
||||
getInferenceEngineVPUType() == CV_DNN_INFERENCE_ENGINE_VPU_TYPE_MYRIAD_X)) ? 6.91e-2 : 1e-5;
|
||||
|
||||
float confThreshold = FLT_MIN;
|
||||
double nmsThreshold = 0.0;
|
||||
|
||||
testDetectModel(weights_file, config_file, img_path, refClassIds, refConfidences, refBoxes,
|
||||
scoreDiff, iouDiff, confThreshold, nmsThreshold, size, mean, scale);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Test_Model, dnnBackendsAndTargets());
|
||||
|
||||
}} // namespace
|
@ -752,15 +752,6 @@ PyObject* pyopencv_from(const Size_<float>& sz)
|
||||
return Py_BuildValue("(ff)", sz.width, sz.height);
|
||||
}
|
||||
|
||||
template<>
|
||||
bool pyopencv_to(PyObject* obj, Rect& r, const char* name)
|
||||
{
|
||||
CV_UNUSED(name);
|
||||
if(!obj || obj == Py_None)
|
||||
return true;
|
||||
return PyArg_ParseTuple(obj, "iiii", &r.x, &r.y, &r.width, &r.height) > 0;
|
||||
}
|
||||
|
||||
template<>
|
||||
PyObject* pyopencv_from(const Rect& r)
|
||||
{
|
||||
@ -1366,6 +1357,25 @@ template<> struct pyopencvVecConverter<RotatedRect>
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
bool pyopencv_to(PyObject* obj, Rect& r, const char* name)
|
||||
{
|
||||
CV_UNUSED(name);
|
||||
if(!obj || obj == Py_None)
|
||||
return true;
|
||||
|
||||
if (PyTuple_Check(obj))
|
||||
return PyArg_ParseTuple(obj, "iiii", &r.x, &r.y, &r.width, &r.height) > 0;
|
||||
else
|
||||
{
|
||||
std::vector<int> value(4);
|
||||
pyopencvVecConverter<int>::to(obj, value, ArgInfo(name, 0));
|
||||
r = Rect(value[0], value[1], value[2], value[3]);
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<>
|
||||
bool pyopencv_to(PyObject *obj, TermCriteria& dst, const char *name)
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user