Add ability to read thresh and nms_threshold from YOLO layer in YOLOV3 cfg file.

Currently the thresh is hard-coded to be 0.2 and nms_threshold as 0.4.
This commit is contained in:
Easton Liu 2019-03-07 09:55:48 +08:00
parent 3132c8ee08
commit fcfb29766b

View File

@ -371,7 +371,7 @@ namespace cv {
fused_layer_names.push_back(last_layer);
}
void setYolo(int classes, const std::vector<int>& mask, const std::vector<float>& anchors)
void setYolo(int classes, const std::vector<int>& mask, const std::vector<float>& anchors, float thresh, float nms_threshold)
{
cv::dnn::LayerParams region_param;
region_param.name = "Region-name";
@ -382,6 +382,8 @@ namespace cv {
region_param.set<int>("classes", classes);
region_param.set<int>("anchors", numAnchors);
region_param.set<bool>("logistic", true);
region_param.set<float>("thresh", thresh);
region_param.set<float>("nms_threshold", nms_threshold);
std::vector<float> usedAnchors(numAnchors * 2);
for (int i = 0; i < numAnchors; ++i)
@ -646,6 +648,8 @@ namespace cv {
{
int classes = getParam<int>(layer_params, "classes", -1);
int num_of_anchors = getParam<int>(layer_params, "num", -1);
float thresh = getParam<float>(layer_params, "thresh", 0.2);
float nms_threshold = getParam<float>(layer_params, "nms_threshold", 0.4);
std::string anchors_values = getParam<std::string>(layer_params, "anchors", std::string());
CV_Assert(!anchors_values.empty());
@ -658,7 +662,7 @@ namespace cv {
CV_Assert(classes > 0 && num_of_anchors > 0 && (num_of_anchors * 2) == anchors_vec.size());
setParams.setPermute(false);
setParams.setYolo(classes, mask_vec, anchors_vec);
setParams.setYolo(classes, mask_vec, anchors_vec, thresh, nms_threshold);
}
else {
CV_Error(cv::Error::StsParseError, "Unknown layer type: " + layer_type);