Fix `yoloPostProcessing`` to handle variable number of classes (nc)

Previously, the yoloPostProcessing function assumed that the number of classes (nc) was fixed at 80. This caused incorrect behavior when a different number of classes was specified, leading to mismatched output shapes.

This update modifies the code to use the provided `nc` value dynamically, ensuring that the output shapes are correctly calculated based on the specified number of classes. This prevents issues when `nc` is not equal to 80 and allows for greater flexibility in model configurations.
pull/26618/head
KangJialiang 2 months ago
parent 1d4110884b
commit 25fe85bbbb
  1. 10
      modules/dnn/test/test_onnx_importer.cpp
  2. 10
      samples/dnn/yolo_detector.cpp

@ -2691,7 +2691,7 @@ void yoloPostProcessing(
}
if (model_name == "yolonas"){
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84]
// outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4]
Mat concat_out;
// squeeze the first dimension
outs[0] = outs[0].reshape(1, outs[0].size[1]);
@ -2701,12 +2701,12 @@ void yoloPostProcessing(
// remove the second element
outs.pop_back();
// unsqueeze the first dimension
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, 84});
outs[0] = outs[0].reshape(0, std::vector<int>{1, outs[0].size[0], outs[0].size[1]});
}
// assert if last dim is 85 or 84
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]");
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: ");
// assert if last dim is nc+5 or nc+4
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]");
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: ");
for (auto preds : outs){

@ -125,7 +125,7 @@ void yoloPostProcessing(
if (model_name == "yolonas")
{
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84]
// outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4]
Mat concat_out;
// squeeze the first dimension
outs[0] = outs[0].reshape(1, outs[0].size[1]);
@ -135,12 +135,12 @@ void yoloPostProcessing(
// remove the second element
outs.pop_back();
// unsqueeze the first dimension
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, nc + 4});
outs[0] = outs[0].reshape(0, std::vector<int>{1, outs[0].size[0], outs[0].size[1]});
}
// assert if last dim is 85 or 84
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]");
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: ");
// assert if last dim is nc+5 or nc+4
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]");
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: ");
for (auto preds : outs)
{

Loading…
Cancel
Save