switched to new device layer in split/merge

pull/1540/head
Vladislav Vinogradov 12 years ago
parent 3ab7f4b26a
commit 6dbb32a05d
  1. 111
      modules/cudaarithm/src/core.cpp
  2. 545
      modules/cudaarithm/src/cuda/split_merge.cu
  3. 185
      modules/cudev/include/opencv2/cudev/grid/split_merge.hpp

@ -63,117 +63,6 @@ void cv::cuda::copyMakeBorder(InputArray, OutputArray, int, int, int, int, int,
#else /* !defined (HAVE_CUDA) */
////////////////////////////////////////////////////////////////////////
// merge/split
namespace cv { namespace cuda { namespace device
{
namespace split_merge
{
void merge(const PtrStepSzb* src, PtrStepSzb& dst, int total_channels, size_t elem_size, const cudaStream_t& stream);
void split(const PtrStepSzb& src, PtrStepSzb* dst, int num_channels, size_t elem_size1, const cudaStream_t& stream);
}
}}}
namespace
{
void merge_caller(const GpuMat* src, size_t n, OutputArray _dst, Stream& stream)
{
CV_Assert( src != 0 );
CV_Assert( n > 0 && n <= 4 );
const int depth = src[0].depth();
const Size size = src[0].size();
for (size_t i = 0; i < n; ++i)
{
CV_Assert( src[i].size() == size );
CV_Assert( src[i].depth() == depth );
CV_Assert( src[i].channels() == 1 );
}
if (depth == CV_64F)
{
if (!deviceSupports(NATIVE_DOUBLE))
CV_Error(cv::Error::StsUnsupportedFormat, "The device doesn't support double");
}
if (n == 1)
{
src[0].copyTo(_dst, stream);
}
else
{
_dst.create(size, CV_MAKE_TYPE(depth, (int)n));
GpuMat dst = _dst.getGpuMat();
PtrStepSzb src_as_devmem[4];
for(size_t i = 0; i < n; ++i)
src_as_devmem[i] = src[i];
PtrStepSzb dst_as_devmem(dst);
cv::cuda::device::split_merge::merge(src_as_devmem, dst_as_devmem, (int)n, CV_ELEM_SIZE(depth), StreamAccessor::getStream(stream));
}
}
void split_caller(const GpuMat& src, GpuMat* dst, Stream& stream)
{
CV_Assert( dst != 0 );
const int depth = src.depth();
const int num_channels = src.channels();
CV_Assert( num_channels <= 4 );
if (depth == CV_64F)
{
if (!deviceSupports(NATIVE_DOUBLE))
CV_Error(cv::Error::StsUnsupportedFormat, "The device doesn't support double");
}
if (num_channels == 1)
{
src.copyTo(dst[0], stream);
return;
}
for (int i = 0; i < num_channels; ++i)
dst[i].create(src.size(), depth);
PtrStepSzb dst_as_devmem[4];
for (int i = 0; i < num_channels; ++i)
dst_as_devmem[i] = dst[i];
PtrStepSzb src_as_devmem(src);
cv::cuda::device::split_merge::split(src_as_devmem, dst_as_devmem, num_channels, src.elemSize1(), StreamAccessor::getStream(stream));
}
}
void cv::cuda::merge(const GpuMat* src, size_t n, OutputArray dst, Stream& stream)
{
merge_caller(src, n, dst, stream);
}
void cv::cuda::merge(const std::vector<GpuMat>& src, OutputArray dst, Stream& stream)
{
merge_caller(&src[0], src.size(), dst, stream);
}
void cv::cuda::split(InputArray _src, GpuMat* dst, Stream& stream)
{
GpuMat src = _src.getGpuMat();
split_caller(src, dst, stream);
}
void cv::cuda::split(InputArray _src, std::vector<GpuMat>& dst, Stream& stream)
{
GpuMat src = _src.getGpuMat();
dst.resize(src.channels());
if(src.channels() > 0)
split_caller(src, &dst[0], stream);
}
////////////////////////////////////////////////////////////////////////
// transpose

@ -40,472 +40,209 @@
//
//M*/
#if !defined CUDA_DISABLER
#include "opencv2/opencv_modules.hpp"
#include "opencv2/core/cuda/common.hpp"
#ifndef HAVE_OPENCV_CUDEV
namespace cv { namespace cuda { namespace device
{
namespace split_merge
{
template <typename T, size_t elem_size = sizeof(T)>
struct TypeTraits
{
typedef T type;
typedef T type2;
typedef T type3;
typedef T type4;
};
#error "opencv_cudev is required"
template <typename T>
struct TypeTraits<T, 1>
{
typedef char type;
typedef char2 type2;
typedef char3 type3;
typedef char4 type4;
};
#else
template <typename T>
struct TypeTraits<T, 2>
{
typedef short type;
typedef short2 type2;
typedef short3 type3;
typedef short4 type4;
};
template <typename T>
struct TypeTraits<T, 4>
{
typedef int type;
typedef int2 type2;
typedef int3 type3;
typedef int4 type4;
};
#include "opencv2/cudaarithm.hpp"
#include "opencv2/cudev.hpp"
template <typename T>
struct TypeTraits<T, 8>
{
typedef double type;
typedef double2 type2;
//typedef double3 type3;
//typedef double4 type3;
};
using namespace cv::cudev;
typedef void (*MergeFunction)(const PtrStepSzb* src, PtrStepSzb& dst, const cudaStream_t& stream);
typedef void (*SplitFunction)(const PtrStepSzb& src, PtrStepSzb* dst, const cudaStream_t& stream);
////////////////////////////////////////////////////////////////////////
/// merge
//------------------------------------------------------------
// Merge
namespace
{
template <int cn, typename T> struct MergeFunc;
template <typename T>
__global__ void mergeC2_(const uchar* src0, size_t src0_step,
const uchar* src1, size_t src1_step,
int rows, int cols, uchar* dst, size_t dst_step)
template <typename T> struct MergeFunc<2, T>
{
static void call(const GpuMat* src, GpuMat& dst, Stream& stream)
{
typedef typename TypeTraits<T>::type2 dst_type;
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const T* src0_y = (const T*)(src0 + y * src0_step);
const T* src1_y = (const T*)(src1 + y * src1_step);
dst_type* dst_y = (dst_type*)(dst + y * dst_step);
if (x < cols && y < rows)
{
dst_type dst_elem;
dst_elem.x = src0_y[x];
dst_elem.y = src1_y[x];
dst_y[x] = dst_elem;
}
gridMerge(zipPtr(globPtr<T>(src[0]), globPtr<T>(src[1])),
globPtr<typename MakeVec<T, 2>::type>(dst),
stream);
}
};
template <typename T>
__global__ void mergeC3_(const uchar* src0, size_t src0_step,
const uchar* src1, size_t src1_step,
const uchar* src2, size_t src2_step,
int rows, int cols, uchar* dst, size_t dst_step)
{
typedef typename TypeTraits<T>::type3 dst_type;
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const T* src0_y = (const T*)(src0 + y * src0_step);
const T* src1_y = (const T*)(src1 + y * src1_step);
const T* src2_y = (const T*)(src2 + y * src2_step);
dst_type* dst_y = (dst_type*)(dst + y * dst_step);
if (x < cols && y < rows)
{
dst_type dst_elem;
dst_elem.x = src0_y[x];
dst_elem.y = src1_y[x];
dst_elem.z = src2_y[x];
dst_y[x] = dst_elem;
}
}
template <>
__global__ void mergeC3_<double>(const uchar* src0, size_t src0_step,
const uchar* src1, size_t src1_step,
const uchar* src2, size_t src2_step,
int rows, int cols, uchar* dst, size_t dst_step)
template <typename T> struct MergeFunc<3, T>
{
static void call(const GpuMat* src, GpuMat& dst, Stream& stream)
{
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const double* src0_y = (const double*)(src0 + y * src0_step);
const double* src1_y = (const double*)(src1 + y * src1_step);
const double* src2_y = (const double*)(src2 + y * src2_step);
double* dst_y = (double*)(dst + y * dst_step);
if (x < cols && y < rows)
{
dst_y[3 * x] = src0_y[x];
dst_y[3 * x + 1] = src1_y[x];
dst_y[3 * x + 2] = src2_y[x];
}
gridMerge(zipPtr(globPtr<T>(src[0]), globPtr<T>(src[1]), globPtr<T>(src[2])),
globPtr<typename MakeVec<T, 3>::type>(dst),
stream);
}
};
template <typename T>
__global__ void mergeC4_(const uchar* src0, size_t src0_step,
const uchar* src1, size_t src1_step,
const uchar* src2, size_t src2_step,
const uchar* src3, size_t src3_step,
int rows, int cols, uchar* dst, size_t dst_step)
template <typename T> struct MergeFunc<4, T>
{
static void call(const GpuMat* src, GpuMat& dst, Stream& stream)
{
typedef typename TypeTraits<T>::type4 dst_type;
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const T* src0_y = (const T*)(src0 + y * src0_step);
const T* src1_y = (const T*)(src1 + y * src1_step);
const T* src2_y = (const T*)(src2 + y * src2_step);
const T* src3_y = (const T*)(src3 + y * src3_step);
dst_type* dst_y = (dst_type*)(dst + y * dst_step);
if (x < cols && y < rows)
{
dst_type dst_elem;
dst_elem.x = src0_y[x];
dst_elem.y = src1_y[x];
dst_elem.z = src2_y[x];
dst_elem.w = src3_y[x];
dst_y[x] = dst_elem;
}
gridMerge(zipPtr(globPtr<T>(src[0]), globPtr<T>(src[1]), globPtr<T>(src[2]), globPtr<T>(src[3])),
globPtr<typename MakeVec<T, 4>::type>(dst),
stream);
}
};
void mergeImpl(const GpuMat* src, size_t n, cv::OutputArray _dst, Stream& stream)
{
CV_DbgAssert( src != 0 );
CV_DbgAssert( n > 0 && n <= 4 );
template <>
__global__ void mergeC4_<double>(const uchar* src0, size_t src0_step,
const uchar* src1, size_t src1_step,
const uchar* src2, size_t src2_step,
const uchar* src3, size_t src3_step,
int rows, int cols, uchar* dst, size_t dst_step)
{
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const double* src0_y = (const double*)(src0 + y * src0_step);
const double* src1_y = (const double*)(src1 + y * src1_step);
const double* src2_y = (const double*)(src2 + y * src2_step);
const double* src3_y = (const double*)(src3 + y * src3_step);
double2* dst_y = (double2*)(dst + y * dst_step);
if (x < cols && y < rows)
{
dst_y[2 * x] = make_double2(src0_y[x], src1_y[x]);
dst_y[2 * x + 1] = make_double2(src2_y[x], src3_y[x]);
}
}
const int depth = src[0].depth();
const cv::Size size = src[0].size();
template <typename T>
static void mergeC2_(const PtrStepSzb* src, PtrStepSzb& dst, const cudaStream_t& stream)
#ifdef _DEBUG
for (size_t i = 0; i < n; ++i)
{
dim3 block(32, 8);
dim3 grid(divUp(dst.cols, block.x), divUp(dst.rows, block.y));
mergeC2_<T><<<grid, block, 0, stream>>>(
src[0].data, src[0].step,
src[1].data, src[1].step,
dst.rows, dst.cols, dst.data, dst.step);
cudaSafeCall( cudaGetLastError() );
if (stream == 0)
cudaSafeCall(cudaDeviceSynchronize());
CV_Assert( src[i].size() == size );
CV_Assert( src[i].depth() == depth );
CV_Assert( src[i].channels() == 1 );
}
#endif
template <typename T>
static void mergeC3_(const PtrStepSzb* src, PtrStepSzb& dst, const cudaStream_t& stream)
if (n == 1)
{
dim3 block(32, 8);
dim3 grid(divUp(dst.cols, block.x), divUp(dst.rows, block.y));
mergeC3_<T><<<grid, block, 0, stream>>>(
src[0].data, src[0].step,
src[1].data, src[1].step,
src[2].data, src[2].step,
dst.rows, dst.cols, dst.data, dst.step);
cudaSafeCall( cudaGetLastError() );
if (stream == 0)
cudaSafeCall(cudaDeviceSynchronize());
src[0].copyTo(_dst, stream);
}
template <typename T>
static void mergeC4_(const PtrStepSzb* src, PtrStepSzb& dst, const cudaStream_t& stream)
else
{
dim3 block(32, 8);
dim3 grid(divUp(dst.cols, block.x), divUp(dst.rows, block.y));
mergeC4_<T><<<grid, block, 0, stream>>>(
src[0].data, src[0].step,
src[1].data, src[1].step,
src[2].data, src[2].step,
src[3].data, src[3].step,
dst.rows, dst.cols, dst.data, dst.step);
cudaSafeCall( cudaGetLastError() );
if (stream == 0)
cudaSafeCall(cudaDeviceSynchronize());
}
void merge(const PtrStepSzb* src, PtrStepSzb& dst,
int total_channels, size_t elem_size,
const cudaStream_t& stream)
{
static MergeFunction merge_func_tbl[] =
typedef void (*func_t)(const GpuMat* src, GpuMat& dst, Stream& stream);
static const func_t funcs[3][5] =
{
mergeC2_<char>, mergeC2_<short>, mergeC2_<int>, 0, mergeC2_<double>,
mergeC3_<char>, mergeC3_<short>, mergeC3_<int>, 0, mergeC3_<double>,
mergeC4_<char>, mergeC4_<short>, mergeC4_<int>, 0, mergeC4_<double>,
{MergeFunc<2, uchar>::call, MergeFunc<2, ushort>::call, MergeFunc<2, int>::call, 0, MergeFunc<2, double>::call},
{MergeFunc<3, uchar>::call, MergeFunc<3, ushort>::call, MergeFunc<3, int>::call, 0, MergeFunc<3, double>::call},
{MergeFunc<4, uchar>::call, MergeFunc<4, ushort>::call, MergeFunc<4, int>::call, 0, MergeFunc<4, double>::call}
};
size_t merge_func_id = (total_channels - 2) * 5 + (elem_size >> 1);
MergeFunction merge_func = merge_func_tbl[merge_func_id];
if (merge_func == 0)
CV_Error(cv::Error::StsUnsupportedFormat, "Unsupported channel count or data type");
merge_func(src, dst, stream);
}
const int channels = static_cast<int>(n);
//------------------------------------------------------------
// Split
_dst.create(size, CV_MAKE_TYPE(depth, channels));
GpuMat dst = _dst.getGpuMat();
const func_t func = funcs[channels - 2][CV_ELEM_SIZE(depth) / 2];
template <typename T>
__global__ void splitC2_(const uchar* src, size_t src_step,
int rows, int cols,
uchar* dst0, size_t dst0_step,
uchar* dst1, size_t dst1_step)
{
typedef typename TypeTraits<T>::type2 src_type;
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const src_type* src_y = (const src_type*)(src + y * src_step);
T* dst0_y = (T*)(dst0 + y * dst0_step);
T* dst1_y = (T*)(dst1 + y * dst1_step);
if (func == 0)
CV_Error(cv::Error::StsUnsupportedFormat, "Unsupported channel count or data type");
if (x < cols && y < rows)
{
src_type src_elem = src_y[x];
dst0_y[x] = src_elem.x;
dst1_y[x] = src_elem.y;
}
func(src, dst, stream);
}
}
}
void cv::cuda::merge(const GpuMat* src, size_t n, OutputArray dst, Stream& stream)
{
mergeImpl(src, n, dst, stream);
}
template <typename T>
__global__ void splitC3_(const uchar* src, size_t src_step,
int rows, int cols,
uchar* dst0, size_t dst0_step,
uchar* dst1, size_t dst1_step,
uchar* dst2, size_t dst2_step)
{
typedef typename TypeTraits<T>::type3 src_type;
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const src_type* src_y = (const src_type*)(src + y * src_step);
T* dst0_y = (T*)(dst0 + y * dst0_step);
T* dst1_y = (T*)(dst1 + y * dst1_step);
T* dst2_y = (T*)(dst2 + y * dst2_step);
void cv::cuda::merge(const std::vector<GpuMat>& src, OutputArray dst, Stream& stream)
{
mergeImpl(&src[0], src.size(), dst, stream);
}
if (x < cols && y < rows)
{
src_type src_elem = src_y[x];
dst0_y[x] = src_elem.x;
dst1_y[x] = src_elem.y;
dst2_y[x] = src_elem.z;
}
}
////////////////////////////////////////////////////////////////////////
/// split
namespace
{
template <int cn, typename T> struct SplitFunc;
template <>
__global__ void splitC3_<double>(
const uchar* src, size_t src_step, int rows, int cols,
uchar* dst0, size_t dst0_step,
uchar* dst1, size_t dst1_step,
uchar* dst2, size_t dst2_step)
template <typename T> struct SplitFunc<2, T>
{
static void call(const GpuMat& src, GpuMat* dst, Stream& stream)
{
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const double* src_y = (const double*)(src + y * src_step);
double* dst0_y = (double*)(dst0 + y * dst0_step);
double* dst1_y = (double*)(dst1 + y * dst1_step);
double* dst2_y = (double*)(dst2 + y * dst2_step);
if (x < cols && y < rows)
GlobPtrSz<T> dstarr[2] =
{
dst0_y[x] = src_y[3 * x];
dst1_y[x] = src_y[3 * x + 1];
dst2_y[x] = src_y[3 * x + 2];
}
}
globPtr<T>(dst[0]), globPtr<T>(dst[1])
};
gridSplit(globPtr<typename MakeVec<T, 2>::type>(src), dstarr, stream);
}
};
template <typename T>
__global__ void splitC4_(const uchar* src, size_t src_step, int rows, int cols,
uchar* dst0, size_t dst0_step,
uchar* dst1, size_t dst1_step,
uchar* dst2, size_t dst2_step,
uchar* dst3, size_t dst3_step)
template <typename T> struct SplitFunc<3, T>
{
static void call(const GpuMat& src, GpuMat* dst, Stream& stream)
{
typedef typename TypeTraits<T>::type4 src_type;
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const src_type* src_y = (const src_type*)(src + y * src_step);
T* dst0_y = (T*)(dst0 + y * dst0_step);
T* dst1_y = (T*)(dst1 + y * dst1_step);
T* dst2_y = (T*)(dst2 + y * dst2_step);
T* dst3_y = (T*)(dst3 + y * dst3_step);
if (x < cols && y < rows)
GlobPtrSz<T> dstarr[3] =
{
src_type src_elem = src_y[x];
dst0_y[x] = src_elem.x;
dst1_y[x] = src_elem.y;
dst2_y[x] = src_elem.z;
dst3_y[x] = src_elem.w;
}
}
globPtr<T>(dst[0]), globPtr<T>(dst[1]), globPtr<T>(dst[2])
};
gridSplit(globPtr<typename MakeVec<T, 3>::type>(src), dstarr, stream);
}
};
template <>
__global__ void splitC4_<double>(
const uchar* src, size_t src_step, int rows, int cols,
uchar* dst0, size_t dst0_step,
uchar* dst1, size_t dst1_step,
uchar* dst2, size_t dst2_step,
uchar* dst3, size_t dst3_step)
template <typename T> struct SplitFunc<4, T>
{
static void call(const GpuMat& src, GpuMat* dst, Stream& stream)
{
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const double2* src_y = (const double2*)(src + y * src_step);
double* dst0_y = (double*)(dst0 + y * dst0_step);
double* dst1_y = (double*)(dst1 + y * dst1_step);
double* dst2_y = (double*)(dst2 + y * dst2_step);
double* dst3_y = (double*)(dst3 + y * dst3_step);
if (x < cols && y < rows)
GlobPtrSz<T> dstarr[4] =
{
double2 src_elem1 = src_y[2 * x];
double2 src_elem2 = src_y[2 * x + 1];
dst0_y[x] = src_elem1.x;
dst1_y[x] = src_elem1.y;
dst2_y[x] = src_elem2.x;
dst3_y[x] = src_elem2.y;
}
globPtr<T>(dst[0]), globPtr<T>(dst[1]), globPtr<T>(dst[2]), globPtr<T>(dst[3])
};
gridSplit(globPtr<typename MakeVec<T, 4>::type>(src), dstarr, stream);
}
};
template <typename T>
static void splitC2_(const PtrStepSzb& src, PtrStepSzb* dst, const cudaStream_t& stream)
void splitImpl(const GpuMat& src, GpuMat* dst, Stream& stream)
{
typedef void (*func_t)(const GpuMat& src, GpuMat* dst, Stream& stream);
static const func_t funcs[3][5] =
{
dim3 block(32, 8);
dim3 grid(divUp(src.cols, block.x), divUp(src.rows, block.y));
splitC2_<T><<<grid, block, 0, stream>>>(
src.data, src.step, src.rows, src.cols,
dst[0].data, dst[0].step,
dst[1].data, dst[1].step);
cudaSafeCall( cudaGetLastError() );
if (stream == 0)
cudaSafeCall(cudaDeviceSynchronize());
}
{SplitFunc<2, uchar>::call, SplitFunc<2, ushort>::call, SplitFunc<2, int>::call, 0, SplitFunc<2, double>::call},
{SplitFunc<3, uchar>::call, SplitFunc<3, ushort>::call, SplitFunc<3, int>::call, 0, SplitFunc<3, double>::call},
{SplitFunc<4, uchar>::call, SplitFunc<4, ushort>::call, SplitFunc<4, int>::call, 0, SplitFunc<4, double>::call}
};
CV_DbgAssert( dst != 0 );
template <typename T>
static void splitC3_(const PtrStepSzb& src, PtrStepSzb* dst, const cudaStream_t& stream)
{
dim3 block(32, 8);
dim3 grid(divUp(src.cols, block.x), divUp(src.rows, block.y));
splitC3_<T><<<grid, block, 0, stream>>>(
src.data, src.step, src.rows, src.cols,
dst[0].data, dst[0].step,
dst[1].data, dst[1].step,
dst[2].data, dst[2].step);
cudaSafeCall( cudaGetLastError() );
if (stream == 0)
cudaSafeCall(cudaDeviceSynchronize());
}
const int depth = src.depth();
const int channels = src.channels();
CV_DbgAssert( channels <= 4 );
if (channels == 0)
return;
template <typename T>
static void splitC4_(const PtrStepSzb& src, PtrStepSzb* dst, const cudaStream_t& stream)
if (channels == 1)
{
dim3 block(32, 8);
dim3 grid(divUp(src.cols, block.x), divUp(src.rows, block.y));
splitC4_<T><<<grid, block, 0, stream>>>(
src.data, src.step, src.rows, src.cols,
dst[0].data, dst[0].step,
dst[1].data, dst[1].step,
dst[2].data, dst[2].step,
dst[3].data, dst[3].step);
cudaSafeCall( cudaGetLastError() );
if (stream == 0)
cudaSafeCall(cudaDeviceSynchronize());
src.copyTo(dst[0], stream);
return;
}
for (int i = 0; i < channels; ++i)
dst[i].create(src.size(), depth);
void split(const PtrStepSzb& src, PtrStepSzb* dst, int num_channels, size_t elem_size1, const cudaStream_t& stream)
{
static SplitFunction split_func_tbl[] =
{
splitC2_<char>, splitC2_<short>, splitC2_<int>, 0, splitC2_<double>,
splitC3_<char>, splitC3_<short>, splitC3_<int>, 0, splitC3_<double>,
splitC4_<char>, splitC4_<short>, splitC4_<int>, 0, splitC4_<double>,
};
const func_t func = funcs[channels - 2][CV_ELEM_SIZE(depth) / 2];
size_t split_func_id = (num_channels - 2) * 5 + (elem_size1 >> 1);
SplitFunction split_func = split_func_tbl[split_func_id];
if (func == 0)
CV_Error(cv::Error::StsUnsupportedFormat, "Unsupported channel count or data type");
if (split_func == 0)
CV_Error(cv::Error::StsUnsupportedFormat, "Unsupported channel count or data type");
func(src, dst, stream);
}
}
split_func(src, dst, stream);
}
} // namespace split_merge
}}} // namespace cv { namespace cuda { namespace cudev
void cv::cuda::split(InputArray _src, GpuMat* dst, Stream& stream)
{
GpuMat src = _src.getGpuMat();
splitImpl(src, dst, stream);
}
void cv::cuda::split(InputArray _src, std::vector<GpuMat>& dst, Stream& stream)
{
GpuMat src = _src.getGpuMat();
dst.resize(src.channels());
if (src.channels() > 0)
splitImpl(src, &dst[0], stream);
}
#endif /* CUDA_DISABLER */
#endif

@ -51,6 +51,7 @@
#include "../util/vec_traits.hpp"
#include "../ptr2d/traits.hpp"
#include "../ptr2d/gpumat.hpp"
#include "../ptr2d/glob.hpp"
#include "../ptr2d/mask.hpp"
#include "detail/split_merge.hpp"
@ -75,6 +76,24 @@ __host__ void gridMerge_(const SrcPtrTuple& src, GpuMat_<DstType>& dst, const Ma
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtrTuple, typename DstType, class MaskPtr>
__host__ void gridMerge_(const SrcPtrTuple& src, const GlobPtrSz<DstType>& dst, const MaskPtr& mask, Stream& stream = Stream::Null())
{
CV_StaticAssert( VecTraits<DstType>::cn == tuple_size<SrcPtrTuple>::value, "" );
const int rows = getRows(src);
const int cols = getCols(src);
CV_Assert( getRows(dst) == rows && getCols(dst) == cols );
CV_Assert( getRows(mask) == rows && getCols(mask) == cols );
grid_split_merge_detail::MergeImpl<VecTraits<DstType>::cn, Policy>::merge(shrinkPtr(src),
shrinkPtr(dst),
shrinkPtr(mask),
rows, cols,
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtrTuple, typename DstType>
__host__ void gridMerge_(const SrcPtrTuple& src, GpuMat_<DstType>& dst, Stream& stream = Stream::Null())
{
@ -92,6 +111,23 @@ __host__ void gridMerge_(const SrcPtrTuple& src, GpuMat_<DstType>& dst, Stream&
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtrTuple, typename DstType>
__host__ void gridMerge_(const SrcPtrTuple& src, const GlobPtrSz<DstType>& dst, Stream& stream = Stream::Null())
{
CV_StaticAssert( VecTraits<DstType>::cn == tuple_size<SrcPtrTuple>::value, "" );
const int rows = getRows(src);
const int cols = getCols(src);
CV_Assert( getRows(dst) == rows && getCols(dst) == cols );
grid_split_merge_detail::MergeImpl<VecTraits<DstType>::cn, Policy>::merge(shrinkPtr(src),
shrinkPtr(dst),
WithOutMask(),
rows, cols,
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtr, typename DstType, class MaskPtr>
__host__ void gridSplit_(const SrcPtr& src, const tuple< GpuMat_<DstType>&, GpuMat_<DstType>& >& dst, const MaskPtr& mask, Stream& stream = Stream::Null())
{
@ -132,6 +168,25 @@ __host__ void gridSplit_(const SrcPtr& src, GpuMat_<DstType> (&dst)[2], const Ma
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtr, typename DstType, class MaskPtr>
__host__ void gridSplit_(const SrcPtr& src, GlobPtrSz<DstType> (&dst)[2], const MaskPtr& mask, Stream& stream = Stream::Null())
{
CV_StaticAssert( VecTraits<typename PtrTraits<SrcPtr>::value_type>::cn == 2, "" );
const int rows = getRows(src);
const int cols = getCols(src);
CV_Assert( getRows(dst[0]) == rows && getCols(dst[0]) == cols );
CV_Assert( getRows(dst[1]) == rows && getCols(dst[1]) == cols );
CV_Assert( getRows(mask) == rows && getCols(mask) == cols );
grid_split_merge_detail::split<Policy>(shrinkPtr(src),
shrinkPtr(dst[0]), shrinkPtr(dst[1]),
shrinkPtr(mask),
rows, cols,
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtr, typename DstType>
__host__ void gridSplit_(const SrcPtr& src, const tuple< GpuMat_<DstType>&, GpuMat_<DstType>& >& dst, Stream& stream = Stream::Null())
{
@ -168,6 +223,24 @@ __host__ void gridSplit_(const SrcPtr& src, GpuMat_<DstType> (&dst)[2], Stream&
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtr, typename DstType>
__host__ void gridSplit_(const SrcPtr& src, GlobPtrSz<DstType> (&dst)[2], Stream& stream = Stream::Null())
{
CV_StaticAssert( VecTraits<typename PtrTraits<SrcPtr>::value_type>::cn == 2, "" );
const int rows = getRows(src);
const int cols = getCols(src);
CV_Assert( getRows(dst[0]) == rows && getCols(dst[0]) == cols );
CV_Assert( getRows(dst[1]) == rows && getCols(dst[1]) == cols );
grid_split_merge_detail::split<Policy>(shrinkPtr(src),
shrinkPtr(dst[0]), shrinkPtr(dst[1]),
WithOutMask(),
rows, cols,
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtr, typename DstType, class MaskPtr>
__host__ void gridSplit_(const SrcPtr& src, const tuple< GpuMat_<DstType>&, GpuMat_<DstType>&, GpuMat_<DstType>& >& dst, const MaskPtr& mask, Stream& stream = Stream::Null())
{
@ -210,6 +283,26 @@ __host__ void gridSplit_(const SrcPtr& src, GpuMat_<DstType> (&dst)[3], const Ma
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtr, typename DstType, class MaskPtr>
__host__ void gridSplit_(const SrcPtr& src, GlobPtrSz<DstType> (&dst)[3], const MaskPtr& mask, Stream& stream = Stream::Null())
{
CV_StaticAssert( VecTraits<typename PtrTraits<SrcPtr>::value_type>::cn == 3, "" );
const int rows = getRows(src);
const int cols = getCols(src);
CV_Assert( getRows(dst[0]) == rows && getCols(dst[0]) == cols );
CV_Assert( getRows(dst[1]) == rows && getCols(dst[1]) == cols );
CV_Assert( getRows(dst[2]) == rows && getCols(dst[2]) == cols );
CV_Assert( getRows(mask) == rows && getCols(mask) == cols );
grid_split_merge_detail::split<Policy>(shrinkPtr(src),
shrinkPtr(dst[0]), shrinkPtr(dst[1]), shrinkPtr(dst[2]),
shrinkPtr(mask),
rows, cols,
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtr, typename DstType>
__host__ void gridSplit_(const SrcPtr& src, const tuple< GpuMat_<DstType>&, GpuMat_<DstType>&, GpuMat_<DstType>& >& dst, Stream& stream = Stream::Null())
{
@ -248,6 +341,25 @@ __host__ void gridSplit_(const SrcPtr& src, GpuMat_<DstType> (&dst)[3], Stream&
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtr, typename DstType>
__host__ void gridSplit_(const SrcPtr& src, GlobPtrSz<DstType> (&dst)[3], Stream& stream = Stream::Null())
{
CV_StaticAssert( VecTraits<typename PtrTraits<SrcPtr>::value_type>::cn == 3, "" );
const int rows = getRows(src);
const int cols = getCols(src);
CV_Assert( getRows(dst[0]) == rows && getCols(dst[0]) == cols );
CV_Assert( getRows(dst[1]) == rows && getCols(dst[1]) == cols );
CV_Assert( getRows(dst[2]) == rows && getCols(dst[2]) == cols );
grid_split_merge_detail::split<Policy>(shrinkPtr(src),
shrinkPtr(dst[0]), shrinkPtr(dst[1]), shrinkPtr(dst[2]),
WithOutMask(),
rows, cols,
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtr, typename DstType, class MaskPtr>
__host__ void gridSplit_(const SrcPtr& src, const tuple< GpuMat_<DstType>&, GpuMat_<DstType>&, GpuMat_<DstType>&, GpuMat_<DstType>& >& dst, const MaskPtr& mask, Stream& stream = Stream::Null())
{
@ -283,10 +395,31 @@ __host__ void gridSplit_(const SrcPtr& src, GpuMat_<DstType> (&dst)[4], const Ma
dst[0].create(rows, cols);
dst[1].create(rows, cols);
dst[2].create(rows, cols);
dst[4].create(rows, cols);
dst[3].create(rows, cols);
grid_split_merge_detail::split<Policy>(shrinkPtr(src),
shrinkPtr(dst[0]), shrinkPtr(dst[1]), shrinkPtr(dst[2]), shrinkPtr(dst[4]),
shrinkPtr(dst[0]), shrinkPtr(dst[1]), shrinkPtr(dst[2]), shrinkPtr(dst[3]),
shrinkPtr(mask),
rows, cols,
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtr, typename DstType, class MaskPtr>
__host__ void gridSplit_(const SrcPtr& src, GlobPtrSz<DstType> (&dst)[4], const MaskPtr& mask, Stream& stream = Stream::Null())
{
CV_StaticAssert( VecTraits<typename PtrTraits<SrcPtr>::value_type>::cn == 4, "" );
const int rows = getRows(src);
const int cols = getCols(src);
CV_Assert( getRows(dst[0]) == rows && getCols(dst[0]) == cols );
CV_Assert( getRows(dst[1]) == rows && getCols(dst[1]) == cols );
CV_Assert( getRows(dst[2]) == rows && getCols(dst[2]) == cols );
CV_Assert( getRows(dst[3]) == rows && getCols(dst[3]) == cols );
CV_Assert( getRows(mask) == rows && getCols(mask) == cols );
grid_split_merge_detail::split<Policy>(shrinkPtr(src),
shrinkPtr(dst[0]), shrinkPtr(dst[1]), shrinkPtr(dst[2]), shrinkPtr(dst[3]),
shrinkPtr(mask),
rows, cols,
StreamAccessor::getStream(stream));
@ -323,10 +456,30 @@ __host__ void gridSplit_(const SrcPtr& src, GpuMat_<DstType> (&dst)[4], Stream&
dst[0].create(rows, cols);
dst[1].create(rows, cols);
dst[2].create(rows, cols);
dst[4].create(rows, cols);
dst[3].create(rows, cols);
grid_split_merge_detail::split<Policy>(shrinkPtr(src),
shrinkPtr(dst[0]), shrinkPtr(dst[1]), shrinkPtr(dst[2]), shrinkPtr(dst[3]),
WithOutMask(),
rows, cols,
StreamAccessor::getStream(stream));
}
template <class Policy, class SrcPtr, typename DstType>
__host__ void gridSplit_(const SrcPtr& src, GlobPtrSz<DstType> (&dst)[4], Stream& stream = Stream::Null())
{
CV_StaticAssert( VecTraits<typename PtrTraits<SrcPtr>::value_type>::cn == 4, "" );
const int rows = getRows(src);
const int cols = getCols(src);
CV_Assert( getRows(dst[0]) == rows && getCols(dst[0]) == cols );
CV_Assert( getRows(dst[1]) == rows && getCols(dst[1]) == cols );
CV_Assert( getRows(dst[2]) == rows && getCols(dst[2]) == cols );
CV_Assert( getRows(dst[3]) == rows && getCols(dst[3]) == cols );
grid_split_merge_detail::split<Policy>(shrinkPtr(src),
shrinkPtr(dst[0]), shrinkPtr(dst[1]), shrinkPtr(dst[2]), shrinkPtr(dst[4]),
shrinkPtr(dst[0]), shrinkPtr(dst[1]), shrinkPtr(dst[2]), shrinkPtr(dst[3]),
WithOutMask(),
rows, cols,
StreamAccessor::getStream(stream));
@ -348,12 +501,24 @@ __host__ void gridMerge(const SrcPtrTuple& src, GpuMat_<DstType>& dst, const Mas
gridMerge_<DefaultSplitMergePolicy>(src, dst, mask, stream);
}
template <class SrcPtrTuple, typename DstType, class MaskPtr>
__host__ void gridMerge(const SrcPtrTuple& src, const GlobPtrSz<DstType>& dst, const MaskPtr& mask, Stream& stream = Stream::Null())
{
gridMerge_<DefaultSplitMergePolicy>(src, dst, mask, stream);
}
template <class SrcPtrTuple, typename DstType>
__host__ void gridMerge(const SrcPtrTuple& src, GpuMat_<DstType>& dst, Stream& stream = Stream::Null())
{
gridMerge_<DefaultSplitMergePolicy>(src, dst, stream);
}
template <class SrcPtrTuple, typename DstType>
__host__ void gridMerge(const SrcPtrTuple& src, const GlobPtrSz<DstType>& dst, Stream& stream = Stream::Null())
{
gridMerge_<DefaultSplitMergePolicy>(src, dst, stream);
}
template <class SrcPtr, typename DstType, class MaskPtr>
__host__ void gridSplit(const SrcPtr& src, const tuple< GpuMat_<DstType>&, GpuMat_<DstType>& >& dst, const MaskPtr& mask, Stream& stream = Stream::Null())
{
@ -396,12 +561,24 @@ __host__ void gridSplit(const SrcPtr& src, GpuMat_<DstType> (&dst)[COUNT], const
gridSplit_<DefaultSplitMergePolicy>(src, dst, mask, stream);
}
template <class SrcPtr, typename DstType, int COUNT, class MaskPtr>
__host__ void gridSplit(const SrcPtr& src, GlobPtrSz<DstType> (&dst)[COUNT], const MaskPtr& mask, Stream& stream = Stream::Null())
{
gridSplit_<DefaultSplitMergePolicy>(src, dst, mask, stream);
}
template <class SrcPtr, typename DstType, int COUNT>
__host__ void gridSplit(const SrcPtr& src, GpuMat_<DstType> (&dst)[COUNT], Stream& stream = Stream::Null())
{
gridSplit_<DefaultSplitMergePolicy>(src, dst, stream);
}
template <class SrcPtr, typename DstType, int COUNT>
__host__ void gridSplit(const SrcPtr& src, GlobPtrSz<DstType> (&dst)[COUNT], Stream& stream = Stream::Null())
{
gridSplit_<DefaultSplitMergePolicy>(src, dst, stream);
}
}}
#endif

Loading…
Cancel
Save