added gpu add, subtract, multiply, divide, absdiff with Scalar.

added gpu exp, log, magnitude, based on NPP.
updated setTo with new NPP functions.
minor fix in tests and comments.
pull/13383/head
Vladislav Vinogradov 15 years ago
parent 037002d3c1
commit 51d5959aca
  1. 58
      modules/gpu/include/opencv2/gpu/gpu.hpp
  2. 158
      modules/gpu/src/arithm.cpp
  3. 2
      modules/gpu/src/beliefpropagation_gpu.cpp
  4. 2
      modules/gpu/src/constantspacebp_gpu.cpp
  5. 4
      modules/gpu/src/imgproc_gpu.cpp
  6. 146
      modules/gpu/src/matrix_operations.cpp
  7. 97
      tests/gpu/src/arithm.cpp
  8. 29
      tests/gpu/src/gputest_main.cpp
  9. 8
      tests/gpu/src/operator_convert_to.cpp
  10. 70
      tests/gpu/src/operator_set_to.cpp

@ -348,48 +348,74 @@ namespace cv
////////////////////////////// Arithmetics ///////////////////////////////////
//! adds one matrix to another (c = a + b)
//! supports CV_8UC1, CV_8UC4, CV_32SC1, CV_32FC1 types
CV_EXPORTS void add(const GpuMat& a, const GpuMat& b, GpuMat& c);
//! adds scalar to a matrix (c = a + s)
//! supports only CV_32FC1 type
CV_EXPORTS void add(const GpuMat& a, const Scalar& sc, GpuMat& c);
//! subtracts one matrix from another (c = a - b)
//! supports CV_8UC1, CV_8UC4, CV_32SC1, CV_32FC1 types
CV_EXPORTS void subtract(const GpuMat& a, const GpuMat& b, GpuMat& c);
//! subtracts scalar from a matrix (c = a - s)
//! supports only CV_32FC1 type
CV_EXPORTS void subtract(const GpuMat& a, const Scalar& sc, GpuMat& c);
//! computes element-wise product of the two arrays (c = a * b)
//! supports CV_8UC1, CV_8UC4, CV_32SC1, CV_32FC1 types
CV_EXPORTS void multiply(const GpuMat& a, const GpuMat& b, GpuMat& c);
//! multiplies matrix to a scalar (c = a * s)
//! supports only CV_32FC1 type
CV_EXPORTS void multiply(const GpuMat& a, const Scalar& sc, GpuMat& c);
//! computes element-wise quotient of the two arrays (c = a / b)
//! supports CV_8UC1, CV_8UC4, CV_32SC1, CV_32FC1 types
CV_EXPORTS void divide(const GpuMat& a, const GpuMat& b, GpuMat& c);
//! computes element-wise quotient of matrix and scalar (c = a / s)
//! supports only CV_32FC1 type
CV_EXPORTS void divide(const GpuMat& a, const Scalar& sc, GpuMat& c);
//! transposes the matrix
//! supports only CV_8UC1 type
CV_EXPORTS void transpose(const GpuMat& src1, GpuMat& dst);
//! computes element-wise absolute difference of two arrays (c = abs(a - b))
//! supports CV_8UC1, CV_8UC4, CV_32SC1, CV_32FC1 types
CV_EXPORTS void absdiff(const GpuMat& a, const GpuMat& b, GpuMat& c);
//! computes element-wise absolute difference of array and scalar (c = abs(a - s))
//! supports only CV_32FC1 type
CV_EXPORTS void absdiff(const GpuMat& a, const Scalar& s, GpuMat& c);
//! compares elements of two arrays (c = a <cmpop> b)
//! Now doesn't support CMP_NE.
//! supports CV_8UC4, CV_32FC1 types
CV_EXPORTS void compare(const GpuMat& a, const GpuMat& b, GpuMat& c, int cmpop);
//! computes mean value and standard deviation of all or selected array elements
//! supports only CV_8UC1 type
CV_EXPORTS void meanStdDev(const GpuMat& mtx, Scalar& mean, Scalar& stddev);
//! computes norm of array
//! Supports NORM_INF, NORM_L1, NORM_L2
//! supports NORM_INF, NORM_L1, NORM_L2
//! supports only CV_8UC1 type
CV_EXPORTS double norm(const GpuMat& src1, int normType=NORM_L2);
//! computes norm of the difference between two arrays
//! Supports NORM_INF, NORM_L1, NORM_L2
//! supports NORM_INF, NORM_L1, NORM_L2
//! supports only CV_8UC1 type
CV_EXPORTS double norm(const GpuMat& src1, const GpuMat& src2, int normType=NORM_L2);
//! reverses the order of the rows, columns or both in a matrix
//! supports CV_8UC1, CV_8UC4 types
CV_EXPORTS void flip(const GpuMat& a, GpuMat& b, int flipCode);
//! computes sum of array elements
//! supports CV_8UC1, CV_8UC4 types
CV_EXPORTS Scalar sum(const GpuMat& m);
//! finds global minimum and maximum array elements and returns their values
//! supports only CV_8UC1 type
CV_EXPORTS void minMax(const GpuMat& src, double* minVal, double* maxVal = 0);
//! transforms 8-bit unsigned integers using lookup table: dst(i)=lut(src(i))
//! supports only single channels source
//! destination array will have the same type as source
//! lut must hase CV_32S depth and the same number of channels as in the source array
//! destination array will have the depth type as lut and the same channels number as source
//! supports CV_8UC1, CV_8UC3 types
CV_EXPORTS void LUT(const GpuMat& src, const Mat& lut, GpuMat& dst);
//! makes multi-channel array out of several single-channel arrays
@ -416,10 +442,21 @@ namespace cv
//! copies each plane of a multi-channel array to a dedicated array (async version)
CV_EXPORTS void split(const GpuMat& src, vector<GpuMat>& dst, const Stream& stream);
//! computes exponent of each matrix element (b = e**a)
//! supports only CV_32FC1 type
CV_EXPORTS void exp(const GpuMat& a, GpuMat& b);
//! computes natural logarithm of absolute value of each matrix element: b = log(abs(a))
//! supports only CV_32FC1 type
CV_EXPORTS void log(const GpuMat& a, GpuMat& b);
//! computes magnitude (magnitude(i)) of each (x(i), y(i)) vector
CV_EXPORTS void magnitude(const GpuMat& x, const GpuMat& y, GpuMat& magnitude);
////////////////////////////// Image processing //////////////////////////////
//! DST[x,y] = SRC[xmap[x,y],ymap[x,y]] with bilinear interpolation.
//! xymap.type() == xymap.type() == CV_32FC1
//! supports CV_8UC1, CV_8UC3 source types and CV_32FC1 map type
CV_EXPORTS void remap(const GpuMat& src, GpuMat& dst, const GpuMat& xmap, const GpuMat& ymap);
//! Does mean shift filtering on GPU.
@ -452,7 +489,8 @@ namespace cv
CV_EXPORTS double threshold(const GpuMat& src, GpuMat& dst, double thresh);
//! resizes the image
//! Supports INTER_NEAREST, INTER_LINEAR, INTER_CUBIC, INTER_LANCZOS4
//! Supports INTER_NEAREST, INTER_LINEAR
//! supports CV_8UC1, CV_8UC4 types
CV_EXPORTS void resize(const GpuMat& src, GpuMat& dst, Size dsize, double fx=0, double fy=0, int interpolation = INTER_LINEAR);
//! warps the image using affine transformation
@ -465,16 +503,20 @@ namespace cv
//! rotate 8bit single or four channel image
//! Supports INTER_NEAREST, INTER_LINEAR, INTER_CUBIC
//! supports CV_8UC1, CV_8UC4 types
CV_EXPORTS void rotate(const GpuMat& src, GpuMat& dst, Size dsize, double angle, double xShift = 0, double yShift = 0, int interpolation = INTER_LINEAR);
//! copies 2D array to a larger destination array and pads borders with user-specifiable constant
//! supports CV_8UC1, CV_8UC4, CV_32SC1 types
CV_EXPORTS void copyMakeBorder(const GpuMat& src, GpuMat& dst, int top, int bottom, int left, int right, const Scalar& value = Scalar());
//! computes the integral image and integral for the squared image
//! sum will have CV_32S type, sqsum - CV32F type
//! supports only CV_32FC1 source type
CV_EXPORTS void integral(GpuMat& src, GpuMat& sum, GpuMat& sqsum);
//! smooths the image using the normalized box filter
//! supports CV_8UC1, CV_8UC4 types and kernel size 3, 5, 7
CV_EXPORTS void boxFilter(const GpuMat& src, GpuMat& dst, Size ksize, Point anchor = Point(-1,-1));
//! a synonym for normalized box filter

@ -49,11 +49,16 @@ using namespace std;
#if !defined (HAVE_CUDA)
void cv::gpu::add(const GpuMat&, const GpuMat&, GpuMat&) { throw_nogpu(); }
void cv::gpu::add(const GpuMat&, const Scalar&, GpuMat&) { throw_nogpu(); }
void cv::gpu::subtract(const GpuMat&, const GpuMat&, GpuMat&) { throw_nogpu(); }
void cv::gpu::subtract(const GpuMat&, const Scalar&, GpuMat&) { throw_nogpu(); }
void cv::gpu::multiply(const GpuMat&, const GpuMat&, GpuMat&) { throw_nogpu(); }
void cv::gpu::multiply(const GpuMat&, const Scalar&, GpuMat&) { throw_nogpu(); }
void cv::gpu::divide(const GpuMat&, const GpuMat&, GpuMat&) { throw_nogpu(); }
void cv::gpu::divide(const GpuMat&, const Scalar&, GpuMat&) { throw_nogpu(); }
void cv::gpu::transpose(const GpuMat&, GpuMat&) { throw_nogpu(); }
void cv::gpu::absdiff(const GpuMat&, const GpuMat&, GpuMat&) { throw_nogpu(); }
void cv::gpu::absdiff(const GpuMat&, const Scalar&, GpuMat&) { throw_nogpu(); }
void cv::gpu::compare(const GpuMat&, const GpuMat&, GpuMat&, int) { throw_nogpu(); }
void cv::gpu::meanStdDev(const GpuMat&, Scalar&, Scalar&) { throw_nogpu(); }
double cv::gpu::norm(const GpuMat&, int) { throw_nogpu(); return 0.0; }
@ -61,7 +66,10 @@ double cv::gpu::norm(const GpuMat&, const GpuMat&, int) { throw_nogpu(); return
void cv::gpu::flip(const GpuMat&, GpuMat&, int) { throw_nogpu(); }
Scalar cv::gpu::sum(const GpuMat&) { throw_nogpu(); return Scalar(); }
void cv::gpu::minMax(const GpuMat&, double*, double*) { throw_nogpu(); }
void cv::gpu::LUT(const GpuMat& src, const Mat& lut, GpuMat& dst) { throw_nogpu(); }
void cv::gpu::LUT(const GpuMat&, const Mat&, GpuMat&) { throw_nogpu(); }
void cv::gpu::exp(const GpuMat&, GpuMat&) { throw_nogpu(); }
void cv::gpu::log(const GpuMat&, GpuMat&) { throw_nogpu(); }
void cv::gpu::magnitude(const GpuMat&, const GpuMat&, GpuMat&) { throw_nogpu(); }
#else /* !defined (HAVE_CUDA) */
@ -72,15 +80,18 @@ namespace
{
typedef NppStatus (*npp_arithm_8u_t)(const Npp8u* pSrc1, int nSrc1Step, const Npp8u* pSrc2, int nSrc2Step, Npp8u* pDst, int nDstStep,
NppiSize oSizeROI, int nScaleFactor);
typedef NppStatus (*npp_arithm_32s_t)(const Npp32s* pSrc1, int nSrc1Step, const Npp32s* pSrc2, int nSrc2Step, Npp32s* pDst,
int nDstStep, NppiSize oSizeROI);
typedef NppStatus (*npp_arithm_32f_t)(const Npp32f* pSrc1, int nSrc1Step, const Npp32f* pSrc2, int nSrc2Step, Npp32f* pDst,
int nDstStep, NppiSize oSizeROI);
void nppFuncCaller(const GpuMat& src1, const GpuMat& src2, GpuMat& dst,
npp_arithm_8u_t npp_func_8uc1, npp_arithm_8u_t npp_func_8uc4, npp_arithm_32f_t npp_func_32fc1)
void nppArithmCaller(const GpuMat& src1, const GpuMat& src2, GpuMat& dst,
npp_arithm_8u_t npp_func_8uc1, npp_arithm_8u_t npp_func_8uc4,
npp_arithm_32s_t npp_func_32sc1, npp_arithm_32f_t npp_func_32fc1)
{
CV_DbgAssert(src1.size() == src2.size() && src1.type() == src2.type());
CV_Assert(src1.type() == CV_8UC1 || src1.type() == CV_8UC4 || src1.type() == CV_32FC1);
CV_Assert(src1.type() == CV_8UC1 || src1.type() == CV_8UC4 || src1.type() == CV_32SC1 || src1.type() == CV_32FC1);
dst.create( src1.size(), src1.type() );
@ -100,6 +111,11 @@ namespace
src2.ptr<Npp8u>(), src2.step,
dst.ptr<Npp8u>(), dst.step, sz, 0) );
break;
case CV_32SC1:
nppSafeCall( npp_func_32sc1(src1.ptr<Npp32s>(), src1.step,
src2.ptr<Npp32s>(), src2.step,
dst.ptr<Npp32s>(), dst.step, sz) );
break;
case CV_32FC1:
nppSafeCall( npp_func_32fc1(src1.ptr<Npp32f>(), src1.step,
src2.ptr<Npp32f>(), src2.step,
@ -109,26 +125,63 @@ namespace
CV_Assert(!"Unsupported source type");
}
}
typedef NppStatus (*npp_arithm_scalar_32f_t)(const Npp32f *pSrc, int nSrcStep, Npp32f nValue, Npp32f *pDst,
int nDstStep, NppiSize oSizeROI);
void nppArithmCaller(const GpuMat& src1, const Scalar& sc, GpuMat& dst,
npp_arithm_scalar_32f_t npp_func)
{
CV_Assert(src1.type() == CV_32FC1);
dst.create(src1.size(), src1.type());
NppiSize sz;
sz.width = src1.cols;
sz.height = src1.rows;
nppSafeCall( npp_func(src1.ptr<Npp32f>(), src1.step, (Npp32f)sc[0], dst.ptr<Npp32f>(), dst.step, sz) );
}
}
void cv::gpu::add(const GpuMat& src1, const GpuMat& src2, GpuMat& dst)
{
nppFuncCaller(src1, src2, dst, nppiAdd_8u_C1RSfs, nppiAdd_8u_C4RSfs, nppiAdd_32f_C1R);
nppArithmCaller(src1, src2, dst, nppiAdd_8u_C1RSfs, nppiAdd_8u_C4RSfs, nppiAdd_32s_C1R, nppiAdd_32f_C1R);
}
void cv::gpu::subtract(const GpuMat& src1, const GpuMat& src2, GpuMat& dst)
{
nppFuncCaller(src2, src1, dst, nppiSub_8u_C1RSfs, nppiSub_8u_C4RSfs, nppiSub_32f_C1R);
nppArithmCaller(src2, src1, dst, nppiSub_8u_C1RSfs, nppiSub_8u_C4RSfs, nppiSub_32s_C1R, nppiSub_32f_C1R);
}
void cv::gpu::multiply(const GpuMat& src1, const GpuMat& src2, GpuMat& dst)
{
nppFuncCaller(src1, src2, dst, nppiMul_8u_C1RSfs, nppiMul_8u_C4RSfs, nppiMul_32f_C1R);
nppArithmCaller(src1, src2, dst, nppiMul_8u_C1RSfs, nppiMul_8u_C4RSfs, nppiMul_32s_C1R, nppiMul_32f_C1R);
}
void cv::gpu::divide(const GpuMat& src1, const GpuMat& src2, GpuMat& dst)
{
nppFuncCaller(src2, src1, dst, nppiDiv_8u_C1RSfs, nppiDiv_8u_C4RSfs, nppiDiv_32f_C1R);
nppArithmCaller(src2, src1, dst, nppiDiv_8u_C1RSfs, nppiDiv_8u_C4RSfs, nppiDiv_32s_C1R, nppiDiv_32f_C1R);
}
void cv::gpu::add(const GpuMat& src, const Scalar& sc, GpuMat& dst)
{
nppArithmCaller(src, sc, dst, nppiAddC_32f_C1R);
}
void cv::gpu::subtract(const GpuMat& src, const Scalar& sc, GpuMat& dst)
{
nppArithmCaller(src, sc, dst, nppiSubC_32f_C1R);
}
void cv::gpu::multiply(const GpuMat& src, const Scalar& sc, GpuMat& dst)
{
nppArithmCaller(src, sc, dst, nppiMulC_32f_C1R);
}
void cv::gpu::divide(const GpuMat& src, const Scalar& sc, GpuMat& dst)
{
nppArithmCaller(src, sc, dst, nppiDivC_32f_C1R);
}
////////////////////////////////////////////////////////////////////////
@ -154,7 +207,7 @@ void cv::gpu::absdiff(const GpuMat& src1, const GpuMat& src2, GpuMat& dst)
{
CV_DbgAssert(src1.size() == src2.size() && src1.type() == src2.type());
CV_Assert(src1.type() == CV_8UC1 || src1.type() == CV_32FC1);
CV_Assert(src1.type() == CV_8UC1 || src1.type() == CV_8UC4 || src1.type() == CV_32SC1 || src1.type() == CV_32FC1);
dst.create( src1.size(), src1.type() );
@ -162,20 +215,46 @@ void cv::gpu::absdiff(const GpuMat& src1, const GpuMat& src2, GpuMat& dst)
sz.width = src1.cols;
sz.height = src1.rows;
if (src1.type() == CV_8UC1)
switch (src1.type())
{
case CV_8UC1:
nppSafeCall( nppiAbsDiff_8u_C1R(src1.ptr<Npp8u>(), src1.step,
src2.ptr<Npp8u>(), src2.step,
dst.ptr<Npp8u>(), dst.step, sz) );
}
else
{
break;
case CV_8UC4:
nppSafeCall( nppiAbsDiff_8u_C4R(src1.ptr<Npp8u>(), src1.step,
src2.ptr<Npp8u>(), src2.step,
dst.ptr<Npp8u>(), dst.step, sz) );
break;
case CV_32SC1:
nppSafeCall( nppiAbsDiff_32s_C1R(src1.ptr<Npp32s>(), src1.step,
src2.ptr<Npp32s>(), src2.step,
dst.ptr<Npp32s>(), dst.step, sz) );
break;
case CV_32FC1:
nppSafeCall( nppiAbsDiff_32f_C1R(src1.ptr<Npp32f>(), src1.step,
src2.ptr<Npp32f>(), src2.step,
dst.ptr<Npp32f>(), dst.step, sz) );
break;
default:
CV_Assert(!"Unsupported source type");
}
}
void cv::gpu::absdiff(const GpuMat& src, const Scalar& s, GpuMat& dst)
{
CV_Assert(src.type() == CV_32FC1);
dst.create( src.size(), src.type() );
NppiSize sz;
sz.width = src.cols;
sz.height = src.rows;
nppSafeCall( nppiAbsDiffC_32f_C1R(src.ptr<Npp32f>(), src.step, dst.ptr<Npp32f>(), dst.step, sz, (Npp32f)s[0]) );
}
////////////////////////////////////////////////////////////////////////
// compare
@ -416,4 +495,57 @@ void cv::gpu::LUT(const GpuMat& src, const Mat& lut, GpuMat& dst)
}
}
////////////////////////////////////////////////////////////////////////
// exp
void cv::gpu::exp(const GpuMat& src, GpuMat& dst)
{
CV_Assert(src.type() == CV_32FC1);
dst.create(src.size(), src.type());
NppiSize sz;
sz.width = src.cols;
sz.height = src.rows;
nppSafeCall( nppiExp_32f_C1R(src.ptr<Npp32f>(), src.step, dst.ptr<Npp32f>(), dst.step, sz) );
}
////////////////////////////////////////////////////////////////////////
// log
void cv::gpu::log(const GpuMat& src, GpuMat& dst)
{
CV_Assert(src.type() == CV_32FC1);
dst.create(src.size(), src.type());
NppiSize sz;
sz.width = src.cols;
sz.height = src.rows;
nppSafeCall( nppiLn_32f_C1R(src.ptr<Npp32f>(), src.step, dst.ptr<Npp32f>(), dst.step, sz) );
}
////////////////////////////////////////////////////////////////////////
// magnitude
void cv::gpu::magnitude(const GpuMat& src1, const GpuMat& src2, GpuMat& dst)
{
CV_DbgAssert(src1.type() == src2.type() && src1.size() == src2.size());
CV_Assert(src1.type() == CV_32FC1);
GpuMat src(src1.size(), CV_32FC2);
GpuMat srcs[] = {src1, src2};
cv::gpu::merge(srcs, 2, src);
dst.create(src1.size(), src1.type());
NppiSize sz;
sz.width = src.cols;
sz.height = src.rows;
nppSafeCall( nppiMagnitude_32fc32f_C1R(src.ptr<Npp32fc>(), src.step, dst.ptr<Npp32f>(), dst.step, sz) );
}
#endif /* !defined (HAVE_CUDA) */

@ -89,7 +89,7 @@ void cv::gpu::StereoBeliefPropagation::estimateRecommendedParams(int width, int
int mm = ::max(width, height);
iters = mm / 100 + 2;
levels = (int)(log(static_cast<double>(mm)) + 1) * 4 / 5;
levels = (int)(::log(static_cast<double>(mm)) + 1) * 4 / 5;
if (levels == 0) levels++;
}

@ -116,7 +116,7 @@ void cv::gpu::StereoConstantSpaceBP::estimateRecommendedParams(int width, int he
int mm = ::max(width, height);
iters = mm / 100 + ((mm > 1200)? - 4 : 4);
levels = (int)log(static_cast<double>(mm)) * 2 / 3;
levels = (int)::log(static_cast<double>(mm)) * 2 / 3;
if (levels == 0) levels++;
nr_plane = (int) ((float) ndisp / pow(2.0, levels + 1));

@ -592,10 +592,10 @@ double cv::gpu::threshold(const GpuMat& src, GpuMat& dst, double thresh)
void cv::gpu::resize(const GpuMat& src, GpuMat& dst, Size dsize, double fx, double fy, int interpolation)
{
static const int npp_inter[] = {NPPI_INTER_NN, NPPI_INTER_LINEAR, NPPI_INTER_CUBIC, 0, NPPI_INTER_LANCZOS};
static const int npp_inter[] = {NPPI_INTER_NN, NPPI_INTER_LINEAR/*, NPPI_INTER_CUBIC, 0, NPPI_INTER_LANCZOS*/};
CV_Assert(src.type() == CV_8UC1 || src.type() == CV_8UC4);
CV_Assert(interpolation == INTER_NEAREST || interpolation == INTER_LINEAR || interpolation == INTER_CUBIC || interpolation == INTER_LANCZOS4);
CV_Assert(interpolation == INTER_NEAREST || interpolation == INTER_LINEAR/* || interpolation == INTER_CUBIC || interpolation == INTER_LANCZOS4*/);
CV_Assert( src.size().area() > 0 );
CV_Assert( !(dsize == Size()) || (fx > 0 && fy > 0) );

@ -151,54 +151,94 @@ void cv::gpu::GpuMat::convertTo( GpuMat& dst, int rtype, double alpha, double be
GpuMat& GpuMat::operator = (const Scalar& s)
{
matrix_operations::set_to_without_mask( *this, depth(), s.val, channels());
setTo(s);
return *this;
}
GpuMat& GpuMat::setTo(const Scalar& s, const GpuMat& mask)
{
//CV_Assert(mask.type() == CV_8U);
CV_Assert(mask.type() == CV_8UC1);
CV_DbgAssert(!this->empty());
NppiSize sz;
sz.width = cols;
sz.height = rows;
if (mask.empty())
{
switch (type())
{
case CV_8UC1:
{
NppiSize sz;
sz.width = cols;
sz.height = rows;
Npp8u nVal = (Npp8u)s[0];
nppSafeCall( nppiSet_8u_C1R(nVal, (Npp8u*)ptr<char>(), step, sz) );
nppSafeCall( nppiSet_8u_C1R(nVal, ptr<Npp8u>(), step, sz) );
break;
}
case CV_8UC4:
{
NppiSize sz;
sz.width = cols;
sz.height = rows;
Npp8u nVal[] = {(Npp8u)s[0], (Npp8u)s[1], (Npp8u)s[2], (Npp8u)s[3]};
nppSafeCall( nppiSet_8u_C4R(nVal, (Npp8u*)ptr<char>(), step, sz) );
Scalar_<Npp8u> nVal = s;
nppSafeCall( nppiSet_8u_C4R(nVal.val, ptr<Npp8u>(), step, sz) );
break;
}
case CV_16UC1:
{
Npp16u nVal = (Npp16u)s[0];
nppSafeCall( nppiSet_16u_C1R(nVal, ptr<Npp16u>(), step, sz) );
break;
}
/*case CV_16UC2:
{
Scalar_<Npp16u> nVal = s;
nppSafeCall( nppiSet_16u_C2R(nVal.val, ptr<Npp16u>(), step, sz) );
break;
}*/
case CV_16UC4:
{
Scalar_<Npp16u> nVal = s;
nppSafeCall( nppiSet_16u_C4R(nVal.val, ptr<Npp16u>(), step, sz) );
break;
}
case CV_16SC1:
{
Npp16s nVal = (Npp16s)s[0];
nppSafeCall( nppiSet_16s_C1R(nVal, ptr<Npp16s>(), step, sz) );
break;
}
/*case CV_16SC2:
{
Scalar_<Npp16s> nVal = s;
nppSafeCall( nppiSet_16s_C2R(nVal.val, ptr<Npp16s>(), step, sz) );
break;
}*/
case CV_16SC4:
{
Scalar_<Npp16s> nVal = s;
nppSafeCall( nppiSet_16s_C4R(nVal.val, ptr<Npp16s>(), step, sz) );
break;
}
case CV_32SC1:
{
NppiSize sz;
sz.width = cols;
sz.height = rows;
Npp32s nVal = (Npp32s)s[0];
nppSafeCall( nppiSet_32s_C1R(nVal, (Npp32s*)ptr<char>(), step, sz) );
nppSafeCall( nppiSet_32s_C1R(nVal, ptr<Npp32s>(), step, sz) );
break;
}
case CV_32SC4:
{
Scalar_<Npp32s> nVal = s;
nppSafeCall( nppiSet_32s_C4R(nVal.val, ptr<Npp32s>(), step, sz) );
break;
}
case CV_32FC1:
{
NppiSize sz;
sz.width = cols;
sz.height = rows;
Npp32f nVal = (Npp32f)s[0];
nppSafeCall( nppiSet_32f_C1R(nVal, (Npp32f*)ptr<char>(), step, sz) );
nppSafeCall( nppiSet_32f_C1R(nVal, ptr<Npp32f>(), step, sz) );
break;
}
case CV_32FC4:
{
Scalar_<Npp32f> nVal = s;
nppSafeCall( nppiSet_32f_C4R(nVal.val, ptr<Npp32f>(), step, sz) );
break;
}
default:
@ -206,7 +246,73 @@ GpuMat& GpuMat::setTo(const Scalar& s, const GpuMat& mask)
}
}
else
matrix_operations::set_to_with_mask( *this, depth(), s.val, mask, channels());
{
switch (type())
{
case CV_8UC1:
{
Npp8u nVal = (Npp8u)s[0];
nppSafeCall( nppiSet_8u_C1MR(nVal, ptr<Npp8u>(), step, sz, mask.ptr<Npp8u>(), mask.step) );
break;
}
case CV_8UC4:
{
Scalar_<Npp8u> nVal = s;
nppSafeCall( nppiSet_8u_C4MR(nVal.val, ptr<Npp8u>(), step, sz, mask.ptr<Npp8u>(), mask.step) );
break;
}
case CV_16UC1:
{
Npp16u nVal = (Npp16u)s[0];
nppSafeCall( nppiSet_16u_C1MR(nVal, ptr<Npp16u>(), step, sz, mask.ptr<Npp8u>(), mask.step) );
break;
}
case CV_16UC4:
{
Scalar_<Npp16u> nVal = s;
nppSafeCall( nppiSet_16u_C4MR(nVal.val, ptr<Npp16u>(), step, sz, mask.ptr<Npp8u>(), mask.step) );
break;
}
case CV_16SC1:
{
Npp16s nVal = (Npp16s)s[0];
nppSafeCall( nppiSet_16s_C1MR(nVal, ptr<Npp16s>(), step, sz, mask.ptr<Npp8u>(), mask.step) );
break;
}
case CV_16SC4:
{
Scalar_<Npp16s> nVal = s;
nppSafeCall( nppiSet_16s_C4MR(nVal.val, ptr<Npp16s>(), step, sz, mask.ptr<Npp8u>(), mask.step) );
break;
}
case CV_32SC1:
{
Npp32s nVal = (Npp32s)s[0];
nppSafeCall( nppiSet_32s_C1MR(nVal, ptr<Npp32s>(), step, sz, mask.ptr<Npp8u>(), mask.step) );
break;
}
case CV_32SC4:
{
Scalar_<Npp32s> nVal = s;
nppSafeCall( nppiSet_32s_C4MR(nVal.val, ptr<Npp32s>(), step, sz, mask.ptr<Npp8u>(), mask.step) );
break;
}
case CV_32FC1:
{
Npp32f nVal = (Npp32f)s[0];
nppSafeCall( nppiSet_32f_C1MR(nVal, ptr<Npp32f>(), step, sz, mask.ptr<Npp8u>(), mask.step) );
break;
}
case CV_32FC4:
{
Scalar_<Npp32f> nVal = s;
nppSafeCall( nppiSet_32f_C4MR(nVal.val, ptr<Npp32f>(), step, sz, mask.ptr<Npp8u>(), mask.step) );
break;
}
default:
matrix_operations::set_to_with_mask( *this, depth(), s.val, mask, channels());
}
}
return *this;
}

@ -74,8 +74,8 @@ int CV_GpuArithmTest::test(int type)
cv::Size sz(200, 200);
cv::Mat mat1(sz, type), mat2(sz, type);
cv::RNG rng(*ts->get_rng());
rng.fill(mat1, cv::RNG::UNIFORM, cv::Scalar::all(10), cv::Scalar::all(100));
rng.fill(mat2, cv::RNG::UNIFORM, cv::Scalar::all(10), cv::Scalar::all(100));
rng.fill(mat1, cv::RNG::UNIFORM, cv::Scalar::all(1), cv::Scalar::all(20));
rng.fill(mat2, cv::RNG::UNIFORM, cv::Scalar::all(1), cv::Scalar::all(20));
return test(mat1, mat2);
}
@ -114,8 +114,8 @@ void CV_GpuArithmTest::run( int )
int testResult = CvTS::OK;
try
{
const int types[] = {CV_8UC1, CV_8UC3, CV_8UC4, CV_32FC1};
const char* type_names[] = {"CV_8UC1", "CV_8UC3", "CV_8UC4", "CV_32FC1"};
const int types[] = {CV_8UC1, CV_8UC3, CV_8UC4, CV_32SC1, CV_32FC1};
const char* type_names[] = {"CV_8UC1", "CV_8UC3", "CV_8UC4", "CV_32SC1", "CV_32FC1"};
const int type_count = sizeof(types)/sizeof(types[0]);
//run tests
@ -151,7 +151,7 @@ struct CV_GpuNppImageAddTest : public CV_GpuArithmTest
virtual int test(const Mat& mat1, const Mat& mat2)
{
if (mat1.type() != CV_8UC1 && mat1.type() != CV_8UC4 && mat1.type() != CV_32FC1)
if (mat1.type() != CV_8UC1 && mat1.type() != CV_8UC4 && mat1.type() != CV_32SC1 && mat1.type() != CV_32FC1)
{
ts->printf(CvTS::LOG, "\nUnsupported type\n");
return CvTS::OK;
@ -177,7 +177,7 @@ struct CV_GpuNppImageSubtractTest : public CV_GpuArithmTest
int test( const Mat& mat1, const Mat& mat2 )
{
if (mat1.type() != CV_8UC1 && mat1.type() != CV_8UC4 && mat1.type() != CV_32FC1)
if (mat1.type() != CV_8UC1 && mat1.type() != CV_8UC4 && mat1.type() != CV_32SC1 && mat1.type() != CV_32FC1)
{
ts->printf(CvTS::LOG, "\nUnsupported type\n");
return CvTS::OK;
@ -203,7 +203,7 @@ struct CV_GpuNppImageMultiplyTest : public CV_GpuArithmTest
int test( const Mat& mat1, const Mat& mat2 )
{
if (mat1.type() != CV_8UC1 && mat1.type() != CV_8UC4 && mat1.type() != CV_32FC1)
if (mat1.type() != CV_8UC1 && mat1.type() != CV_8UC4 && mat1.type() != CV_32SC1 && mat1.type() != CV_32FC1)
{
ts->printf(CvTS::LOG, "\nUnsupported type\n");
return CvTS::OK;
@ -229,7 +229,7 @@ struct CV_GpuNppImageDivideTest : public CV_GpuArithmTest
int test( const Mat& mat1, const Mat& mat2 )
{
if (mat1.type() != CV_8UC1 && mat1.type() != CV_8UC4 && mat1.type() != CV_32FC1)
if (mat1.type() != CV_8UC1 && mat1.type() != CV_8UC4 && mat1.type() != CV_32SC1 && mat1.type() != CV_32FC1)
{
ts->printf(CvTS::LOG, "\nUnsupported type\n");
return CvTS::OK;
@ -280,7 +280,7 @@ struct CV_GpuNppImageAbsdiffTest : public CV_GpuArithmTest
int test( const Mat& mat1, const Mat& mat2 )
{
if (mat1.type() != CV_8UC1 && mat1.type() != CV_32FC1)
if (mat1.type() != CV_8UC1 && mat1.type() != CV_8UC4 && mat1.type() != CV_32SC1 && mat1.type() != CV_32FC1)
{
ts->printf(CvTS::LOG, "\nUnsupported type\n");
return CvTS::OK;
@ -532,6 +532,82 @@ struct CV_GpuNppImageLUTTest : public CV_GpuArithmTest
}
};
////////////////////////////////////////////////////////////////////////////////
// exp
struct CV_GpuNppImageExpTest : public CV_GpuArithmTest
{
CV_GpuNppImageExpTest() : CV_GpuArithmTest( "GPU-NppImageExp", "exp" ) {}
int test( const Mat& mat1, const Mat& )
{
if (mat1.type() != CV_32FC1)
{
ts->printf(CvTS::LOG, "\nUnsupported type\n");
return CvTS::OK;
}
cv::Mat cpuRes;
cv::exp(mat1, cpuRes);
GpuMat gpu1(mat1);
GpuMat gpuRes;
cv::gpu::exp(gpu1, gpuRes);
return CheckNorm(cpuRes, gpuRes);
}
};
////////////////////////////////////////////////////////////////////////////////
// log
struct CV_GpuNppImageLogTest : public CV_GpuArithmTest
{
CV_GpuNppImageLogTest() : CV_GpuArithmTest( "GPU-NppImageLog", "log" ) {}
int test( const Mat& mat1, const Mat& )
{
if (mat1.type() != CV_32FC1)
{
ts->printf(CvTS::LOG, "\nUnsupported type\n");
return CvTS::OK;
}
cv::Mat cpuRes;
cv::log(mat1, cpuRes);
GpuMat gpu1(mat1);
GpuMat gpuRes;
cv::gpu::log(gpu1, gpuRes);
return CheckNorm(cpuRes, gpuRes);
}
};
////////////////////////////////////////////////////////////////////////////////
// magnitude
struct CV_GpuNppImageMagnitudeTest : public CV_GpuArithmTest
{
CV_GpuNppImageMagnitudeTest() : CV_GpuArithmTest( "GPU-NppImageMagnitude", "magnitude" ) {}
int test( const Mat& mat1, const Mat& mat2 )
{
if (mat1.type() != CV_32FC1)
{
ts->printf(CvTS::LOG, "\nUnsupported type\n");
return CvTS::OK;
}
cv::Mat cpuRes;
cv::magnitude(mat1, mat2, cpuRes);
GpuMat gpu1(mat1);
GpuMat gpu2(mat2);
GpuMat gpuRes;
cv::gpu::magnitude(gpu1, gpu2, gpuRes);
return CheckNorm(cpuRes, gpuRes);
}
};
/////////////////////////////////////////////////////////////////////////////
/////////////////// tests registration /////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////
@ -553,3 +629,6 @@ CV_GpuNppImageFlipTest CV_GpuNppImageFlip_test;
CV_GpuNppImageSumTest CV_GpuNppImageSum_test;
CV_GpuNppImageMinNaxTest CV_GpuNppImageMinNax_test;
CV_GpuNppImageLUTTest CV_GpuNppImageLUT_test;
CV_GpuNppImageExpTest CV_GpuNppImageExp_test;
CV_GpuNppImageLogTest CV_GpuNppImageLog_test;
CV_GpuNppImageMagnitudeTest CV_GpuNppImageMagnitude_test;

@ -45,19 +45,22 @@ CvTS test_system;
const char* blacklist[] =
{
"GPU-NppImageSum",
"GPU-MatOperatorAsyncCall",
//"GPU-NppErode",
//"GPU-NppDilate",
//"GPU-NppMorphologyEx",
//"GPU-NppImageDivide",
//"GPU-NppImageMeanStdDev",
//"GPU-NppImageMinNax",
//"GPU-NppImageResize",
//"GPU-NppImageWarpAffine",
//"GPU-NppImageWarpPerspective",
//"GPU-NppImageIntegral",
//"GPU-NppImageBlur",
"GPU-NppImageSum", // crash
"GPU-MatOperatorAsyncCall", // crash
//"GPU-NppErode", // npp func returns error code (CUDA_KERNEL_LAUNCH_ERROR or TEXTURE_BIND_ERROR)
//"GPU-NppDilate", // npp func returns error code (CUDA_KERNEL_LAUNCH_ERROR or TEXTURE_BIND_ERROR)
//"GPU-NppMorphologyEx", // npp func returns error code (CUDA_KERNEL_LAUNCH_ERROR or TEXTURE_BIND_ERROR)
//"GPU-NppImageDivide", // different round mode
//"GPU-NppImageMeanStdDev", // different precision
//"GPU-NppImageMinNax", // npp bug
//"GPU-NppImageResize", // different precision in interpolation
//"GPU-NppImageWarpAffine", // different precision in interpolation
//"GPU-NppImageWarpPerspective", // different precision in interpolation
//"GPU-NppImageIntegral", // different precision
//"GPU-NppImageBlur", // different precision
//"GPU-NppImageExp", // different precision
//"GPU-NppImageLog", // different precision
//"GPU-NppImageMagnitude", // different precision
0
};

@ -68,7 +68,6 @@ void CV_GpuMatOpConvertToTest::run(int /* start_from */)
const int types[] = {CV_8U, CV_8S, CV_16U, CV_16S, CV_32S, CV_32F, CV_64F};
const int types_num = sizeof(types) / sizeof(int);
const char* types_str[] = {"CV_8U", "CV_8S", "CV_16U", "CV_16S", "CV_32S", "CV_32F", "CV_64F"};
bool passed = true;
@ -78,17 +77,16 @@ void CV_GpuMatOpConvertToTest::run(int /* start_from */)
{
for (int j = 0; j < types_num && passed; ++j)
{
for (int c = 1; c < 2 && passed; ++c)
for (int c = 1; c < 5 && passed; ++c)
{
const int src_type = CV_MAKETYPE(types[i], c);
const int dst_type = types[j];
const double alpha = (double)rand() / RAND_MAX * 2.0;
const double beta = (double)rand() / RAND_MAX * 150.0 - 75;
cv::RNG rng(*ts->get_rng());
const double alpha = rng.uniform(0.0, 2.0);
const double beta = rng.uniform(-75.0, 75.0);
Mat cpumatsrc(img_size, src_type);
rng.fill(cpumatsrc, RNG::UNIFORM, Scalar::all(0), Scalar::all(300));
GpuMat gpumatsrc(cpumatsrc);

@ -40,15 +40,7 @@
//M*/
#include "gputest.hpp"
#include "highgui.h"
#include <string>
#include <iostream>
#include <fstream>
#include <iterator>
#include <limits>
#include <numeric>
#include <iomanip> // for cout << setw()
using namespace cv;
using namespace std;
@ -62,9 +54,8 @@ public:
protected:
void run(int);
void print_mat(cv::Mat & mat, std::string name = "cpu mat");
void print_mat(gpu::GpuMat & mat, std::string name = "gpu mat");
bool compare_matrix(cv::Mat & cpumat, gpu::GpuMat & gpumat);
bool testSetTo(cv::Mat& cpumat, gpu::GpuMat& gpumat, const cv::Mat& cpumask = cv::Mat(), const cv::gpu::GpuMat& gpumask = cv::gpu::GpuMat());
private:
int rows;
@ -74,51 +65,23 @@ private:
CV_GpuMatOpSetToTest::CV_GpuMatOpSetToTest(): CvTest( "GPU-MatOperatorSetTo", "setTo" )
{
rows = 256;
cols = 124;
rows = 35;
cols = 67;
s.val[0] = 127.0;
s.val[1] = 127.0;
s.val[2] = 127.0;
s.val[3] = 127.0;
//#define PRINT_MATRIX
}
void CV_GpuMatOpSetToTest::print_mat(cv::Mat & mat, std::string name )
bool CV_GpuMatOpSetToTest::testSetTo(cv::Mat& cpumat, gpu::GpuMat& gpumat, const cv::Mat& cpumask, const cv::gpu::GpuMat& gpumask)
{
cv::imshow(name, mat);
}
cpumat.setTo(s, cpumask);
gpumat.setTo(s, gpumask);
void CV_GpuMatOpSetToTest::print_mat(gpu::GpuMat & mat, std::string name)
{
cv::Mat newmat;
mat.download(newmat);
print_mat(newmat, name);
}
double ret = norm(cpumat, gpumat, NORM_INF);
bool CV_GpuMatOpSetToTest::compare_matrix(cv::Mat & cpumat, gpu::GpuMat & gpumat)
{
//int64 time = getTickCount();
cpumat.setTo(s);
//int64 time1 = getTickCount();
gpumat.setTo(s);
//int64 time2 = getTickCount();
//std::cout << "\ntime cpu: " << std::fixed << std::setprecision(12) << double((time1 - time) / (double)getTickFrequency());
//std::cout << "\ntime gpu: " << std::fixed << std::setprecision(12) << double((time2 - time1) / (double)getTickFrequency());
//std::cout << "\n";
#ifdef PRINT_MATRIX
print_mat(cpumat);
print_mat(gpumat);
cv::waitKey(0);
#endif
double ret = norm(cpumat, gpumat);
if (ret < 1.0)
if (ret < std::numeric_limits<double>::epsilon())
return true;
else
{
@ -133,11 +96,20 @@ void CV_GpuMatOpSetToTest::run( int /* start_from */)
try
{
cv::Mat cpumask(rows, cols, CV_8UC1);
cv::RNG rng(*ts->get_rng());
rng.fill(cpumask, RNG::UNIFORM, cv::Scalar::all(0.0), cv::Scalar(1.5));
cv::gpu::GpuMat gpumask(cpumask);
for (int i = 0; i < 7; i++)
{
Mat cpumat(rows, cols, i, Scalar::all(0));
GpuMat gpumat(cpumat);
is_test_good &= compare_matrix(cpumat, gpumat);
for (int cn = 1; cn <= 4; ++cn)
{
int mat_type = CV_MAKETYPE(i, cn);
Mat cpumat(rows, cols, mat_type, Scalar::all(0));
GpuMat gpumat(cpumat);
is_test_good &= testSetTo(cpumat, gpumat, cpumask, gpumask);
}
}
}
catch(const cv::Exception& e)

Loading…
Cancel
Save