diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 8cfc469a7e..66c33eb0a5 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -2691,7 +2691,7 @@ void yoloPostProcessing( } if (model_name == "yolonas"){ - // outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84] + // outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4] Mat concat_out; // squeeze the first dimension outs[0] = outs[0].reshape(1, outs[0].size[1]); @@ -2701,12 +2701,12 @@ void yoloPostProcessing( // remove the second element outs.pop_back(); // unsqueeze the first dimension - outs[0] = outs[0].reshape(0, std::vector{1, 8400, 84}); + outs[0] = outs[0].reshape(0, std::vector{1, outs[0].size[0], outs[0].size[1]}); } - // assert if last dim is 85 or 84 - CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]"); - CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: "); + // assert if last dim is nc+5 or nc+4 + CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]"); + CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: "); for (auto preds : outs){ diff --git a/samples/dnn/yolo_detector.cpp b/samples/dnn/yolo_detector.cpp index bd82acff4a..57493cdfb0 100644 --- a/samples/dnn/yolo_detector.cpp +++ b/samples/dnn/yolo_detector.cpp @@ -125,7 +125,7 @@ void yoloPostProcessing( if (model_name == "yolonas") { - // outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84] + // outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4] Mat concat_out; // squeeze the first dimension outs[0] = outs[0].reshape(1, outs[0].size[1]); @@ -135,12 +135,12 @@ void yoloPostProcessing( // remove the second element outs.pop_back(); // unsqueeze the first dimension - outs[0] = outs[0].reshape(0, std::vector{1, 8400, nc + 4}); + outs[0] = outs[0].reshape(0, std::vector{1, outs[0].size[0], outs[0].size[1]}); } - // assert if last dim is 85 or 84 - CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]"); - CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: "); + // assert if last dim is nc+5 or nc+4 + CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]"); + CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: "); for (auto preds : outs) {