mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
added test for yolo nas
This commit is contained in:
parent
84bb1cda4e
commit
d30bf1bc3c
@ -2666,24 +2666,36 @@ void yoloPostProcessing(
|
|||||||
cv::transposeND(outs[0], {0, 2, 1}, outs[0]);
|
cv::transposeND(outs[0], {0, 2, 1}, outs[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// each row is [cx, cy, w, h, conf_obj, conf_class1, ..., conf_class80]
|
if (test_name == "yolonas"){
|
||||||
|
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84]
|
||||||
|
Mat concat_out;
|
||||||
|
// squeeze the first dimension
|
||||||
|
outs[0] = outs[0].reshape(1, outs[0].size[1]);
|
||||||
|
outs[1] = outs[1].reshape(1, outs[1].size[1]);
|
||||||
|
cv::hconcat(outs[1], outs[0], concat_out);
|
||||||
|
outs[0] = concat_out;
|
||||||
|
// remove the second element
|
||||||
|
outs.pop_back();
|
||||||
|
// unsqueeze the first dimension
|
||||||
|
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, 84});
|
||||||
|
}
|
||||||
|
|
||||||
for (auto preds : outs){
|
for (auto preds : outs){
|
||||||
|
|
||||||
preds = preds.reshape(1, preds.size[1]); // [1, 8400, 85] -> [8400, 85]
|
preds = preds.reshape(1, preds.size[1]); // [1, 8400, 85] -> [8400, 85]
|
||||||
|
|
||||||
for (int i = 0; i < preds.rows; ++i)
|
for (int i = 0; i < preds.rows; ++i)
|
||||||
{
|
{
|
||||||
// filter out non objects
|
// filter out non object
|
||||||
float obj_conf = (test_name != "yolov8") ? preds.at<float>(i, 4) : 1.0f;
|
float obj_conf = (test_name == "yolov8" || test_name == "yolonas") ? 1.0f : preds.at<float>(i, 4) ;
|
||||||
if (obj_conf < conf_threshold)
|
if (obj_conf < conf_threshold)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
Mat scores = preds.row(i).colRange((test_name != "yolov8") ? 5 : 4, preds.cols);
|
Mat scores = preds.row(i).colRange((test_name == "yolov8" || test_name == "yolonas") ? 4 : 5, preds.cols);
|
||||||
double conf;
|
double conf;
|
||||||
Point maxLoc;
|
Point maxLoc;
|
||||||
minMaxLoc(scores, 0, &conf, 0, &maxLoc);
|
minMaxLoc(scores, 0, &conf, 0, &maxLoc);
|
||||||
|
|
||||||
conf = (test_name != "yolov8") ? conf * obj_conf : conf;
|
conf = (test_name == "yolov8" || test_name == "yolonas") ? conf : conf * obj_conf;
|
||||||
if (conf < conf_threshold)
|
if (conf < conf_threshold)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
@ -2694,10 +2706,15 @@ void yoloPostProcessing(
|
|||||||
double w = det[2];
|
double w = det[2];
|
||||||
double h = det[3];
|
double h = det[3];
|
||||||
|
|
||||||
|
// std::cout << "cx: " << cx << " cy: " << cy << " w: " << w << " h: " << h << " conf: " << conf << " idx: " << maxLoc.x << std::endl;
|
||||||
// [x1, y1, x2, y2]
|
// [x1, y1, x2, y2]
|
||||||
boxes.push_back(Rect2d(cx - 0.5 * w, cy - 0.5 * h,
|
if (test_name == "yolonas"){
|
||||||
cx + 0.5 * w, cy + 0.5 * h));
|
boxes.push_back(Rect2d(cx, cy, w, h));
|
||||||
classIds.push_back(maxLoc.x);
|
} else {
|
||||||
|
boxes.push_back(Rect2d(cx - 0.5 * w, cy - 0.5 * h,
|
||||||
|
cx + 0.5 * w, cy + 0.5 * h));
|
||||||
|
}
|
||||||
|
classIds.push_back(maxLoc.x);
|
||||||
confidences.push_back(conf);
|
confidences.push_back(conf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2751,6 +2768,41 @@ TEST_P(Test_ONNX_nets, YOLOX)
|
|||||||
1.0e-4, 1.0e-4);
|
1.0e-4, 1.0e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_ONNX_nets, YOLONas)
|
||||||
|
{
|
||||||
|
// model information: https://dl.opencv.org/models/yolo-nas/Readme.md
|
||||||
|
std::string weightPath = _tf("models/yolo_nas_s.onnx", false);
|
||||||
|
|
||||||
|
Size targetSize{640, 640};
|
||||||
|
float conf_threshold = 0.50;
|
||||||
|
float iou_threshold = 0.50;
|
||||||
|
|
||||||
|
std::vector<int> refClassIds{1, 16, 7};
|
||||||
|
std::vector<float> refScores{0.9720f, 0.9283f, 0.8990f};
|
||||||
|
// [x1, y1, x2, y2]
|
||||||
|
std::vector<Rect2d> refBoxes{
|
||||||
|
Rect2d(105.516, 173.696, 471.323, 430.433),
|
||||||
|
Rect2d(109.241, 263.406, 259.872, 531.858),
|
||||||
|
Rect2d(390.153, 142.492, 574.932, 222.709)
|
||||||
|
};
|
||||||
|
|
||||||
|
Image2BlobParams imgParams(
|
||||||
|
Scalar::all(1/255.0),
|
||||||
|
targetSize,
|
||||||
|
Scalar::all(0),
|
||||||
|
false,
|
||||||
|
CV_32F,
|
||||||
|
DNN_LAYOUT_NCHW,
|
||||||
|
DNN_PMODE_LETTERBOX,
|
||||||
|
Scalar::all(114)
|
||||||
|
);
|
||||||
|
|
||||||
|
testYOLO(
|
||||||
|
weightPath, refClassIds, refScores, refBoxes,
|
||||||
|
imgParams, conf_threshold, iou_threshold,
|
||||||
|
1.0e-4, 1.0e-4, "yolonas");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(Test_ONNX_nets, YOLOv8)
|
TEST_P(Test_ONNX_nets, YOLOv8)
|
||||||
{
|
{
|
||||||
std::string weightPath = _tf("models/yolov8n.onnx", false);
|
std::string weightPath = _tf("models/yolov8n.onnx", false);
|
||||||
|
Loading…
Reference in New Issue
Block a user