From 39533a0b1b8384b72b4796f3d16c62e20c8d9bb0 Mon Sep 17 00:00:00 2001
From: Alexey Spizhevoy <no@email>
Date: Thu, 22 Sep 2011 08:58:48 +0000
Subject: [PATCH] Added BA refinement flags into stitching_detailed

---
 .../stitching/detail/motion_estimators.hpp    |  24 +++-
 .../opencv2/stitching/detail/warpers_inl.hpp  |   2 +-
 modules/stitching/src/motion_estimators.cpp   | 113 ++++++++++++------
 samples/cpp/stitching_detailed.cpp            |  25 ++++
 4 files changed, 125 insertions(+), 39 deletions(-)

diff --git a/modules/stitching/include/opencv2/stitching/detail/motion_estimators.hpp b/modules/stitching/include/opencv2/stitching/detail/motion_estimators.hpp
index 6a913d591c..65f6d235d5 100644
--- a/modules/stitching/include/opencv2/stitching/detail/motion_estimators.hpp
+++ b/modules/stitching/include/opencv2/stitching/detail/motion_estimators.hpp
@@ -83,6 +83,13 @@ private:
 class CV_EXPORTS BundleAdjusterBase : public Estimator
 {
 public:
+    const Mat refinementMask() const { return refinement_mask_.clone(); }
+    void setRefinementMask(const Mat &mask) 
+    { 
+        CV_Assert(mask.type() == CV_8U && mask.size() == Size(3, 3));
+        refinement_mask_ = mask.clone(); 
+    }
+
     double confThresh() const { return conf_thresh_; }
     void setConfThresh(double conf_thresh) { conf_thresh_ = conf_thresh; }
 
@@ -90,7 +97,10 @@ protected:
     BundleAdjusterBase(int num_params_per_cam, int num_errs_per_measurement) 
         : num_params_per_cam_(num_params_per_cam), 
           num_errs_per_measurement_(num_errs_per_measurement) 
-        { setConfThresh(1.); }
+    {    
+        setRefinementMask(Mat::ones(3, 3, CV_8U));
+        setConfThresh(1.); 
+    }
 
     // Runs bundle adjustment
     virtual void estimate(const std::vector<ImageFeatures> &features, 
@@ -102,6 +112,9 @@ protected:
     virtual void calcError(Mat &err) = 0;
     virtual void calcJacobian(Mat &jac) = 0;
 
+    // 3x3 8U mask, where 0 means don't refine respective parameter, != 0 means refine
+    Mat refinement_mask_;
+
     int num_images_;
     int total_num_matches_;
 
@@ -122,11 +135,13 @@ protected:
 };
 
 
-// Minimizes reprojection error
+// Minimizes reprojection error.
+// It can estimate focal length, aspect ratio, principal point. 
+// You can affect only on them via the refinement mask.
 class CV_EXPORTS BundleAdjusterReproj : public BundleAdjusterBase
 {
 public:
-    BundleAdjusterReproj() : BundleAdjusterBase(6, 2) {}
+    BundleAdjusterReproj() : BundleAdjusterBase(7, 2) {}
 
 private:
     void setUpInitialCameraParams(const std::vector<CameraParams> &cameras);
@@ -138,7 +153,8 @@ private:
 };
 
 
-// Minimizes sun of ray-to-ray distances
+// Minimizes sun of ray-to-ray distances.
+// It can estimate focal length. It ignores the refinement mask for now.
 class CV_EXPORTS BundleAdjusterRay : public BundleAdjusterBase
 {
 public:
diff --git a/modules/stitching/include/opencv2/stitching/detail/warpers_inl.hpp b/modules/stitching/include/opencv2/stitching/detail/warpers_inl.hpp
index a6d6184faf..b0a9e2417a 100644
--- a/modules/stitching/include/opencv2/stitching/detail/warpers_inl.hpp
+++ b/modules/stitching/include/opencv2/stitching/detail/warpers_inl.hpp
@@ -103,7 +103,7 @@ Rect WarperBase<P>::warpRoi(const Size &sz, const Mat &K, const Mat &R)
 
 
 template <class P>
-Rect WarperBase<P>::warpRoi(const Size &sz, const Mat &K, const Mat &R, const Mat &T)
+Rect WarperBase<P>::warpRoi(const Size &/*sz*/, const Mat &/*K*/, const Mat &/*R*/, const Mat &/*T*/)
 {
     CV_Error(CV_StsNotImplemented, "translation support isn't implemented");
     return Rect();
diff --git a/modules/stitching/src/motion_estimators.cpp b/modules/stitching/src/motion_estimators.cpp
index c948209243..6b8da9204d 100644
--- a/modules/stitching/src/motion_estimators.cpp
+++ b/modules/stitching/src/motion_estimators.cpp
@@ -247,13 +247,14 @@ void BundleAdjusterBase::estimate(const vector<ImageFeatures> &features,
 
 void BundleAdjusterReproj::setUpInitialCameraParams(const vector<CameraParams> &cameras)
 {
-    cam_params_.create(num_images_ * 6, 1, CV_64F);
+    cam_params_.create(num_images_ * 7, 1, CV_64F);
     SVD svd;
     for (int i = 0; i < num_images_; ++i)
     {
-        cam_params_.at<double>(i * 6, 0) = cameras[i].focal;
-        cam_params_.at<double>(i * 6 + 1, 0) = cameras[i].ppx;
-        cam_params_.at<double>(i * 6 + 2, 0) = cameras[i].ppy;
+        cam_params_.at<double>(i * 7, 0) = cameras[i].focal;
+        cam_params_.at<double>(i * 7 + 1, 0) = cameras[i].ppx;
+        cam_params_.at<double>(i * 7 + 2, 0) = cameras[i].ppy;
+        cam_params_.at<double>(i * 7 + 3, 0) = cameras[i].aspect;
 
         svd(cameras[i].R, SVD::FULL_UV);
         Mat R = svd.u * svd.vt;
@@ -263,9 +264,9 @@ void BundleAdjusterReproj::setUpInitialCameraParams(const vector<CameraParams> &
         Mat rvec;
         Rodrigues(R, rvec);
         CV_Assert(rvec.type() == CV_32F);
-        cam_params_.at<double>(i * 6 + 3, 0) = rvec.at<float>(0, 0);
-        cam_params_.at<double>(i * 6 + 4, 0) = rvec.at<float>(1, 0);
-        cam_params_.at<double>(i * 6 + 5, 0) = rvec.at<float>(2, 0);
+        cam_params_.at<double>(i * 7 + 4, 0) = rvec.at<float>(0, 0);
+        cam_params_.at<double>(i * 7 + 5, 0) = rvec.at<float>(1, 0);
+        cam_params_.at<double>(i * 7 + 6, 0) = rvec.at<float>(2, 0);
     }
 }
 
@@ -274,14 +275,15 @@ void BundleAdjusterReproj::obtainRefinedCameraParams(vector<CameraParams> &camer
 {
     for (int i = 0; i < num_images_; ++i)
     {
-        cameras[i].focal = cam_params_.at<double>(i * 6, 0);
-        cameras[i].ppx = cam_params_.at<double>(i * 6 + 1, 0);
-        cameras[i].ppy = cam_params_.at<double>(i * 6 + 2, 0);
+        cameras[i].focal = cam_params_.at<double>(i * 7, 0);
+        cameras[i].ppx = cam_params_.at<double>(i * 7 + 1, 0);
+        cameras[i].ppy = cam_params_.at<double>(i * 7 + 2, 0);
+        cameras[i].aspect = cam_params_.at<double>(i * 7 + 3, 0);
 
         Mat rvec(3, 1, CV_64F);
-        rvec.at<double>(0, 0) = cam_params_.at<double>(i * 6 + 3, 0);
-        rvec.at<double>(1, 0) = cam_params_.at<double>(i * 6 + 4, 0);
-        rvec.at<double>(2, 0) = cam_params_.at<double>(i * 6 + 5, 0);
+        rvec.at<double>(0, 0) = cam_params_.at<double>(i * 7 + 4, 0);
+        rvec.at<double>(1, 0) = cam_params_.at<double>(i * 7 + 5, 0);
+        rvec.at<double>(2, 0) = cam_params_.at<double>(i * 7 + 6, 0);
         Rodrigues(rvec, cameras[i].R);
 
         Mat tmp;
@@ -300,26 +302,28 @@ void BundleAdjusterReproj::calcError(Mat &err)
     {
         int i = edges_[edge_idx].first;
         int j = edges_[edge_idx].second;
-        double f1 = cam_params_.at<double>(i * 6, 0);
-        double f2 = cam_params_.at<double>(j * 6, 0);
-        double ppx1 = cam_params_.at<double>(i * 6 + 1, 0);
-        double ppx2 = cam_params_.at<double>(j * 6 + 1, 0);
-        double ppy1 = cam_params_.at<double>(i * 6 + 2, 0);
-        double ppy2 = cam_params_.at<double>(j * 6 + 2, 0);
+        double f1 = cam_params_.at<double>(i * 7, 0);
+        double f2 = cam_params_.at<double>(j * 7, 0);
+        double ppx1 = cam_params_.at<double>(i * 7 + 1, 0);
+        double ppx2 = cam_params_.at<double>(j * 7 + 1, 0);
+        double ppy1 = cam_params_.at<double>(i * 7 + 2, 0);
+        double ppy2 = cam_params_.at<double>(j * 7 + 2, 0);
+        double a1 = cam_params_.at<double>(i * 7 + 3, 0);
+        double a2 = cam_params_.at<double>(j * 7 + 3, 0);
 
         double R1[9];
         Mat R1_(3, 3, CV_64F, R1);
         Mat rvec(3, 1, CV_64F);
-        rvec.at<double>(0, 0) = cam_params_.at<double>(i * 6 + 3, 0);
-        rvec.at<double>(1, 0) = cam_params_.at<double>(i * 6 + 4, 0);
-        rvec.at<double>(2, 0) = cam_params_.at<double>(i * 6 + 5, 0);
+        rvec.at<double>(0, 0) = cam_params_.at<double>(i * 7 + 4, 0);
+        rvec.at<double>(1, 0) = cam_params_.at<double>(i * 7 + 5, 0);
+        rvec.at<double>(2, 0) = cam_params_.at<double>(i * 7 + 6, 0);
         Rodrigues(rvec, R1_);
 
         double R2[9];
         Mat R2_(3, 3, CV_64F, R2);
-        rvec.at<double>(0, 0) = cam_params_.at<double>(j * 6 + 3, 0);
-        rvec.at<double>(1, 0) = cam_params_.at<double>(j * 6 + 4, 0);
-        rvec.at<double>(2, 0) = cam_params_.at<double>(j * 6 + 5, 0);
+        rvec.at<double>(0, 0) = cam_params_.at<double>(j * 7 + 4, 0);
+        rvec.at<double>(1, 0) = cam_params_.at<double>(j * 7 + 5, 0);
+        rvec.at<double>(2, 0) = cam_params_.at<double>(j * 7 + 6, 0);
         Rodrigues(rvec, R2_);
 
         const ImageFeatures& features1 = features_[i];
@@ -328,11 +332,11 @@ void BundleAdjusterReproj::calcError(Mat &err)
 
         Mat_<double> K1 = Mat::eye(3, 3, CV_64F);
         K1(0,0) = f1; K1(0,2) = ppx1;
-        K1(1,1) = f1; K1(1,2) = ppy1;
+        K1(1,1) = f1*a1; K1(1,2) = ppy1;
 
         Mat_<double> K2 = Mat::eye(3, 3, CV_64F);
         K2(0,0) = f2; K2(0,2) = ppx2;
-        K2(1,1) = f2; K2(1,2) = ppy2;
+        K2(1,1) = f2*a2; K2(1,2) = ppy2;
 
         Mat_<double> H = K2 * R2_.inv() * R1_ * K1.inv();
 
@@ -358,22 +362,63 @@ void BundleAdjusterReproj::calcError(Mat &err)
 
 void BundleAdjusterReproj::calcJacobian(Mat &jac)
 {
-    jac.create(total_num_matches_ * 2, num_images_ * 6, CV_64F);
+    jac.create(total_num_matches_ * 2, num_images_ * 7, CV_64F);
+    jac.setTo(0);
 
     double val;
     const double step = 1e-4;
 
     for (int i = 0; i < num_images_; ++i)
     {
-        for (int j = 0; j < 6; ++j)
+        if (refinement_mask_.at<uchar>(0, 0))
         {
-            val = cam_params_.at<double>(i * 6 + j, 0);
-            cam_params_.at<double>(i * 6 + j, 0) = val - step;
+            val = cam_params_.at<double>(i * 7, 0);
+            cam_params_.at<double>(i * 7, 0) = val - step;
             calcError(err1_);
-            cam_params_.at<double>(i * 6 + j, 0) = val + step;
+            cam_params_.at<double>(i * 7, 0) = val + step;
             calcError(err2_);
-            calcDeriv(err1_, err2_, 2 * step, jac.col(i * 6 + j));
-            cam_params_.at<double>(i * 6 + j, 0) = val;
+            calcDeriv(err1_, err2_, 2 * step, jac.col(i * 7));
+            cam_params_.at<double>(i * 7, 0) = val;
+        }
+        if (refinement_mask_.at<uchar>(0, 2))        
+        {
+            val = cam_params_.at<double>(i * 7 + 1, 0);
+            cam_params_.at<double>(i * 7 + 1, 0) = val - step;
+            calcError(err1_);
+            cam_params_.at<double>(i * 7 + 1, 0) = val + step;
+            calcError(err2_);
+            calcDeriv(err1_, err2_, 2 * step, jac.col(i * 7 + 1));
+            cam_params_.at<double>(i * 7 + 1, 0) = val;
+        }
+        if (refinement_mask_.at<uchar>(1, 2))        
+        {
+            val = cam_params_.at<double>(i * 7 + 2, 0);
+            cam_params_.at<double>(i * 7 + 2, 0) = val - step;
+            calcError(err1_);
+            cam_params_.at<double>(i * 7 + 2, 0) = val + step;
+            calcError(err2_);
+            calcDeriv(err1_, err2_, 2 * step, jac.col(i * 7 + 2));
+            cam_params_.at<double>(i * 7 + 2, 0) = val;
+        }
+        if (refinement_mask_.at<uchar>(1, 1))
+        {
+            val = cam_params_.at<double>(i * 7 + 3, 0);
+            cam_params_.at<double>(i * 7 + 3, 0) = val - step;
+            calcError(err1_);
+            cam_params_.at<double>(i * 7 + 3, 0) = val + step;
+            calcError(err2_);
+            calcDeriv(err1_, err2_, 2 * step, jac.col(i * 7 + 3));
+            cam_params_.at<double>(i * 7 + 3, 0) = val;
+        }
+        for (int j = 4; j < 7; ++j)
+        {
+            val = cam_params_.at<double>(i * 7 + j, 0);
+            cam_params_.at<double>(i * 7 + j, 0) = val - step;
+            calcError(err1_);
+            cam_params_.at<double>(i * 7 + j, 0) = val + step;
+            calcError(err2_);
+            calcDeriv(err1_, err2_, 2 * step, jac.col(i * 7 + j));
+            cam_params_.at<double>(i * 7 + j, 0) = val;
         }
     }
 }
diff --git a/samples/cpp/stitching_detailed.cpp b/samples/cpp/stitching_detailed.cpp
index 8d6c05593d..8e780caf43 100644
--- a/samples/cpp/stitching_detailed.cpp
+++ b/samples/cpp/stitching_detailed.cpp
@@ -81,6 +81,13 @@ void printUsage()
         "      The default is 1.0.\n"
         "  --ba (reproj|ray)\n"
         "      Bundle adjustment cost function. The default is ray.\n"
+        "  --ba_refine_mask (mask)\n"
+        "      Set refinement mask for bundle adjustment. It looks like 'x_xxx',\n"
+        "      where 'x' means refine respective parameter and '_' means don't\n"
+        "      refine one, and has the following format:\n"
+        "      <fx><skew><ppx><aspect><ppy>. The default mask is 'xxxxx'. If bundle\n"
+        "      adjustment doesn't support estimation of selected parameter then\n"
+        "      the respective flag is ignored.\n"
         "  --wave_correct (no|yes)\n"
         "      Perform wave effect correction. The default is 'yes'.\n"
         "  --save_graph <file_name>\n"
@@ -117,6 +124,7 @@ double seam_megapix = 0.1;
 double compose_megapix = -1;
 float conf_thresh = 1.f;
 string ba_cost_func = "ray";
+string ba_refine_mask = "xxxxx";
 bool wave_correct = true;
 bool save_graph = false;
 std::string save_graph_to;
@@ -194,6 +202,16 @@ int parseCmdArgs(int argc, char** argv)
             ba_cost_func = argv[i + 1];
             i++;
         }
+        else if (string(argv[i]) == "--ba_refine_mask")
+        {
+            ba_refine_mask = argv[i + 1];
+            if (ba_refine_mask.size() != 5)
+            {
+                cout << "Incorrect refinement mask length.\n";
+                return -1;
+            }
+            i++;
+        }
         else if (string(argv[i]) == "--wave_correct")
         {
             if (string(argv[i + 1]) == "no")
@@ -430,6 +448,13 @@ int main(int argc, char* argv[])
         return -1; 
     }
     adjuster->setConfThresh(conf_thresh);
+    Mat_<uchar> refine_mask = Mat::zeros(3, 3, CV_8U);
+    if (ba_refine_mask[0] == 'x') refine_mask(0,0) = 1;
+    if (ba_refine_mask[1] == 'x') refine_mask(0,1) = 1;
+    if (ba_refine_mask[2] == 'x') refine_mask(0,2) = 1;
+    if (ba_refine_mask[3] == 'x') refine_mask(1,1) = 1;
+    if (ba_refine_mask[4] == 'x') refine_mask(1,2) = 1;
+    adjuster->setRefinementMask(refine_mask);
     (*adjuster)(features, pairwise_matches, cameras);
 
     // Find median focal length