Add CV_16UC1 support for cuda::CLAHE

Due to size limit of shared memory, histogram is built on
the global memory for CV_16UC1 case.

The amount of memory needed for building histogram is:

    65536 * 4byte = 256KB

and shared memory limit is 48KB typically.

Added test cases for CV_16UC1 and various clip limits.
Added perf tests for CV_16UC1 on both CPU and CUDA code.

There was also a bug in CV_8UC1 case when redistributing
"residual" clipped pixels. Adding the test case where clip
limit is 5.0 exposes this bug.
pull/13764/head
Namgoo Lee 6 years ago
parent a63f66c90e
commit fb8e652c3f
  1. 2
      modules/core/include/opencv2/core/cuda_types.hpp
  2. 8
      modules/cudaimgproc/perf/perf_histogram.cpp
  3. 256
      modules/cudaimgproc/src/cuda/clahe.cu
  4. 29
      modules/cudaimgproc/src/histogram.cpp
  5. 13
      modules/cudaimgproc/test/test_histogram.cpp
  6. 8
      modules/imgproc/perf/perf_histogram.cpp

@ -127,10 +127,12 @@ namespace cv
};
typedef PtrStepSz<unsigned char> PtrStepSzb;
typedef PtrStepSz<unsigned short> PtrStepSzus;
typedef PtrStepSz<float> PtrStepSzf;
typedef PtrStepSz<int> PtrStepSzi;
typedef PtrStep<unsigned char> PtrStepb;
typedef PtrStep<unsigned short> PtrStepus;
typedef PtrStep<float> PtrStepf;
typedef PtrStep<int> PtrStepi;

@ -183,16 +183,18 @@ PERF_TEST_P(Sz, EqualizeHist,
//////////////////////////////////////////////////////////////////////
// CLAHE
DEF_PARAM_TEST(Sz_ClipLimit, cv::Size, double);
DEF_PARAM_TEST(Sz_ClipLimit, cv::Size, double, MatType);
PERF_TEST_P(Sz_ClipLimit, CLAHE,
Combine(CUDA_TYPICAL_MAT_SIZES,
Values(0.0, 40.0)))
Values(0.0, 40.0),
Values(MatType(CV_8UC1), MatType(CV_16UC1))))
{
const cv::Size size = GET_PARAM(0);
const double clipLimit = GET_PARAM(1);
const int type = GET_PARAM(2);
cv::Mat src(size, CV_8UC1);
cv::Mat src(size, type);
declare.in(src, WARMUP_RNG);
if (PERF_RUN_CUDA())

@ -48,11 +48,11 @@ using namespace cv::cudev;
namespace clahe
{
__global__ void calcLutKernel(const PtrStepb src, PtrStepb lut,
const int2 tileSize, const int tilesX,
const int clipLimit, const float lutScale)
__global__ void calcLutKernel_8U(const PtrStepb src, PtrStepb lut,
const int2 tileSize, const int tilesX,
const int clipLimit, const float lutScale)
{
__shared__ int smem[512];
__shared__ int smem[256];
const int tx = blockIdx.x;
const int ty = blockIdx.y;
@ -95,18 +95,28 @@ namespace clahe
// broadcast evaluated value
__shared__ int totalClipped;
__shared__ int redistBatch;
__shared__ int residual;
__shared__ int rStep;
if (tid == 0)
{
totalClipped = clipped;
redistBatch = totalClipped / 256;
residual = totalClipped - redistBatch * 256;
rStep = 1;
if (residual != 0)
rStep = 256 / residual;
}
__syncthreads();
// redistribute clipped samples evenly
int redistBatch = totalClipped / 256;
tHistVal += redistBatch;
int residual = totalClipped - redistBatch * 256;
if (tid < residual)
if (residual && tid % rStep == 0 && tid / rStep < residual)
++tHistVal;
}
@ -115,12 +125,212 @@ namespace clahe
lut(ty * tilesX + tx, tid) = saturate_cast<uchar>(__float2int_rn(lutScale * lutVal));
}
void calcLut(PtrStepSzb src, PtrStepb lut, int tilesX, int tilesY, int2 tileSize, int clipLimit, float lutScale, cudaStream_t stream)
__global__ void calcLutKernel_16U(const PtrStepus src, PtrStepus lut,
const int2 tileSize, const int tilesX,
const int clipLimit, const float lutScale,
PtrStepSzi hist)
{
#define histSize 65536
#define blockSize 256
__shared__ int smem[blockSize];
const int tx = blockIdx.x;
const int ty = blockIdx.y;
const unsigned int tid = threadIdx.y * blockDim.x + threadIdx.x;
const int histRow = ty * tilesX + tx;
// build histogram
for (int i = tid; i < histSize; i += blockSize)
hist(histRow, i) = 0;
__syncthreads();
for (int i = threadIdx.y; i < tileSize.y; i += blockDim.y)
{
const ushort* srcPtr = src.ptr(ty * tileSize.y + i) + tx * tileSize.x;
for (int j = threadIdx.x; j < tileSize.x; j += blockDim.x)
{
const int data = srcPtr[j];
::atomicAdd(&hist(histRow, data), 1);
}
}
__syncthreads();
if (clipLimit > 0)
{
// clip histogram bar &&
// find number of overall clipped samples
__shared__ int partialSum[blockSize];
for (int i = tid; i < histSize; i += blockSize)
{
int histVal = hist(histRow, i);
int clipped = 0;
if (histVal > clipLimit)
{
clipped = histVal - clipLimit;
hist(histRow, i) = clipLimit;
}
// Following code block is in effect equivalent to:
//
// blockReduce<blockSize>(smem, clipped, tid, plus<int>());
//
{
for (int j = 16; j >= 1; j /= 2)
{
#if __CUDACC_VER_MAJOR__ >= 9
int val = __shfl_down_sync(0xFFFFFFFFU, clipped, j);
#else
int val = __shfl_down(clipped, j);
#endif
clipped += val;
}
if (tid % 32 == 0)
smem[tid / 32] = clipped;
__syncthreads();
if (tid < 8)
{
clipped = smem[tid];
for (int j = 4; j >= 1; j /= 2)
{
#if __CUDACC_VER_MAJOR__ >= 9
int val = __shfl_down_sync(0x000000FFU, clipped, j);
#else
int val = __shfl_down(clipped, j);
#endif
clipped += val;
}
}
}
// end of code block
if (tid == 0)
partialSum[i / blockSize] = clipped;
__syncthreads();
}
int partialSum_ = partialSum[tid];
// Following code block is in effect equivalent to:
//
// blockReduce<blockSize>(smem, partialSum_, tid, plus<int>());
//
{
for (int j = 16; j >= 1; j /= 2)
{
#if __CUDACC_VER_MAJOR__ >= 9
int val = __shfl_down_sync(0xFFFFFFFFU, partialSum_, j);
#else
int val = __shfl_down(partialSum_, j);
#endif
partialSum_ += val;
}
if (tid % 32 == 0)
smem[tid / 32] = partialSum_;
__syncthreads();
if (tid < 8)
{
partialSum_ = smem[tid];
for (int j = 4; j >= 1; j /= 2)
{
#if __CUDACC_VER_MAJOR__ >= 9
int val = __shfl_down_sync(0x000000FFU, partialSum_, j);
#else
int val = __shfl_down(partialSum_, j);
#endif
partialSum_ += val;
}
}
}
// end of code block
// broadcast evaluated value &&
// redistribute clipped samples evenly
__shared__ int totalClipped;
__shared__ int redistBatch;
__shared__ int residual;
__shared__ int rStep;
if (tid == 0)
{
totalClipped = partialSum_;
redistBatch = totalClipped / histSize;
residual = totalClipped - redistBatch * histSize;
rStep = 1;
if (residual != 0)
rStep = histSize / residual;
}
__syncthreads();
for (int i = tid; i < histSize; i += blockSize)
{
int histVal = hist(histRow, i);
int equalized = histVal + redistBatch;
if (residual && i % rStep == 0 && i / rStep < residual)
++equalized;
hist(histRow, i) = equalized;
}
}
__shared__ int partialScan[blockSize];
for (int i = tid; i < histSize; i += blockSize)
{
int equalized = hist(histRow, i);
equalized = blockScanInclusive<blockSize>(equalized, smem, tid);
if (tid == blockSize - 1)
partialScan[i / blockSize] = equalized;
hist(histRow, i) = equalized;
}
__syncthreads();
int partialScan_ = partialScan[tid];
partialScan[tid] = blockScanExclusive<blockSize>(partialScan_, smem, tid);
__syncthreads();
for (int i = tid; i < histSize; i += blockSize)
{
const int lutVal = hist(histRow, i) + partialScan[i / blockSize];
lut(histRow, i) = saturate_cast<ushort>(__float2int_rn(lutScale * lutVal));
}
#undef histSize
#undef blockSize
}
void calcLut_8U(PtrStepSzb src, PtrStepb lut, int tilesX, int tilesY, int2 tileSize, int clipLimit, float lutScale, cudaStream_t stream)
{
const dim3 block(32, 8);
const dim3 grid(tilesX, tilesY);
calcLutKernel<<<grid, block, 0, stream>>>(src, lut, tileSize, tilesX, clipLimit, lutScale);
calcLutKernel_8U<<<grid, block, 0, stream>>>(src, lut, tileSize, tilesX, clipLimit, lutScale);
CV_CUDEV_SAFE_CALL( cudaGetLastError() );
@ -128,7 +338,21 @@ namespace clahe
CV_CUDEV_SAFE_CALL( cudaDeviceSynchronize() );
}
__global__ void transformKernel(const PtrStepSzb src, PtrStepb dst, const PtrStepb lut, const int2 tileSize, const int tilesX, const int tilesY)
void calcLut_16U(PtrStepSzus src, PtrStepus lut, int tilesX, int tilesY, int2 tileSize, int clipLimit, float lutScale, PtrStepSzi hist, cudaStream_t stream)
{
const dim3 block(32, 8);
const dim3 grid(tilesX, tilesY);
calcLutKernel_16U<<<grid, block, 0, stream>>>(src, lut, tileSize, tilesX, clipLimit, lutScale, hist);
CV_CUDEV_SAFE_CALL( cudaGetLastError() );
if (stream == 0)
CV_CUDEV_SAFE_CALL( cudaDeviceSynchronize() );
}
template <typename T>
__global__ void transformKernel(const PtrStepSz<T> src, PtrStep<T> dst, const PtrStep<T> lut, const int2 tileSize, const int tilesX, const int tilesY)
{
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
@ -159,22 +383,26 @@ namespace clahe
res += lut(ty2 * tilesX + tx1, srcVal) * ((1.0f - xa) * (ya));
res += lut(ty2 * tilesX + tx2, srcVal) * ((xa) * (ya));
dst(y, x) = saturate_cast<uchar>(res);
dst(y, x) = saturate_cast<T>(res);
}
void transform(PtrStepSzb src, PtrStepSzb dst, PtrStepb lut, int tilesX, int tilesY, int2 tileSize, cudaStream_t stream)
template <typename T>
void transform(PtrStepSz<T> src, PtrStepSz<T> dst, PtrStep<T> lut, int tilesX, int tilesY, int2 tileSize, cudaStream_t stream)
{
const dim3 block(32, 8);
const dim3 grid(divUp(src.cols, block.x), divUp(src.rows, block.y));
CV_CUDEV_SAFE_CALL( cudaFuncSetCacheConfig(transformKernel, cudaFuncCachePreferL1) );
CV_CUDEV_SAFE_CALL( cudaFuncSetCacheConfig(transformKernel<T>, cudaFuncCachePreferL1) );
transformKernel<<<grid, block, 0, stream>>>(src, dst, lut, tileSize, tilesX, tilesY);
transformKernel<T><<<grid, block, 0, stream>>>(src, dst, lut, tileSize, tilesX, tilesY);
CV_CUDEV_SAFE_CALL( cudaGetLastError() );
if (stream == 0)
CV_CUDEV_SAFE_CALL( cudaDeviceSynchronize() );
}
template void transform<uchar>(PtrStepSz<uchar> src, PtrStepSz<uchar> dst, PtrStep<uchar> lut, int tilesX, int tilesY, int2 tileSize, cudaStream_t stream);
template void transform<ushort>(PtrStepSz<ushort> src, PtrStepSz<ushort> dst, PtrStep<ushort> lut, int tilesX, int tilesY, int2 tileSize, cudaStream_t stream);
}
#endif // CUDA_DISABLER

@ -141,8 +141,9 @@ void cv::cuda::equalizeHist(InputArray _src, OutputArray _dst, Stream& _stream)
namespace clahe
{
void calcLut(PtrStepSzb src, PtrStepb lut, int tilesX, int tilesY, int2 tileSize, int clipLimit, float lutScale, cudaStream_t stream);
void transform(PtrStepSzb src, PtrStepSzb dst, PtrStepb lut, int tilesX, int tilesY, int2 tileSize, cudaStream_t stream);
void calcLut_8U(PtrStepSzb src, PtrStepb lut, int tilesX, int tilesY, int2 tileSize, int clipLimit, float lutScale, cudaStream_t stream);
void calcLut_16U(PtrStepSzus src, PtrStepus lut, int tilesX, int tilesY, int2 tileSize, int clipLimit, float lutScale, PtrStepSzi hist, cudaStream_t stream);
template <typename T> void transform(PtrStepSz<T> src, PtrStepSz<T> dst, PtrStep<T> lut, int tilesX, int tilesY, int2 tileSize, cudaStream_t stream);
}
namespace
@ -170,6 +171,7 @@ namespace
GpuMat srcExt_;
GpuMat lut_;
GpuMat hist_; // histogram on global memory for CV_16UC1 case
};
CLAHE_Impl::CLAHE_Impl(double clipLimit, int tilesX, int tilesY) :
@ -186,14 +188,16 @@ namespace
{
GpuMat src = _src.getGpuMat();
CV_Assert( src.type() == CV_8UC1 );
const int type = src.type();
_dst.create( src.size(), src.type() );
CV_Assert( type == CV_8UC1 || type == CV_16UC1 );
_dst.create( src.size(), type );
GpuMat dst = _dst.getGpuMat();
const int histSize = 256;
const int histSize = type == CV_8UC1 ? 256 : 65536;
ensureSizeIsEnough(tilesX_ * tilesY_, histSize, CV_8UC1, lut_);
ensureSizeIsEnough(tilesX_ * tilesY_, histSize, type, lut_);
cudaStream_t stream = StreamAccessor::getStream(s);
@ -227,9 +231,18 @@ namespace
clipLimit = std::max(clipLimit, 1);
}
clahe::calcLut(srcForLut, lut_, tilesX_, tilesY_, make_int2(tileSize.width, tileSize.height), clipLimit, lutScale, stream);
if (type == CV_8UC1)
clahe::calcLut_8U(srcForLut, lut_, tilesX_, tilesY_, make_int2(tileSize.width, tileSize.height), clipLimit, lutScale, stream);
else // type == CV_16UC1
{
ensureSizeIsEnough(tilesX_ * tilesY_, histSize, CV_32SC1, hist_);
clahe::calcLut_16U(srcForLut, lut_, tilesX_, tilesY_, make_int2(tileSize.width, tileSize.height), clipLimit, lutScale, hist_, stream);
}
clahe::transform(src, dst, lut_, tilesX_, tilesY_, make_int2(tileSize.width, tileSize.height), stream);
if (type == CV_8UC1)
clahe::transform<uchar>(src, dst, lut_, tilesX_, tilesY_, make_int2(tileSize.width, tileSize.height), stream);
else // type == CV_16UC1
clahe::transform<ushort>(src, dst, lut_, tilesX_, tilesY_, make_int2(tileSize.width, tileSize.height), stream);
}
void CLAHE_Impl::setClipLimit(double clipLimit)

@ -236,17 +236,19 @@ namespace
IMPLEMENT_PARAM_CLASS(ClipLimit, double)
}
PARAM_TEST_CASE(CLAHE, cv::cuda::DeviceInfo, cv::Size, ClipLimit)
PARAM_TEST_CASE(CLAHE, cv::cuda::DeviceInfo, cv::Size, ClipLimit, MatType)
{
cv::cuda::DeviceInfo devInfo;
cv::Size size;
double clipLimit;
int type;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
size = GET_PARAM(1);
clipLimit = GET_PARAM(2);
type = GET_PARAM(3);
cv::cuda::setDevice(devInfo.deviceID());
}
@ -254,7 +256,11 @@ PARAM_TEST_CASE(CLAHE, cv::cuda::DeviceInfo, cv::Size, ClipLimit)
CUDA_TEST_P(CLAHE, Accuracy)
{
cv::Mat src = randomMat(size, CV_8UC1);
cv::Mat src;
if (type == CV_8UC1)
src = randomMat(size, type);
else if (type == CV_16UC1)
src = randomMat(size, type, 0, 65535);
cv::Ptr<cv::cuda::CLAHE> clahe = cv::cuda::createCLAHE(clipLimit);
cv::cuda::GpuMat dst;
@ -270,7 +276,8 @@ CUDA_TEST_P(CLAHE, Accuracy)
INSTANTIATE_TEST_CASE_P(CUDA_ImgProc, CLAHE, testing::Combine(
ALL_DEVICES,
DIFFERENT_SIZES,
testing::Values(0.0, 40.0)));
testing::Values(0.0, 5.0, 10.0, 20.0, 40.0),
testing::Values(MatType(CV_8UC1), MatType(CV_16UC1))));
}} // namespace

@ -141,18 +141,20 @@ PERF_TEST_P(Dim_Cmpmethod, compareHist,
SANITY_CHECK_NOTHING();
}
typedef tuple<Size, double> Sz_ClipLimit_t;
typedef tuple<Size, double, MatType> Sz_ClipLimit_t;
typedef TestBaseWithParam<Sz_ClipLimit_t> Sz_ClipLimit;
PERF_TEST_P(Sz_ClipLimit, CLAHE,
testing::Combine(testing::Values(::perf::szVGA, ::perf::sz720p, ::perf::sz1080p),
testing::Values(0.0, 40.0))
testing::Values(0.0, 40.0),
testing::Values(MatType(CV_8UC1), MatType(CV_16UC1)))
)
{
const Size size = get<0>(GetParam());
const double clipLimit = get<1>(GetParam());
const int type = get<2>(GetParam());
Mat src(size, CV_8UC1);
Mat src(size, type);
declare.in(src, WARMUP_RNG);
Ptr<CLAHE> clahe = createCLAHE(clipLimit);

Loading…
Cancel
Save