added test for yolo nas

This commit is contained in:
Abduragim 2024-01-03 12:42:10 +03:00 committed by Alexander Smorkalov
parent 84bb1cda4e
commit d30bf1bc3c

View File

@ -2666,24 +2666,36 @@ void yoloPostProcessing(
cv::transposeND(outs[0], {0, 2, 1}, outs[0]);
}
// each row is [cx, cy, w, h, conf_obj, conf_class1, ..., conf_class80]
if (test_name == "yolonas"){
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84]
Mat concat_out;
// squeeze the first dimension
outs[0] = outs[0].reshape(1, outs[0].size[1]);
outs[1] = outs[1].reshape(1, outs[1].size[1]);
cv::hconcat(outs[1], outs[0], concat_out);
outs[0] = concat_out;
// remove the second element
outs.pop_back();
// unsqueeze the first dimension
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, 84});
}
for (auto preds : outs){
preds = preds.reshape(1, preds.size[1]); // [1, 8400, 85] -> [8400, 85]
for (int i = 0; i < preds.rows; ++i)
{
// filter out non objects
float obj_conf = (test_name != "yolov8") ? preds.at<float>(i, 4) : 1.0f;
// filter out non object
float obj_conf = (test_name == "yolov8" || test_name == "yolonas") ? 1.0f : preds.at<float>(i, 4) ;
if (obj_conf < conf_threshold)
continue;
Mat scores = preds.row(i).colRange((test_name != "yolov8") ? 5 : 4, preds.cols);
Mat scores = preds.row(i).colRange((test_name == "yolov8" || test_name == "yolonas") ? 4 : 5, preds.cols);
double conf;
Point maxLoc;
minMaxLoc(scores, 0, &conf, 0, &maxLoc);
conf = (test_name != "yolov8") ? conf * obj_conf : conf;
conf = (test_name == "yolov8" || test_name == "yolonas") ? conf : conf * obj_conf;
if (conf < conf_threshold)
continue;
@ -2694,10 +2706,15 @@ void yoloPostProcessing(
double w = det[2];
double h = det[3];
// std::cout << "cx: " << cx << " cy: " << cy << " w: " << w << " h: " << h << " conf: " << conf << " idx: " << maxLoc.x << std::endl;
// [x1, y1, x2, y2]
boxes.push_back(Rect2d(cx - 0.5 * w, cy - 0.5 * h,
cx + 0.5 * w, cy + 0.5 * h));
classIds.push_back(maxLoc.x);
if (test_name == "yolonas"){
boxes.push_back(Rect2d(cx, cy, w, h));
} else {
boxes.push_back(Rect2d(cx - 0.5 * w, cy - 0.5 * h,
cx + 0.5 * w, cy + 0.5 * h));
}
classIds.push_back(maxLoc.x);
confidences.push_back(conf);
}
}
@ -2751,6 +2768,41 @@ TEST_P(Test_ONNX_nets, YOLOX)
1.0e-4, 1.0e-4);
}
TEST_P(Test_ONNX_nets, YOLONas)
{
// model information: https://dl.opencv.org/models/yolo-nas/Readme.md
std::string weightPath = _tf("models/yolo_nas_s.onnx", false);
Size targetSize{640, 640};
float conf_threshold = 0.50;
float iou_threshold = 0.50;
std::vector<int> refClassIds{1, 16, 7};
std::vector<float> refScores{0.9720f, 0.9283f, 0.8990f};
// [x1, y1, x2, y2]
std::vector<Rect2d> refBoxes{
Rect2d(105.516, 173.696, 471.323, 430.433),
Rect2d(109.241, 263.406, 259.872, 531.858),
Rect2d(390.153, 142.492, 574.932, 222.709)
};
Image2BlobParams imgParams(
Scalar::all(1/255.0),
targetSize,
Scalar::all(0),
false,
CV_32F,
DNN_LAYOUT_NCHW,
DNN_PMODE_LETTERBOX,
Scalar::all(114)
);
testYOLO(
weightPath, refClassIds, refScores, refBoxes,
imgParams, conf_threshold, iou_threshold,
1.0e-4, 1.0e-4, "yolonas");
}
TEST_P(Test_ONNX_nets, YOLOv8)
{
std::string weightPath = _tf("models/yolov8n.onnx", false);