Fix yoloPostProcessing` to handle variable number of classes (nc)

Previously, the yoloPostProcessing function assumed that the number of classes (nc) was fixed at 80. This caused incorrect behavior when a different number of classes was specified, leading to mismatched output shapes.

This update modifies the code to use the provided `nc` value dynamically, ensuring that the output shapes are correctly calculated based on the specified number of classes. This prevents issues when `nc` is not equal to 80 and allows for greater flexibility in model configurations.
This commit is contained in:
KangJialiang 2024-12-12 15:41:14 +08:00
parent 1d4110884b
commit 25fe85bbbb
2 changed files with 10 additions and 10 deletions

View File

@ -2691,7 +2691,7 @@ void yoloPostProcessing(
} }
if (model_name == "yolonas"){ if (model_name == "yolonas"){
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84] // outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4]
Mat concat_out; Mat concat_out;
// squeeze the first dimension // squeeze the first dimension
outs[0] = outs[0].reshape(1, outs[0].size[1]); outs[0] = outs[0].reshape(1, outs[0].size[1]);
@ -2701,12 +2701,12 @@ void yoloPostProcessing(
// remove the second element // remove the second element
outs.pop_back(); outs.pop_back();
// unsqueeze the first dimension // unsqueeze the first dimension
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, 84}); outs[0] = outs[0].reshape(0, std::vector<int>{1, outs[0].size[0], outs[0].size[1]});
} }
// assert if last dim is 85 or 84 // assert if last dim is nc+5 or nc+4
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]"); CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]");
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: "); CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: ");
for (auto preds : outs){ for (auto preds : outs){

View File

@ -125,7 +125,7 @@ void yoloPostProcessing(
if (model_name == "yolonas") if (model_name == "yolonas")
{ {
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84] // outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4]
Mat concat_out; Mat concat_out;
// squeeze the first dimension // squeeze the first dimension
outs[0] = outs[0].reshape(1, outs[0].size[1]); outs[0] = outs[0].reshape(1, outs[0].size[1]);
@ -135,12 +135,12 @@ void yoloPostProcessing(
// remove the second element // remove the second element
outs.pop_back(); outs.pop_back();
// unsqueeze the first dimension // unsqueeze the first dimension
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, nc + 4}); outs[0] = outs[0].reshape(0, std::vector<int>{1, outs[0].size[0], outs[0].size[1]});
} }
// assert if last dim is 85 or 84 // assert if last dim is nc+5 or nc+4
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]"); CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]");
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: "); CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: ");
for (auto preds : outs) for (auto preds : outs)
{ {