mirror of
https://github.com/opencv/opencv.git
synced 2025-07-20 11:06:38 +08:00
Refactor NMS procedure at RegionLayer
This commit is contained in:
parent
047ad4ff71
commit
c67e75b68f
@ -482,7 +482,7 @@ namespace cv {
|
|||||||
}
|
}
|
||||||
else if (layer_type == "region")
|
else if (layer_type == "region")
|
||||||
{
|
{
|
||||||
float thresh = 0.001; // in the original Darknet is equal to the detection threshold set by the user
|
float thresh = getParam<float>(layer_params, "thresh", 0.001);
|
||||||
int coords = getParam<int>(layer_params, "coords", 4);
|
int coords = getParam<int>(layer_params, "coords", 4);
|
||||||
int classes = getParam<int>(layer_params, "classes", -1);
|
int classes = getParam<int>(layer_params, "classes", -1);
|
||||||
int num_of_anchors = getParam<int>(layer_params, "num", -1);
|
int num_of_anchors = getParam<int>(layer_params, "num", -1);
|
||||||
|
@ -43,7 +43,7 @@
|
|||||||
#include "../precomp.hpp"
|
#include "../precomp.hpp"
|
||||||
#include <opencv2/dnn/shape_utils.hpp>
|
#include <opencv2/dnn/shape_utils.hpp>
|
||||||
#include <opencv2/dnn/all_layers.hpp>
|
#include <opencv2/dnn/all_layers.hpp>
|
||||||
#include <iostream>
|
#include "nms.inl.hpp"
|
||||||
#include "opencl_kernels_dnn.hpp"
|
#include "opencl_kernels_dnn.hpp"
|
||||||
|
|
||||||
namespace cv
|
namespace cv
|
||||||
@ -173,8 +173,7 @@ public:
|
|||||||
if (nmsThreshold > 0) {
|
if (nmsThreshold > 0) {
|
||||||
Mat mat = outBlob.getMat(ACCESS_WRITE);
|
Mat mat = outBlob.getMat(ACCESS_WRITE);
|
||||||
float *dstData = mat.ptr<float>();
|
float *dstData = mat.ptr<float>();
|
||||||
do_nms_sort(dstData, rows*cols*anchors, nmsThreshold);
|
do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold);
|
||||||
//do_nms(dstData, rows*cols*anchors, nmsThreshold);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -263,128 +262,48 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (nmsThreshold > 0) {
|
if (nmsThreshold > 0) {
|
||||||
do_nms_sort(dstData, rows*cols*anchors, nmsThreshold);
|
do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold);
|
||||||
//do_nms(dstData, rows*cols*anchors, nmsThreshold);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline float rectOverlap(const Rect2f& a, const Rect2f& b)
|
||||||
struct box {
|
|
||||||
float x, y, w, h;
|
|
||||||
float *probs;
|
|
||||||
};
|
|
||||||
|
|
||||||
float overlap(float x1, float w1, float x2, float w2)
|
|
||||||
{
|
{
|
||||||
float l1 = x1 - w1 / 2;
|
return 1.0f - jaccardDistance(a, b);
|
||||||
float l2 = x2 - w2 / 2;
|
|
||||||
float left = l1 > l2 ? l1 : l2;
|
|
||||||
float r1 = x1 + w1 / 2;
|
|
||||||
float r2 = x2 + w2 / 2;
|
|
||||||
float right = r1 < r2 ? r1 : r2;
|
|
||||||
return right - left;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
float box_intersection(box a, box b)
|
void do_nms_sort(float *detections, int total, float score_thresh, float nms_thresh)
|
||||||
{
|
{
|
||||||
float w = overlap(a.x, a.w, b.x, b.w);
|
std::vector<Rect2f> boxes(total);
|
||||||
float h = overlap(a.y, a.h, b.y, b.h);
|
std::vector<float> scores(total);
|
||||||
if (w < 0 || h < 0) return 0;
|
|
||||||
float area = w*h;
|
|
||||||
return area;
|
|
||||||
}
|
|
||||||
|
|
||||||
float box_union(box a, box b)
|
for (int i = 0; i < total; ++i)
|
||||||
{
|
{
|
||||||
float i = box_intersection(a, b);
|
Rect2f &b = boxes[i];
|
||||||
float u = a.w*a.h + b.w*b.h - i;
|
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
float box_iou(box a, box b)
|
|
||||||
{
|
|
||||||
return box_intersection(a, b) / box_union(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct sortable_bbox {
|
|
||||||
int index;
|
|
||||||
float *probs;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct nms_comparator {
|
|
||||||
int k;
|
|
||||||
nms_comparator(int _k) : k(_k) {}
|
|
||||||
bool operator ()(sortable_bbox v1, sortable_bbox v2) {
|
|
||||||
return v2.probs[k] < v1.probs[k];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void do_nms_sort(float *detections, int total, float nms_thresh)
|
|
||||||
{
|
|
||||||
std::vector<box> boxes(total);
|
|
||||||
for (int i = 0; i < total; ++i) {
|
|
||||||
box &b = boxes[i];
|
|
||||||
int box_index = i * (classes + coords + 1);
|
int box_index = i * (classes + coords + 1);
|
||||||
b.x = detections[box_index + 0];
|
b.width = detections[box_index + 2];
|
||||||
b.y = detections[box_index + 1];
|
b.height = detections[box_index + 3];
|
||||||
b.w = detections[box_index + 2];
|
b.x = detections[box_index + 0] - b.width / 2;
|
||||||
b.h = detections[box_index + 3];
|
b.y = detections[box_index + 1] - b.height / 2;
|
||||||
int class_index = i * (classes + 5) + 5;
|
|
||||||
b.probs = (detections + class_index);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<sortable_bbox> s(total);
|
std::vector<int> indices;
|
||||||
|
for (int k = 0; k < classes; ++k)
|
||||||
for (int i = 0; i < total; ++i) {
|
{
|
||||||
s[i].index = i;
|
for (int i = 0; i < total; ++i)
|
||||||
int class_index = i * (classes + 5) + 5;
|
|
||||||
s[i].probs = (detections + class_index);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int k = 0; k < classes; ++k) {
|
|
||||||
std::stable_sort(s.begin(), s.end(), nms_comparator(k));
|
|
||||||
for (int i = 0; i < total; ++i) {
|
|
||||||
if (boxes[s[i].index].probs[k] == 0) continue;
|
|
||||||
box a = boxes[s[i].index];
|
|
||||||
for (int j = i + 1; j < total; ++j) {
|
|
||||||
box b = boxes[s[j].index];
|
|
||||||
if (box_iou(a, b) > nms_thresh) {
|
|
||||||
boxes[s[j].index].probs[k] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void do_nms(float *detections, int total, float nms_thresh)
|
|
||||||
{
|
{
|
||||||
std::vector<box> boxes(total);
|
|
||||||
for (int i = 0; i < total; ++i) {
|
|
||||||
box &b = boxes[i];
|
|
||||||
int box_index = i * (classes + coords + 1);
|
int box_index = i * (classes + coords + 1);
|
||||||
b.x = detections[box_index + 0];
|
int class_index = box_index + 5;
|
||||||
b.y = detections[box_index + 1];
|
scores[i] = detections[class_index + k];
|
||||||
b.w = detections[box_index + 2];
|
detections[class_index + k] = 0;
|
||||||
b.h = detections[box_index + 3];
|
|
||||||
int class_index = i * (classes + 5) + 5;
|
|
||||||
b.probs = (detections + class_index);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < total; ++i) {
|
|
||||||
bool any = false;
|
|
||||||
for (int k = 0; k < classes; ++k) any = any || (boxes[i].probs[k] > 0);
|
|
||||||
if (!any) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (int j = i + 1; j < total; ++j) {
|
|
||||||
if (box_iou(boxes[i], boxes[j]) > nms_thresh) {
|
|
||||||
for (int k = 0; k < classes; ++k) {
|
|
||||||
if (boxes[i].probs[k] < boxes[j].probs[k]) boxes[i].probs[k] = 0;
|
|
||||||
else boxes[j].probs[k] = 0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
NMSFast_(boxes, scores, score_thresh, nms_thresh, 1, 0, indices, rectOverlap);
|
||||||
|
for (int i = 0, n = indices.size(); i < n; ++i)
|
||||||
|
{
|
||||||
|
int box_index = indices[i] * (classes + coords + 1);
|
||||||
|
int class_index = box_index + 5;
|
||||||
|
detections[class_index + k] = scores[indices[i]];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user