fixed ocl::setIdentity

pull/1519/head
Ilya Lavrenov 12 years ago
parent 7edcefb2be
commit 7379152afb
  1. 3
      modules/ocl/include/opencv2/ocl/ocl.hpp
  2. 70
      modules/ocl/src/arithm.cpp
  3. 47
      modules/ocl/src/opencl/arithm_setidentity.cl
  4. 60
      modules/ocl/test/test_arithm.cpp

@ -584,7 +584,8 @@ namespace cv
CV_EXPORTS void cvtColor(const oclMat &src, oclMat &dst, int code , int dcn = 0); CV_EXPORTS void cvtColor(const oclMat &src, oclMat &dst, int code , int dcn = 0);
CV_EXPORTS void setIdentity(oclMat& src, double val); //! initializes a scaled identity matrix
CV_EXPORTS void setIdentity(oclMat& src, const Scalar & val = Scalar(1));
//////////////////////////////// Filter Engine //////////////////////////////// //////////////////////////////// Filter Engine ////////////////////////////////

@ -1709,63 +1709,35 @@ void cv::ocl::pow(const oclMat &x, double p, oclMat &y)
/////////////////////////////// setIdentity ////////////////////////////////// /////////////////////////////// setIdentity //////////////////////////////////
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
void cv::ocl::setIdentity(oclMat& src, double scalar) void cv::ocl::setIdentity(oclMat& src, const Scalar & scalar)
{ {
CV_Assert(src.empty() == false && src.rows == src.cols);
CV_Assert(src.type() == CV_32SC1 || src.type() == CV_32FC1);
int src_step = src.step/src.elemSize();
Context *clCxt = Context::getContext(); Context *clCxt = Context::getContext();
size_t local_threads[] = {16, 16, 1}; if (!clCxt->supportsFeature(Context::CL_DOUBLE) && src.depth() == CV_64F)
size_t global_threads[] = {src.cols, src.rows, 1};
string kernelName = "setIdentityKernel";
if (src.type() == CV_32FC1)
kernelName += "_F1";
else if (src.type() == CV_32SC1)
kernelName += "_I1";
else
{ {
kernelName += "_D1"; CV_Error(CV_GpuNotSupported, "Selected device doesn't support double\r\n");
if (!(clCxt->supportsFeature(Context::CL_DOUBLE))) return;
{
oclMat temp;
src.convertTo(temp, CV_32FC1);
temp.copyTo(src);
}
} }
CV_Assert(src.step % src.elemSize() == 0);
int src_step1 = src.step / src.elemSize(), src_offset1 = src.offset / src.elemSize();
size_t local_threads[] = { 16, 16, 1 };
size_t global_threads[] = { src.cols, src.rows, 1 };
const char * const typeMap[] = { "uchar", "char", "ushort", "short", "int", "float", "double" };
const char * const channelMap[] = { "", "", "2", "4", "4" };
string buildOptions = format("-D T=%s%s", typeMap[src.depth()], channelMap[src.oclchannels()]);
vector<pair<size_t , const void *> > args; vector<pair<size_t , const void *> > args;
args.push_back( make_pair( sizeof(cl_mem), (void *)&src.data )); args.push_back( make_pair( sizeof(cl_mem), (void *)&src.data ));
args.push_back( make_pair( sizeof(cl_int), (void *)&src.rows)); args.push_back( make_pair( sizeof(cl_int), (void *)&src_step1 ));
args.push_back( make_pair( sizeof(cl_int), (void *)&src_offset1 ));
args.push_back( make_pair( sizeof(cl_int), (void *)&src.cols)); args.push_back( make_pair( sizeof(cl_int), (void *)&src.cols));
args.push_back( make_pair( sizeof(cl_int), (void *)&src_step )); args.push_back( make_pair( sizeof(cl_int), (void *)&src.rows));
int scalar_i = 0; oclMat sc(1, 1, src.type(), scalar);
float scalar_f = 0.0f; args.push_back( make_pair( sizeof(cl_mem), (void *)&sc.data ));
if (clCxt->supportsFeature(Context::CL_DOUBLE))
{
if (src.type() == CV_32SC1)
{
scalar_i = (int)scalar;
args.push_back(make_pair(sizeof(cl_int), (void*)&scalar_i));
}
else
args.push_back(make_pair(sizeof(cl_double), (void*)&scalar));
}
else
{
if (src.type() == CV_32SC1)
{
scalar_i = (int)scalar;
args.push_back(make_pair(sizeof(cl_int), (void*)&scalar_i));
}
else
{
scalar_f = (float)scalar;
args.push_back(make_pair(sizeof(cl_float), (void*)&scalar_f));
}
}
openCLExecuteKernel(clCxt, &arithm_setidentity, kernelName, global_threads, local_threads, args, -1, -1); openCLExecuteKernel(clCxt, &arithm_setidentity, "setIdentity", global_threads, local_threads,
args, -1, -1, buildOptions.c_str());
} }

@ -42,6 +42,7 @@
// the use of this software, even if advised of the possibility of such damage. // the use of this software, even if advised of the possibility of such damage.
// //
//M*/ //M*/
#if defined (DOUBLE_SUPPORT) #if defined (DOUBLE_SUPPORT)
#ifdef cl_khr_fp64 #ifdef cl_khr_fp64
#pragma OPENCL EXTENSION cl_khr_fp64:enable #pragma OPENCL EXTENSION cl_khr_fp64:enable
@ -50,51 +51,19 @@
#endif #endif
#endif #endif
__kernel void setIdentity(__global T * src, int src_step, int src_offset,
#if defined (DOUBLE_SUPPORT) int cols, int rows, __global const T * scalar)
#define DATA_TYPE double
#else
#define DATA_TYPE float
#endif
__kernel void setIdentityKernel_F1(__global float* src, int src_row, int src_col, int src_step, DATA_TYPE scalar)
{
int x = get_global_id(0);
int y = get_global_id(1);
if(x < src_col && y < src_row)
{
if(x == y)
src[y * src_step + x] = scalar;
else
src[y * src_step + x] = 0 * scalar;
}
}
__kernel void setIdentityKernel_D1(__global DATA_TYPE* src, int src_row, int src_col, int src_step, DATA_TYPE scalar)
{ {
int x = get_global_id(0); int x = get_global_id(0);
int y = get_global_id(1); int y = get_global_id(1);
if(x < src_col && y < src_row) if (x < cols && y < rows)
{ {
if(x == y) int src_index = mad24(y, src_step, src_offset + x);
src[y * src_step + x] = scalar;
else
src[y * src_step + x] = 0 * scalar;
}
}
__kernel void setIdentityKernel_I1(__global int* src, int src_row, int src_col, int src_step, int scalar) if (x == y)
{ src[src_index] = *scalar;
int x = get_global_id(0);
int y = get_global_id(1);
if(x < src_col && y < src_row)
{
if(x == y)
src[y * src_step + x] = scalar;
else else
src[y * src_step + x] = 0 * scalar; src[src_index] = 0;
} }
} }

@ -1423,34 +1423,52 @@ TEST_P(AddWeighted, Mat)
} }
} }
//////////////////////////////// setIdentity /////////////////////////////////////////////////
typedef ArithmTestBase SetIdentity;
TEST_P(SetIdentity, Mat)
{
for (int j = 0; j < LOOP_TIMES; j++)
{
random_roi();
cv::setIdentity(dst1_roi, val);
cv::ocl::setIdentity(gdst1, val);
Near(0);
}
}
//////////////////////////////////////// Instantiation ///////////////////////////////////////// //////////////////////////////////////// Instantiation /////////////////////////////////////////
INSTANTIATE_TEST_CASE_P(Arithm, Lut, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool(), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Lut, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool(), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Exp, Combine(testing::Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Exp, Combine(testing::Values(CV_32F, CV_64F), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Log, Combine(testing::Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Log, Combine(testing::Values(CV_32F, CV_64F), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Add, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Add, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Sub, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Sub, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Mul, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); INSTANTIATE_TEST_CASE_P(Arithm, Mul, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Div, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); INSTANTIATE_TEST_CASE_P(Arithm, Div, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Absdiff, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Absdiff, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, CartToPolar, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, CartToPolar, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, PolarToCart, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, PolarToCart, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Magnitude, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Magnitude, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Transpose, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Transpose, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Flip, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Flip, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, MinMax, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool())); INSTANTIATE_TEST_CASE_P(Arithm, MinMax, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, MinMaxLoc, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, MinMaxLoc, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Sum, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); INSTANTIATE_TEST_CASE_P(Arithm, Sum, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, SqrSum, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); INSTANTIATE_TEST_CASE_P(Arithm, SqrSum, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, AbsSum, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); INSTANTIATE_TEST_CASE_P(Arithm, AbsSum, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, CountNonZero, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, CountNonZero, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Phase, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Phase, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_and, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_and, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_or, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_or, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_xor, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_xor, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_not, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_not, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Compare, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Compare, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, Pow, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, Pow, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, AddWeighted, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + INSTANTIATE_TEST_CASE_P(Arithm, AddWeighted, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
INSTANTIATE_TEST_CASE_P(Arithm, SetIdentity, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool()));
#endif // HAVE_OPENCL #endif // HAVE_OPENCL

Loading…
Cancel
Save