mirror of
https://github.com/opencv/opencv.git
synced 2025-07-20 11:06:38 +08:00
Refactor NMS procedure at RegionLayer
This commit is contained in:
parent
047ad4ff71
commit
c67e75b68f
@ -482,7 +482,7 @@ namespace cv {
|
||||
}
|
||||
else if (layer_type == "region")
|
||||
{
|
||||
float thresh = 0.001; // in the original Darknet is equal to the detection threshold set by the user
|
||||
float thresh = getParam<float>(layer_params, "thresh", 0.001);
|
||||
int coords = getParam<int>(layer_params, "coords", 4);
|
||||
int classes = getParam<int>(layer_params, "classes", -1);
|
||||
int num_of_anchors = getParam<int>(layer_params, "num", -1);
|
||||
|
@ -43,7 +43,7 @@
|
||||
#include "../precomp.hpp"
|
||||
#include <opencv2/dnn/shape_utils.hpp>
|
||||
#include <opencv2/dnn/all_layers.hpp>
|
||||
#include <iostream>
|
||||
#include "nms.inl.hpp"
|
||||
#include "opencl_kernels_dnn.hpp"
|
||||
|
||||
namespace cv
|
||||
@ -173,8 +173,7 @@ public:
|
||||
if (nmsThreshold > 0) {
|
||||
Mat mat = outBlob.getMat(ACCESS_WRITE);
|
||||
float *dstData = mat.ptr<float>();
|
||||
do_nms_sort(dstData, rows*cols*anchors, nmsThreshold);
|
||||
//do_nms(dstData, rows*cols*anchors, nmsThreshold);
|
||||
do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold);
|
||||
}
|
||||
|
||||
}
|
||||
@ -263,128 +262,48 @@ public:
|
||||
}
|
||||
|
||||
if (nmsThreshold > 0) {
|
||||
do_nms_sort(dstData, rows*cols*anchors, nmsThreshold);
|
||||
//do_nms(dstData, rows*cols*anchors, nmsThreshold);
|
||||
do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
struct box {
|
||||
float x, y, w, h;
|
||||
float *probs;
|
||||
};
|
||||
|
||||
float overlap(float x1, float w1, float x2, float w2)
|
||||
static inline float rectOverlap(const Rect2f& a, const Rect2f& b)
|
||||
{
|
||||
float l1 = x1 - w1 / 2;
|
||||
float l2 = x2 - w2 / 2;
|
||||
float left = l1 > l2 ? l1 : l2;
|
||||
float r1 = x1 + w1 / 2;
|
||||
float r2 = x2 + w2 / 2;
|
||||
float right = r1 < r2 ? r1 : r2;
|
||||
return right - left;
|
||||
return 1.0f - jaccardDistance(a, b);
|
||||
}
|
||||
|
||||
float box_intersection(box a, box b)
|
||||
void do_nms_sort(float *detections, int total, float score_thresh, float nms_thresh)
|
||||
{
|
||||
float w = overlap(a.x, a.w, b.x, b.w);
|
||||
float h = overlap(a.y, a.h, b.y, b.h);
|
||||
if (w < 0 || h < 0) return 0;
|
||||
float area = w*h;
|
||||
return area;
|
||||
}
|
||||
std::vector<Rect2f> boxes(total);
|
||||
std::vector<float> scores(total);
|
||||
|
||||
float box_union(box a, box b)
|
||||
{
|
||||
float i = box_intersection(a, b);
|
||||
float u = a.w*a.h + b.w*b.h - i;
|
||||
return u;
|
||||
}
|
||||
|
||||
float box_iou(box a, box b)
|
||||
{
|
||||
return box_intersection(a, b) / box_union(a, b);
|
||||
}
|
||||
|
||||
struct sortable_bbox {
|
||||
int index;
|
||||
float *probs;
|
||||
};
|
||||
|
||||
struct nms_comparator {
|
||||
int k;
|
||||
nms_comparator(int _k) : k(_k) {}
|
||||
bool operator ()(sortable_bbox v1, sortable_bbox v2) {
|
||||
return v2.probs[k] < v1.probs[k];
|
||||
}
|
||||
};
|
||||
|
||||
void do_nms_sort(float *detections, int total, float nms_thresh)
|
||||
{
|
||||
std::vector<box> boxes(total);
|
||||
for (int i = 0; i < total; ++i) {
|
||||
box &b = boxes[i];
|
||||
for (int i = 0; i < total; ++i)
|
||||
{
|
||||
Rect2f &b = boxes[i];
|
||||
int box_index = i * (classes + coords + 1);
|
||||
b.x = detections[box_index + 0];
|
||||
b.y = detections[box_index + 1];
|
||||
b.w = detections[box_index + 2];
|
||||
b.h = detections[box_index + 3];
|
||||
int class_index = i * (classes + 5) + 5;
|
||||
b.probs = (detections + class_index);
|
||||
b.width = detections[box_index + 2];
|
||||
b.height = detections[box_index + 3];
|
||||
b.x = detections[box_index + 0] - b.width / 2;
|
||||
b.y = detections[box_index + 1] - b.height / 2;
|
||||
}
|
||||
|
||||
std::vector<sortable_bbox> s(total);
|
||||
|
||||
for (int i = 0; i < total; ++i) {
|
||||
s[i].index = i;
|
||||
int class_index = i * (classes + 5) + 5;
|
||||
s[i].probs = (detections + class_index);
|
||||
}
|
||||
|
||||
for (int k = 0; k < classes; ++k) {
|
||||
std::stable_sort(s.begin(), s.end(), nms_comparator(k));
|
||||
for (int i = 0; i < total; ++i) {
|
||||
if (boxes[s[i].index].probs[k] == 0) continue;
|
||||
box a = boxes[s[i].index];
|
||||
for (int j = i + 1; j < total; ++j) {
|
||||
box b = boxes[s[j].index];
|
||||
if (box_iou(a, b) > nms_thresh) {
|
||||
boxes[s[j].index].probs[k] = 0;
|
||||
}
|
||||
}
|
||||
std::vector<int> indices;
|
||||
for (int k = 0; k < classes; ++k)
|
||||
{
|
||||
for (int i = 0; i < total; ++i)
|
||||
{
|
||||
int box_index = i * (classes + coords + 1);
|
||||
int class_index = box_index + 5;
|
||||
scores[i] = detections[class_index + k];
|
||||
detections[class_index + k] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void do_nms(float *detections, int total, float nms_thresh)
|
||||
{
|
||||
std::vector<box> boxes(total);
|
||||
for (int i = 0; i < total; ++i) {
|
||||
box &b = boxes[i];
|
||||
int box_index = i * (classes + coords + 1);
|
||||
b.x = detections[box_index + 0];
|
||||
b.y = detections[box_index + 1];
|
||||
b.w = detections[box_index + 2];
|
||||
b.h = detections[box_index + 3];
|
||||
int class_index = i * (classes + 5) + 5;
|
||||
b.probs = (detections + class_index);
|
||||
}
|
||||
|
||||
for (int i = 0; i < total; ++i) {
|
||||
bool any = false;
|
||||
for (int k = 0; k < classes; ++k) any = any || (boxes[i].probs[k] > 0);
|
||||
if (!any) {
|
||||
continue;
|
||||
}
|
||||
for (int j = i + 1; j < total; ++j) {
|
||||
if (box_iou(boxes[i], boxes[j]) > nms_thresh) {
|
||||
for (int k = 0; k < classes; ++k) {
|
||||
if (boxes[i].probs[k] < boxes[j].probs[k]) boxes[i].probs[k] = 0;
|
||||
else boxes[j].probs[k] = 0;
|
||||
}
|
||||
}
|
||||
NMSFast_(boxes, scores, score_thresh, nms_thresh, 1, 0, indices, rectOverlap);
|
||||
for (int i = 0, n = indices.size(); i < n; ++i)
|
||||
{
|
||||
int box_index = indices[i] * (classes + coords + 1);
|
||||
int class_index = box_index + 5;
|
||||
detections[class_index + k] = scores[indices[i]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user