mirror of
https://github.com/opencv/opencv.git
synced 2024-11-27 20:50:25 +08:00
added test for yolo nas
This commit is contained in:
parent
84bb1cda4e
commit
d30bf1bc3c
@ -2666,24 +2666,36 @@ void yoloPostProcessing(
|
||||
cv::transposeND(outs[0], {0, 2, 1}, outs[0]);
|
||||
}
|
||||
|
||||
// each row is [cx, cy, w, h, conf_obj, conf_class1, ..., conf_class80]
|
||||
if (test_name == "yolonas"){
|
||||
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84]
|
||||
Mat concat_out;
|
||||
// squeeze the first dimension
|
||||
outs[0] = outs[0].reshape(1, outs[0].size[1]);
|
||||
outs[1] = outs[1].reshape(1, outs[1].size[1]);
|
||||
cv::hconcat(outs[1], outs[0], concat_out);
|
||||
outs[0] = concat_out;
|
||||
// remove the second element
|
||||
outs.pop_back();
|
||||
// unsqueeze the first dimension
|
||||
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, 84});
|
||||
}
|
||||
|
||||
for (auto preds : outs){
|
||||
|
||||
preds = preds.reshape(1, preds.size[1]); // [1, 8400, 85] -> [8400, 85]
|
||||
|
||||
for (int i = 0; i < preds.rows; ++i)
|
||||
{
|
||||
// filter out non objects
|
||||
float obj_conf = (test_name != "yolov8") ? preds.at<float>(i, 4) : 1.0f;
|
||||
// filter out non object
|
||||
float obj_conf = (test_name == "yolov8" || test_name == "yolonas") ? 1.0f : preds.at<float>(i, 4) ;
|
||||
if (obj_conf < conf_threshold)
|
||||
continue;
|
||||
|
||||
Mat scores = preds.row(i).colRange((test_name != "yolov8") ? 5 : 4, preds.cols);
|
||||
Mat scores = preds.row(i).colRange((test_name == "yolov8" || test_name == "yolonas") ? 4 : 5, preds.cols);
|
||||
double conf;
|
||||
Point maxLoc;
|
||||
minMaxLoc(scores, 0, &conf, 0, &maxLoc);
|
||||
|
||||
conf = (test_name != "yolov8") ? conf * obj_conf : conf;
|
||||
conf = (test_name == "yolov8" || test_name == "yolonas") ? conf : conf * obj_conf;
|
||||
if (conf < conf_threshold)
|
||||
continue;
|
||||
|
||||
@ -2694,10 +2706,15 @@ void yoloPostProcessing(
|
||||
double w = det[2];
|
||||
double h = det[3];
|
||||
|
||||
// std::cout << "cx: " << cx << " cy: " << cy << " w: " << w << " h: " << h << " conf: " << conf << " idx: " << maxLoc.x << std::endl;
|
||||
// [x1, y1, x2, y2]
|
||||
boxes.push_back(Rect2d(cx - 0.5 * w, cy - 0.5 * h,
|
||||
cx + 0.5 * w, cy + 0.5 * h));
|
||||
classIds.push_back(maxLoc.x);
|
||||
if (test_name == "yolonas"){
|
||||
boxes.push_back(Rect2d(cx, cy, w, h));
|
||||
} else {
|
||||
boxes.push_back(Rect2d(cx - 0.5 * w, cy - 0.5 * h,
|
||||
cx + 0.5 * w, cy + 0.5 * h));
|
||||
}
|
||||
classIds.push_back(maxLoc.x);
|
||||
confidences.push_back(conf);
|
||||
}
|
||||
}
|
||||
@ -2751,6 +2768,41 @@ TEST_P(Test_ONNX_nets, YOLOX)
|
||||
1.0e-4, 1.0e-4);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_nets, YOLONas)
|
||||
{
|
||||
// model information: https://dl.opencv.org/models/yolo-nas/Readme.md
|
||||
std::string weightPath = _tf("models/yolo_nas_s.onnx", false);
|
||||
|
||||
Size targetSize{640, 640};
|
||||
float conf_threshold = 0.50;
|
||||
float iou_threshold = 0.50;
|
||||
|
||||
std::vector<int> refClassIds{1, 16, 7};
|
||||
std::vector<float> refScores{0.9720f, 0.9283f, 0.8990f};
|
||||
// [x1, y1, x2, y2]
|
||||
std::vector<Rect2d> refBoxes{
|
||||
Rect2d(105.516, 173.696, 471.323, 430.433),
|
||||
Rect2d(109.241, 263.406, 259.872, 531.858),
|
||||
Rect2d(390.153, 142.492, 574.932, 222.709)
|
||||
};
|
||||
|
||||
Image2BlobParams imgParams(
|
||||
Scalar::all(1/255.0),
|
||||
targetSize,
|
||||
Scalar::all(0),
|
||||
false,
|
||||
CV_32F,
|
||||
DNN_LAYOUT_NCHW,
|
||||
DNN_PMODE_LETTERBOX,
|
||||
Scalar::all(114)
|
||||
);
|
||||
|
||||
testYOLO(
|
||||
weightPath, refClassIds, refScores, refBoxes,
|
||||
imgParams, conf_threshold, iou_threshold,
|
||||
1.0e-4, 1.0e-4, "yolonas");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_nets, YOLOv8)
|
||||
{
|
||||
std::string weightPath = _tf("models/yolov8n.onnx", false);
|
||||
|
Loading…
Reference in New Issue
Block a user