From 674c618471f2f7c57d6ca51a6638667864b6ebc8 Mon Sep 17 00:00:00 2001
From: CSBVision <bjoern.boeken@csb.com>
Date: Tue, 8 Aug 2023 13:31:32 +0200
Subject: [PATCH] Update dnn_utils.cpp

---
 modules/dnn/src/dnn_utils.cpp  | 65 ++++++++++++++++++++--------------
 modules/dnn/test/test_misc.cpp | 22 ++++++++++++
 2 files changed, 60 insertions(+), 27 deletions(-)

diff --git a/modules/dnn/src/dnn_utils.cpp b/modules/dnn/src/dnn_utils.cpp
index 18c7e975eb..d4d7dda008 100644
--- a/modules/dnn/src/dnn_utils.cpp
+++ b/modules/dnn/src/dnn_utils.cpp
@@ -5,6 +5,7 @@
 #include "precomp.hpp"
 
 #include <opencv2/imgproc.hpp>
+#include <opencv2/core/utils/logger.hpp>
 
 
 namespace cv {
@@ -100,13 +101,27 @@ void blobFromImagesWithParams(InputArrayOfArrays images_, OutputArray blob_, con
     images_.getMatVector(images);
     CV_Assert(!images.empty());
 
+    if (param.ddepth == CV_8U)
+    {
+        CV_Assert(param.scalefactor == Scalar::all(1.0) && "Scaling is not supported for CV_8U blob depth");
+        CV_Assert(param.mean == Scalar() && "Mean subtraction is not supported for CV_8U blob depth");
+    }
+
     int nch = images[0].channels();
     Scalar scalefactor = param.scalefactor;
+    Scalar mean = param.mean;
 
-    if (param.ddepth == CV_8U)
+    if (param.swapRB)
     {
-        CV_Assert(scalefactor == Scalar::all(1.0) && "Scaling is not supported for CV_8U blob depth");
-        CV_Assert(param.mean == Scalar() && "Mean subtraction is not supported for CV_8U blob depth");
+        if (nch > 2)
+        {
+            std::swap(mean[0], mean[2]);
+            std::swap(scalefactor[0], scalefactor[2]);
+        }
+        else
+        {
+            CV_LOG_WARNING(NULL, "Red/blue color swapping requires at least three image channels.");
+        }
     }
 
     for (size_t i = 0; i < images.size(); i++)
@@ -126,34 +141,26 @@ void blobFromImagesWithParams(InputArrayOfArrays images_, OutputArray blob_, con
                           size);
                 images[i] = images[i](crop);
             }
+            else if (param.paddingmode == DNN_PMODE_LETTERBOX)
+            {
+                float resizeFactor = std::min(size.width / (float)imgSize.width,
+                                              size.height / (float)imgSize.height);
+                int rh = int(imgSize.height * resizeFactor);
+                int rw = int(imgSize.width * resizeFactor);
+                resize(images[i], images[i], Size(rw, rh), INTER_LINEAR);
+
+                int top = (size.height - rh)/2;
+                int bottom = size.height - top - rh;
+                int left = (size.width - rw)/2;
+                int right = size.width - left - rw;
+                copyMakeBorder(images[i], images[i], top, bottom, left, right, BORDER_CONSTANT);
+            }
             else
             {
-                if (param.paddingmode == DNN_PMODE_LETTERBOX)
-                {
-                    float resizeFactor = std::min(size.width / (float)imgSize.width,
-                                                  size.height / (float)imgSize.height);
-                    int rh = int(imgSize.height * resizeFactor);
-                    int rw = int(imgSize.width * resizeFactor);
-                    resize(images[i], images[i], Size(rw, rh), INTER_LINEAR);
-
-                    int top = (size.height - rh)/2;
-                    int bottom = size.height - top - rh;
-                    int left = (size.width - rw)/2;
-                    int right = size.width - left - rw;
-                    copyMakeBorder(images[i], images[i], top, bottom, left, right, BORDER_CONSTANT);
-                }
-                else
-                    resize(images[i], images[i], size, 0, 0, INTER_LINEAR);
+                resize(images[i], images[i], size, 0, 0, INTER_LINEAR);
             }
         }
 
-        Scalar mean = param.mean;
-        if (param.swapRB)
-        {
-            std::swap(mean[0], mean[2]);
-            std::swap(scalefactor[0], scalefactor[2]);
-        }
-
         if (images[i].depth() == CV_8U && param.ddepth == CV_32F)
             images[i].convertTo(images[i], CV_32F);
 
@@ -220,18 +227,22 @@ void blobFromImagesWithParams(InputArrayOfArrays images_, OutputArray blob_, con
             CV_Assert(image.depth() == blob_.depth());
             CV_Assert(image.channels() == image0.channels());
             CV_Assert(image.size() == image0.size());
-            if (param.swapRB)
+            if (nch > 2 && param.swapRB)
             {
                 Mat tmpRB;
                 cvtColor(image, tmpRB, COLOR_BGR2RGB);
                 tmpRB.copyTo(Mat(tmpRB.rows, tmpRB.cols, subMatType, blob.ptr((int)i, 0)));
             }
             else
+            {
                 image.copyTo(Mat(image.rows, image.cols, subMatType, blob.ptr((int)i, 0)));
+            }
         }
     }
     else
+    {
         CV_Error(Error::StsUnsupportedFormat, "Unsupported data layout in blobFromImagesWithParams function.");
+    }
 }
 
 void imagesFromBlob(const cv::Mat& blob_, OutputArrayOfArrays images_)
diff --git a/modules/dnn/test/test_misc.cpp b/modules/dnn/test/test_misc.cpp
index 4ee3e013cb..0c5fb28c5d 100644
--- a/modules/dnn/test/test_misc.cpp
+++ b/modules/dnn/test/test_misc.cpp
@@ -120,6 +120,28 @@ TEST(blobFromImageWithParams_4ch, letter_box)
     EXPECT_EQ(0, cvtest::norm(targetBlob, blob, NORM_INF));
 }
 
+TEST(blobFromImagesWithParams_4ch, multi_image)
+{
+    Mat img(10, 10, CV_8UC4, cv::Scalar(0, 1, 2, 3));
+    Scalar scalefactor(0.1, 0.2, 0.3, 0.4);
+
+    Image2BlobParams param;
+    param.scalefactor = scalefactor;
+    param.datalayout = DNN_LAYOUT_NHWC;
+
+    Mat blobs = blobFromImagesWithParams(std::vector<Mat> { img, 2*img }, param);
+    vector<Range> ranges;
+    ranges.push_back(Range(0, 1));
+    ranges.push_back(Range(0, blobs.size[1]));
+    ranges.push_back(Range(0, blobs.size[2]));
+    ranges.push_back(Range(0, blobs.size[3]));
+    Mat blob0 = blobs(ranges);
+    ranges[0] = Range(1, 2);
+    Mat blob1 = blobs(ranges);
+
+    EXPECT_EQ(0, cvtest::norm(2*blob0, blob1, NORM_INF));
+}
+
 TEST(readNet, Regression)
 {
     Net net = readNet(findDataFile("dnn/squeezenet_v1.1.prototxt"),