boost NMS performance

This commit is contained in:
Qoo 2021-02-24 04:32:45 -05:00
parent 7ffc4b57aa
commit 47337e2196
2 changed files with 18 additions and 4 deletions

View File

@ -133,6 +133,12 @@ public:
typedef std::map<int, std::vector<util::NormalizedBBox> > LabelBBox;
inline int getNumOfTargetClasses() {
unsigned numBackground =
(_backgroundLabelId >= 0 && _backgroundLabelId < _numClasses) ? 1 : 0;
return (_numClasses - numBackground);
}
bool getParameterDict(const LayerParams &params,
const std::string &parameterName,
DictValue& result)
@ -584,12 +590,13 @@ public:
LabelBBox::const_iterator label_bboxes = decodeBBoxes.find(label);
if (label_bboxes == decodeBBoxes.end())
CV_Error_(cv::Error::StsError, ("Could not find location predictions for label %d", label));
int limit = (getNumOfTargetClasses() == 1) ? _keepTopK : std::numeric_limits<int>::max();
if (_bboxesNormalized)
NMSFast_(label_bboxes->second, scores, _confidenceThreshold, _nmsThreshold, 1.0, _topK,
indices[c], util::caffe_norm_box_overlap);
indices[c], util::caffe_norm_box_overlap, limit);
else
NMSFast_(label_bboxes->second, scores, _confidenceThreshold, _nmsThreshold, 1.0, _topK,
indices[c], util::caffe_box_overlap);
indices[c], util::caffe_box_overlap, limit);
numDetections += indices[c].size();
}
if (_keepTopK > -1 && numDetections > (size_t)_keepTopK)

View File

@ -62,12 +62,15 @@ inline void GetMaxScoreIndex(const std::vector<float>& scores, const float thres
// score_threshold: a threshold used to filter detection results.
// nms_threshold: a threshold used in non maximum suppression.
// top_k: if not > 0, keep at most top_k picked indices.
// limit: early terminate once the # of picked indices has reached it.
// indices: the kept indices of bboxes after nms.
template <typename BoxType>
inline void NMSFast_(const std::vector<BoxType>& bboxes,
const std::vector<float>& scores, const float score_threshold,
const float nms_threshold, const float eta, const int top_k,
std::vector<int>& indices, float (*computeOverlap)(const BoxType&, const BoxType&))
std::vector<int>& indices,
float (*computeOverlap)(const BoxType&, const BoxType&),
int limit = std::numeric_limits<int>::max())
{
CV_Assert(bboxes.size() == scores.size());
@ -86,8 +89,12 @@ inline void NMSFast_(const std::vector<BoxType>& bboxes,
float overlap = computeOverlap(bboxes[idx], bboxes[kept_idx]);
keep = overlap <= adaptive_threshold;
}
if (keep)
if (keep) {
indices.push_back(idx);
if (indices.size() >= limit) {
break;
}
}
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}