modify the detect and decode part for WeChat QRCodeDetector

This commit is contained in:
cswccc 2024-11-26 20:34:55 +08:00
parent fee096f4e8
commit f5ecad6dd4
7 changed files with 464 additions and 459 deletions

View File

@ -17,12 +17,3 @@ if(HAVE_QUIRC)
ocv_include_directories(${QUIRC_INCLUDE}) ocv_include_directories(${QUIRC_INCLUDE})
ocv_target_link_libraries(${the_module} quirc) ocv_target_link_libraries(${the_module} quirc)
endif() endif()
if(CMAKE_VERSION VERSION_GREATER "3.11")
find_package(Iconv QUIET)
if(Iconv_FOUND)
ocv_target_link_libraries(${the_module} Iconv::Iconv)
else()
ocv_target_compile_definitions(${the_module} PRIVATE "NO_ICONV=1")
endif()
endif()

View File

@ -868,15 +868,15 @@ public:
/** @brief Initialize the QRCodeDetectorWeChat. /** @brief Initialize the QRCodeDetectorWeChat.
* *
* Parameters allow to load _optional_ Detection and Super Resolution DNN model for better quality. * Parameters allow to load _optional_ Detection and Super Resolution DNN model for better quality.
* @param detection_model_path_ model file path for the detection model * @param detection_model_path model file path for the detection model
* @param super_resolution_model_path_ model file path for the super resolution model * @param super_resolution_model_path model file path for the super resolution model
* @param graphical_detector detector to be optimized * @param graphical_detector detector to be optimized
* @param detector_iou_thres nms iou threshold for detection part * @param detector_iou_thres nms iou threshold for detection part
* @param score_thres score threshold for detection part * @param score_thres score threshold for detection part
* @param reference_size the length of the image to align during pre-processing before detection * @param reference_size the length of the image to align during pre-processing before detection
*/ */
CV_WRAP QRCodeDetectorWeChat(const std::string& detection_model_path_ = "", CV_WRAP QRCodeDetectorWeChat(const std::string& detection_model_path = "",
const std::string& super_resolution_model_path_ = "", const std::string& super_resolution_model_path = "",
Ptr<GraphicalCodeDetector> graphical_detector = cv::makePtr<QRCodeDetectorAruco>(), Ptr<GraphicalCodeDetector> graphical_detector = cv::makePtr<QRCodeDetectorAruco>(),
const float detector_iou_thres = 0.6, const float detector_iou_thres = 0.6,
const float score_thres = 0.3, const float score_thres = 0.3,

View File

@ -1019,7 +1019,7 @@ class QRDecode
{ {
public: public:
QRDecode(bool useAlignmentMarkers); QRDecode(bool useAlignmentMarkers);
void init(const Mat &src, const vector<Point2f> &points); void init(const Mat &src, const vector<Point2f> &points, float sr_scale_=1.f);
Mat getIntermediateBarcode() { return intermediate; } Mat getIntermediateBarcode() { return intermediate; }
Mat getStraightBarcode() { return straight; } Mat getStraightBarcode() { return straight; }
size_t getVersion() { return version; } size_t getVersion() { return version; }
@ -1028,6 +1028,7 @@ public:
bool curvedDecodingProcess(); bool curvedDecodingProcess();
vector<Point2f> alignment_coords; vector<Point2f> alignment_coords;
float coeff_expansion = 1.f; float coeff_expansion = 1.f;
float sr_scale = 1.f;
vector<Point2f> getOriginalPoints() {return original_points;} vector<Point2f> getOriginalPoints() {return original_points;}
bool useAlignmentMarkers; bool useAlignmentMarkers;
@ -1137,7 +1138,7 @@ float static getMinSideLen(const vector<Point2f> &points) {
} }
void QRDecode::init(const Mat &src, const vector<Point2f> &points) void QRDecode::init(const Mat &src, const vector<Point2f> &points, float sr_scale_)
{ {
CV_TRACE_FUNCTION(); CV_TRACE_FUNCTION();
vector<Point2f> bbox = points; vector<Point2f> bbox = points;
@ -1150,6 +1151,7 @@ void QRDecode::init(const Mat &src, const vector<Point2f> &points)
version_size = 0; version_size = 0;
test_perspective_size = max(getMinSideLen(points)+1.f, 251.f); test_perspective_size = max(getMinSideLen(points)+1.f, 251.f);
result_info = ""; result_info = "";
sr_scale = sr_scale_;
} }
inline double QRDecode::pointPosition(Point2f a, Point2f b , Point2f c) inline double QRDecode::pointPosition(Point2f a, Point2f b , Point2f c)
@ -3625,6 +3627,10 @@ void QRDetectMulti::findQRCodeContours(vector<Point2f>& tmp_localization_points,
int count_contours = num_qrcodes; int count_contours = num_qrcodes;
if (all_contours_points.size() < size_t(num_qrcodes)) if (all_contours_points.size() < size_t(num_qrcodes))
count_contours = (int)all_contours_points.size(); count_contours = (int)all_contours_points.size();
// If the contours cannot be found, return. Otherwise kmeans will get error.
if (all_contours_points.size() == 0)
return;
kmeans(all_contours_points, count_contours, qrcode_labels, kmeans(all_contours_points, count_contours, qrcode_labels,
TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 10, 0.1), TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 10, 0.1),
count_contours, KMEANS_PP_CENTERS, clustered_localization_points); count_contours, KMEANS_PP_CENTERS, clustered_localization_points);
@ -4008,9 +4014,9 @@ class ParallelDecodeProcess : public ParallelLoopBody
{ {
public: public:
ParallelDecodeProcess(Mat& inarr_, vector<QRDecode>& qrdec_, vector<std::string>& decoded_info_, ParallelDecodeProcess(Mat& inarr_, vector<QRDecode>& qrdec_, vector<std::string>& decoded_info_,
vector<Mat>& straight_barcode_, vector< vector< Point2f > >& src_points_, std::shared_ptr<SuperScale> sr_ = nullptr) vector<Mat>& straight_barcode_, vector< vector< Point2f > >& src_points_, std::shared_ptr<SuperScale> sr_ = nullptr, bool use_sr_model_ = false)
: inarr(inarr_), qrdec(qrdec_), decoded_info(decoded_info_) : inarr(inarr_), qrdec(qrdec_), decoded_info(decoded_info_)
, straight_barcode(straight_barcode_), src_points(src_points_), sr_(sr_) , straight_barcode(straight_barcode_), src_points(src_points_), sr(sr_), use_sr_model(use_sr_model_)
{ {
// nothing // nothing
} }
@ -4018,13 +4024,13 @@ public:
{ {
for (int i = range.start; i < range.end; i++) for (int i = range.start; i < range.end; i++)
{ {
if (sr_ != nullptr) { // Modified the input image for decoding,
// by extracting the part containing the QR code and adjusting the resolution.
// Only QRCodeDetectorWeChat() will call this decoding attempt.
if (sr != nullptr) {
int width = inarr.size().width, height = inarr.size().height; int width = inarr.size().width, height = inarr.size().height;
int min_x = src_points[i][0].x, min_y = src_points[i][0].y;
int min_x = src_points[i][0].x; int max_x = src_points[i][0].x, max_y = src_points[i][0].y;
int min_y = src_points[i][0].y;
int max_x = src_points[i][0].x;
int max_y = src_points[i][0].y;
for (const auto& point : src_points[i]) { for (const auto& point : src_points[i]) {
min_x = min_x > point.x ? point.x : min_x; min_x = min_x > point.x ? point.x : min_x;
min_y = min_y > point.y ? point.y : min_y; min_y = min_y > point.y ? point.y : min_y;
@ -4035,30 +4041,33 @@ public:
max_x = min(max(0, max_x), width); max_x = min(max(0, max_x), width);
min_y = min(max(0, min_y), height); min_y = min(max(0, min_y), height);
max_y = min(max(0, max_y), height); max_y = min(max(0, max_y), height);
cv::Rect cropRect(min_x, min_y, max_x - min_x, max_y - min_y); cv::Rect cropRect(min_x, min_y, max_x - min_x, max_y - min_y);
auto scale_list = sr_->getScaleList(max_y - min_y, max_x - min_x); auto scale_list = sr->getScaleList(max_x - min_x, max_y - min_y);
Mat crop_image = inarr(cropRect).clone(); Mat crop_image = inarr(cropRect).clone();
for (auto cur_scale : scale_list) { for (auto cur_scale : scale_list) {
std::lock_guard<std::mutex> lock(sr_mutex); Mat scaled_img;
Mat scaled_img = sr_->ProcessImageScale(crop_image, cur_scale, true); // If a super-resolution model is loaded, parallelism cannot be performed
if (use_sr_model)
std::lock_guard<std::mutex> lock(sr_mutex);
sr->processImageScale(crop_image, scaled_img, cur_scale, use_sr_model);
vector<Point2f> points; vector<Point2f> points;
for (const auto& point : src_points[i]) for (const auto& point : src_points[i])
points.push_back(Point2f((point.x - min_x) * cur_scale, (point.y - min_y) * cur_scale)); points.push_back(Point2f((point.x - min_x) * cur_scale, (point.y - min_y) * cur_scale));
qrdec[i].init(scaled_img, points); qrdec[i].init(scaled_img, points, cur_scale);
bool ok = qrdec[i].straightDecodingProcess(); bool ok = qrdec[i].straightDecodingProcess();
if (ok) if (ok)
{ {
decoded_info[i] = qrdec[i].getDecodeInformation(); decoded_info[i] = qrdec[i].getDecodeInformation();
straight_barcode[i] = qrdec[i].getStraightBarcode(); straight_barcode[i] = qrdec[i].getStraightBarcode();
break;
} }
if (decoded_info[i].empty()) }
decoded_info[i] = ""; if (decoded_info[i].empty())
} decoded_info[i] = "";
} }
// The old decoding attempt.
else { else {
qrdec[i].init(inarr, src_points[i]); qrdec[i].init(inarr, src_points[i]);
bool ok = qrdec[i].straightDecodingProcess(); bool ok = qrdec[i].straightDecodingProcess();
@ -4102,7 +4111,8 @@ private:
vector<std::string>& decoded_info; vector<std::string>& decoded_info;
vector<Mat>& straight_barcode; vector<Mat>& straight_barcode;
vector< vector< Point2f > >& src_points; vector< vector< Point2f > >& src_points;
std::shared_ptr<SuperScale> sr_; std::shared_ptr<SuperScale> sr;
bool use_sr_model;
mutable std::mutex sr_mutex; mutable std::mutex sr_mutex;
}; };
@ -4198,8 +4208,20 @@ bool ImplContour::decodeMulti(
updateQrCorners.resize(src_points.size()*4ull); updateQrCorners.resize(src_points.size()*4ull);
for (size_t i = 0ull; i < src_points.size(); i++) { for (size_t i = 0ull; i < src_points.size(); i++) {
alignmentMarkers[i] = qrdec[i].alignment_coords; alignmentMarkers[i] = qrdec[i].alignment_coords;
for (size_t j = 0ull; j < 4ull; j++) for (size_t j = 0ull; j < 4ull; j++) {
updateQrCorners[i*4ull+j] = qrdec[i].getOriginalPoints()[j] * qrdec[i].coeff_expansion; updateQrCorners[i*4ull+j] = qrdec[i].getOriginalPoints()[j] * qrdec[i].coeff_expansion / qrdec[i].sr_scale;
if (sr_ != nullptr) {
int min_x = src_points[i][0].x, min_y = src_points[i][0].y;
for (const auto& point : src_points[i]) {
min_x = min_x > point.x ? point.x : min_x;
min_y = min_y > point.y ? point.y : min_y;
}
min_x = min(max(0, min_x), inarr.size().width);
min_y = min(max(0, min_y), inarr.size().height);
updateQrCorners[i*4ull+j].x += min_x;
updateQrCorners[i*4ull+j].y += min_y;
}
}
} }
if (!decoded_info.empty()) if (!decoded_info.empty())
return true; return true;
@ -4761,13 +4783,11 @@ public:
PimplQRWeChat(std::shared_ptr<GraphicalCodeDetector> graphical_detector) PimplQRWeChat(std::shared_ptr<GraphicalCodeDetector> graphical_detector)
: graphical_detector_(std::move(graphical_detector)) { : graphical_detector_(std::move(graphical_detector)) {
sr_ = std::make_shared<SuperScale>();
detector_ = std::make_shared<Detector>(); detector_ = std::make_shared<Detector>();
sr_ = std::make_shared<SuperScale>();
} }
bool detectMulti(InputArray img, OutputArray points) const override; bool detectMulti(InputArray img, OutputArray points) const override;
bool detectAndDecodeMulti(InputArray img, std::vector<cv::String>& decoded_info, OutputArray points,
OutputArrayOfArrays straight_qrcode) const override;
}; };
bool PimplQRWeChat::detectMulti(InputArray img, OutputArray points) const bool PimplQRWeChat::detectMulti(InputArray img, OutputArray points) const
@ -4778,69 +4798,106 @@ bool PimplQRWeChat::detectMulti(InputArray img, OutputArray points) const
return false; return false;
} }
std::vector<DetectInfo> _detect_results;
if (use_det_model_) { if (use_det_model_) {
detector_->detect(gray, _detect_results); std::vector<DetectInfo> detect_results;
std::vector<Rect> crop_rects;
detector_->detect(gray, detect_results);
vector<Point2f> results; for (size_t k = 0; k < detect_results.size(); k++) {
for (size_t k = 0; k < _detect_results.size(); k++) { int x0 = detect_results[k].x, y0 = detect_results[k].y;
int x0 = _detect_results[k].x, y0 = _detect_results[k].y; int width = detect_results[k].width, height = detect_results[k].height;
int width = _detect_results[k].width, height = _detect_results[k].height;
int x2 = x0 + width - 1, y2 = y0 + width -1; int x2 = x0 + width - 1, y2 = y0 + width -1;
int padx = max(0.5f * width, static_cast<float>(20)); int padx = max(0.5f * width, static_cast<float>(20));
int pady = max(0.5f * height, static_cast<float>(20)); int pady = max(0.5f * height, static_cast<float>(20));
int crop_x_ = max(x0 - padx, 0); int crop_x_ = max(x0 - padx, 0), crop_y_ = max(y0 - pady, 0);
int crop_y_ = max(y0 - pady, 0); int end_x = min(x2 + padx, gray.cols), end_y = min(y2 + pady, gray.rows);
int end_x = min(x2 + padx, gray.cols); crop_rects.push_back(Rect(crop_x_, crop_y_, end_x - crop_x_, end_y - crop_y_));
int end_y = min(y2 + pady, gray.rows); }
cv::Rect cropRect(crop_x_, crop_y_, end_x - crop_x_, end_y - crop_y_);
vector<Point2f> corners; vector<vector<Point2f> > crop_images_corners;
vector<Point2f> results;
auto scale_list = sr_->getScaleList(end_y - crop_y_, end_x - crop_x_); for (auto& rect : crop_rects) {
Mat crop_image = gray(cropRect).clone(); Mat crop_image = gray(rect).clone();
for (auto cur_scale : scale_list) { int width = rect.width, height = rect.height;
Mat scaled_img = sr_->ProcessImageScale(crop_image, cur_scale, use_sr_model_); auto scale_lists = sr_->getScaleList(width, height);
if (graphical_detector_->detectMulti(scaled_img, corners)) { bool flag = false;
for (size_t i = 0; i < corners.size(); i += 4) { for (auto cur_scale : scale_lists) {
bool flag = true; Mat scaled_image;
std::vector<Point2f> pts_i, pts_j; sr_->processImageScale(crop_image, scaled_image, cur_scale, use_sr_model_);
for (size_t p = 0; p < 4; p++) { vector<Point2f> corners;
int x = corners[i+p].x/cur_scale + crop_x_, y = corners[i+p].y/cur_scale + crop_y_; if (graphical_detector_->detectMulti(scaled_image, corners)) {
pts_i.push_back(Point2f(min(max(x, 0), gray.cols-1), min(max(y, 0), gray.rows-1))); if (corners.size() % 4 != 0)
} continue;
float area1 = cv::contourArea(pts_i); for (size_t i = 0; i < corners.size(); i++) {
corners[i].x = corners[i].x/cur_scale + rect.x;
for (size_t j = 0; j < results.size(); j+= 4) { corners[i].x = min(max(corners[i].x, 0.0f), gray.cols * 1.0f);
pts_j.clear(); corners[i].y = corners[i].y/cur_scale + rect.y;
for (size_t p = 0; p < 4; p++) corners[i].y = min(max(corners[i].y, 0.0f), gray.rows * 1.0f);
pts_j.push_back(Point2f(results[j+p].x, results[j+p].y));
float area2 = cv::contourArea(pts_j);
float intersectionArea = 0.0;
std::vector<cv::Point2f> intersection;
cv::rotatedRectangleIntersection(cv::minAreaRect(pts_i), cv::minAreaRect(pts_j), intersection);
if (!intersection.empty())
intersectionArea = cv::contourArea(intersection);
double iou = intersectionArea / (area1 + area2 - intersectionArea);
double cover = intersectionArea / min(area1, area2);
if (iou > 0.7 || cover > 0.96) {
flag = false;
break;
}
}
if (flag) {
for (auto p : pts_i)
results.push_back(p);
}
} }
crop_images_corners.push_back(corners);
flag = true;
break; break;
} }
} }
if (!flag) {
crop_images_corners.push_back(vector<Point2f>());
}
}
for (size_t j = 0; j < crop_images_corners.size(); j++) {
auto& corners = crop_images_corners[j];
if (corners.size() != 0) {
for (size_t p_1 = 0; p_1 < corners.size(); p_1 += 4) {
if (corners[p_1].x < 0) continue;
Point2f topLeft, bottomRight;
topLeft.x = min(min(corners[p_1].x, corners[p_1+1].x), min(corners[p_1+2].x, corners[p_1+3].x));
topLeft.y = min(min(corners[p_1].y, corners[p_1+1].y), min(corners[p_1+2].y, corners[p_1+3].y));
bottomRight.x = max(max(corners[p_1].x, corners[p_1+1].x), max(corners[p_1+2].x, corners[p_1+3].x));
bottomRight.y = max(max(corners[p_1].y, corners[p_1+1].y), max(corners[p_1+2].y, corners[p_1+3].y));
float area1 = (bottomRight.x - topLeft.x) * (bottomRight.y - topLeft.y);
for (size_t i = j+1; i < crop_images_corners.size(); i++) {
vector<Point2f>& corners_i = crop_images_corners[i];
for (size_t p_i = 0; p_i < corners_i.size(); p_i += 4) {
if (corners_i[p_i].x < 0) continue;
Point2f topLeft_i, bottomRight_i;
topLeft_i.x = min(min(corners_i[p_i].x, corners_i[p_i+1].x), min(corners_i[p_i+2].x, corners_i[p_i+3].x));
topLeft_i.y = min(min(corners_i[p_i].y, corners_i[p_i+1].y), min(corners_i[p_i+2].y, corners_i[p_i+3].y));
bottomRight_i.x = max(max(corners_i[p_i].x, corners_i[p_i+1].x), max(corners_i[p_i+2].x, corners_i[p_i+3].x));
bottomRight_i.y = max(max(corners_i[p_i].y, corners_i[p_i+1].y), max(corners_i[p_i+2].y, corners_i[p_i+3].y));
float interLeft = std::max(topLeft.x, topLeft_i.x), interTop = std::max(topLeft.y, topLeft_i.y);
float interRight = std::min(bottomRight.x, bottomRight_i.x), interBottom = std::min(bottomRight.y, bottomRight_i.y);
float interWidth = max(interRight - interLeft, 0.0f), interHeight = max(interBottom - interTop, 0.0f);
float interArea = interWidth * interHeight;
float area2 = (bottomRight_i.x - topLeft_i.x) * (bottomRight_i.y - topLeft_i.y);
float iou = interArea / (area1 + area2 - interArea);
if (iou > 0.5) {
if (area2 < area1) {
topLeft = topLeft_i;
bottomRight = bottomRight_i;
area1 = area2;
corners[p_1] = corners_i[p_i];
corners[p_1+1] = corners_i[p_i+1];
corners[p_1+2] = corners_i[p_i+2];
corners[p_1+3] = corners_i[p_i+3];
}
corners_i[p_i].x = -1;
corners_i[p_i+1].x = -1;
corners_i[p_i+2].x = -1;
corners_i[p_i+3].x = -1;
break;
}
}
}
}
}
for (auto p : corners)
if (p.x >= 0) {
results.push_back(p);
}
} }
if (results.size() >= 4) { if (results.size() >= 4) {
@ -4855,22 +4912,8 @@ bool PimplQRWeChat::detectMulti(InputArray img, OutputArray points) const
} }
} }
bool PimplQRWeChat::detectAndDecodeMulti( QRCodeDetectorWeChat::QRCodeDetectorWeChat(const std::string& detection_model_path,
InputArray img, const std::string& super_resolution_model_path,
CV_OUT std::vector<cv::String>& decoded_info,
OutputArray points_,
OutputArrayOfArrays straight_qrcode
) const {
bool ok = detectMulti(img, points_);
if (ok)
return decodeMulti(img, points_, decoded_info, straight_qrcode);
return false;
}
QRCodeDetectorWeChat::QRCodeDetectorWeChat(const std::string& detection_model_path_,
const std::string& super_resolution_model_path_,
Ptr<GraphicalCodeDetector> graphical_detector, Ptr<GraphicalCodeDetector> graphical_detector,
const float detector_iou_thres, const float detector_iou_thres,
const float score_thres, const float score_thres,
@ -4878,15 +4921,15 @@ QRCodeDetectorWeChat::QRCodeDetectorWeChat(const std::string& detection_model_pa
Ptr<PimplQRWeChat> p_ = std::make_shared<PimplQRWeChat>(std::move(graphical_detector)); Ptr<PimplQRWeChat> p_ = std::make_shared<PimplQRWeChat>(std::move(graphical_detector));
p = p_; p = p_;
if (!super_resolution_model_path_.empty()) { if (!super_resolution_model_path.empty()) {
CV_Assert(utils::fs::exists(super_resolution_model_path_)); CV_Assert(utils::fs::exists(super_resolution_model_path));
int res = p_->sr_->init(super_resolution_model_path_); int res = p_->sr_->init(super_resolution_model_path);
CV_Assert(res == 0); CV_Assert(res == 0);
p_->use_sr_model_ = true; p_->use_sr_model_ = true;
} }
if (!detection_model_path_.empty()) { if (!detection_model_path.empty()) {
CV_Assert(utils::fs::exists(detection_model_path_)); CV_Assert(utils::fs::exists(detection_model_path));
int res = p_->detector_->init(detection_model_path_); int res = p_->detector_->init(detection_model_path);
CV_Assert(res == 0); CV_Assert(res == 0);
p_->use_det_model_ = true; p_->use_det_model_ = true;
} }

View File

@ -10,234 +10,164 @@
#define CLIP(x, x1, x2) (std::fmax<float>)(x1, (std::fmin<float>)(x, x2)) #define CLIP(x, x1, x2) (std::fmax<float>)(x1, (std::fmin<float>)(x, x2))
namespace cv { namespace cv {
int Detector::init(const std::string &det_path) int Detector::init(const std::string &det_path)
{ {
try dnn::Net network = dnn::readNetFromONNX(det_path);
{ this->qbar_detector = std::make_shared<dnn::Net>(network);
dnn::Net network = dnn::readNetFromONNX(det_path);
if(network.empty())
{
return -101;
}
this->qbar_detector = std::make_shared<dnn::Net>(network);
}
catch (const std::exception &e)
{
printf("%s", e.what());
return -3;
}
return 0; return 0;
}
bool Detector::detect(const Mat &image,std::vector<DetectInfo> &bboxes)
{
Mat input_blob;
bool ret = this->pre_process_det(image,input_blob);
input_blob = dnn::blobFromImage(input_blob);
std::vector<Mat> outputs;
//forward
this->qbar_detector->setInput(input_blob, "input");
std::vector<std::string> output_names;
output_names.push_back("cls_pred_stride_8");
output_names.push_back("dis_pred_stride_8");
output_names.push_back("cls_pred_stride_16");
output_names.push_back("dis_pred_stride_16");
output_names.push_back("cls_pred_stride_32");
output_names.push_back("dis_pred_stride_32");
this->qbar_detector->forward(outputs, output_names);
bboxes.clear();
if (outputs.size() == 0)
return false;
std::vector<BoxInfo> det_bboxes;
ret = this->post_process_det(outputs,input_blob.size[3],input_blob.size[2],det_bboxes);
if (ret)
{
for (size_t i = 0; i < det_bboxes.size(); i++)
{
DetectInfo object;
object.class_id = det_bboxes[i].label + 1;
object.prob = det_bboxes[i].score;
object.x = det_bboxes[i].x1 * image.cols;
object.y = det_bboxes[i].y1 * image.rows;
object.width = det_bboxes[i].x2 * image.cols - object.x;
object.height = det_bboxes[i].y2 * image.rows - object.y;
bboxes.push_back(object);
}
} }
int Detector::detect(const Mat &image,std::vector<DetectInfo> &bboxes) return ret;
{ }
Mat input_blob;
int ret = this->pre_process_det(image,input_blob);
input_blob = dnn::blobFromImage(input_blob);
std::vector<Mat> outputs;
//forward bool Detector::pre_process_det(const Mat &image,Mat &out_blob)
{
int set_size = this->reference_size;
int setWidth, setHeight;
this->qbar_detector->setInput(input_blob, "input"); if (image.cols <= set_size && image.rows <= set_size) {
if (image.cols >= image.rows) {
std::vector<std::string> output_names; setWidth = set_size;
output_names.push_back("cls_pred_stride_8"); setHeight = std::ceil(image.rows * 1.0 * set_size / image.cols);
output_names.push_back("dis_pred_stride_8");
output_names.push_back("cls_pred_stride_16");
output_names.push_back("dis_pred_stride_16");
output_names.push_back("cls_pred_stride_32");
output_names.push_back("dis_pred_stride_32");
this->qbar_detector->forward(outputs, output_names);
if(outputs.size()==0)
{
return -103;
} }
else {
setHeight = set_size;
std::vector<BoxInfo> det_bboxes; setWidth = std::ceil(image.cols * 1.0 * set_size / image.rows);
ret = this->post_process_det(outputs,input_blob.size[3],input_blob.size[2],det_bboxes);
if (!ret)
{
for (size_t i = 0; i < det_bboxes.size(); i++)
{
DetectInfo object;
object.class_id = det_bboxes[i].label + 1;
object.prob = det_bboxes[i].score;
object.x = det_bboxes[i].x1 * image.cols;
object.y = det_bboxes[i].y1 * image.rows;
object.width = det_bboxes[i].x2 * image.cols - object.x;
object.height = det_bboxes[i].y2 * image.rows - object.y;
bboxes.push_back(object);
}
} }
}
return ret; else {
float resizeRatio = sqrt(image.cols * image.rows * 1.0 / (set_size * set_size));
setWidth = image.cols / resizeRatio;
setHeight = image.rows / resizeRatio;
} }
setHeight = static_cast<int>((setHeight + 32 - 1) / 32) * 32;
setWidth = static_cast<int>((setWidth + 32 - 1) / 32) * 32;
resize(image,out_blob,Size(setWidth, setHeight));
out_blob.convertTo(out_blob,CV_32FC1,1.0f/128.0f);
out_blob = out_blob - Scalar(1.0);
return true;
}
int Detector::pre_process_det(const Mat &image,Mat &out_blob) bool Detector::post_process_det(std::vector<Mat> outputs,float inputWidth,float inputHeight,std::vector<BoxInfo>& dets)
{
// step 1: extract
std::vector<std::vector<int>> outShape(6);
float *outPtr[6];
for (int i = 0; i < 6; i++)
{ {
int reference_size = this->reference_size; outPtr[i] = reinterpret_cast<float *>(outputs[i].data);
int setWidth, setHeight; if (outPtr[i] == NULL)
return false;
if (image.cols <= reference_size && image.rows <= reference_size) { for(int j = 0;j<outputs[i].dims;j++)
if (image.cols >= image.rows) outShape[i].push_back(outputs[i].size[j]);
{
setWidth = reference_size;
setHeight = std::ceil(image.rows * 1.0 * reference_size / image.cols);
}
else
{
setHeight = reference_size;
setWidth = std::ceil(image.cols * 1.0 * reference_size / image.rows);
}
}
else
{
float resizeRatio = sqrt(image.cols * image.rows * 1.0 / (reference_size * reference_size));
setWidth = image.cols / resizeRatio;
setHeight = image.rows / resizeRatio;
}
setHeight = static_cast<int>((setHeight + 32 - 1) / 32) * 32;
setWidth = static_cast<int>((setWidth + 32 - 1) / 32) * 32;
resize(image,out_blob,Size(setWidth, setHeight));
out_blob.convertTo(out_blob,CV_32FC1,1.0f/128.0f);
out_blob = out_blob - Scalar(1.0);
return 0;
} }
// step2: decode s8\s16\s32
std::vector<std::vector<BoxInfo>> results;
int numClasses = outShape[0][2];
results.resize(numClasses);
int Detector::post_process_det(std::vector<Mat> outputs,float inputWidth,float inputHeight,std::vector<BoxInfo>& dets) this->decode_infer(outPtr[0], outPtr[1], 8, results, outShape[0], outShape[1], score_thres,inputWidth, inputHeight);
this->decode_infer(outPtr[2], outPtr[3], 16, results, outShape[2], outShape[3], score_thres,inputWidth, inputHeight);
this->decode_infer(outPtr[4], outPtr[5], 32, results, outShape[4], outShape[5], score_thres,inputWidth, inputHeight);
// step3: nms
std::vector<BoxInfo> rets;
for (size_t i = 0; i < results.size(); i++)
{ {
// step 1: extract this->nms(results[i], iou_thres);
std::vector<std::vector<int>> outShape(6); for (auto & box : results[i])
float *outPtr[6];
for (int i = 0; i < 6; i++)
{ {
if (box.score > score_thres)
outPtr[i] = reinterpret_cast<float *>(outputs[i].data);
if (outPtr[i] == NULL)
return -1;
for(int j = 0;j<outputs[i].dims;j++)
{ {
outShape[i].push_back(outputs[i].size[j]); rets.push_back(box);
}
}
// step2: decode s8\s16\s32
std::vector<std::vector<BoxInfo>> results;
int numClasses = outShape[0][2];
results.resize(numClasses);
this->decode_infer(outPtr[0], outPtr[1], 8, results, outShape[0], outShape[1], score_thres,inputWidth, inputHeight);
this->decode_infer(outPtr[2], outPtr[3], 16, results, outShape[2], outShape[3], score_thres,inputWidth, inputHeight);
this->decode_infer(outPtr[4], outPtr[5], 32, results, outShape[4], outShape[5], score_thres,inputWidth, inputHeight);
// step3: nms
std::vector<BoxInfo> rets;
for (size_t i = 0; i < results.size(); i++)
{
this->nms(results[i], iou_thres); // 0.5
for (auto & box : results[i])
{
if (box.score > score_thres)
{
rets.push_back(box);
}
} }
} }
// step4: multi-class nms to gen BoxMultiInfo results
this->multiclass_nms(rets, dets, iou_thres, inputWidth, inputHeight);
return 0;
} }
// step4: multi-class nms to gen BoxMultiInfo results
this->multiclass_nms(rets, dets, iou_thres, inputWidth, inputHeight);
void Detector::multiclass_nms(std::vector<BoxInfo> &input_boxes, std::vector<BoxInfo> &output_boxes, float thr, int inputWidth, int inputHeight) return true;
}
void Detector::multiclass_nms(std::vector<BoxInfo> &input_boxes, std::vector<BoxInfo> &output_boxes, float thr, int inputWidth, int inputHeight)
{
if (input_boxes.size() <= 0) return;
output_boxes.clear();
std::vector<bool> skip(input_boxes.size());
for (size_t i = 0; i < input_boxes.size(); ++i)
skip[i] = false;
// merge overlapped results
for (size_t i = 0; i < input_boxes.size(); ++i)
{ {
if (input_boxes.size() <= 0) return; if (skip[i])
output_boxes.clear(); continue;
std::vector<bool> skip(input_boxes.size()); skip[i] = true;
for (size_t i = 0; i < input_boxes.size(); ++i) BoxInfo box;
{ box.x1 = input_boxes[i].x1;
skip[i] = false; box.y1 = input_boxes[i].y1;
} box.x2 = input_boxes[i].x2;
box.y2 = input_boxes[i].y2;
box.score = input_boxes[i].score;
// merge overlapped results for (size_t j = i + 1; j < input_boxes.size(); ++j)
for (size_t i = 0; i < input_boxes.size(); ++i)
{ {
int labeli = input_boxes[i].label; if (skip[j])
if (skip[i])
continue; continue;
skip[i] = true;
BoxInfo box;
box.x1 = input_boxes[i].x1;
box.y1 = input_boxes[i].y1;
box.x2 = input_boxes[i].x2;
box.y2 = input_boxes[i].y2;
box.score = input_boxes[i].score;
for (size_t j = i + 1; j < input_boxes.size(); ++j)
{
int labelj = input_boxes[j].label;
if (skip[j])
continue;
{
float area_i = (input_boxes[i].x2 - input_boxes[i].x1 + 1) * (input_boxes[i].y2 - input_boxes[i].y1 + 1);
float area_j = (input_boxes[j].x2 - input_boxes[j].x1 + 1) * (input_boxes[j].y2 - input_boxes[j].y1 + 1);
float xx1 = (std::max)(input_boxes[i].x1, input_boxes[j].x1);
float yy1 = (std::max)(input_boxes[i].y1, input_boxes[j].y1);
float xx2 = (std::min)(input_boxes[i].x2, input_boxes[j].x2);
float yy2 = (std::min)(input_boxes[i].y2, input_boxes[j].y2);
float w = (std::max)(static_cast<float>(0), xx2 - xx1 + 1);
float h = (std::max)(static_cast<float>(0), yy2 - yy1 + 1);
float inter = w * h;
float ovr = inter / (area_i + area_j - inter);
float cover = inter / (std::min)(area_i, area_j);
if (ovr > thr || cover > 0.96){
box.x1 = (std::min)(box.x1, input_boxes[j].x1);
box.y1 = (std::min)(box.y1, input_boxes[j].y1);
box.x2 = (std::max)(box.x2, input_boxes[j].x2);
box.y2 = (std::max)(box.y2, input_boxes[j].y2);
box.score = (std::max)(box.score, input_boxes[j].score);
box.label = 5;
skip[j] = true;
}
}
}
box.x1 = CLIP(box.x1 / inputWidth, 0, 1);
box.y1 = CLIP(box.y1 / inputHeight, 0, 1);
box.x2 = CLIP(box.x2 / inputWidth, 0, 1);
box.y2 = CLIP(box.y2 / inputHeight, 0, 1);
output_boxes.push_back(box);
}
}
void Detector::nms(std::vector<BoxInfo>& input_boxes, float NMS_THRESH)
{
if (input_boxes.size() <= 1) return;
std::sort(input_boxes.begin(), input_boxes.end(), [](BoxInfo a, BoxInfo b) { return a.score > b.score; });
std::vector<float> vArea(input_boxes.size());
for (size_t i = 0; i < input_boxes.size(); ++i)
{
vArea[i] = (input_boxes.at(i).x2 - input_boxes.at(i).x1 + 1)
* (input_boxes.at(i).y2 - input_boxes.at(i).y1 + 1);
}
for (size_t i = 0; i < input_boxes.size(); ++i)
{
for (size_t j = i + 1; j < input_boxes.size();)
{ {
float area_i = (input_boxes[i].x2 - input_boxes[i].x1 + 1) * (input_boxes[i].y2 - input_boxes[i].y1 + 1);
float area_j = (input_boxes[j].x2 - input_boxes[j].x1 + 1) * (input_boxes[j].y2 - input_boxes[j].y1 + 1);
float xx1 = (std::max)(input_boxes[i].x1, input_boxes[j].x1); float xx1 = (std::max)(input_boxes[i].x1, input_boxes[j].x1);
float yy1 = (std::max)(input_boxes[i].y1, input_boxes[j].y1); float yy1 = (std::max)(input_boxes[i].y1, input_boxes[j].y1);
float xx2 = (std::min)(input_boxes[i].x2, input_boxes[j].x2); float xx2 = (std::min)(input_boxes[i].x2, input_boxes[j].x2);
@ -245,110 +175,155 @@ namespace cv {
float w = (std::max)(static_cast<float>(0), xx2 - xx1 + 1); float w = (std::max)(static_cast<float>(0), xx2 - xx1 + 1);
float h = (std::max)(static_cast<float>(0), yy2 - yy1 + 1); float h = (std::max)(static_cast<float>(0), yy2 - yy1 + 1);
float inter = w * h; float inter = w * h;
float ovr = inter / (vArea[i] + vArea[j] - inter); float ovr = inter / (area_i + area_j - inter);
float cover = inter / (std::min)(vArea[i], vArea[j]); float cover = inter / (std::min)(area_i, area_j);
if (ovr >= NMS_THRESH) if (ovr > thr || cover > 0.96){
if (input_boxes[j].score > box.score) {
box.x1 = input_boxes[j].x1;
box.y1 = input_boxes[j].y1;
box.x2 = input_boxes[j].x2;
box.y2 = input_boxes[j].y2;
box.score = input_boxes[j].score;
box.label = 5;
}
skip[j] = true;
}
}
}
box.x1 = CLIP(box.x1 / inputWidth, 0, 1);
box.y1 = CLIP(box.y1 / inputHeight, 0, 1);
box.x2 = CLIP(box.x2 / inputWidth, 0, 1);
box.y2 = CLIP(box.y2 / inputHeight, 0, 1);
output_boxes.push_back(box);
}
}
void Detector::nms(std::vector<BoxInfo>& input_boxes, float NMS_THRESH)
{
if (input_boxes.size() <= 1) return;
std::sort(input_boxes.begin(), input_boxes.end(), [](BoxInfo a, BoxInfo b) { return a.score > b.score; });
std::vector<float> vArea(input_boxes.size());
for (size_t i = 0; i < input_boxes.size(); ++i)
{
vArea[i] = (input_boxes.at(i).x2 - input_boxes.at(i).x1 + 1)
* (input_boxes.at(i).y2 - input_boxes.at(i).y1 + 1);
}
for (size_t i = 0; i < input_boxes.size(); ++i)
{
for (size_t j = i + 1; j < input_boxes.size();)
{
float xx1 = (std::max)(input_boxes[i].x1, input_boxes[j].x1);
float yy1 = (std::max)(input_boxes[i].y1, input_boxes[j].y1);
float xx2 = (std::min)(input_boxes[i].x2, input_boxes[j].x2);
float yy2 = (std::min)(input_boxes[i].y2, input_boxes[j].y2);
float w = (std::max)(static_cast<float>(0), xx2 - xx1 + 1);
float h = (std::max)(static_cast<float>(0), yy2 - yy1 + 1);
float inter = w * h;
float ovr = inter / (vArea[i] + vArea[j] - inter);
float cover = inter / (std::min)(vArea[i], vArea[j]);
if (ovr >= NMS_THRESH)
{
input_boxes.erase(input_boxes.begin() + j);
vArea.erase(vArea.begin() + j);
}
else if (cover >= 0.96)
{
if (vArea[i] > vArea[j])
{ {
input_boxes.erase(input_boxes.begin() + j); input_boxes.erase(input_boxes.begin() + j);
vArea.erase(vArea.begin() + j); vArea.erase(vArea.begin() + j);
} }
else if (cover >= 0.96) // qiantao
{
if (vArea[i] > vArea[j])
{
input_boxes.erase(input_boxes.begin() + j);
vArea.erase(vArea.begin() + j);
}
else
{
input_boxes[i].x1 = input_boxes[j].x1;
input_boxes[i].y1 = input_boxes[j].y1;
input_boxes[i].x2 = input_boxes[j].x2;
input_boxes[i].y2 = input_boxes[j].y2;
input_boxes.erase(input_boxes.begin() + j);
vArea.erase(vArea.begin() + j);
}
}
else else
{ {
j++; input_boxes[i].x1 = input_boxes[j].x1;
input_boxes[i].y1 = input_boxes[j].y1;
input_boxes[i].x2 = input_boxes[j].x2;
input_boxes[i].y2 = input_boxes[j].y2;
input_boxes.erase(input_boxes.begin() + j);
vArea.erase(vArea.begin() + j);
} }
} }
} else
}
inline float fast_exp(float x)
{
union {
uint32_t i;
float f;
} v{};
v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f);
return v.f;
}
template<typename _Tp>
int activation_function_softmax(const _Tp* src, _Tp* dst, int length)
{
const _Tp alpha = *(std::max_element)(src, src + length);
_Tp denominator{ 0 };
for (int i = 0; i < length; ++i)
{
dst[i] = fast_exp(src[i] - alpha);
denominator += dst[i];
}
for (int i = 0; i < length; ++i)
{
dst[i] /= denominator;
}
return 0;
}
void Detector::decode_infer(float *clsPred, float *disPred, int stride, std::vector<std::vector<BoxInfo>> &results, const std::vector<int> &outShapeCls, const std::vector<int> &outShapeDis, float scoreThres,float inputWidth,float inputHeight)
{
int numClasses = outShapeCls[2];
int RegMax = (outShapeDis[2] / 4) - 1;
int lenFeat = outShapeCls[1];
float * pScore = clsPred;
std::vector<int> inputShape;
int featWidth = std::ceil(inputWidth / stride);
for (int idx = 0; idx < lenFeat; idx++)
{
for (int label = 0; label < numClasses; label ++)
{ {
float score = pScore[idx * numClasses + label]; j++;
if (score > scoreThres)
{
const float * bboxPred = disPred + idx*outShapeDis[2]; //(RegMax + 1)*4;
int row = idx / featWidth;
int col = idx % featWidth;
// disPred2Bbox
float centerX = col * stride;
float centerY = row * stride;
std::vector<float> disPred_;
disPred_.resize(4);
for (int i = 0; i < 4; i++){
float dis = 0;
float* disAfterSoftmax = new float[RegMax + 1];
activation_function_softmax(bboxPred + i*(RegMax + 1), disAfterSoftmax, RegMax+1);
for (int j = 0; j < (RegMax+1); j++){
dis += j * disAfterSoftmax[j];
}
dis *= stride;
disPred_[i] = dis;
delete[] disAfterSoftmax;
}
float xMin = (std::max)(centerX - disPred_[0], .0f);
float yMin = (std::max)(centerY - disPred_[1], .0f);
float xMax = (std::min)(centerX + disPred_[2], inputWidth);
float yMax = (std::min)(centerY + disPred_[3], inputHeight);
BoxInfo bbox = {xMin, yMin, xMax, yMax, score, label};
results[label].push_back(bbox);
}
} }
} }
} }
}
inline float fast_exp(float x)
{
union {
uint32_t i;
float f;
} v{};
v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f);
return v.f;
}
template<typename _Tp>
int activation_function_softmax(const _Tp* src, _Tp* dst, int length)
{
const _Tp alpha = *(std::max_element)(src, src + length);
_Tp denominator{ 0 };
for (int i = 0; i < length; ++i)
{
dst[i] = fast_exp(src[i] - alpha);
denominator += dst[i];
}
for (int i = 0; i < length; ++i)
{
dst[i] /= denominator;
}
return 0;
}
void Detector::decode_infer(float *clsPred, float *disPred, int stride, std::vector<std::vector<BoxInfo>> &results, const std::vector<int> &outShapeCls, const std::vector<int> &outShapeDis, float scoreThres,float inputWidth,float inputHeight)
{
int numClasses = outShapeCls[2];
int RegMax = (outShapeDis[2] / 4) - 1;
int lenFeat = outShapeCls[1];
float * pScore = clsPred;
std::vector<int> inputShape;
int featWidth = std::ceil(inputWidth / stride);
for (int idx = 0; idx < lenFeat; idx++)
{
for (int label = 0; label < numClasses; label ++)
{
float score = pScore[idx * numClasses + label];
if (score > scoreThres)
{
const float * bboxPred = disPred + idx*outShapeDis[2];
int row = idx / featWidth;
int col = idx % featWidth;
float centerX = col * stride;
float centerY = row * stride;
std::vector<float> disPred_;
disPred_.resize(4);
for (int i = 0; i < 4; i++){
float dis = 0;
float* disAfterSoftmax = new float[RegMax + 1];
activation_function_softmax(bboxPred + i*(RegMax + 1), disAfterSoftmax, RegMax+1);
for (int j = 0; j < (RegMax+1); j++){
dis += j * disAfterSoftmax[j];
}
dis *= stride;
disPred_[i] = dis;
delete[] disAfterSoftmax;
}
float xMin = (std::max)(centerX - disPred_[0], .0f);
float yMin = (std::max)(centerY - disPred_[1], .0f);
float xMax = (std::min)(centerX + disPred_[2], inputWidth);
float yMax = (std::min)(centerY + disPred_[3], inputHeight);
BoxInfo bbox = {xMin, yMin, xMax, yMax, score, label};
results[label].push_back(bbox);
}
}
}
}
} // namespace cv } // namespace cv

View File

@ -45,14 +45,14 @@ namespace cv {
Detector(){}; Detector(){};
~Detector(){}; ~Detector(){};
int init(const std::string &config_path); int init(const std::string &config_path);
int detect(const Mat &image,std::vector<DetectInfo> &bboxes); bool detect(const Mat &image,std::vector<DetectInfo> &bboxes);
void setReferenceSize(int reference_size) {this->reference_size = reference_size;} void setReferenceSize(int reference_size_) {this->reference_size = reference_size_;}
void setScoreThres(float score_thres) {this->score_thres = score_thres;} void setScoreThres(float score_thres_) {this->score_thres = score_thres_;}
void setIouThres(float iou_thres) {this->iou_thres = iou_thres;} void setIouThres(float iou_thres_) {this->iou_thres = iou_thres_;}
private: private:
int post_process_det(std::vector<Mat> outputs,float inputWidth,float inputHeight,std::vector<BoxInfo>& dets); bool post_process_det(std::vector<Mat> outputs,float inputWidth,float inputHeight,std::vector<BoxInfo>& dets);
int pre_process_det(const Mat &image,Mat &out_blob); bool pre_process_det(const Mat &image,Mat &out_blob);
void multiclass_nms(std::vector<BoxInfo> &input_boxes, std::vector<BoxInfo> &output_boxes, float thr, int inputWidth, int inputHeight); void multiclass_nms(std::vector<BoxInfo> &input_boxes, std::vector<BoxInfo> &output_boxes, float thr, int inputWidth, int inputHeight);
void decode_infer(float *clsPred, float *disPred, int stride, std::vector<std::vector<BoxInfo>> &results, const std::vector<int> &outShapeCls, const std::vector<int> &outShapeDis, float scoreThres,float inputHeight,float inputWidth); void decode_infer(float *clsPred, float *disPred, int stride, std::vector<std::vector<BoxInfo>> &results, const std::vector<int> &outShapeCls, const std::vector<int> &outShapeDis, float scoreThres,float inputHeight,float inputWidth);
void nms(std::vector<BoxInfo>& input_boxes, float NMS_THRESH); void nms(std::vector<BoxInfo>& input_boxes, float NMS_THRESH);

View File

@ -11,57 +11,51 @@
#define CLIP(x, x1, x2) max(x1, min(x, x2)) #define CLIP(x, x1, x2) max(x1, min(x, x2))
namespace cv { namespace cv {
int SuperScale::init(const std::string &sr_path) { int SuperScale::init(const std::string &sr_path) {
try dnn::Net network = dnn::readNetFromONNX(sr_path);
{ this->qbar_sr = std::make_shared<dnn::Net>(network);
dnn::Net network = dnn::readNetFromONNX(sr_path);
if(network.empty())
{
return -101;
}
this->qbar_sr = std::make_shared<dnn::Net>(network);
}
catch (const std::exception &e)
{
printf("%s", e.what());
return -3;
}
net_loaded_ = true;
return 0; return 0;
} }
std::vector<float> SuperScale::getScaleList(const int width, const int height) { std::vector<float> SuperScale::getScaleList(const int width, const int height) {
if (width < 320 || height < 320) return {1.0, 2.0, 0.5}; float min_side = min(width, height);
if (width < 640 && height < 640) return {1.0, 0.5}; if (min_side <= 450.f) {
return {0.5, 1.0}; return {1.0f, 300.0f / min_side, 2.0f};
}
else
return {1.0f, 450.f / min_side};
} }
Mat SuperScale::ProcessImageScale(const Mat &src, float scale, const bool &use_sr, void SuperScale::processImageScale(const Mat &src, Mat &dst, float scale, bool use_sr,
int sr_max_size) { int sr_max_size) {
Mat dst = src; scale = min(scale, MAX_SCALE);
if (scale == 1.0) { // src if (scale > .0 && scale < 1.0)
return dst; { // down sample
}
int width = src.cols;
int height = src.rows;
if (scale == 2.0) { // upsample
int SR_TH = sr_max_size;
if (use_sr && (int)sqrt(width * height * 1.0) < SR_TH && net_loaded_) {
int ret = SuperResoutionScale(src, dst);
if (ret == 0) return dst;
}
{ resize(src, dst, Size(), scale, scale, INTER_CUBIC); }
} else if (scale < 1.0) { // downsample
resize(src, dst, Size(), scale, scale, INTER_AREA); resize(src, dst, Size(), scale, scale, INTER_AREA);
} }
else if (scale >= 1.0 && scale < 2.0)
return dst; {
resize(src, dst, Size(), scale, scale, INTER_CUBIC);
}
else if (scale >= 2.0)
{
int width = src.cols;
int height = src.rows;
if (use_sr && (int) sqrt(width * height * 1.0) < sr_max_size && !qbar_sr->empty())
{
superResolutionScale(src, dst);
if (scale > 2.0)
{
processImageScale(dst, dst, scale / 2.0f, use_sr);
}
}
else
{ resize(src, dst, Size(), scale, scale, INTER_CUBIC); }
}
} }
int SuperScale::SuperResoutionScale(const Mat &src, Mat &dst) { int SuperScale::superResolutionScale(const Mat &src, Mat &dst) {
// cv::resize(src, dst, Size(), 2, 2, INTER_CUBIC);
Mat blob; Mat blob;
dnn::blobFromImage(src, blob, 1.0, Size(src.cols, src.rows), {0.0f}, false, false); dnn::blobFromImage(src, blob, 1.0, Size(src.cols, src.rows), {0.0f}, false, false);
@ -70,6 +64,7 @@ int SuperScale::SuperResoutionScale(const Mat &src, Mat &dst) {
dst = Mat(prob.size[2], prob.size[3], CV_32F, prob.ptr<float>()); dst = Mat(prob.size[2], prob.size[3], CV_32F, prob.ptr<float>());
dst.convertTo(dst, CV_8UC1); dst.convertTo(dst, CV_8UC1);
return 0; return 0;
} }
} // namesapce cv } // namesapce cv

View File

@ -14,18 +14,19 @@
using namespace std; using namespace std;
namespace cv { namespace cv {
constexpr static float MAX_SCALE = 4.0f;
class SuperScale { class SuperScale {
public: public:
SuperScale(){}; SuperScale(){};
~SuperScale(){}; ~SuperScale(){};
int init(const std::string &config_path); int init(const std::string &config_path);
std::vector<float> getScaleList(const int width, const int height); std::vector<float> getScaleList(const int width, const int height);
Mat ProcessImageScale(const Mat &src, float scale, const bool &use_sr, int sr_max_size = 160); void processImageScale(const Mat &src, Mat &dst, float scale, bool use_sr, int sr_max_size = 160);
private: private:
std::shared_ptr<dnn::Net> qbar_sr; std::shared_ptr<dnn::Net> qbar_sr;
bool net_loaded_ = false; int superResolutionScale(const Mat &src, Mat &dst);
int SuperResoutionScale(const Mat &src, Mat &dst);
}; };
} // namesapce cv } // namesapce cv
#endif // __SCALE_SUPER_SCALE_HPP_ #endif // __SCALE_SUPER_SCALE_HPP_