diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 4d56cb0e17..dc8d7dc0a0 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -2666,24 +2666,36 @@ void yoloPostProcessing( cv::transposeND(outs[0], {0, 2, 1}, outs[0]); } - // each row is [cx, cy, w, h, conf_obj, conf_class1, ..., conf_class80] + if (test_name == "yolonas"){ + // outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84] + Mat concat_out; + // squeeze the first dimension + outs[0] = outs[0].reshape(1, outs[0].size[1]); + outs[1] = outs[1].reshape(1, outs[1].size[1]); + cv::hconcat(outs[1], outs[0], concat_out); + outs[0] = concat_out; + // remove the second element + outs.pop_back(); + // unsqueeze the first dimension + outs[0] = outs[0].reshape(0, std::vector{1, 8400, 84}); + } + for (auto preds : outs){ preds = preds.reshape(1, preds.size[1]); // [1, 8400, 85] -> [8400, 85] - for (int i = 0; i < preds.rows; ++i) { - // filter out non objects - float obj_conf = (test_name != "yolov8") ? preds.at(i, 4) : 1.0f; + // filter out non object + float obj_conf = (test_name == "yolov8" || test_name == "yolonas") ? 1.0f : preds.at(i, 4) ; if (obj_conf < conf_threshold) continue; - Mat scores = preds.row(i).colRange((test_name != "yolov8") ? 5 : 4, preds.cols); + Mat scores = preds.row(i).colRange((test_name == "yolov8" || test_name == "yolonas") ? 4 : 5, preds.cols); double conf; Point maxLoc; minMaxLoc(scores, 0, &conf, 0, &maxLoc); - conf = (test_name != "yolov8") ? conf * obj_conf : conf; + conf = (test_name == "yolov8" || test_name == "yolonas") ? conf : conf * obj_conf; if (conf < conf_threshold) continue; @@ -2694,10 +2706,15 @@ void yoloPostProcessing( double w = det[2]; double h = det[3]; + // std::cout << "cx: " << cx << " cy: " << cy << " w: " << w << " h: " << h << " conf: " << conf << " idx: " << maxLoc.x << std::endl; // [x1, y1, x2, y2] - boxes.push_back(Rect2d(cx - 0.5 * w, cy - 0.5 * h, - cx + 0.5 * w, cy + 0.5 * h)); - classIds.push_back(maxLoc.x); + if (test_name == "yolonas"){ + boxes.push_back(Rect2d(cx, cy, w, h)); + } else { + boxes.push_back(Rect2d(cx - 0.5 * w, cy - 0.5 * h, + cx + 0.5 * w, cy + 0.5 * h)); + } + classIds.push_back(maxLoc.x); confidences.push_back(conf); } } @@ -2751,6 +2768,41 @@ TEST_P(Test_ONNX_nets, YOLOX) 1.0e-4, 1.0e-4); } +TEST_P(Test_ONNX_nets, YOLONas) +{ + // model information: https://dl.opencv.org/models/yolo-nas/Readme.md + std::string weightPath = _tf("models/yolo_nas_s.onnx", false); + + Size targetSize{640, 640}; + float conf_threshold = 0.50; + float iou_threshold = 0.50; + + std::vector refClassIds{1, 16, 7}; + std::vector refScores{0.9720f, 0.9283f, 0.8990f}; + // [x1, y1, x2, y2] + std::vector refBoxes{ + Rect2d(105.516, 173.696, 471.323, 430.433), + Rect2d(109.241, 263.406, 259.872, 531.858), + Rect2d(390.153, 142.492, 574.932, 222.709) + }; + + Image2BlobParams imgParams( + Scalar::all(1/255.0), + targetSize, + Scalar::all(0), + false, + CV_32F, + DNN_LAYOUT_NCHW, + DNN_PMODE_LETTERBOX, + Scalar::all(114) + ); + + testYOLO( + weightPath, refClassIds, refScores, refBoxes, + imgParams, conf_threshold, iou_threshold, + 1.0e-4, 1.0e-4, "yolonas"); +} + TEST_P(Test_ONNX_nets, YOLOv8) { std::string weightPath = _tf("models/yolov8n.onnx", false);