Add CUDA Stereo Semi Global Matching

pull/2772/head
Shingo Otsuka 6 years ago committed by shingo.otsuka
parent 960714d467
commit 7883ec8b98
  1. 2
      modules/cudastereo/CMakeLists.txt
  2. 10
      modules/cudastereo/doc/cudastereo.bib
  3. 47
      modules/cudastereo/include/opencv2/cudastereo.hpp
  4. 34
      modules/cudastereo/perf/perf_stereo.cpp
  5. 2081
      modules/cudastereo/src/cuda/stereosgm.cu
  6. 61
      modules/cudastereo/src/cuda/stereosgm.hpp
  7. 153
      modules/cudastereo/src/stereosgm.cpp
  8. 444
      modules/cudastereo/test/test_sgm_funcs.cpp
  9. 755
      modules/cudastereo/test/test_stereo.cpp

@ -6,4 +6,4 @@ set(the_description "CUDA-accelerated Stereo Correspondence")
ocv_warnings_disable(CMAKE_CXX_FLAGS /wd4127 /wd4324 /wd4512 -Wundef -Wmissing-declarations -Wshadow)
ocv_define_module(cudastereo opencv_calib3d WRAP python)
ocv_define_module(cudastereo opencv_calib3d OPTIONAL opencv_cudev WRAP python)

@ -0,0 +1,10 @@
@InProceedings{Spangenberg2013,
author = {Spangenberg, Robert and Langner, Tobias and Rojas, Ra{\'u}l},
title = {Weighted Semi-Global Matching and Center-Symmetric Census Transform for Robust Driver Assistance},
booktitle = {Computer Analysis of Images and Patterns},
year = {2013},
pages = {34--41},
publisher = {Springer Berlin Heidelberg},
abstract = {Automotive applications based on stereo vision require robust and fast matching algorithms, which makes semi-global matching (SGM) a popular method in this field. Typically the Census transform is used as a cost function, since it is advantageous for outdoor scenes. We propose an extension based on center-symmetric local binary patterns, which allows better efficiency and higher matching quality. Our second contribution exploits knowledge about the three-dimensional structure of the scene to selectively enforce the smoothness constraints of SGM. It is shown that information about surface normals can be easily integrated by weighing the paths according to the gradient of the disparity. The different approaches are evaluated on the KITTI benchmark, which provides real imagery with LIDAR ground truth. The results indicate improved performance compared to state-of-the-art SGM based algorithms.},
url = {https://www.mi.fu-berlin.de/inf/groups/ag-ki/publications/Semi-Global_Matching/caip2013rsp_fu.pdf}
}

@ -241,6 +241,53 @@ public:
CV_EXPORTS_W Ptr<cuda::StereoConstantSpaceBP>
createStereoConstantSpaceBP(int ndisp = 128, int iters = 8, int levels = 4, int nr_plane = 4, int msg_type = CV_32F);
/////////////////////////////////////////
// StereoSGM
/** @brief The class implements the modified H. Hirschmuller algorithm @cite HH08.
Limitation and difference are as follows:
- By default, the algorithm uses only 4 directions which are horizontal and vertical path instead of 8.
Set mode=StereoSGM::MODE_HH in createStereoSGM to run the full variant of the algorithm.
- Mutual Information cost function is not implemented.
Instead, Center-Symmetric Census Transform with \f$9 \times 7\f$ window size from @cite Spangenberg2013
is used for robustness.
@sa cv::StereoSGBM
*/
class CV_EXPORTS_W StereoSGM : public cv::StereoSGBM
{
public:
/** @brief Computes disparity map for the specified stereo pair
@param left Left 8-bit or 16-bit unsigned single-channel image.
@param right Right image of the same size and the same type as the left one.
@param disparity Output disparity map. It has the same size as the input images.
StereoSGM computes 16-bit fixed-point disparity map (where each disparity value has 4 fractional bits).
*/
CV_WRAP virtual void compute(InputArray left, InputArray right, OutputArray disparity) CV_OVERRIDE = 0;
/** @brief Computes disparity map with specified CUDA Stream
@sa compute
*/
CV_WRAP_AS(compute_with_stream) virtual void compute(InputArray left, InputArray right, OutputArray disparity, Stream& stream) = 0;
};
/** @brief Creates StereoSGM object.
@param minDisparity Minimum possible disparity value. Normally, it is zero but sometimes rectification algorithms can shift images, so this parameter needs to be adjusted accordingly.
@param numDisparities Maximum disparity minus minimum disparity. The value must be 64, 128 or 256.
@param P1 The first parameter controlling the disparity smoothness.This parameter is used for the case of slanted surfaces (not fronto parallel).
@param P2 The second parameter controlling the disparity smoothness.This parameter is used for "solving" the depth discontinuities problem.
@param uniquenessRatio Margin in percentage by which the best (minimum) computed cost function
value should "win" the second best value to consider the found match correct. Normally, a value
within the 5-15 range is good enough.
@param mode Set it to StereoSGM::MODE_HH to run the full-scale two-pass dynamic programming algorithm.
It will consume O(W\*H\*numDisparities) bytes. By default, it is set to StereoSGM::MODE_HH4.
*/
CV_EXPORTS_W Ptr<cuda::StereoSGM> createStereoSGM(int minDisparity = 0, int numDisparities = 128, int P1 = 10, int P2 = 120, int uniquenessRatio = 5, int mode = cv::cuda::StereoSGM::MODE_HH4);
/////////////////////////////////////////
// DisparityBilateralFilter

@ -252,4 +252,38 @@ PERF_TEST_P(Sz_Depth, DrawColorDisp,
}
}
//////////////////////////////////////////////////////////////////////
// StereoSGM
PERF_TEST_P(ImagePair, StereoSGM,
Values(pair_string("gpu/perf/aloe.png", "gpu/perf/aloeR.png")))
{
declare.time(300.0);
const cv::Mat imgLeft = readImage(GET_PARAM(0), cv::IMREAD_GRAYSCALE);
ASSERT_FALSE(imgLeft.empty());
const cv::Mat imgRight = readImage(GET_PARAM(1), cv::IMREAD_GRAYSCALE);
ASSERT_FALSE(imgRight.empty());
const int ndisp = 128;
if (PERF_RUN_CUDA())
{
cv::Ptr<cv::cuda::StereoSGM> d_sgm = cv::cuda::createStereoSGM(0, ndisp);
const cv::cuda::GpuMat d_imgLeft(imgLeft);
const cv::cuda::GpuMat d_imgRight(imgRight);
cv::cuda::GpuMat dst;
TEST_CYCLE() d_sgm->compute(d_imgLeft, d_imgRight, dst);
CUDA_SANITY_CHECK(dst);
}
else
{
FAIL_NO_CPU();
}
}
}} // namespace

File diff suppressed because it is too large Load Diff

@ -0,0 +1,61 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Author: The "adaskit Team" at Fixstars Corporation
#ifndef OPENCV_CUDASTEREO_SGM_HPP
#define OPENCV_CUDASTEREO_SGM_HPP
#include "opencv2/core/cuda.hpp"
namespace cv { namespace cuda { namespace device {
namespace stereosgm
{
namespace census_transform
{
CV_EXPORTS void censusTransform(const GpuMat& src, GpuMat& dest, cv::cuda::Stream& stream);
}
namespace path_aggregation
{
class PathAggregation
{
private:
static constexpr unsigned int MAX_NUM_PATHS = 8;
std::array<Stream, MAX_NUM_PATHS> streams;
std::array<Event, MAX_NUM_PATHS> events;
std::array<GpuMat, MAX_NUM_PATHS> subs;
public:
template <size_t MAX_DISPARITY>
void operator() (const GpuMat& left, const GpuMat& right, GpuMat& dest, int mode, int p1, int p2, int min_disp, Stream& stream);
};
}
namespace winner_takes_all
{
template <size_t MAX_DISPARITY>
void winnerTakesAll(const GpuMat& src, GpuMat& left, GpuMat& right, float uniqueness, bool subpixel, int mode, cv::cuda::Stream& stream);
}
namespace median_filter
{
void medianFilter(const GpuMat& src, GpuMat& dst, Stream& stream);
}
namespace check_consistency
{
void checkConsistency(GpuMat& left_disp, const GpuMat& right_disp, const GpuMat& src_left, bool subpixel, Stream& stream);
}
namespace correct_disparity_range
{
void correctDisparityRange(GpuMat& disp, bool subpixel, int min_disp, Stream& stream);
}
} // namespace stereosgm
}}} // namespace cv { namespace cuda { namespace device {
#endif /* OPENCV_CUDASTEREO_SGM_HPP */

@ -0,0 +1,153 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Author: The "adaskit Team" at Fixstars Corporation
#include "precomp.hpp"
using namespace cv;
using namespace cv::cuda;
#if !defined (HAVE_CUDA) || defined (CUDA_DISABLER)
Ptr<cuda::StereoSGM> cv::cuda::createStereoSGM(int, int, int, int, int, int) { throw_no_cuda(); return Ptr<cuda::StereoSGM>(); }
#else /* !defined (HAVE_CUDA) */
#include "cuda/stereosgm.hpp"
namespace
{
struct StereoSGMParams
{
int minDisparity;
int numDisparities;
int P1;
int P2;
int uniquenessRatio;
int mode;
StereoSGMParams(int minDisparity = 0, int numDisparities = 128, int P1 = 10, int P2 = 120, int uniquenessRatio = 5, int mode = StereoSGM::MODE_HH4) : minDisparity(minDisparity), numDisparities(numDisparities), P1(P1), P2(P2), uniquenessRatio(uniquenessRatio), mode(mode) {}
};
class StereoSGMImpl CV_FINAL : public StereoSGM
{
public:
StereoSGMImpl(int minDisparity, int numDisparities, int P1, int P2, int uniquenessRatio, int mode);
void compute(InputArray left, InputArray right, OutputArray disparity) CV_OVERRIDE;
void compute(InputArray left, InputArray right, OutputArray disparity, Stream& stream) CV_OVERRIDE;
int getBlockSize() const CV_OVERRIDE { return -1; }
void setBlockSize(int /*blockSize*/) CV_OVERRIDE {}
int getDisp12MaxDiff() const CV_OVERRIDE { return 1; }
void setDisp12MaxDiff(int /*disp12MaxDiff*/) CV_OVERRIDE {}
int getMinDisparity() const CV_OVERRIDE { return params.minDisparity; }
void setMinDisparity(int minDisparity) CV_OVERRIDE { params.minDisparity = minDisparity; }
int getNumDisparities() const CV_OVERRIDE { return params.numDisparities; }
void setNumDisparities(int numDisparities) CV_OVERRIDE { params.numDisparities = numDisparities; }
int getSpeckleWindowSize() const CV_OVERRIDE { return 0; }
void setSpeckleWindowSize(int /*speckleWindowSize*/) CV_OVERRIDE {}
int getSpeckleRange() const CV_OVERRIDE { return 0; }
void setSpeckleRange(int /*speckleRange*/) CV_OVERRIDE {}
int getP1() const CV_OVERRIDE { return params.P1; }
void setP1(int P1) CV_OVERRIDE { params.P1 = P1; }
int getP2() const CV_OVERRIDE { return params.P2; }
void setP2(int P2) CV_OVERRIDE { params.P2 = P2; }
int getUniquenessRatio() const CV_OVERRIDE { return params.uniquenessRatio; }
void setUniquenessRatio(int uniquenessRatio) CV_OVERRIDE { params.uniquenessRatio = uniquenessRatio; }
int getMode() const CV_OVERRIDE { return params.mode; }
void setMode(int mode) CV_OVERRIDE { params.mode = mode; }
int getPreFilterCap() const CV_OVERRIDE { return -1; }
void setPreFilterCap(int /*preFilterCap*/) CV_OVERRIDE {}
private:
StereoSGMParams params;
device::stereosgm::path_aggregation::PathAggregation pathAggregation;
GpuMat censused_left, censused_right;
GpuMat aggregated;
GpuMat left_disp_tmp, right_disp_tmp;
GpuMat right_disp;
};
StereoSGMImpl::StereoSGMImpl(int minDisparity, int numDisparities, int P1, int P2, int uniquenessRatio, int mode)
: params(minDisparity, numDisparities, P1, P2, uniquenessRatio, mode)
{
}
void StereoSGMImpl::compute(InputArray left, InputArray right, OutputArray disparity)
{
compute(left, right, disparity, Stream::Null());
}
void StereoSGMImpl::compute(InputArray _left, InputArray _right, OutputArray _disparity, Stream& _stream)
{
using namespace device::stereosgm;
GpuMat left = _left.getGpuMat();
GpuMat right = _right.getGpuMat();
const Size size = left.size();
if (params.mode != MODE_HH && params.mode != MODE_HH4)
{
CV_Error(Error::StsBadArg, "Unsupported mode");
}
const unsigned int num_paths = params.mode == MODE_HH4 ? 4 : 8;
CV_Assert(left.type() == CV_8UC1 || left.type() == CV_16UC1);
CV_Assert(size == right.size() && left.type() == right.type());
_disparity.create(size, CV_16SC1);
ensureSizeIsEnough(size, CV_16SC1, right_disp);
GpuMat left_disp = _disparity.getGpuMat();
ensureSizeIsEnough(size, CV_32SC1, censused_left);
ensureSizeIsEnough(size, CV_32SC1, censused_right);
census_transform::censusTransform(left, censused_left, _stream);
census_transform::censusTransform(right, censused_right, _stream);
ensureSizeIsEnough(1, size.width * size.height * params.numDisparities * num_paths, CV_8UC1, aggregated);
ensureSizeIsEnough(size, CV_16SC1, left_disp_tmp);
ensureSizeIsEnough(size, CV_16SC1, right_disp_tmp);
switch (params.numDisparities)
{
case 64:
pathAggregation.operator()<64>(censused_left, censused_right, aggregated, params.mode, params.P1, params.P2, params.minDisparity, _stream);
winner_takes_all::winnerTakesAll<64>(aggregated, left_disp_tmp, right_disp_tmp, (float)(100 - params.uniquenessRatio) / 100, true, params.mode, _stream);
break;
case 128:
pathAggregation.operator()<128>(censused_left, censused_right, aggregated, params.mode, params.P1, params.P2, params.minDisparity, _stream);
winner_takes_all::winnerTakesAll<128>(aggregated, left_disp_tmp, right_disp_tmp, (float)(100 - params.uniquenessRatio) / 100, true, params.mode, _stream);
break;
case 256:
pathAggregation.operator()<256>(censused_left, censused_right, aggregated, params.mode, params.P1, params.P2, params.minDisparity, _stream);
winner_takes_all::winnerTakesAll<256>(aggregated, left_disp_tmp, right_disp_tmp, (float)(100 - params.uniquenessRatio) / 100, true, params.mode, _stream);
break;
default:
CV_Error(Error::StsBadArg, "Unsupported num of disparities");
}
median_filter::medianFilter(left_disp_tmp, left_disp, _stream);
median_filter::medianFilter(right_disp_tmp, right_disp, _stream);
check_consistency::checkConsistency(left_disp, right_disp, left, true, _stream);
correct_disparity_range::correctDisparityRange(left_disp, true, params.minDisparity, _stream);
}
} // anonymous namespace
Ptr<cuda::StereoSGM> cv::cuda::createStereoSGM(int minDisparity, int numDisparities, int P1, int P2, int uniquenessRatio, int mode)
{
return makePtr<StereoSGMImpl>(minDisparity, numDisparities, P1, P2, uniquenessRatio, mode);
}
#endif /* !defined (HAVE_CUDA) */

@ -0,0 +1,444 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Author: The "adaskit Team" at Fixstars Corporation
#include "test_precomp.hpp"
#ifdef HAVE_CUDA
#ifdef _WIN32
#define popcnt64 __popcnt64
#else
#define popcnt64 __builtin_popcountll
#endif
#include "opencv2/core/cuda.hpp"
namespace cv { namespace cuda { namespace device {
namespace stereosgm
{
namespace census_transform
{
void censusTransform(const GpuMat& src, GpuMat& dest, cv::cuda::Stream& stream);
}
namespace path_aggregation
{
namespace horizontal
{
template <unsigned int MAX_DISPARITY>
void aggregateLeft2RightPath(
const GpuMat& left,
const GpuMat& right,
GpuMat& dest,
unsigned int p1,
unsigned int p2,
int min_disp,
Stream& stream);
template <unsigned int MAX_DISPARITY>
void aggregateRight2LeftPath(
const GpuMat& left,
const GpuMat& right,
GpuMat& dest,
unsigned int p1,
unsigned int p2,
int min_disp,
Stream& stream);
}
namespace vertical
{
template <unsigned int MAX_DISPARITY>
void aggregateUp2DownPath(
const GpuMat& left,
const GpuMat& right,
GpuMat& dest,
unsigned int p1,
unsigned int p2,
int min_disp,
Stream& stream);
template <unsigned int MAX_DISPARITY>
void aggregateDown2UpPath(
const GpuMat& left,
const GpuMat& right,
GpuMat& dest,
unsigned int p1,
unsigned int p2,
int min_disp,
Stream& stream);
}
namespace oblique
{
template <unsigned int MAX_DISPARITY>
void aggregateUpleft2DownrightPath(
const GpuMat& left,
const GpuMat& right,
GpuMat& dest,
unsigned int p1,
unsigned int p2,
int min_disp,
Stream& stream);
template <unsigned int MAX_DISPARITY>
void aggregateUpright2DownleftPath(
const GpuMat& left,
const GpuMat& right,
GpuMat& dest,
unsigned int p1,
unsigned int p2,
int min_disp,
Stream& stream);
template <unsigned int MAX_DISPARITY>
void aggregateDownright2UpleftPath(
const GpuMat& left,
const GpuMat& right,
GpuMat& dest,
unsigned int p1,
unsigned int p2,
int min_disp,
Stream& stream);
template <unsigned int MAX_DISPARITY>
void aggregateDownleft2UprightPath(
const GpuMat& left,
const GpuMat& right,
GpuMat& dest,
unsigned int p1,
unsigned int p2,
int min_disp,
Stream& stream);
}
} // namespace path_aggregation
namespace winner_takes_all
{
template <size_t MAX_DISPARITY>
void winnerTakesAll(const GpuMat& src, GpuMat& left, GpuMat& right, float uniqueness, bool subpixel, int mode, cv::cuda::Stream& stream);
}
} // namespace stereosgm
}}} // namespace cv { namespace cuda { namespace device {
namespace opencv_test { namespace {
void census_transform(const cv::Mat& src, cv::Mat& dst)
{
const int hor = 9 / 2, ver = 7 / 2;
dst.create(src.size(), CV_32SC1);
dst = 0;
for (int y = ver; y < static_cast<int>(src.rows) - ver; ++y) {
for (int x = hor; x < static_cast<int>(src.cols) - hor; ++x) {
int32_t value = 0;
for (int dy = -ver; dy <= 0; ++dy) {
for (int dx = -hor; dx <= (dy == 0 ? -1 : hor); ++dx) {
const auto a = src.at<uint8_t>(y + dy, x + dx);
const auto b = src.at<uint8_t>(y - dy, x - dx);
value <<= 1;
if (a > b) { value |= 1; }
}
}
dst.at<int32_t>(y, x) = value;
}
}
}
PARAM_TEST_CASE(StereoSGM_CensusTransformImage, cv::cuda::DeviceInfo, std::string, UseRoi)
{
cv::cuda::DeviceInfo devInfo;
std::string path;
bool useRoi;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
path = GET_PARAM(1);
useRoi = GET_PARAM(2);
cv::cuda::setDevice(devInfo.deviceID());
}
};
CUDA_TEST_P(StereoSGM_CensusTransformImage, Image)
{
cv::Mat image = readImage(path, cv::IMREAD_GRAYSCALE);
cv::Mat dst_gold;
census_transform(image, dst_gold);
cv::cuda::GpuMat g_dst;
g_dst.create(image.size(), CV_32SC1);
cv::cuda::device::stereosgm::census_transform::censusTransform(loadMat(image, useRoi), g_dst, cv::cuda::Stream::Null());
cv::Mat dst;
g_dst.download(dst);
EXPECT_MAT_NEAR(dst_gold, dst, 0);
}
INSTANTIATE_TEST_CASE_P(CUDA_StereoSGM_funcs, StereoSGM_CensusTransformImage, testing::Combine(
ALL_DEVICES,
testing::Values("stereobm/aloe-L.png", "stereobm/aloe-R.png"),
WHOLE_SUBMAT));
PARAM_TEST_CASE(StereoSGM_CensusTransformRandom, cv::cuda::DeviceInfo, cv::Size, UseRoi)
{
cv::cuda::DeviceInfo devInfo;
cv::Size size;
bool useRoi;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
size = GET_PARAM(1);
useRoi = GET_PARAM(2);
cv::cuda::setDevice(devInfo.deviceID());
}
};
CUDA_TEST_P(StereoSGM_CensusTransformRandom, Random)
{
cv::Mat image = randomMat(size, CV_8UC1);
cv::Mat dst_gold;
census_transform(image, dst_gold);
cv::cuda::GpuMat g_dst;
g_dst.create(image.size(), CV_32SC1);
cv::cuda::device::stereosgm::census_transform::censusTransform(loadMat(image, useRoi), g_dst, cv::cuda::Stream::Null());
cv::Mat dst;
g_dst.download(dst);
EXPECT_MAT_NEAR(dst_gold, dst, 0);
}
INSTANTIATE_TEST_CASE_P(CUDA_StereoSGM_funcs, StereoSGM_CensusTransformRandom, testing::Combine(
ALL_DEVICES,
DIFFERENT_SIZES,
WHOLE_SUBMAT));
static void path_aggregation(
const cv::Mat& left,
const cv::Mat& right,
cv::Mat& dst,
int max_disparity, int min_disparity, int p1, int p2,
int dx, int dy)
{
const int width = left.cols;
const int height = left.rows;
dst.create(cv::Size(width * height * max_disparity, 1), CV_8UC1);
std::vector<int> before(max_disparity);
for (int i = (dy < 0 ? height - 1 : 0); 0 <= i && i < height; i += (dy < 0 ? -1 : 1)) {
for (int j = (dx < 0 ? width - 1 : 0); 0 <= j && j < width; j += (dx < 0 ? -1 : 1)) {
const int i2 = i - dy, j2 = j - dx;
const bool inside = (0 <= i2 && i2 < height && 0 <= j2 && j2 < width);
for (int k = 0; k < max_disparity; ++k) {
before[k] = inside ? dst.at<uint8_t>(0, k + (j2 + i2 * width) * max_disparity) : 0;
}
const int min_cost = *min_element(before.begin(), before.end());
for (int k = 0; k < max_disparity; ++k) {
const auto l = left.at<int32_t>(i, j);
const auto r = (k + min_disparity > j ? 0 : right.at<int32_t>(i, j - k - min_disparity));
int cost = std::min(before[k] - min_cost, p2);
if (k > 0) {
cost = std::min(cost, before[k - 1] - min_cost + p1);
}
if (k + 1 < max_disparity) {
cost = std::min(cost, before[k + 1] - min_cost + p1);
}
cost += static_cast<int>(popcnt64(l ^ r));
dst.at<uint8_t>(0, k + (j + i * width) * max_disparity) = static_cast<uint8_t>(cost);
}
}
}
}
static constexpr size_t DISPARITY = 128;
static constexpr int P1 = 10;
static constexpr int P2 = 120;
PARAM_TEST_CASE(StereoSGM_PathAggregation, cv::cuda::DeviceInfo, cv::Size, UseRoi, int)
{
cv::cuda::DeviceInfo devInfo;
cv::Size size;
bool useRoi;
int minDisp;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
size = GET_PARAM(1);
useRoi = GET_PARAM(2);
minDisp = GET_PARAM(3);
cv::cuda::setDevice(devInfo.deviceID());
}
template<typename T>
void test_path_aggregation(T func, int dx, int dy)
{
cv::Mat left_image = randomMat(size, CV_32SC1, 0.0, static_cast<double>(std::numeric_limits<int32_t>::max()));
cv::Mat right_image = randomMat(size, CV_32SC1, 0.0, static_cast<double>(std::numeric_limits<int32_t>::max()));
cv::Mat dst_gold;
path_aggregation(left_image, right_image, dst_gold, DISPARITY, minDisp, P1, P2, dx, dy);
cv::cuda::GpuMat g_dst;
g_dst.create(cv::Size(left_image.cols * left_image.rows * DISPARITY, 1), CV_8UC1);
func(loadMat(left_image, useRoi), loadMat(right_image, useRoi), g_dst, P1, P2, minDisp, cv::cuda::Stream::Null());
cv::Mat dst;
g_dst.download(dst);
EXPECT_MAT_NEAR(dst_gold, dst, 0);
}
};
CUDA_TEST_P(StereoSGM_PathAggregation, RandomLeft2Right)
{
test_path_aggregation(cv::cuda::device::stereosgm::path_aggregation::horizontal::aggregateLeft2RightPath<DISPARITY>, 1, 0);
}
CUDA_TEST_P(StereoSGM_PathAggregation, RandomRight2Left)
{
test_path_aggregation(cv::cuda::device::stereosgm::path_aggregation::horizontal::aggregateRight2LeftPath<DISPARITY>, -1, 0);
}
CUDA_TEST_P(StereoSGM_PathAggregation, RandomUp2Down)
{
test_path_aggregation(cv::cuda::device::stereosgm::path_aggregation::vertical::aggregateUp2DownPath<DISPARITY>, 0, 1);
}
CUDA_TEST_P(StereoSGM_PathAggregation, RandomDown2Up)
{
test_path_aggregation(cv::cuda::device::stereosgm::path_aggregation::vertical::aggregateDown2UpPath<DISPARITY>, 0, -1);
}
CUDA_TEST_P(StereoSGM_PathAggregation, RandomUpLeft2DownRight)
{
test_path_aggregation(cv::cuda::device::stereosgm::path_aggregation::oblique::aggregateUpleft2DownrightPath<DISPARITY>, 1, 1);
}
CUDA_TEST_P(StereoSGM_PathAggregation, RandomUpRight2DownLeft)
{
test_path_aggregation(cv::cuda::device::stereosgm::path_aggregation::oblique::aggregateUpright2DownleftPath<DISPARITY>, -1, 1);
}
CUDA_TEST_P(StereoSGM_PathAggregation, RandomDownRight2UpLeft)
{
test_path_aggregation(cv::cuda::device::stereosgm::path_aggregation::oblique::aggregateDownright2UpleftPath<DISPARITY>, -1, -1);
}
CUDA_TEST_P(StereoSGM_PathAggregation, RandomDownLeft2UpRight)
{
test_path_aggregation(cv::cuda::device::stereosgm::path_aggregation::oblique::aggregateDownleft2UprightPath<DISPARITY>, 1, -1);
}
INSTANTIATE_TEST_CASE_P(CUDA_StereoSGM_funcs, StereoSGM_PathAggregation, testing::Combine(
ALL_DEVICES,
DIFFERENT_SIZES,
WHOLE_SUBMAT,
testing::Values(0, 1, 10)));
void winner_takes_all_left(
const cv::Mat& src,
cv::Mat& dst,
int width, int height, int disparity, int num_paths,
float uniqueness, bool subpixel)
{
dst.create(cv::Size(width, height), CV_16UC1);
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j) {
std::vector<std::pair<int, int>> v;
for (int k = 0; k < disparity; ++k) {
int cost_sum = 0;
for (int p = 0; p < num_paths; ++p) {
cost_sum += static_cast<int>(src.at<uint8_t>(0,
p * disparity * width * height +
i * disparity * width +
j * disparity +
k));
}
v.emplace_back(cost_sum, static_cast<int>(k));
}
const auto ite = std::min_element(v.begin(), v.end());
assert(ite != v.end());
const auto best = *ite;
const int best_cost = best.first;
int best_disp = best.second;
int ans = best_disp;
if (subpixel) {
ans <<= StereoMatcher::DISP_SHIFT;
if (0 < best_disp && best_disp < static_cast<int>(disparity) - 1) {
const int left = v[best_disp - 1].first;
const int right = v[best_disp + 1].first;
const int numer = left - right;
const int denom = left - 2 * best_cost + right;
ans += ((numer << StereoMatcher::DISP_SHIFT) + denom) / (2 * denom);
}
}
for (const auto& p : v) {
const int cost = p.first;
const int disp = p.second;
if (cost * uniqueness < best_cost && abs(disp - best_disp) > 1) {
ans = -1;
break;
}
}
dst.at<uint16_t>(i, j) = static_cast<uint16_t>(ans);
}
}
}
PARAM_TEST_CASE(StereoSGM_WinnerTakesAll, cv::cuda::DeviceInfo, cv::Size, bool, int)
{
cv::cuda::DeviceInfo devInfo;
cv::Size size;
bool subpixel;
int mode;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
size = GET_PARAM(1);
subpixel = GET_PARAM(2);
mode = GET_PARAM(3);
cv::cuda::setDevice(devInfo.deviceID());
}
};
CUDA_TEST_P(StereoSGM_WinnerTakesAll, RandomLeft)
{
int num_paths = mode == cv::cuda::StereoSGM::MODE_HH4 ? 4 : 8;
cv::Mat aggregated = randomMat(cv::Size(size.width * size.height * DISPARITY * num_paths, 1), CV_8UC1, 0.0, 32.0);
cv::Mat dst_gold;
winner_takes_all_left(aggregated, dst_gold, size.width, size.height, DISPARITY, num_paths, 0.95f, subpixel);
cv::cuda::GpuMat g_src, g_dst, g_dst_right;
g_src.upload(aggregated);
g_dst.create(size, CV_16UC1);
g_dst_right.create(size, CV_16UC1);
cv::cuda::device::stereosgm::winner_takes_all::winnerTakesAll<DISPARITY>(g_src, g_dst, g_dst_right, 0.95f, subpixel, mode, cv::cuda::Stream::Null());
cv::Mat dst;
g_dst.download(dst);
EXPECT_MAT_NEAR(dst_gold, dst, 0);
}
INSTANTIATE_TEST_CASE_P(CUDA_StereoSGM_funcs, StereoSGM_WinnerTakesAll, testing::Combine(
ALL_DEVICES,
DIFFERENT_SIZES,
testing::Values(false, true),
testing::Values(cv::cuda::StereoSGM::MODE_HH4, cv::cuda::StereoSGM::MODE_HH)));
}} // namespace
#endif // HAVE_CUDA

@ -209,6 +209,761 @@ INSTANTIATE_TEST_CASE_P(CUDA_Stereo, ReprojectImageTo3D, testing::Combine(
testing::Values(MatDepth(CV_8U), MatDepth(CV_16S)),
WHOLE_SUBMAT));
////////////////////////////////////////////////////////////////////////////////
// StereoSGM
/*
This is a regression test for stereo matching algorithms. This test gets some quality metrics
described in "A Taxonomy and Evaluation of Dense Two-Frame Stereo Correspondence Algorithms".
Daniel Scharstein, Richard Szeliski
*/
const float EVAL_BAD_THRESH = 1.f;
const int EVAL_TEXTURELESS_WIDTH = 3;
const float EVAL_TEXTURELESS_THRESH = 4.f;
const float EVAL_DISP_THRESH = 1.f;
const float EVAL_DISP_GAP = 2.f;
const int EVAL_DISCONT_WIDTH = 9;
const int EVAL_IGNORE_BORDER = 10;
const int ERROR_KINDS_COUNT = 6;
//============================== quality measuring functions =================================================
/*
Calculate textureless regions of image (regions where the squared horizontal intensity gradient averaged over
a square window of size=evalTexturelessWidth is below a threshold=evalTexturelessThresh) and textured regions.
*/
void computeTextureBasedMasks(const Mat& _img, Mat* texturelessMask, Mat* texturedMask,
int texturelessWidth = EVAL_TEXTURELESS_WIDTH, float texturelessThresh = EVAL_TEXTURELESS_THRESH)
{
if (!texturelessMask && !texturedMask)
return;
if (_img.empty())
CV_Error(Error::StsBadArg, "img is empty");
Mat img = _img;
if (_img.channels() > 1)
{
Mat tmp; cvtColor(_img, tmp, COLOR_BGR2GRAY); img = tmp;
}
Mat dxI; Sobel(img, dxI, CV_32FC1, 1, 0, 3);
Mat dxI2; pow(dxI / 8.f/*normalize*/, 2, dxI2);
Mat avgDxI2; boxFilter(dxI2, avgDxI2, CV_32FC1, Size(texturelessWidth, texturelessWidth));
if (texturelessMask)
*texturelessMask = avgDxI2 < texturelessThresh;
if (texturedMask)
*texturedMask = avgDxI2 >= texturelessThresh;
}
void checkTypeAndSizeOfDisp(const Mat& dispMap, const Size* sz)
{
if (dispMap.empty())
CV_Error(Error::StsBadArg, "dispMap is empty");
if (dispMap.type() != CV_32FC1)
CV_Error(Error::StsBadArg, "dispMap must have CV_32FC1 type");
if (sz && (dispMap.rows != sz->height || dispMap.cols != sz->width))
CV_Error(Error::StsBadArg, "dispMap has incorrect size");
}
void checkTypeAndSizeOfMask(const Mat& mask, Size sz)
{
if (mask.empty())
CV_Error(Error::StsBadArg, "mask is empty");
if (mask.type() != CV_8UC1)
CV_Error(Error::StsBadArg, "mask must have CV_8UC1 type");
if (mask.rows != sz.height || mask.cols != sz.width)
CV_Error(Error::StsBadArg, "mask has incorrect size");
}
void checkDispMapsAndUnknDispMasks(const Mat& leftDispMap, const Mat& rightDispMap,
const Mat& leftUnknDispMask, const Mat& rightUnknDispMask)
{
// check type and size of disparity maps
checkTypeAndSizeOfDisp(leftDispMap, 0);
if (!rightDispMap.empty())
{
Size sz = leftDispMap.size();
checkTypeAndSizeOfDisp(rightDispMap, &sz);
}
// check size and type of unknown disparity maps
if (!leftUnknDispMask.empty())
checkTypeAndSizeOfMask(leftUnknDispMask, leftDispMap.size());
if (!rightUnknDispMask.empty())
checkTypeAndSizeOfMask(rightUnknDispMask, rightDispMap.size());
// check values of disparity maps (known disparity values musy be positive)
double leftMinVal = 0, rightMinVal = 0;
if (leftUnknDispMask.empty())
minMaxLoc(leftDispMap, &leftMinVal);
else
minMaxLoc(leftDispMap, &leftMinVal, 0, 0, 0, ~leftUnknDispMask);
if (!rightDispMap.empty())
{
if (rightUnknDispMask.empty())
minMaxLoc(rightDispMap, &rightMinVal);
else
minMaxLoc(rightDispMap, &rightMinVal, 0, 0, 0, ~rightUnknDispMask);
}
if (leftMinVal < 0 || rightMinVal < 0)
CV_Error(Error::StsBadArg, "known disparity values must be positive");
}
/*
Calculate occluded regions of reference image (left image) (regions that are occluded in the matching image (right image),
i.e., where the forward-mapped disparity lands at a location with a larger (nearer) disparity) and non occluded regions.
*/
void computeOcclusionBasedMasks(const Mat& leftDisp, const Mat& _rightDisp,
Mat* occludedMask, Mat* nonOccludedMask,
const Mat& leftUnknDispMask = Mat(), const Mat& rightUnknDispMask = Mat(),
float dispThresh = EVAL_DISP_THRESH)
{
if (!occludedMask && !nonOccludedMask)
return;
checkDispMapsAndUnknDispMasks(leftDisp, _rightDisp, leftUnknDispMask, rightUnknDispMask);
Mat rightDisp;
if (_rightDisp.empty())
{
if (!rightUnknDispMask.empty())
CV_Error(Error::StsBadArg, "rightUnknDispMask must be empty if _rightDisp is empty");
rightDisp.create(leftDisp.size(), CV_32FC1);
rightDisp.setTo(Scalar::all(0));
for (int leftY = 0; leftY < leftDisp.rows; leftY++)
{
for (int leftX = 0; leftX < leftDisp.cols; leftX++)
{
if (!leftUnknDispMask.empty() && leftUnknDispMask.at<uchar>(leftY, leftX))
continue;
float leftDispVal = leftDisp.at<float>(leftY, leftX);
int rightX = leftX - cvRound(leftDispVal), rightY = leftY;
if (rightX >= 0)
rightDisp.at<float>(rightY, rightX) = max(rightDisp.at<float>(rightY, rightX), leftDispVal);
}
}
}
else
_rightDisp.copyTo(rightDisp);
if (occludedMask)
{
occludedMask->create(leftDisp.size(), CV_8UC1);
occludedMask->setTo(Scalar::all(0));
}
if (nonOccludedMask)
{
nonOccludedMask->create(leftDisp.size(), CV_8UC1);
nonOccludedMask->setTo(Scalar::all(0));
}
for (int leftY = 0; leftY < leftDisp.rows; leftY++)
{
for (int leftX = 0; leftX < leftDisp.cols; leftX++)
{
if (!leftUnknDispMask.empty() && leftUnknDispMask.at<uchar>(leftY, leftX))
continue;
float leftDispVal = leftDisp.at<float>(leftY, leftX);
int rightX = leftX - cvRound(leftDispVal), rightY = leftY;
if (rightX < 0 && occludedMask)
occludedMask->at<uchar>(leftY, leftX) = 255;
else
{
if (!rightUnknDispMask.empty() && rightUnknDispMask.at<uchar>(rightY, rightX))
continue;
float rightDispVal = rightDisp.at<float>(rightY, rightX);
if (rightDispVal > leftDispVal + dispThresh)
{
if (occludedMask)
occludedMask->at<uchar>(leftY, leftX) = 255;
}
else
{
if (nonOccludedMask)
nonOccludedMask->at<uchar>(leftY, leftX) = 255;
}
}
}
}
}
/*
Calculate depth discontinuty regions: pixels whose neiboring disparities differ by more than
dispGap, dilated by window of width discontWidth.
*/
void computeDepthDiscontMask(const Mat& disp, Mat& depthDiscontMask, const Mat& unknDispMask = Mat(),
float dispGap = EVAL_DISP_GAP, int discontWidth = EVAL_DISCONT_WIDTH)
{
if (disp.empty())
CV_Error(Error::StsBadArg, "disp is empty");
if (disp.type() != CV_32FC1)
CV_Error(Error::StsBadArg, "disp must have CV_32FC1 type");
if (!unknDispMask.empty())
checkTypeAndSizeOfMask(unknDispMask, disp.size());
Mat curDisp; disp.copyTo(curDisp);
if (!unknDispMask.empty())
curDisp.setTo(Scalar(std::numeric_limits<float>::min()), unknDispMask);
Mat maxNeighbDisp; dilate(curDisp, maxNeighbDisp, Mat(3, 3, CV_8UC1, Scalar(1)));
if (!unknDispMask.empty())
curDisp.setTo(Scalar(std::numeric_limits<float>::max()), unknDispMask);
Mat minNeighbDisp; erode(curDisp, minNeighbDisp, Mat(3, 3, CV_8UC1, Scalar(1)));
depthDiscontMask = max((Mat)(maxNeighbDisp - disp), (Mat)(disp - minNeighbDisp)) > dispGap;
if (!unknDispMask.empty())
depthDiscontMask &= ~unknDispMask;
dilate(depthDiscontMask, depthDiscontMask, Mat(discontWidth, discontWidth, CV_8UC1, Scalar(1)));
}
/*
Get evaluation masks excluding a border.
*/
Mat getBorderedMask(Size maskSize, int border = EVAL_IGNORE_BORDER)
{
CV_Assert(border >= 0);
Mat mask(maskSize, CV_8UC1, Scalar(0));
int w = maskSize.width - 2 * border, h = maskSize.height - 2 * border;
if (w < 0 || h < 0)
mask.setTo(Scalar(0));
else
mask(Rect(Point(border, border), Size(w, h))).setTo(Scalar(255));
return mask;
}
/*
Calculate root-mean-squared error between the computed disparity map (computedDisp) and ground truth map (groundTruthDisp).
*/
float dispRMS(const Mat& computedDisp, const Mat& groundTruthDisp, const Mat& mask)
{
checkTypeAndSizeOfDisp(groundTruthDisp, 0);
Size sz = groundTruthDisp.size();
checkTypeAndSizeOfDisp(computedDisp, &sz);
int pointsCount = sz.height*sz.width;
if (!mask.empty())
{
checkTypeAndSizeOfMask(mask, sz);
pointsCount = countNonZero(mask);
}
return 1.f / sqrt((float)pointsCount) * (float)cvtest::norm(computedDisp, groundTruthDisp, NORM_L2, mask);
}
/*
Calculate fraction of bad matching pixels.
*/
float badMatchPxlsFraction(const Mat& computedDisp, const Mat& groundTruthDisp, const Mat& mask,
float _badThresh = EVAL_BAD_THRESH)
{
int badThresh = cvRound(_badThresh);
checkTypeAndSizeOfDisp(groundTruthDisp, 0);
Size sz = groundTruthDisp.size();
checkTypeAndSizeOfDisp(computedDisp, &sz);
Mat badPxlsMap;
absdiff(computedDisp, groundTruthDisp, badPxlsMap);
badPxlsMap = badPxlsMap > badThresh;
int pointsCount = sz.height*sz.width;
if (!mask.empty())
{
checkTypeAndSizeOfMask(mask, sz);
badPxlsMap = badPxlsMap & mask;
pointsCount = countNonZero(mask);
}
return 1.f / pointsCount * countNonZero(badPxlsMap);
}
//===================== regression test for stereo matching algorithms ==============================
const string ALGORITHMS_DIR = "stereomatching/algorithms/";
const string DATASETS_DIR = "stereomatching/datasets/";
const string DATASETS_FILE = "datasets.xml";
const string RUN_PARAMS_FILE = "_params.xml";
const string RESULT_FILE = "_res.xml";
const string LEFT_IMG_NAME = "im2.png";
const string RIGHT_IMG_NAME = "im6.png";
const string TRUE_LEFT_DISP_NAME = "disp2.png";
const string TRUE_RIGHT_DISP_NAME = "disp6.png";
string ERROR_PREFIXES[] = { "borderedAll",
"borderedNoOccl",
"borderedOccl",
"borderedTextured",
"borderedTextureless",
"borderedDepthDiscont" }; // size of ERROR_KINDS_COUNT
string ROI_PREFIXES[] = { "roiX",
"roiY",
"roiWidth",
"roiHeight" };
const string RMS_STR = "RMS";
const string BAD_PXLS_FRACTION_STR = "BadPxlsFraction";
const string ROI_STR = "ValidDisparityROI";
class QualityEvalParams
{
public:
QualityEvalParams()
{
setDefaults();
}
QualityEvalParams(int _ignoreBorder)
{
setDefaults();
ignoreBorder = _ignoreBorder;
}
void setDefaults()
{
badThresh = EVAL_BAD_THRESH;
texturelessWidth = EVAL_TEXTURELESS_WIDTH;
texturelessThresh = EVAL_TEXTURELESS_THRESH;
dispThresh = EVAL_DISP_THRESH;
dispGap = EVAL_DISP_GAP;
discontWidth = EVAL_DISCONT_WIDTH;
ignoreBorder = EVAL_IGNORE_BORDER;
}
float badThresh;
int texturelessWidth;
float texturelessThresh;
float dispThresh;
float dispGap;
int discontWidth;
int ignoreBorder;
};
class CV_StereoMatchingTest : public cvtest::BaseTest
{
public:
CV_StereoMatchingTest()
{
rmsEps.resize(ERROR_KINDS_COUNT, 0.01f); fracEps.resize(ERROR_KINDS_COUNT, 1.e-6f);
}
protected:
// assumed that left image is a reference image
virtual int runStereoMatchingAlgorithm(const Mat& leftImg, const Mat& rightImg,
Rect& calcROI, Mat& leftDisp, Mat& rightDisp, int caseIdx) = 0; // return ignored border width
int readDatasetsParams(FileStorage& fs);
virtual int readRunParams(FileStorage& fs);
void writeErrors(const string& errName, const vector<float>& errors, FileStorage* fs = 0);
void writeROI(const Rect& calcROI, FileStorage* fs = 0);
void readErrors(FileNode& fn, const string& errName, vector<float>& errors);
void readROI(FileNode& fn, Rect& trueROI);
int compareErrors(const vector<float>& calcErrors, const vector<float>& validErrors,
const vector<float>& eps, const string& errName);
int compareROI(const Rect& calcROI, const Rect& validROI);
int processStereoMatchingResults(FileStorage& fs, int caseIdx, bool isWrite,
const Mat& leftImg, const Mat& rightImg,
const Rect& calcROI,
const Mat& trueLeftDisp, const Mat& trueRightDisp,
const Mat& leftDisp, const Mat& rightDisp,
const QualityEvalParams& qualityEvalParams);
void run(int);
vector<float> rmsEps;
vector<float> fracEps;
struct DatasetParams
{
int dispScaleFactor;
int dispUnknVal;
};
map<string, DatasetParams> datasetsParams;
vector<string> caseNames;
vector<string> caseDatasets;
};
void CV_StereoMatchingTest::run(int)
{
addDataSearchSubDirectory("cv");
string algorithmName = name;
assert(!algorithmName.empty());
FileStorage datasetsFS(findDataFile(DATASETS_DIR + DATASETS_FILE), FileStorage::READ);
int code = readDatasetsParams(datasetsFS);
if (code != cvtest::TS::OK)
{
ts->set_failed_test_info(code);
return;
}
FileStorage runParamsFS(findDataFile(ALGORITHMS_DIR + algorithmName + RUN_PARAMS_FILE), FileStorage::READ);
code = readRunParams(runParamsFS);
if (code != cvtest::TS::OK)
{
ts->set_failed_test_info(code);
return;
}
string fullResultFilename = findDataDirectory(ALGORITHMS_DIR) + algorithmName + RESULT_FILE;
FileStorage resFS(fullResultFilename, FileStorage::READ);
bool isWrite = true; // write or compare results
if (resFS.isOpened())
isWrite = false;
else
{
resFS.open(fullResultFilename, FileStorage::WRITE);
if (!resFS.isOpened())
{
ts->printf(cvtest::TS::LOG, "file %s can not be read or written\n", fullResultFilename.c_str());
ts->set_failed_test_info(cvtest::TS::FAIL_BAD_ARG_CHECK);
return;
}
resFS << "stereo_matching" << "{";
}
int progress = 0, caseCount = (int)caseNames.size();
for (int ci = 0; ci < caseCount; ci++)
{
progress = update_progress(progress, ci, caseCount, 0);
printf("progress: %d%%\n", progress);
fflush(stdout);
string datasetName = caseDatasets[ci];
string datasetFullDirName = findDataDirectory(DATASETS_DIR) + datasetName + "/";
Mat leftImg = imread(datasetFullDirName + LEFT_IMG_NAME);
Mat rightImg = imread(datasetFullDirName + RIGHT_IMG_NAME);
Mat trueLeftDisp = imread(datasetFullDirName + TRUE_LEFT_DISP_NAME, 0);
Mat trueRightDisp = imread(datasetFullDirName + TRUE_RIGHT_DISP_NAME, 0);
Rect calcROI;
if (leftImg.empty() || rightImg.empty() || trueLeftDisp.empty())
{
ts->printf(cvtest::TS::LOG, "images or left ground-truth disparities of dataset %s can not be read", datasetName.c_str());
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
continue;
}
int dispScaleFactor = datasetsParams[datasetName].dispScaleFactor;
Mat tmp;
trueLeftDisp.convertTo(tmp, CV_32FC1, 1.f / dispScaleFactor);
trueLeftDisp = tmp;
tmp.release();
if (!trueRightDisp.empty())
{
trueRightDisp.convertTo(tmp, CV_32FC1, 1.f / dispScaleFactor);
trueRightDisp = tmp;
tmp.release();
}
Mat leftDisp, rightDisp;
int ignBorder = max(runStereoMatchingAlgorithm(leftImg, rightImg, calcROI, leftDisp, rightDisp, ci), EVAL_IGNORE_BORDER);
leftDisp.convertTo(tmp, CV_32FC1);
leftDisp = tmp;
tmp.release();
rightDisp.convertTo(tmp, CV_32FC1);
rightDisp = tmp;
tmp.release();
int tempCode = processStereoMatchingResults(resFS, ci, isWrite,
leftImg, rightImg, calcROI, trueLeftDisp, trueRightDisp, leftDisp, rightDisp, QualityEvalParams(ignBorder));
code = tempCode == cvtest::TS::OK ? code : tempCode;
}
if (isWrite)
resFS << "}"; // "stereo_matching"
ts->set_failed_test_info(code);
}
void calcErrors(const Mat& leftImg, const Mat& /*rightImg*/,
const Mat& trueLeftDisp, const Mat& trueRightDisp,
const Mat& trueLeftUnknDispMask, const Mat& trueRightUnknDispMask,
const Mat& calcLeftDisp, const Mat& /*calcRightDisp*/,
vector<float>& rms, vector<float>& badPxlsFractions,
const QualityEvalParams& qualityEvalParams)
{
Mat texturelessMask, texturedMask;
computeTextureBasedMasks(leftImg, &texturelessMask, &texturedMask,
qualityEvalParams.texturelessWidth, qualityEvalParams.texturelessThresh);
Mat occludedMask, nonOccludedMask;
computeOcclusionBasedMasks(trueLeftDisp, trueRightDisp, &occludedMask, &nonOccludedMask,
trueLeftUnknDispMask, trueRightUnknDispMask, qualityEvalParams.dispThresh);
Mat depthDiscontMask;
computeDepthDiscontMask(trueLeftDisp, depthDiscontMask, trueLeftUnknDispMask,
qualityEvalParams.dispGap, qualityEvalParams.discontWidth);
Mat borderedKnownMask = getBorderedMask(leftImg.size(), qualityEvalParams.ignoreBorder) & ~trueLeftUnknDispMask;
nonOccludedMask &= borderedKnownMask;
occludedMask &= borderedKnownMask;
texturedMask &= nonOccludedMask; // & borderedKnownMask
texturelessMask &= nonOccludedMask; // & borderedKnownMask
depthDiscontMask &= nonOccludedMask; // & borderedKnownMask
rms.resize(ERROR_KINDS_COUNT);
rms[0] = dispRMS(calcLeftDisp, trueLeftDisp, borderedKnownMask);
rms[1] = dispRMS(calcLeftDisp, trueLeftDisp, nonOccludedMask);
rms[2] = dispRMS(calcLeftDisp, trueLeftDisp, occludedMask);
rms[3] = dispRMS(calcLeftDisp, trueLeftDisp, texturedMask);
rms[4] = dispRMS(calcLeftDisp, trueLeftDisp, texturelessMask);
rms[5] = dispRMS(calcLeftDisp, trueLeftDisp, depthDiscontMask);
badPxlsFractions.resize(ERROR_KINDS_COUNT);
badPxlsFractions[0] = badMatchPxlsFraction(calcLeftDisp, trueLeftDisp, borderedKnownMask, qualityEvalParams.badThresh);
badPxlsFractions[1] = badMatchPxlsFraction(calcLeftDisp, trueLeftDisp, nonOccludedMask, qualityEvalParams.badThresh);
badPxlsFractions[2] = badMatchPxlsFraction(calcLeftDisp, trueLeftDisp, occludedMask, qualityEvalParams.badThresh);
badPxlsFractions[3] = badMatchPxlsFraction(calcLeftDisp, trueLeftDisp, texturedMask, qualityEvalParams.badThresh);
badPxlsFractions[4] = badMatchPxlsFraction(calcLeftDisp, trueLeftDisp, texturelessMask, qualityEvalParams.badThresh);
badPxlsFractions[5] = badMatchPxlsFraction(calcLeftDisp, trueLeftDisp, depthDiscontMask, qualityEvalParams.badThresh);
}
int CV_StereoMatchingTest::processStereoMatchingResults(FileStorage& fs, int caseIdx, bool isWrite,
const Mat& leftImg, const Mat& rightImg,
const Rect& calcROI,
const Mat& trueLeftDisp, const Mat& trueRightDisp,
const Mat& leftDisp, const Mat& rightDisp,
const QualityEvalParams& qualityEvalParams)
{
// rightDisp is not used in current test virsion
int code = cvtest::TS::OK;
assert(fs.isOpened());
assert(trueLeftDisp.type() == CV_32FC1);
assert(trueRightDisp.empty() || trueRightDisp.type() == CV_32FC1);
assert(leftDisp.type() == CV_32FC1 && (rightDisp.empty() || rightDisp.type() == CV_32FC1));
// get masks for unknown ground truth disparity values
Mat leftUnknMask, rightUnknMask;
DatasetParams params = datasetsParams[caseDatasets[caseIdx]];
absdiff(trueLeftDisp, Scalar(params.dispUnknVal), leftUnknMask);
leftUnknMask = leftUnknMask < std::numeric_limits<float>::epsilon();
assert(leftUnknMask.type() == CV_8UC1);
if (!trueRightDisp.empty())
{
absdiff(trueRightDisp, Scalar(params.dispUnknVal), rightUnknMask);
rightUnknMask = rightUnknMask < std::numeric_limits<float>::epsilon();
assert(rightUnknMask.type() == CV_8UC1);
}
// calculate errors
vector<float> rmss, badPxlsFractions;
calcErrors(leftImg, rightImg, trueLeftDisp, trueRightDisp, leftUnknMask, rightUnknMask,
leftDisp, rightDisp, rmss, badPxlsFractions, qualityEvalParams);
if (isWrite)
{
fs << caseNames[caseIdx] << "{";
fs.writeComment(RMS_STR, 0);
writeErrors(RMS_STR, rmss, &fs);
fs.writeComment(BAD_PXLS_FRACTION_STR, 0);
writeErrors(BAD_PXLS_FRACTION_STR, badPxlsFractions, &fs);
fs.writeComment(ROI_STR, 0);
writeROI(calcROI, &fs);
fs << "}"; // datasetName
}
else // compare
{
ts->printf(cvtest::TS::LOG, "\nquality of case named %s\n", caseNames[caseIdx].c_str());
ts->printf(cvtest::TS::LOG, "%s\n", RMS_STR.c_str());
writeErrors(RMS_STR, rmss);
ts->printf(cvtest::TS::LOG, "%s\n", BAD_PXLS_FRACTION_STR.c_str());
writeErrors(BAD_PXLS_FRACTION_STR, badPxlsFractions);
ts->printf(cvtest::TS::LOG, "%s\n", ROI_STR.c_str());
writeROI(calcROI);
FileNode fn = fs.getFirstTopLevelNode()[caseNames[caseIdx]];
vector<float> validRmss, validBadPxlsFractions;
Rect validROI;
readErrors(fn, RMS_STR, validRmss);
readErrors(fn, BAD_PXLS_FRACTION_STR, validBadPxlsFractions);
readROI(fn, validROI);
int tempCode = compareErrors(rmss, validRmss, rmsEps, RMS_STR);
code = tempCode == cvtest::TS::OK ? code : tempCode;
tempCode = compareErrors(badPxlsFractions, validBadPxlsFractions, fracEps, BAD_PXLS_FRACTION_STR);
code = tempCode == cvtest::TS::OK ? code : tempCode;
tempCode = compareROI(calcROI, validROI);
code = tempCode == cvtest::TS::OK ? code : tempCode;
}
return code;
}
int CV_StereoMatchingTest::readDatasetsParams(FileStorage& fs)
{
if (!fs.isOpened())
{
ts->printf(cvtest::TS::LOG, "datasetsParams can not be read ");
return cvtest::TS::FAIL_INVALID_TEST_DATA;
}
datasetsParams.clear();
FileNode fn = fs.getFirstTopLevelNode();
assert(fn.isSeq());
for (int i = 0; i < (int)fn.size(); i += 3)
{
String _name = fn[i];
DatasetParams params;
String sf = fn[i + 1]; params.dispScaleFactor = atoi(sf.c_str());
String uv = fn[i + 2]; params.dispUnknVal = atoi(uv.c_str());
datasetsParams[_name] = params;
}
return cvtest::TS::OK;
}
int CV_StereoMatchingTest::readRunParams(FileStorage& fs)
{
if (!fs.isOpened())
{
ts->printf(cvtest::TS::LOG, "runParams can not be read ");
return cvtest::TS::FAIL_INVALID_TEST_DATA;
}
caseNames.clear();;
caseDatasets.clear();
return cvtest::TS::OK;
}
void CV_StereoMatchingTest::writeErrors(const string& errName, const vector<float>& errors, FileStorage* fs)
{
assert((int)errors.size() == ERROR_KINDS_COUNT);
vector<float>::const_iterator it = errors.begin();
if (fs)
for (int i = 0; i < ERROR_KINDS_COUNT; i++, ++it)
*fs << ERROR_PREFIXES[i] + errName << *it;
else
for (int i = 0; i < ERROR_KINDS_COUNT; i++, ++it)
ts->printf(cvtest::TS::LOG, "%s = %f\n", string(ERROR_PREFIXES[i] + errName).c_str(), *it);
}
void CV_StereoMatchingTest::writeROI(const Rect& calcROI, FileStorage* fs)
{
if (fs)
{
*fs << ROI_PREFIXES[0] << calcROI.x;
*fs << ROI_PREFIXES[1] << calcROI.y;
*fs << ROI_PREFIXES[2] << calcROI.width;
*fs << ROI_PREFIXES[3] << calcROI.height;
}
else
{
ts->printf(cvtest::TS::LOG, "%s = %d\n", ROI_PREFIXES[0].c_str(), calcROI.x);
ts->printf(cvtest::TS::LOG, "%s = %d\n", ROI_PREFIXES[1].c_str(), calcROI.y);
ts->printf(cvtest::TS::LOG, "%s = %d\n", ROI_PREFIXES[2].c_str(), calcROI.width);
ts->printf(cvtest::TS::LOG, "%s = %d\n", ROI_PREFIXES[3].c_str(), calcROI.height);
}
}
void CV_StereoMatchingTest::readErrors(FileNode& fn, const string& errName, vector<float>& errors)
{
errors.resize(ERROR_KINDS_COUNT);
vector<float>::iterator it = errors.begin();
for (int i = 0; i < ERROR_KINDS_COUNT; i++, ++it)
fn[ERROR_PREFIXES[i] + errName] >> *it;
}
void CV_StereoMatchingTest::readROI(FileNode& fn, Rect& validROI)
{
fn[ROI_PREFIXES[0]] >> validROI.x;
fn[ROI_PREFIXES[1]] >> validROI.y;
fn[ROI_PREFIXES[2]] >> validROI.width;
fn[ROI_PREFIXES[3]] >> validROI.height;
}
int CV_StereoMatchingTest::compareErrors(const vector<float>& calcErrors, const vector<float>& validErrors,
const vector<float>& eps, const string& errName)
{
assert((int)calcErrors.size() == ERROR_KINDS_COUNT);
assert((int)validErrors.size() == ERROR_KINDS_COUNT);
assert((int)eps.size() == ERROR_KINDS_COUNT);
vector<float>::const_iterator calcIt = calcErrors.begin(),
validIt = validErrors.begin(),
epsIt = eps.begin();
bool ok = true;
for (int i = 0; i < ERROR_KINDS_COUNT; i++, ++calcIt, ++validIt, ++epsIt)
if (*calcIt - *validIt > *epsIt)
{
ts->printf(cvtest::TS::LOG, "bad accuracy of %s (valid=%f; calc=%f)\n", string(ERROR_PREFIXES[i] + errName).c_str(), *validIt, *calcIt);
ok = false;
}
return ok ? cvtest::TS::OK : cvtest::TS::FAIL_BAD_ACCURACY;
}
int CV_StereoMatchingTest::compareROI(const Rect& calcROI, const Rect& validROI)
{
int compare[4][2] = {
{ calcROI.x, validROI.x },
{ calcROI.y, validROI.y },
{ calcROI.width, validROI.width },
{ calcROI.height, validROI.height },
};
bool ok = true;
for (int i = 0; i < 4; i++)
{
if (compare[i][0] != compare[i][1])
{
ts->printf(cvtest::TS::LOG, "bad accuracy of %s (valid=%d; calc=%d)\n", ROI_PREFIXES[i].c_str(), compare[i][1], compare[i][0]);
ok = false;
}
}
return ok ? cvtest::TS::OK : cvtest::TS::FAIL_BAD_ACCURACY;
}
//----------------------------------- StereoSGM test -----------------------------------------------------
class CV_Cuda_StereoSGMTest : public CV_StereoMatchingTest
{
public:
CV_Cuda_StereoSGMTest()
{
name = "cuda_stereosgm";
fill(rmsEps.begin(), rmsEps.end(), 0.25f);
fill(fracEps.begin(), fracEps.end(), 0.01f);
}
protected:
struct RunParams
{
int ndisp;
int mode;
};
vector<RunParams> caseRunParams;
virtual int readRunParams(FileStorage& fs)
{
int code = CV_StereoMatchingTest::readRunParams(fs);
FileNode fn = fs.getFirstTopLevelNode();
assert(fn.isSeq());
for (int i = 0; i < (int)fn.size(); i += 4)
{
String caseName = fn[i], datasetName = fn[i + 1];
RunParams params;
String ndisp = fn[i + 2]; params.ndisp = atoi(ndisp.c_str());
String mode = fn[i + 3]; params.mode = atoi(mode.c_str());
caseNames.push_back(caseName);
caseDatasets.push_back(datasetName);
caseRunParams.push_back(params);
}
return code;
}
virtual int runStereoMatchingAlgorithm(const Mat& leftImg, const Mat& rightImg,
Rect& calcROI, Mat& leftDisp, Mat& /*rightDisp*/, int caseIdx)
{
RunParams params = caseRunParams[caseIdx];
assert(params.ndisp % 16 == 0);
Ptr<StereoMatcher> sgm = createStereoSGM(0, params.ndisp, 10, 120, 5, params.mode);
cv::Mat G1, G2;
cv::cvtColor(leftImg, G1, cv::COLOR_RGB2GRAY);
cv::cvtColor(rightImg, G2, cv::COLOR_RGB2GRAY);
cv::cuda::GpuMat d_leftImg, d_rightImg, d_leftDisp;
d_leftImg.upload(G1);
d_rightImg.upload(G2);
sgm->compute(d_leftImg, d_rightImg, d_leftDisp);
d_leftDisp.download(leftDisp);
CV_Assert(leftDisp.type() == CV_16SC1);
leftDisp.convertTo(leftDisp, CV_32FC1, 1.0 / StereoMatcher::DISP_SCALE);
calcROI.x = calcROI.y = 0;
calcROI.width = leftImg.cols;
calcROI.height = leftImg.rows;
return 0;
}
};
TEST(CudaStereo_StereoSGM, regression) { CV_Cuda_StereoSGMTest test; test.safe_run(); }
}} // namespace
#endif // HAVE_CUDA

Loading…
Cancel
Save