Fix yoloPostProcessing` to handle variable number of classes (nc)

Previously, the yoloPostProcessing function assumed that the number of classes (nc) was fixed at 80. This caused incorrect behavior when a different number of classes was specified, leading to mismatched output shapes.

This update modifies the code to use the provided `nc` value dynamically, ensuring that the output shapes are correctly calculated based on the specified number of classes. This prevents issues when `nc` is not equal to 80 and allows for greater flexibility in model configurations.
This commit is contained in:
KangJialiang 2024-12-12 15:41:14 +08:00
parent 1d4110884b
commit 25fe85bbbb
2 changed files with 10 additions and 10 deletions

View File

@ -2691,7 +2691,7 @@ void yoloPostProcessing(
}
if (model_name == "yolonas"){
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84]
// outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4]
Mat concat_out;
// squeeze the first dimension
outs[0] = outs[0].reshape(1, outs[0].size[1]);
@ -2701,12 +2701,12 @@ void yoloPostProcessing(
// remove the second element
outs.pop_back();
// unsqueeze the first dimension
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, 84});
outs[0] = outs[0].reshape(0, std::vector<int>{1, outs[0].size[0], outs[0].size[1]});
}
// assert if last dim is 85 or 84
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]");
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: ");
// assert if last dim is nc+5 or nc+4
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]");
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: ");
for (auto preds : outs){

View File

@ -125,7 +125,7 @@ void yoloPostProcessing(
if (model_name == "yolonas")
{
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84]
// outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4]
Mat concat_out;
// squeeze the first dimension
outs[0] = outs[0].reshape(1, outs[0].size[1]);
@ -135,12 +135,12 @@ void yoloPostProcessing(
// remove the second element
outs.pop_back();
// unsqueeze the first dimension
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, nc + 4});
outs[0] = outs[0].reshape(0, std::vector<int>{1, outs[0].size[0], outs[0].size[1]});
}
// assert if last dim is 85 or 84
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]");
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: ");
// assert if last dim is nc+5 or nc+4
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]");
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: ");
for (auto preds : outs)
{