mirror of
https://github.com/opencv/opencv.git
synced 2025-07-24 22:16:27 +08:00
Fix yoloPostProcessing
` to handle variable number of classes (nc)
Previously, the yoloPostProcessing function assumed that the number of classes (nc) was fixed at 80. This caused incorrect behavior when a different number of classes was specified, leading to mismatched output shapes. This update modifies the code to use the provided `nc` value dynamically, ensuring that the output shapes are correctly calculated based on the specified number of classes. This prevents issues when `nc` is not equal to 80 and allows for greater flexibility in model configurations.
This commit is contained in:
parent
1d4110884b
commit
25fe85bbbb
@ -2691,7 +2691,7 @@ void yoloPostProcessing(
|
||||
}
|
||||
|
||||
if (model_name == "yolonas"){
|
||||
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84]
|
||||
// outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4]
|
||||
Mat concat_out;
|
||||
// squeeze the first dimension
|
||||
outs[0] = outs[0].reshape(1, outs[0].size[1]);
|
||||
@ -2701,12 +2701,12 @@ void yoloPostProcessing(
|
||||
// remove the second element
|
||||
outs.pop_back();
|
||||
// unsqueeze the first dimension
|
||||
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, 84});
|
||||
outs[0] = outs[0].reshape(0, std::vector<int>{1, outs[0].size[0], outs[0].size[1]});
|
||||
}
|
||||
|
||||
// assert if last dim is 85 or 84
|
||||
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]");
|
||||
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: ");
|
||||
// assert if last dim is nc+5 or nc+4
|
||||
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]");
|
||||
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: ");
|
||||
|
||||
for (auto preds : outs){
|
||||
|
||||
|
@ -125,7 +125,7 @@ void yoloPostProcessing(
|
||||
|
||||
if (model_name == "yolonas")
|
||||
{
|
||||
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84]
|
||||
// outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4]
|
||||
Mat concat_out;
|
||||
// squeeze the first dimension
|
||||
outs[0] = outs[0].reshape(1, outs[0].size[1]);
|
||||
@ -135,12 +135,12 @@ void yoloPostProcessing(
|
||||
// remove the second element
|
||||
outs.pop_back();
|
||||
// unsqueeze the first dimension
|
||||
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, nc + 4});
|
||||
outs[0] = outs[0].reshape(0, std::vector<int>{1, outs[0].size[0], outs[0].size[1]});
|
||||
}
|
||||
|
||||
// assert if last dim is 85 or 84
|
||||
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]");
|
||||
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: ");
|
||||
// assert if last dim is nc+5 or nc+4
|
||||
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]");
|
||||
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: ");
|
||||
|
||||
for (auto preds : outs)
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user