mirror of
https://github.com/opencv/opencv.git
synced 2025-07-25 22:57:53 +08:00
Fix yoloPostProcessing
` to handle variable number of classes (nc)
Previously, the yoloPostProcessing function assumed that the number of classes (nc) was fixed at 80. This caused incorrect behavior when a different number of classes was specified, leading to mismatched output shapes. This update modifies the code to use the provided `nc` value dynamically, ensuring that the output shapes are correctly calculated based on the specified number of classes. This prevents issues when `nc` is not equal to 80 and allows for greater flexibility in model configurations.
This commit is contained in:
parent
1d4110884b
commit
25fe85bbbb
@ -2691,7 +2691,7 @@ void yoloPostProcessing(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (model_name == "yolonas"){
|
if (model_name == "yolonas"){
|
||||||
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84]
|
// outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4]
|
||||||
Mat concat_out;
|
Mat concat_out;
|
||||||
// squeeze the first dimension
|
// squeeze the first dimension
|
||||||
outs[0] = outs[0].reshape(1, outs[0].size[1]);
|
outs[0] = outs[0].reshape(1, outs[0].size[1]);
|
||||||
@ -2701,12 +2701,12 @@ void yoloPostProcessing(
|
|||||||
// remove the second element
|
// remove the second element
|
||||||
outs.pop_back();
|
outs.pop_back();
|
||||||
// unsqueeze the first dimension
|
// unsqueeze the first dimension
|
||||||
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, 84});
|
outs[0] = outs[0].reshape(0, std::vector<int>{1, outs[0].size[0], outs[0].size[1]});
|
||||||
}
|
}
|
||||||
|
|
||||||
// assert if last dim is 85 or 84
|
// assert if last dim is nc+5 or nc+4
|
||||||
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]");
|
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]");
|
||||||
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: ");
|
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: ");
|
||||||
|
|
||||||
for (auto preds : outs){
|
for (auto preds : outs){
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ void yoloPostProcessing(
|
|||||||
|
|
||||||
if (model_name == "yolonas")
|
if (model_name == "yolonas")
|
||||||
{
|
{
|
||||||
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84]
|
// outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4]
|
||||||
Mat concat_out;
|
Mat concat_out;
|
||||||
// squeeze the first dimension
|
// squeeze the first dimension
|
||||||
outs[0] = outs[0].reshape(1, outs[0].size[1]);
|
outs[0] = outs[0].reshape(1, outs[0].size[1]);
|
||||||
@ -135,12 +135,12 @@ void yoloPostProcessing(
|
|||||||
// remove the second element
|
// remove the second element
|
||||||
outs.pop_back();
|
outs.pop_back();
|
||||||
// unsqueeze the first dimension
|
// unsqueeze the first dimension
|
||||||
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, nc + 4});
|
outs[0] = outs[0].reshape(0, std::vector<int>{1, outs[0].size[0], outs[0].size[1]});
|
||||||
}
|
}
|
||||||
|
|
||||||
// assert if last dim is 85 or 84
|
// assert if last dim is nc+5 or nc+4
|
||||||
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]");
|
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]");
|
||||||
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: ");
|
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: ");
|
||||||
|
|
||||||
for (auto preds : outs)
|
for (auto preds : outs)
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user