Merge branch 4.x

pull/3772/head
Alexander Smorkalov 8 months ago
commit 1af009236a
  1. 53
      modules/cudaarithm/include/opencv2/cudaarithm.hpp
  2. 242
      modules/cudaarithm/src/cuda/polar_cart.cu
  3. 231
      modules/cudaarithm/test/test_element_operations.cpp
  4. 6
      modules/cudawarping/include/opencv2/cudawarping.hpp
  5. 88
      modules/cudawarping/src/cuda/remap.cu
  6. 9
      modules/cudawarping/src/remap.cpp
  7. 77
      modules/cudawarping/test/test_remap.cpp
  8. 50
      modules/cudev/include/opencv2/cudev/functional/functional.hpp

@ -433,6 +433,17 @@ CV_EXPORTS_W void magnitudeSqr(InputArray x, InputArray y, OutputArray magnitude
*/
CV_EXPORTS_W void phase(InputArray x, InputArray y, OutputArray angle, bool angleInDegrees = false, Stream& stream = Stream::Null());
/** @brief Computes polar angles of complex matrix elements.
@param xy Source matrix containing real and imaginary components ( CV_32FC2 ).
@param angle Destination matrix of angles ( CV_32FC1 ).
@param angleInDegrees Flag for angles that must be evaluated in degrees.
@param stream Stream for the asynchronous version.
@sa phase
*/
CV_EXPORTS_W void phase(InputArray xy, OutputArray angle, bool angleInDegrees = false, Stream& stream = Stream::Null());
/** @brief Converts Cartesian coordinates into polar.
@param x Source matrix containing real components ( CV_32FC1 ).
@ -446,6 +457,29 @@ CV_EXPORTS_W void phase(InputArray x, InputArray y, OutputArray angle, bool angl
*/
CV_EXPORTS_W void cartToPolar(InputArray x, InputArray y, OutputArray magnitude, OutputArray angle, bool angleInDegrees = false, Stream& stream = Stream::Null());
/** @brief Converts Cartesian coordinates into polar.
@param xy Source matrix containing real and imaginary components ( CV_32FC2 ).
@param magnitude Destination matrix of float magnitudes ( CV_32FC1 ).
@param angle Destination matrix of angles ( CV_32FC1 ).
@param angleInDegrees Flag for angles that must be evaluated in degrees.
@param stream Stream for the asynchronous version.
@sa cartToPolar
*/
CV_EXPORTS_W void cartToPolar(InputArray xy, OutputArray magnitude, OutputArray angle, bool angleInDegrees = false, Stream& stream = Stream::Null());
/** @brief Converts Cartesian coordinates into polar.
@param xy Source matrix containing real and imaginary components ( CV_32FC2 ).
@param magnitudeAngle Destination matrix of float magnitudes and angles ( CV_32FC2 ).
@param angleInDegrees Flag for angles that must be evaluated in degrees.
@param stream Stream for the asynchronous version.
@sa cartToPolar
*/
CV_EXPORTS_W void cartToPolar(InputArray xy, OutputArray magnitudeAngle, bool angleInDegrees = false, Stream& stream = Stream::Null());
/** @brief Converts polar coordinates into Cartesian.
@param magnitude Source matrix containing magnitudes ( CV_32FC1 or CV_64FC1 ).
@ -457,6 +491,25 @@ CV_EXPORTS_W void cartToPolar(InputArray x, InputArray y, OutputArray magnitude,
*/
CV_EXPORTS_W void polarToCart(InputArray magnitude, InputArray angle, OutputArray x, OutputArray y, bool angleInDegrees = false, Stream& stream = Stream::Null());
/** @brief Converts polar coordinates into Cartesian.
@param magnitude Source matrix containing magnitudes ( CV_32FC1 or CV_64FC1 ).
@param angle Source matrix containing angles ( same type as magnitude ).
@param xy Destination matrix of real and imaginary components ( same depth as magnitude, i.e. CV_32FC2 or CV_64FC2 ).
@param angleInDegrees Flag that indicates angles in degrees.
@param stream Stream for the asynchronous version.
*/
CV_EXPORTS_W void polarToCart(InputArray magnitude, InputArray angle, OutputArray xy, bool angleInDegrees = false, Stream& stream = Stream::Null());
/** @brief Converts polar coordinates into Cartesian.
@param magnitudeAngle Source matrix containing magnitudes and angles ( CV_32FC2 or CV_64FC2 ).
@param xy Destination matrix of real and imaginary components ( same depth as source ).
@param angleInDegrees Flag that indicates angles in degrees.
@param stream Stream for the asynchronous version.
*/
CV_EXPORTS_W void polarToCart(InputArray magnitudeAngle, OutputArray xy, bool angleInDegrees = false, Stream& stream = Stream::Null());
//! @} cudaarithm_elem
//! @addtogroup cudaarithm_core

@ -52,8 +52,10 @@
#include "opencv2/cudev.hpp"
#include "opencv2/core/private.cuda.hpp"
using namespace cv;
using namespace cv::cuda;
//do not use implicit cv::cuda to avoid clash of tuples from ::cuda::std
/*using namespace cv;
using namespace cv::cuda;*/
using namespace cv::cudev;
void cv::cuda::magnitude(InputArray _x, InputArray _y, OutputArray _dst, Stream& stream)
@ -66,11 +68,7 @@ void cv::cuda::magnitude(InputArray _x, InputArray _y, OutputArray _dst, Stream&
GpuMat dst = getOutputMat(_dst, x.size(), CV_32FC1, stream);
GpuMat_<float> xc(x.reshape(1));
GpuMat_<float> yc(y.reshape(1));
GpuMat_<float> magc(dst.reshape(1));
gridTransformBinary(xc, yc, magc, magnitude_func<float>(), stream);
gridTransformBinary(globPtr<float>(x), globPtr<float>(y), globPtr<float>(dst), magnitude_func<float>(), stream);
syncOutput(dst, _dst, stream);
}
@ -85,11 +83,7 @@ void cv::cuda::magnitudeSqr(InputArray _x, InputArray _y, OutputArray _dst, Stre
GpuMat dst = getOutputMat(_dst, x.size(), CV_32FC1, stream);
GpuMat_<float> xc(x.reshape(1));
GpuMat_<float> yc(y.reshape(1));
GpuMat_<float> magc(dst.reshape(1));
gridTransformBinary(xc, yc, magc, magnitude_sqr_func<float>(), stream);
gridTransformBinary(globPtr<float>(x), globPtr<float>(y), globPtr<float>(dst), magnitude_sqr_func<float>(), stream);
syncOutput(dst, _dst, stream);
}
@ -104,14 +98,26 @@ void cv::cuda::phase(InputArray _x, InputArray _y, OutputArray _dst, bool angleI
GpuMat dst = getOutputMat(_dst, x.size(), CV_32FC1, stream);
GpuMat_<float> xc(x.reshape(1));
GpuMat_<float> yc(y.reshape(1));
GpuMat_<float> anglec(dst.reshape(1));
if (angleInDegrees)
gridTransformBinary(globPtr<float>(x), globPtr<float>(y), globPtr<float>(dst), direction_func<float, true>(), stream);
else
gridTransformBinary(globPtr<float>(x), globPtr<float>(y), globPtr<float>(dst), direction_func<float, false>(), stream);
syncOutput(dst, _dst, stream);
}
void cv::cuda::phase(InputArray _xy, OutputArray _dst, bool angleInDegrees, Stream& stream)
{
GpuMat xy = getInputMat(_xy, stream);
CV_Assert( xy.type() == CV_32FC2 );
GpuMat dst = getOutputMat(_dst, xy.size(), CV_32FC1, stream);
if (angleInDegrees)
gridTransformBinary(xc, yc, anglec, direction_func<float, true>(), stream);
gridTransformUnary(globPtr<float2>(xy), globPtr<float>(dst), direction_interleaved_func<float2, true>(), stream);
else
gridTransformBinary(xc, yc, anglec, direction_func<float, false>(), stream);
gridTransformUnary(globPtr<float2>(xy), globPtr<float>(dst), direction_interleaved_func<float2, false>(), stream);
syncOutput(dst, _dst, stream);
}
@ -127,10 +133,10 @@ void cv::cuda::cartToPolar(InputArray _x, InputArray _y, OutputArray _mag, Outpu
GpuMat mag = getOutputMat(_mag, x.size(), CV_32FC1, stream);
GpuMat angle = getOutputMat(_angle, x.size(), CV_32FC1, stream);
GpuMat_<float> xc(x.reshape(1));
GpuMat_<float> yc(y.reshape(1));
GpuMat_<float> magc(mag.reshape(1));
GpuMat_<float> anglec(angle.reshape(1));
GpuMat_<float> xc(x);
GpuMat_<float> yc(y);
GpuMat_<float> magc(mag);
GpuMat_<float> anglec(angle);
if (angleInDegrees)
gridTransformBinary(xc, yc, magc, anglec, magnitude_func<float>(), direction_func<float, true>(), stream);
@ -141,6 +147,69 @@ void cv::cuda::cartToPolar(InputArray _x, InputArray _y, OutputArray _mag, Outpu
syncOutput(angle, _angle, stream);
}
void cv::cuda::cartToPolar(InputArray _xy, OutputArray _mag, OutputArray _angle, bool angleInDegrees, Stream& stream)
{
GpuMat xy = getInputMat(_xy, stream);
CV_Assert( xy.type() == CV_32FC2 );
GpuMat mag = getOutputMat(_mag, xy.size(), CV_32FC1, stream);
GpuMat angle = getOutputMat(_angle, xy.size(), CV_32FC1, stream);
GpuMat_<float> magc(mag);
GpuMat_<float> anglec(angle);
if (angleInDegrees)
{
auto f1 = magnitude_interleaved_func<float2>();
auto f2 = direction_interleaved_func<float2, true>();
cv::cudev::tuple<decltype(f1), decltype(f2)> f12 = cv::cudev::make_tuple(f1, f2);
gridTransformTuple(globPtr<float2>(xy),
tie(magc, anglec),
f12,
stream);
}
else
{
auto f1 = magnitude_interleaved_func<float2>();
auto f2 = direction_interleaved_func<float2, false>();
cv::cudev::tuple<decltype(f1), decltype(f2)> f12 = cv::cudev::make_tuple(f1, f2);
gridTransformTuple(globPtr<float2>(xy),
tie(magc, anglec),
f12,
stream);
}
syncOutput(mag, _mag, stream);
syncOutput(angle, _angle, stream);
}
void cv::cuda::cartToPolar(InputArray _xy, OutputArray _magAngle, bool angleInDegrees, Stream& stream)
{
GpuMat xy = getInputMat(_xy, stream);
CV_Assert( xy.type() == CV_32FC2 );
GpuMat magAngle = getOutputMat(_magAngle, xy.size(), CV_32FC2, stream);
if (angleInDegrees)
{
gridTransformUnary(globPtr<float2>(xy),
globPtr<float2>(magAngle),
magnitude_direction_interleaved_func<float2, true>(),
stream);
}
else
{
gridTransformUnary(globPtr<float2>(xy),
globPtr<float2>(magAngle),
magnitude_direction_interleaved_func<float2, false>(),
stream);
}
syncOutput(magAngle, _magAngle, stream);
}
namespace
{
template <typename T> struct sincos_op
@ -159,12 +228,12 @@ namespace
};
template <typename T, bool useMag>
__global__ void polarToCartImpl_(const GlobPtr<T> mag, const GlobPtr<T> angle, GlobPtr<T> xmat, GlobPtr<T> ymat, const T scale, const int rows, const int cols)
__global__ void polarToCartImpl_(const PtrStep<T> mag, const PtrStepSz<T> angle, PtrStep<T> xmat, PtrStep<T> ymat, const T scale)
{
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
if (x >= cols || y >= rows)
if (x >= angle.cols || y >= angle.rows)
return;
const T mag_val = useMag ? mag(y, x) : static_cast<T>(1.0);
@ -178,23 +247,90 @@ namespace
ymat(y, x) = mag_val * sin_a;
}
template <typename T, bool useMag>
__global__ void polarToCartDstInterleavedImpl_(const PtrStep<T> mag, const PtrStepSz<T> angle, PtrStep<typename MakeVec<T, 2>::type > xymat, const T scale)
{
typedef typename MakeVec<T, 2>::type T2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
if (x >= angle.cols || y >= angle.rows)
return;
const T mag_val = useMag ? mag(y, x) : static_cast<T>(1.0);
const T angle_val = angle(y, x);
T sin_a, cos_a;
sincos_op<T> op;
op(scale * angle_val, &sin_a, &cos_a);
const T2 xy = {mag_val * cos_a, mag_val * sin_a};
xymat(y, x) = xy;
}
template <typename T>
__global__ void polarToCartInterleavedImpl_(const PtrStepSz<typename MakeVec<T, 2>::type > magAngle, PtrStep<typename MakeVec<T, 2>::type > xymat, const T scale)
{
typedef typename MakeVec<T, 2>::type T2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
if (x >= magAngle.cols || y >= magAngle.rows)
return;
const T2 magAngle_val = magAngle(y, x);
const T mag_val = magAngle_val.x;
const T angle_val = magAngle_val.y;
T sin_a, cos_a;
sincos_op<T> op;
op(scale * angle_val, &sin_a, &cos_a);
const T2 xy = {mag_val * cos_a, mag_val * sin_a};
xymat(y, x) = xy;
}
template <typename T>
void polarToCartImpl(const GpuMat& mag, const GpuMat& angle, GpuMat& x, GpuMat& y, bool angleInDegrees, cudaStream_t& stream)
{
GpuMat_<T> xc(x.reshape(1));
GpuMat_<T> yc(y.reshape(1));
GpuMat_<T> magc(mag.reshape(1));
GpuMat_<T> anglec(angle.reshape(1));
const dim3 block(32, 8);
const dim3 grid(divUp(angle.cols, block.x), divUp(angle.rows, block.y));
const T scale = angleInDegrees ? static_cast<T>(CV_PI / 180.0) : static_cast<T>(1.0);
if (mag.empty())
polarToCartImpl_<T, false> << <grid, block, 0, stream >> >(mag, angle, x, y, scale);
else
polarToCartImpl_<T, true> << <grid, block, 0, stream >> >(mag, angle, x, y, scale);
}
template <typename T>
void polarToCartDstInterleavedImpl(const GpuMat& mag, const GpuMat& angle, GpuMat& xy, bool angleInDegrees, cudaStream_t& stream)
{
typedef typename MakeVec<T, 2>::type T2;
const dim3 block(32, 8);
const dim3 grid(divUp(anglec.cols, block.x), divUp(anglec.rows, block.y));
const dim3 grid(divUp(angle.cols, block.x), divUp(angle.rows, block.y));
const T scale = angleInDegrees ? static_cast<T>(CV_PI / 180.0) : static_cast<T>(1.0);
if (magc.empty())
polarToCartImpl_<T, false> << <grid, block, 0, stream >> >(shrinkPtr(magc), shrinkPtr(anglec), shrinkPtr(xc), shrinkPtr(yc), scale, anglec.rows, anglec.cols);
if (mag.empty())
polarToCartDstInterleavedImpl_<T, false> << <grid, block, 0, stream >> >(mag, angle, xy, scale);
else
polarToCartImpl_<T, true> << <grid, block, 0, stream >> >(shrinkPtr(magc), shrinkPtr(anglec), shrinkPtr(xc), shrinkPtr(yc), scale, anglec.rows, anglec.cols);
polarToCartDstInterleavedImpl_<T, true> << <grid, block, 0, stream >> >(mag, angle, xy, scale);
}
template <typename T>
void polarToCartInterleavedImpl(const GpuMat& magAngle, GpuMat& xy, bool angleInDegrees, cudaStream_t& stream)
{
typedef typename MakeVec<T, 2>::type T2;
const dim3 block(32, 8);
const dim3 grid(divUp(magAngle.cols, block.x), divUp(magAngle.rows, block.y));
const T scale = angleInDegrees ? static_cast<T>(CV_PI / 180.0) : static_cast<T>(1.0);
polarToCartInterleavedImpl_<T> << <grid, block, 0, stream >> >(magAngle, xy, scale);
}
}
@ -223,4 +359,48 @@ void cv::cuda::polarToCart(InputArray _mag, InputArray _angle, OutputArray _x, O
CV_CUDEV_SAFE_CALL( cudaDeviceSynchronize() );
}
void cv::cuda::polarToCart(InputArray _mag, InputArray _angle, OutputArray _xy, bool angleInDegrees, Stream& _stream)
{
typedef void(*func_t)(const GpuMat& mag, const GpuMat& angle, GpuMat& xy, bool angleInDegrees, cudaStream_t& stream);
static const func_t funcs[7] = { 0, 0, 0, 0, 0, polarToCartDstInterleavedImpl<float>, polarToCartDstInterleavedImpl<double> };
GpuMat mag = getInputMat(_mag, _stream);
GpuMat angle = getInputMat(_angle, _stream);
CV_Assert(angle.depth() == CV_32F || angle.depth() == CV_64F);
CV_Assert( mag.empty() || (mag.type() == angle.type() && mag.size() == angle.size()) );
GpuMat xy = getOutputMat(_xy, angle.size(), CV_MAKETYPE(angle.depth(), 2), _stream);
cudaStream_t stream = StreamAccessor::getStream(_stream);
funcs[angle.depth()](mag, angle, xy, angleInDegrees, stream);
CV_CUDEV_SAFE_CALL( cudaGetLastError() );
syncOutput(xy, _xy, _stream);
if (stream == 0)
CV_CUDEV_SAFE_CALL( cudaDeviceSynchronize() );
}
void cv::cuda::polarToCart(InputArray _magAngle, OutputArray _xy, bool angleInDegrees, Stream& _stream)
{
typedef void(*func_t)(const GpuMat& magAngle, GpuMat& xy, bool angleInDegrees, cudaStream_t& stream);
static const func_t funcs[7] = { 0, 0, 0, 0, 0, polarToCartInterleavedImpl<float>, polarToCartInterleavedImpl<double> };
GpuMat magAngle = getInputMat(_magAngle, _stream);
CV_Assert(magAngle.type() == CV_32FC2 || magAngle.type() == CV_64FC2);
GpuMat xy = getOutputMat(_xy, magAngle.size(), magAngle.type(), _stream);
cudaStream_t stream = StreamAccessor::getStream(_stream);
funcs[magAngle.depth()](magAngle, xy, angleInDegrees, stream);
CV_CUDEV_SAFE_CALL( cudaGetLastError() );
syncOutput(xy, _xy, _stream);
if (stream == 0)
CV_CUDEV_SAFE_CALL( cudaDeviceSynchronize() );
}
#endif

@ -2765,6 +2765,47 @@ INSTANTIATE_TEST_CASE_P(CUDA_Arithm, Phase, testing::Combine(
testing::Values(AngleInDegrees(false), AngleInDegrees(true)),
WHOLE_SUBMAT));
PARAM_TEST_CASE(PhaseInterleaved, cv::cuda::DeviceInfo, cv::Size, AngleInDegrees, UseRoi)
{
cv::cuda::DeviceInfo devInfo;
cv::Size size;
bool angleInDegrees;
bool useRoi;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
size = GET_PARAM(1);
angleInDegrees = GET_PARAM(2);
useRoi = GET_PARAM(3);
cv::cuda::setDevice(devInfo.deviceID());
}
};
CUDA_TEST_P(PhaseInterleaved, Accuracy)
{
cv::Mat x = randomMat(size, CV_32FC1);
cv::Mat y = randomMat(size, CV_32FC1);
cv::Mat xy;
std::vector<cv::Mat> xyChannels = {x, y};
cv::merge(xyChannels, xy);
cv::cuda::GpuMat dst = createMat(size, CV_32FC1, useRoi);
cv::cuda::phase(loadMat(xy, useRoi), dst, angleInDegrees);
cv::Mat dst_gold;
cv::phase(x, y, dst_gold, angleInDegrees);
EXPECT_MAT_NEAR(dst_gold, dst, angleInDegrees ? 1e-2 : 1e-3);
}
INSTANTIATE_TEST_CASE_P(CUDA_Arithm, PhaseInterleaved, testing::Combine(
ALL_DEVICES,
DIFFERENT_SIZES,
testing::Values(AngleInDegrees(false), AngleInDegrees(true)),
WHOLE_SUBMAT));
////////////////////////////////////////////////////////////////////////////////
// CartToPolar
@ -2809,6 +2850,97 @@ INSTANTIATE_TEST_CASE_P(CUDA_Arithm, CartToPolar, testing::Combine(
testing::Values(AngleInDegrees(false), AngleInDegrees(true)),
WHOLE_SUBMAT));
PARAM_TEST_CASE(CartToPolarInterleavedXY, cv::cuda::DeviceInfo, cv::Size, AngleInDegrees, UseRoi)
{
cv::cuda::DeviceInfo devInfo;
cv::Size size;
bool angleInDegrees;
bool useRoi;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
size = GET_PARAM(1);
angleInDegrees = GET_PARAM(2);
useRoi = GET_PARAM(3);
cv::cuda::setDevice(devInfo.deviceID());
}
};
CUDA_TEST_P(CartToPolarInterleavedXY, Accuracy)
{
cv::Mat x = randomMat(size, CV_32FC1);
cv::Mat y = randomMat(size, CV_32FC1);
cv::Mat xy;
std::vector<cv::Mat> xyChannels = {x, y};
cv::merge(xyChannels, xy);
cv::cuda::GpuMat mag = createMat(size, CV_32FC1, useRoi);
cv::cuda::GpuMat angle = createMat(size, CV_32FC1, useRoi);
cv::cuda::cartToPolar(loadMat(xy, useRoi), mag, angle, angleInDegrees);
cv::Mat mag_gold;
cv::Mat angle_gold;
cv::cartToPolar(x, y, mag_gold, angle_gold, angleInDegrees);
EXPECT_MAT_NEAR(mag_gold, mag, 1e-4);
EXPECT_MAT_NEAR(angle_gold, angle, angleInDegrees ? 1e-2 : 1e-3);
}
INSTANTIATE_TEST_CASE_P(CUDA_Arithm, CartToPolarInterleavedXY, testing::Combine(
ALL_DEVICES,
DIFFERENT_SIZES,
testing::Values(AngleInDegrees(false), AngleInDegrees(true)),
WHOLE_SUBMAT));
PARAM_TEST_CASE(CartToPolarInterleavedXYMagAngle, cv::cuda::DeviceInfo, cv::Size, AngleInDegrees, UseRoi)
{
cv::cuda::DeviceInfo devInfo;
cv::Size size;
bool angleInDegrees;
bool useRoi;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
size = GET_PARAM(1);
angleInDegrees = GET_PARAM(2);
useRoi = GET_PARAM(3);
cv::cuda::setDevice(devInfo.deviceID());
}
};
CUDA_TEST_P(CartToPolarInterleavedXYMagAngle, Accuracy)
{
cv::Mat x = randomMat(size, CV_32FC1);
cv::Mat y = randomMat(size, CV_32FC1);
cv::Mat xy;
std::vector<cv::Mat> xyChannels = {x, y};
cv::merge(xyChannels, xy);
cv::cuda::GpuMat magAngle = createMat(size, CV_32FC2, useRoi);
cv::cuda::cartToPolar(loadMat(xy, useRoi), magAngle, angleInDegrees);
std::vector<cv::cuda::GpuMat> magAngleChannels;
cv::cuda::split(magAngle, magAngleChannels);
cv::cuda::GpuMat& mag = magAngleChannels[0];
cv::cuda::GpuMat& angle = magAngleChannels[1];
cv::Mat mag_gold;
cv::Mat angle_gold;
cv::cartToPolar(x, y, mag_gold, angle_gold, angleInDegrees);
EXPECT_MAT_NEAR(mag_gold, mag, 1e-4);
EXPECT_MAT_NEAR(angle_gold, angle, angleInDegrees ? 1e-2 : 1e-3);
}
INSTANTIATE_TEST_CASE_P(CUDA_Arithm, CartToPolarInterleavedXYMagAngle, testing::Combine(
ALL_DEVICES,
DIFFERENT_SIZES,
testing::Values(AngleInDegrees(false), AngleInDegrees(true)),
WHOLE_SUBMAT));
////////////////////////////////////////////////////////////////////////////////
// polarToCart
@ -2857,5 +2989,104 @@ INSTANTIATE_TEST_CASE_P(CUDA_Arithm, PolarToCart, testing::Combine(
testing::Values(AngleInDegrees(false), AngleInDegrees(true)),
WHOLE_SUBMAT));
PARAM_TEST_CASE(PolarToCartInterleaveXY, cv::cuda::DeviceInfo, cv::Size, MatType, AngleInDegrees, UseRoi)
{
cv::cuda::DeviceInfo devInfo;
cv::Size size;
int type;
bool angleInDegrees;
bool useRoi;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
size = GET_PARAM(1);
type = GET_PARAM(2);
angleInDegrees = GET_PARAM(3);
useRoi = GET_PARAM(4);
cv::cuda::setDevice(devInfo.deviceID());
}
};
CUDA_TEST_P(PolarToCartInterleaveXY, Accuracy)
{
cv::Mat magnitude = randomMat(size, type);
cv::Mat angle = randomMat(size, type);
const double tol = (type == CV_32FC1 ? 1.6e-4 : 1e-4) * (angleInDegrees ? 1.0 : 19.47);
cv::cuda::GpuMat xy = createMat(size, CV_MAKETYPE(CV_MAT_DEPTH(type), 2), useRoi);
cv::cuda::polarToCart(loadMat(magnitude, useRoi), loadMat(angle, useRoi), xy, angleInDegrees);
std::vector<cv::cuda::GpuMat> xyChannels;
cv::cuda::split(xy, xyChannels);
cv::cuda::GpuMat& x = xyChannels[0];
cv::cuda::GpuMat& y = xyChannels[1];
cv::Mat x_gold;
cv::Mat y_gold;
cv::polarToCart(magnitude, angle, x_gold, y_gold, angleInDegrees);
EXPECT_MAT_NEAR(x_gold, x, tol);
EXPECT_MAT_NEAR(y_gold, y, tol);
}
INSTANTIATE_TEST_CASE_P(CUDA_Arithm, PolarToCartInterleaveXY, testing::Combine(
ALL_DEVICES,
DIFFERENT_SIZES,
testing::Values(CV_32FC1, CV_64FC1),
testing::Values(AngleInDegrees(false), AngleInDegrees(true)),
WHOLE_SUBMAT));
PARAM_TEST_CASE(PolarToCartInterleaveMagAngleXY, cv::cuda::DeviceInfo, cv::Size, MatType, AngleInDegrees, UseRoi)
{
cv::cuda::DeviceInfo devInfo;
cv::Size size;
int type;
bool angleInDegrees;
bool useRoi;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
size = GET_PARAM(1);
type = GET_PARAM(2);
angleInDegrees = GET_PARAM(3);
useRoi = GET_PARAM(4);
cv::cuda::setDevice(devInfo.deviceID());
}
};
CUDA_TEST_P(PolarToCartInterleaveMagAngleXY, Accuracy)
{
cv::Mat magnitude = randomMat(size, type);
cv::Mat angle = randomMat(size, type);
std::vector<cv::Mat> magAngleChannels = {magnitude, angle};
cv::Mat magAngle;
cv::merge(magAngleChannels, magAngle);
const double tol = (type == CV_32FC1 ? 1.6e-4 : 1e-4) * (angleInDegrees ? 1.0 : 19.47);
cv::cuda::GpuMat xy = createMat(size, CV_MAKETYPE(CV_MAT_DEPTH(type), 2), useRoi);
cv::cuda::polarToCart(loadMat(magAngle, useRoi), xy, angleInDegrees);
std::vector<cv::cuda::GpuMat> xyChannels;
cv::cuda::split(xy, xyChannels);
cv::cuda::GpuMat& x = xyChannels[0];
cv::cuda::GpuMat& y = xyChannels[1];
cv::Mat x_gold;
cv::Mat y_gold;
cv::polarToCart(magnitude, angle, x_gold, y_gold, angleInDegrees);
EXPECT_MAT_NEAR(x_gold, x, tol);
EXPECT_MAT_NEAR(y_gold, y, tol);
}
INSTANTIATE_TEST_CASE_P(CUDA_Arithm, PolarToCartInterleaveMagAngleXY, testing::Combine(
ALL_DEVICES,
DIFFERENT_SIZES,
testing::Values(CV_32FC1, CV_64FC1),
testing::Values(AngleInDegrees(false), AngleInDegrees(true)),
WHOLE_SUBMAT));
}} // namespace
#endif // HAVE_CUDA

@ -70,6 +70,8 @@ namespace cv { namespace cuda {
@param ymap Y values. Only CV_32FC1 type is supported.
@param interpolation Interpolation method (see resize ). INTER_NEAREST , INTER_LINEAR and
INTER_CUBIC are supported for now.
The extra flag WARP_RELATIVE_MAP can be ORed to the interpolation method
(e.g. INTER_LINEAR | WARP_RELATIVE_MAP)
@param borderMode Pixel extrapolation method (see borderInterpolate ). BORDER_REFLECT101 ,
BORDER_REPLICATE , BORDER_CONSTANT , BORDER_REFLECT and BORDER_WRAP are supported for now.
@param borderValue Value used in case of a constant border. By default, it is 0.
@ -79,6 +81,10 @@ The function transforms the source image using the specified map:
\f[\texttt{dst} (x,y) = \texttt{src} (xmap(x,y), ymap(x,y))\f]
with the WARP_RELATIVE_MAP flag :
\f[\texttt{dst} (x,y) = \texttt{src} (x+map_x(x,y),y+map_y(x,y))\f]
Values of pixels with non-integer coordinates are computed using the bilinear interpolation.
@sa remap

@ -68,9 +68,23 @@ namespace cv { namespace cuda { namespace device
}
}
template <typename Ptr2D, typename T> __global__ void remap_relative(const Ptr2D src, const PtrStepf mapx, const PtrStepf mapy, PtrStepSz<T> dst)
{
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
if (x < dst.cols && y < dst.rows)
{
const float xcoo = x+mapx.ptr(y)[x];
const float ycoo = y+mapy.ptr(y)[x];
dst.ptr(y)[x] = saturate_cast<T>(src(ycoo, xcoo));
}
}
template <template <typename> class Filter, template <typename> class B, typename T> struct RemapDispatcherStream
{
static void call(PtrStepSz<T> src, PtrStepSzf mapx, PtrStepSzf mapy, PtrStepSz<T> dst, const float* borderValue, cudaStream_t stream, bool)
static void call(PtrStepSz<T> src, PtrStepSzf mapx, PtrStepSzf mapy, PtrStepSz<T> dst, const float* borderValue, cudaStream_t stream, bool, bool isRelative)
{
typedef typename TypeVec<float, VecTraits<T>::cn>::vec_type work_type;
@ -81,14 +95,17 @@ namespace cv { namespace cuda { namespace device
BorderReader<PtrStep<T>, B<work_type>> brdSrc(src, brd);
Filter<BorderReader<PtrStep<T>, B<work_type>>> filter_src(brdSrc);
remap<<<grid, block, 0, stream>>>(filter_src, mapx, mapy, dst);
if (isRelative)
remap_relative<<<grid, block, 0, stream>>>(filter_src, mapx, mapy, dst);
else
remap<<<grid, block, 0, stream>>>(filter_src, mapx, mapy, dst);
cudaSafeCall( cudaGetLastError() );
}
};
template <template <typename> class Filter, template <typename> class B, typename T> struct RemapDispatcherNonStream
{
static void call(PtrStepSz<T> src, PtrStepSz<T> srcWhole, int xoff, int yoff, PtrStepSzf mapx, PtrStepSzf mapy, PtrStepSz<T> dst, const float* borderValue, bool)
static void call(PtrStepSz<T> src, PtrStepSz<T> srcWhole, int xoff, int yoff, PtrStepSzf mapx, PtrStepSzf mapy, PtrStepSz<T> dst, const float* borderValue, bool, bool isRelative)
{
CV_UNUSED(srcWhole);
CV_UNUSED(xoff);
@ -102,7 +119,10 @@ namespace cv { namespace cuda { namespace device
BorderReader<PtrStep<T>, B<work_type>> brdSrc(src, brd);
Filter<BorderReader<PtrStep<T>, B<work_type>>> filter_src(brdSrc);
remap<<<grid, block>>>(filter_src, mapx, mapy, dst);
if (isRelative)
remap_relative<<<grid, block>>>(filter_src, mapx, mapy, dst);
else
remap<<<grid, block>>>(filter_src, mapx, mapy, dst);
cudaSafeCall( cudaGetLastError() );
cudaSafeCall( cudaDeviceSynchronize() );
@ -112,7 +132,7 @@ namespace cv { namespace cuda { namespace device
template <template <typename> class Filter, template <typename> class B, typename T> struct RemapDispatcherNonStreamTex
{
static void call(PtrStepSz< T > src, PtrStepSz< T > srcWhole, int xoff, int yoff, PtrStepSzf mapx, PtrStepSzf mapy,
PtrStepSz< T > dst, const float* borderValue, bool cc20)
PtrStepSz< T > dst, const float* borderValue, bool cc20, bool isRelative)
{
typedef typename TypeVec<float, VecTraits< T >::cn>::vec_type work_type;
dim3 block(32, cc20 ? 8 : 4);
@ -123,7 +143,10 @@ namespace cv { namespace cuda { namespace device
B<work_type> brd(src.rows, src.cols, VecTraits<work_type>::make(borderValue));
BorderReader<cudev::TexturePtr<T>, B<work_type>> brdSrc(texSrcWhole, brd);
Filter<BorderReader<cudev::TexturePtr<T>, B<work_type>>> filter_src(brdSrc);
remap<<<grid, block>>>(filter_src, mapx, mapy, dst);
if (isRelative)
remap_relative<<<grid, block>>>(filter_src, mapx, mapy, dst);
else
remap<<<grid, block>>>(filter_src, mapx, mapy, dst);
}
else {
@ -131,7 +154,10 @@ namespace cv { namespace cuda { namespace device
B<work_type> brd(src.rows, src.cols, VecTraits<work_type>::make(borderValue));
BorderReader<cudev::TextureOffPtr<T>, B<work_type>> brdSrc(texSrcWhole, brd);
Filter<BorderReader<cudev::TextureOffPtr<T>, B<work_type>>> filter_src(brdSrc);
remap<<<grid, block >>>(filter_src, mapx, mapy, dst);
if (isRelative)
remap_relative<<<grid, block>>>(filter_src, mapx, mapy, dst);
else
remap<<<grid, block >>>(filter_src, mapx, mapy, dst);
}
cudaSafeCall( cudaGetLastError() );
@ -142,7 +168,7 @@ namespace cv { namespace cuda { namespace device
template <template <typename> class Filter, typename T> struct RemapDispatcherNonStreamTex<Filter, BrdReplicate, T>
{
static void call(PtrStepSz< T > src, PtrStepSz< T > srcWhole, int xoff, int yoff, PtrStepSzf mapx, PtrStepSzf mapy,
PtrStepSz< T > dst, const float*, bool)
PtrStepSz< T > dst, const float*, bool, bool isRelative)
{
dim3 block(32, 8);
dim3 grid(divUp(dst.cols, block.x), divUp(dst.rows, block.y));
@ -150,7 +176,10 @@ namespace cv { namespace cuda { namespace device
{
cudev::Texture<T> texSrcWhole(srcWhole);
Filter<cudev::TexturePtr<T>> filter_src(texSrcWhole);
remap<<<grid, block>>>(filter_src, mapx, mapy, dst);
if (isRelative)
remap_relative<<<grid, block>>>(filter_src, mapx, mapy, dst);
else
remap<<<grid, block>>>(filter_src, mapx, mapy, dst);
}
else
{
@ -158,7 +187,10 @@ namespace cv { namespace cuda { namespace device
BrdReplicate<T> brd(src.rows, src.cols);
BorderReader<cudev::TextureOffPtr<T>, BrdReplicate<T>> brdSrc(texSrcWhole, brd);
Filter<BorderReader<cudev::TextureOffPtr<T>, BrdReplicate<T>>> filter_src(brdSrc);
remap<<<grid, block>>>(filter_src, mapx, mapy, dst);
if (isRelative)
remap_relative<<<grid, block>>>(filter_src, mapx, mapy, dst);
else
remap<<<grid, block>>>(filter_src, mapx, mapy, dst);
}
cudaSafeCall( cudaGetLastError() );
cudaSafeCall( cudaDeviceSynchronize() );
@ -203,20 +235,20 @@ namespace cv { namespace cuda { namespace device
template <template <typename> class Filter, template <typename> class B, typename T> struct RemapDispatcher
{
static void call(PtrStepSz<T> src, PtrStepSz<T> srcWhole, int xoff, int yoff, PtrStepSzf mapx, PtrStepSzf mapy,
PtrStepSz<T> dst, const float* borderValue, cudaStream_t stream, bool cc20)
PtrStepSz<T> dst, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative)
{
if (stream == 0)
RemapDispatcherNonStream<Filter, B, T>::call(src, srcWhole, xoff, yoff, mapx, mapy, dst, borderValue, cc20);
RemapDispatcherNonStream<Filter, B, T>::call(src, srcWhole, xoff, yoff, mapx, mapy, dst, borderValue, cc20, isRelative);
else
RemapDispatcherStream<Filter, B, T>::call(src, mapx, mapy, dst, borderValue, stream, cc20);
RemapDispatcherStream<Filter, B, T>::call(src, mapx, mapy, dst, borderValue, stream, cc20, isRelative);
}
};
template <typename T> void remap_gpu(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap,
PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20)
PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative)
{
typedef void (*caller_t)(PtrStepSz<T> src, PtrStepSz<T> srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap,
PtrStepSz<T> dst, const float* borderValue, cudaStream_t stream, bool cc20);
PtrStepSz<T> dst, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
static const caller_t callers[3][5] =
{
@ -244,24 +276,24 @@ namespace cv { namespace cuda { namespace device
};
callers[interpolation][borderMode](static_cast<PtrStepSz<T>>(src), static_cast<PtrStepSz<T>>(srcWhole), xoff, yoff, xmap, ymap,
static_cast<PtrStepSz<T>>(dst), borderValue, stream, cc20);
static_cast<PtrStepSz<T>>(dst), borderValue, stream, cc20, isRelative);
}
template void remap_gpu<uchar >(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
template void remap_gpu<uchar3>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
template void remap_gpu<uchar4>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
template void remap_gpu<uchar >(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
template void remap_gpu<uchar3>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
template void remap_gpu<uchar4>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
template void remap_gpu<ushort >(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
template void remap_gpu<ushort3>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
template void remap_gpu<ushort4>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
template void remap_gpu<ushort >(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
template void remap_gpu<ushort3>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
template void remap_gpu<ushort4>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
template void remap_gpu<short >(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
template void remap_gpu<short3>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
template void remap_gpu<short4>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
template void remap_gpu<short >(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
template void remap_gpu<short3>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
template void remap_gpu<short4>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
template void remap_gpu<float >(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
template void remap_gpu<float3>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
template void remap_gpu<float4>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
template void remap_gpu<float >(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
template void remap_gpu<float3>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
template void remap_gpu<float4>(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
} // namespace imgproc
}}} // namespace cv { namespace cuda { namespace cudev

@ -54,7 +54,7 @@ namespace cv { namespace cuda { namespace device
{
template <typename T>
void remap_gpu(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst,
int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
int interpolation, int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
}
}}}
@ -62,8 +62,11 @@ void cv::cuda::remap(InputArray _src, OutputArray _dst, InputArray _xmap, InputA
{
using namespace cv::cuda::device::imgproc;
const bool hasRelativeFlag = ((interpolation & WARP_RELATIVE_MAP) != 0);
interpolation &= ~WARP_RELATIVE_MAP;
typedef void (*func_t)(PtrStepSzb src, PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzf xmap, PtrStepSzf ymap, PtrStepSzb dst, int interpolation,
int borderMode, const float* borderValue, cudaStream_t stream, bool cc20);
int borderMode, const float* borderValue, cudaStream_t stream, bool cc20, bool isRelative);
static const func_t funcs[6][4] =
{
{remap_gpu<uchar> , 0 /*remap_gpu<uchar2>*/ , remap_gpu<uchar3> , remap_gpu<uchar4> },
@ -98,7 +101,7 @@ void cv::cuda::remap(InputArray _src, OutputArray _dst, InputArray _xmap, InputA
src.locateROI(wholeSize, ofs);
func(src, PtrStepSzb(wholeSize.height, wholeSize.width, src.datastart, src.step), ofs.x, ofs.y, xmap, ymap,
dst, interpolation, borderMode, borderValueFloat.val, StreamAccessor::getStream(stream), deviceSupports(FEATURE_SET_COMPUTE_20));
dst, interpolation, borderMode, borderValueFloat.val, StreamAccessor::getStream(stream), deviceSupports(FEATURE_SET_COMPUTE_20), hasRelativeFlag);
}
#endif // HAVE_CUDA

@ -204,5 +204,82 @@ INSTANTIATE_TEST_CASE_P(CUDA_Warping, RemapOutOfScope, testing::Combine(
testing::Values(BorderType(cv::BORDER_CONSTANT)),
WHOLE_SUBMAT));
PARAM_TEST_CASE(RemapRelative, cv::cuda::DeviceInfo, MatType, Interpolation, BorderType)
{
cv::cuda::DeviceInfo devInfo;
int type;
int interpolation;
int borderType;
cv::cuda::GpuMat gSrc;
cv::cuda::GpuMat gMapRelativeX32F;
cv::cuda::GpuMat gMapRelativeY32F;
cv::cuda::GpuMat gMapAbsoluteX32F;
cv::cuda::GpuMat gMapAbsoluteY32F;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
type = GET_PARAM(1);
interpolation = GET_PARAM(2);
borderType = GET_PARAM(3);
cv::cuda::setDevice(devInfo.deviceID());
const int nChannels = CV_MAT_CN(type);
const cv::Size size(127, 61);
cv::Mat data64FC1(1, size.area()*nChannels, CV_64FC1);
data64FC1.forEach<double>([&](double& pixel, const int* position) {pixel = static_cast<double>(position[1]);});
cv::Mat src;
data64FC1.reshape(nChannels, size.height).convertTo(src, type);
cv::Mat mapRelativeX32F(size, CV_32FC1);
mapRelativeX32F.setTo(cv::Scalar::all(-0.33));
cv::Mat mapRelativeY32F(size, CV_32FC1);
mapRelativeY32F.setTo(cv::Scalar::all(-0.33));
cv::Mat mapAbsoluteX32F = mapRelativeX32F.clone();
mapAbsoluteX32F.forEach<float>([&](float& pixel, const int* position) {
pixel += static_cast<float>(position[1]);
});
cv::Mat mapAbsoluteY32F = mapRelativeY32F.clone();
mapAbsoluteY32F.forEach<float>([&](float& pixel, const int* position) {
pixel += static_cast<float>(position[0]);
});
gSrc.upload(src);
gMapRelativeX32F.upload(mapRelativeX32F);
gMapRelativeY32F.upload(mapRelativeY32F);
gMapAbsoluteX32F.upload(mapAbsoluteX32F);
gMapAbsoluteY32F.upload(mapAbsoluteY32F);
}
};
CUDA_TEST_P(RemapRelative, RemapRelative_Validity)
{
cv::cuda::GpuMat gDstAbsolute;
cv::cuda::remap(gSrc, gDstAbsolute, gMapAbsoluteX32F, gMapAbsoluteY32F, interpolation, borderType);
cv::cuda::GpuMat gDstRelative;
cv::cuda::remap(gSrc, gDstRelative, gMapRelativeX32F, gMapRelativeY32F, interpolation | WARP_RELATIVE_MAP, borderType);
cv::Mat dstAbsolute;
gDstAbsolute.download(dstAbsolute);
cv::Mat dstRelative;
gDstRelative.download(dstRelative);
EXPECT_MAT_NEAR(dstAbsolute, dstRelative, (dstAbsolute.depth() == CV_32F) ? 1e-3 : 1.0);
}
INSTANTIATE_TEST_CASE_P(CUDA_RemapRelative, RemapRelative, testing::Combine(
ALL_DEVICES,
testing::Values(MatType(CV_8UC1), MatType(CV_8UC3), MatType(CV_8UC4),
MatType(CV_16UC1), MatType(CV_16UC3), MatType(CV_16UC4),
MatType(CV_16SC1), MatType(CV_16SC3), MatType(CV_16SC4),
MatType(CV_32FC1), MatType(CV_32FC3), MatType(CV_32FC4)),
testing::Values(Interpolation(cv::INTER_NEAREST), Interpolation(cv::INTER_LINEAR), Interpolation(cv::INTER_CUBIC)),
testing::Values(BorderType(cv::BORDER_REFLECT101), BorderType(cv::BORDER_REPLICATE), BorderType(cv::BORDER_CONSTANT))));
}} // namespace
#endif // HAVE_CUDA

@ -619,6 +619,17 @@ template <typename T> struct magnitude_func : binary_function<T, T, typename fun
}
};
template <typename T2> struct magnitude_interleaved_func : unary_function<T2, typename VecTraits<T2>::elem_type>
{
typedef typename VecTraits<T2>::elem_type T;
typedef typename TypeTraits<T2>::parameter_type parameter_type;
__device__ __forceinline__ T operator ()(parameter_type ab) const
{
sqrt_func<typename functional_detail::FloatType<T>::type> f;
return f(ab.x * ab.x + ab.y * ab.y);
}
};
template <typename T> struct magnitude_sqr_func : binary_function<T, T, typename functional_detail::FloatType<T>::type>
{
__device__ __forceinline__ typename functional_detail::FloatType<T>::type operator ()(typename TypeTraits<T>::parameter_type a, typename TypeTraits<T>::parameter_type b) const
@ -627,6 +638,16 @@ template <typename T> struct magnitude_sqr_func : binary_function<T, T, typename
}
};
template <typename T2> struct magnitude_sqr_interleaved_func : unary_function<T2, typename VecTraits<T2>::elem_type>
{
typedef typename VecTraits<T2>::elem_type T;
typedef typename TypeTraits<T2>::parameter_type parameter_type;
__device__ __forceinline__ T operator ()(parameter_type ab) const
{
return ab.x * ab.x + ab.y * ab.y;
}
};
template <typename T, bool angleInDegrees> struct direction_func : binary_function<T, T, T>
{
__device__ T operator ()(T x, T y) const
@ -643,6 +664,35 @@ template <typename T, bool angleInDegrees> struct direction_func : binary_functi
}
};
template <typename T2, bool angleInDegrees> struct direction_interleaved_func : unary_function<T2, typename VecTraits<T2>::elem_type>
{
typedef typename VecTraits<T2>::elem_type T;
__device__ T operator ()(T2 xy) const
{
atan2_func<T> f;
typename atan2_func<T>::result_type angle = f(xy.y, xy.x);
angle += (angle < 0) * (2.0f * CV_PI_F);
if (angleInDegrees)
angle *= (180.0f / CV_PI_F);
return saturate_cast<T>(angle);
}
};
template <typename T2, bool angleInDegrees> struct magnitude_direction_interleaved_func : unary_function<T2, typename VecTraits<T2>::elem_type>
{
typedef typename VecTraits<T2>::elem_type T;
__device__ T2 operator ()(T2 xy) const
{
const T mag = magnitude_interleaved_func<T2>()(xy);
const T angle = direction_interleaved_func<T2, angleInDegrees>()(xy);
const T2 magAngle = {saturate_cast<T>(mag), saturate_cast<T>(angle)};
return magAngle;
}
};
template <typename T> struct pow_func : binary_function<T, float, float>
{
__device__ __forceinline__ float operator ()(T val, float power) const

Loading…
Cancel
Save