CUDA median filtering using histograms

pull/6035/head
Oded Green 9 years ago
parent d2e169929c
commit 1a0282df21
  1. 12
      modules/cudafilters/include/opencv2/cudafilters.hpp
  2. 40
      modules/cudafilters/perf/perf_filters.cpp
  3. 348
      modules/cudafilters/src/cuda/median_filter.cu
  4. 72
      modules/cudafilters/src/filtering.cpp
  5. 61
      modules/cudafilters/test/test_filters.cpp
  6. 1
      modules/cudaimgproc/include/opencv2/cudaimgproc.hpp

@ -314,6 +314,18 @@ CV_EXPORTS Ptr<Filter> createColumnSumFilter(int srcType, int dstType, int ksize
//! @}
///////////////////////////// Median Filtering //////////////////////////////
/** @brief Performs median filtering for each point of the source image.
@param srcType type of of source image. Only CV_8UC1 images are supported for now.
@param windowSize Size of the kernerl used for the filtering. Uses a (windowSize x windowSize) filter.
@param partition Specifies the parallel granularity of the workload. This parameter should be used GPU experts when optimizing performance.
Outputs an image that has been filtered using median-filtering formulation.
*/
CV_EXPORTS Ptr<Filter> createMedianFilter(int srcType, int windowSize, int partition=128);
}} // namespace cv { namespace cuda {
#endif /* __OPENCV_CUDAFILTERS_HPP__ */

@ -375,3 +375,43 @@ PERF_TEST_P(Sz_Type_Op, MorphologyEx, Combine(CUDA_TYPICAL_MAT_SIZES, Values(CV_
CPU_SANITY_CHECK(dst);
}
}
//////////////////////////////////////////////////////////////////////
// MedianFilter
//////////////////////////////////////////////////////////////////////
// Median
DEF_PARAM_TEST(Sz_KernelSz, cv::Size, int);
//PERF_TEST_P(Sz_Type_KernelSz, Median, Combine(CUDA_TYPICAL_MAT_SIZES, Values(CV_8UC1,CV_8UC1), Values(3, 5, 7, 9, 11, 13, 15)))
PERF_TEST_P(Sz_KernelSz, Median, Combine(CUDA_TYPICAL_MAT_SIZES, Values(3, 5, 7, 9, 11, 13, 15)))
{
declare.time(20.0);
const cv::Size size = GET_PARAM(0);
// const int type = GET_PARAM(1);
const int type = CV_8UC1;
const int kernel = GET_PARAM(1);
cv::Mat src(size, type);
declare.in(src, WARMUP_RNG);
if (PERF_RUN_CUDA())
{
const cv::cuda::GpuMat d_src(src);
cv::cuda::GpuMat dst;
cv::Ptr<cv::cuda::Filter> median = cv::cuda::createMedianFilter(d_src.type(), kernel);
TEST_CYCLE() median->apply(d_src, dst);
SANITY_CHECK_NOTHING();
}
else
{
cv::Mat dst;
TEST_CYCLE() cv::medianBlur(src,dst,kernel);
SANITY_CHECK_NOTHING();
}
}

@ -0,0 +1,348 @@
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
// Copyright (C) 2009, Willow Garage Inc., all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#if !defined CUDA_DISABLER
#include "precomp.hpp"
using namespace cv;
using namespace cv::cuda;
#include "opencv2/core/cuda/common.hpp"
#include "opencv2/core/cuda/vec_traits.hpp"
#include "opencv2/core/cuda/vec_math.hpp"
#include "opencv2/core/cuda/saturate_cast.hpp"
#include "opencv2/core/cuda/border_interpolate.hpp"
namespace cv { namespace cuda { namespace device
{
// // namespace imgproc
// {
__device__ void histogramAddAndSub8(int* H, const int * hist_colAdd,const int * hist_colSub){
int tx = threadIdx.x;
if (tx<8){
H[tx]+=hist_colAdd[tx]-hist_colSub[tx];
}
}
__device__ void histogramMultipleAdd8(int* H, const int * hist_col,int histCount){
int tx = threadIdx.x;
if (tx<8){
int temp=H[tx];
for(int i=0; i<histCount; i++)
temp+=hist_col[(i<<3)+tx];
H[tx]=temp;
}
}
__device__ void histogramClear8(int* H){
int tx = threadIdx.x;
if (tx<8){
H[tx]=0;
}
}
__device__ void histogramAdd8(int* H, const int * hist_col){
int tx = threadIdx.x;
if (tx<8){
H[tx]+=hist_col[tx];
}
}
__device__ void histogramSub8(int* H, const int * hist_col){
int tx = threadIdx.x;
if (tx<8){
H[tx]-=hist_col[tx];
}
}
__device__ void histogramAdd32(int* H, const int * hist_col){
int tx = threadIdx.x;
if (tx<32){
H[tx]+=hist_col[tx];
}
}
__device__ void histogramAddAndSub32(int* H, const int * hist_colAdd,const int * hist_colSub){
int tx = threadIdx.x;
if (tx<32){
H[tx]+=hist_colAdd[tx]-hist_colSub[tx];
}
}
__device__ void histogramClear32(int* H){
int tx = threadIdx.x;
if (tx<32){
H[tx]=0;
}
}
__device__ void lucClear8(int* luc){
int tx = threadIdx.x;
if (tx<8)
luc[tx]=0;
}
__device__ void histogramMedianPar8LookupOnly(int* H,int* Hscan, const int medPos,int* retval, int* countAtMed){
int tx=threadIdx.x;
*retval=*countAtMed=0;
if(tx<8){
Hscan[tx]=H[tx];
}
__syncthreads();
if(tx<8){
if(tx>=1 )
Hscan[tx]+=Hscan[tx-1];
if(tx>=2)
Hscan[tx]+=Hscan[tx-2];
if(tx>=4)
Hscan[tx]+=Hscan[tx-4];
}
__syncthreads();
if(tx<7){
if(Hscan[tx+1] > medPos && Hscan[tx] < medPos){
*retval=tx+1;
*countAtMed=Hscan[tx];
}
else if(Hscan[tx]==medPos){
if(Hscan[tx+1]>medPos){
*retval=tx+1;
*countAtMed=Hscan[tx];
}
}
}
}
__device__ void histogramMedianPar32LookupOnly(int* H,int* Hscan, const int medPos,int* retval, int* countAtMed){
int tx=threadIdx.x;
*retval=*countAtMed=0;
if(tx<32){
Hscan[tx]=H[tx];
}
__syncthreads();
if(tx<32){
if(tx>=1)
Hscan[tx]+=Hscan[tx-1];
if(tx>=2)
Hscan[tx]+=Hscan[tx-2];
if(tx>=4)
Hscan[tx]+=Hscan[tx-4];
if(tx>=8)
Hscan[tx]+=Hscan[tx-8];
if(tx>=16)
Hscan[tx]+=Hscan[tx-16];
}
__syncthreads();
if(tx<31){
if(Hscan[tx+1] > medPos && Hscan[tx] < medPos){
*retval=tx+1;
*countAtMed=Hscan[tx];
}
else if(Hscan[tx]==medPos){
if(Hscan[tx+1]>medPos){
*retval=tx+1;
*countAtMed=Hscan[tx];
}
}
}
}
__global__ void cuMedianFilterMultiBlock(PtrStepSzb src, PtrStepSzb dest, PtrStepSzi histPar, PtrStepSzi coarseHistGrid,int r, int medPos_)
{
__shared__ int HCoarse[8];
__shared__ int HCoarseScan[32];
__shared__ int HFine[8][32];
__shared__ int luc[8];
__shared__ int firstBin,countAtMed, retval;
int rows = src.rows, cols=src.cols;
int extraRowThread=rows%gridDim.x;
int doExtraRow=blockIdx.x<extraRowThread;
int startRow=0, stopRow=0;
int rowsPerBlock= rows/gridDim.x+doExtraRow;
// The following code partitions the work to the blocks. Some blocks will do one row more
// than other blocks. This code is responsible for doing that balancing
if(doExtraRow){
startRow=rowsPerBlock*blockIdx.x;
stopRow=::min(rows, startRow+rowsPerBlock);
}
else{
startRow=(rowsPerBlock+1)*extraRowThread+(rowsPerBlock)*(blockIdx.x-extraRowThread);
stopRow=::min(rows, startRow+rowsPerBlock);
}
int* hist= histPar.data+cols*256*blockIdx.x;
int* histCoarse=coarseHistGrid.data +cols*8*blockIdx.x;
if (blockIdx.x==(gridDim.x-1))
stopRow=rows;
__syncthreads();
int initNeeded=0, initVal, initStartRow, initStopRow;
if(blockIdx.x==0){
initNeeded=1; initVal=r+2; initStartRow=1; initStopRow=r;
}
else if (startRow<(r+2)){
initNeeded=1; initVal=r+2-startRow; initStartRow=1; initStopRow=r+startRow;
}
else{
initNeeded=0; initVal=0; initStartRow=startRow-(r+1); initStopRow=r+startRow;
}
__syncthreads();
// In the original algorithm an initialization phase was required as part of the window was outside the
// image. In this parallel version, the initializtion is required for all thread blocks that part
// of the median filter is outside the window.
// For all threads in the block the same code will be executed.
if (initNeeded){
for (int j=threadIdx.x; j<(cols); j+=blockDim.x){
hist[j*256+src.ptr(0)[j]]=initVal;
histCoarse[j*8+(src.ptr(0)[j]>>5)]=initVal;
}
}
__syncthreads();
// Fot all remaining rows in the median filter, add the values to the the histogram
for (int j=threadIdx.x; j<cols; j+=blockDim.x){
for(int i=initStartRow; i<initStopRow; i++){
int pos=::min(i,rows-1);
hist[j*256+src.ptr(pos)[j]]++;
histCoarse[j*8+(src.ptr(pos)[j]>>5)]++;
}
}
__syncthreads();
// Going through all the rows that the block is responsible for.
int inc=blockDim.x*256;
int incCoarse=blockDim.x*8;
for(int i=startRow; i< stopRow; i++){
// For every new row that is started the global histogram for the entire window is restarted.
histogramClear8(HCoarse);
lucClear8(luc);
// Computing some necessary indices
int possub=::max(0,i-r-1),posadd=::min(rows-1,i+r);
int histPos=threadIdx.x*256;
int histCoarsePos=threadIdx.x*8;
// Going through all the elements of a specific row. Foeach histogram, a value is taken out and
// one value is added.
for (int j=threadIdx.x; j<cols; j+=blockDim.x){
hist[histPos+ src.ptr(possub)[j] ]--;
hist[histPos+ src.ptr(posadd)[j] ]++;
histCoarse[histCoarsePos+ (src.ptr(possub)[j]>>5) ]--;
histCoarse[histCoarsePos+ (src.ptr(posadd)[j]>>5) ]++;
histPos+=inc;
histCoarsePos+=incCoarse;
}
__syncthreads();
histogramMultipleAdd8(HCoarse,histCoarse, 2*r+1);
// __syncthreads();
int cols_m_1=cols-1;
for(int j=r;j<cols-r;j++){
int possub=::max(j-r,0);
int posadd=::min(j+1+r,cols_m_1);
int medPos=medPos_;
__syncthreads();
histogramMedianPar8LookupOnly(HCoarse,HCoarseScan,medPos, &firstBin,&countAtMed);
__syncthreads();
if ( luc[firstBin] <= (j-r))
{
histogramClear32(HFine[firstBin]);
for ( luc[firstBin] = j-r; luc[firstBin] < ::min(j+r+1,cols); luc[firstBin]++ ){
histogramAdd32(HFine[firstBin], hist+(luc[firstBin]*256+(firstBin<<5) ) );
}
}
else{
for ( ; luc[firstBin] < (j+r+1);luc[firstBin]++ ) {
histogramAddAndSub32(HFine[firstBin],
hist+(::min(luc[firstBin],cols_m_1)*256+(firstBin<<5) ),
hist+(::max(luc[firstBin]-2*r-1,0)*256+(firstBin<<5) ) );
__syncthreads();
}
}
__syncthreads();
int leftOver=medPos-countAtMed;
if(leftOver>=0){
histogramMedianPar32LookupOnly(HFine[firstBin],HCoarseScan,leftOver,&retval,&countAtMed);
}
else retval=0;
__syncthreads();
if (threadIdx.x==0){
dest.ptr(i)[j]=(firstBin<<5) + retval;
}
histogramAddAndSub8(HCoarse, histCoarse+(int)(posadd<<3),histCoarse+(int)(possub<<3));
__syncthreads();
}
__syncthreads();
}
}
void medianFiltering_gpu(const PtrStepSzb src, PtrStepSzb dst, PtrStepSzi devHist, PtrStepSzi devCoarseHist,int kernel, int partitions,cudaStream_t stream){
int medPos=2*kernel*kernel+2*kernel;
dim3 gridDim; gridDim.x=partitions;
dim3 blockDim; blockDim.x=32;
cuMedianFilterMultiBlock<<<gridDim,blockDim,0, stream>>>(src, dst, devHist,devCoarseHist, kernel, medPos);
if (!stream)
cudaSafeCall( cudaDeviceSynchronize() );
}
}}}
#endif

@ -69,6 +69,8 @@ Ptr<Filter> cv::cuda::createBoxMinFilter(int, Size, Point, int, Scalar) { throw_
Ptr<Filter> cv::cuda::createRowSumFilter(int, int, int, int, int, Scalar) { throw_no_cuda(); return Ptr<Filter>(); }
Ptr<Filter> cv::cuda::createColumnSumFilter(int, int, int, int, int, Scalar) { throw_no_cuda(); return Ptr<Filter>(); }
Ptr<Filter> cv::cuda::createMedianFilter(int srcType, int _windowSize, int _partitions){ throw_no_cuda(); return Ptr<Filter>();}
#else
namespace
@ -995,4 +997,74 @@ Ptr<Filter> cv::cuda::createColumnSumFilter(int srcType, int dstType, int ksize,
return makePtr<NppColumnSumFilter>(srcType, dstType, ksize, anchor, borderMode, borderVal);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Median Filter
namespace cv { namespace cuda { namespace device
{
void medianFiltering_gpu(const PtrStepSzb src, PtrStepSzb dst, PtrStepSzi devHist,
PtrStepSzi devCoarseHist,int kernel, int partitions, cudaStream_t stream);
}}}
namespace
{
class MedianFilter : public Filter
{
public:
MedianFilter(int srcType, int _windowSize, int _partitions=128);
void apply(InputArray src, OutputArray dst, Stream& stream = Stream::Null());
private:
int windowSize;
int partitions;
};
MedianFilter::MedianFilter(int srcType, int _windowSize, int _partitions) :
windowSize(_windowSize),partitions(_partitions)
{
CV_Assert( srcType == CV_8UC1 );
CV_Assert(windowSize>=3);
CV_Assert(_partitions>=1);
}
void MedianFilter::apply(InputArray _src, OutputArray _dst, Stream& _stream)
{
using namespace cv::cuda::device;
GpuMat src = _src.getGpuMat();
_dst.create(src.rows, src.cols, src.type());
GpuMat dst = _dst.getGpuMat();
if (partitions>src.rows)
partitions=src.rows/2;
// Kernel needs to be half window size
int kernel=windowSize/2;
CV_Assert(kernel < src.rows);
CV_Assert(kernel < src.cols);
// Note - these are hardcoded in the actual GPU kernel. Do not change these values.
int histSize=256, histCoarseSize=8;
BufferPool pool(_stream);
GpuMat devHist = pool.getBuffer(1, src.cols*histSize*partitions,CV_32SC1);
GpuMat devCoarseHist = pool.getBuffer(1,src.cols*histCoarseSize*partitions,CV_32SC1);
devHist.setTo(0, _stream);
devCoarseHist.setTo(0, _stream);
medianFiltering_gpu(src,dst,devHist, devCoarseHist,kernel,partitions,StreamAccessor::getStream(_stream));
}
}
Ptr<Filter> cv::cuda::createMedianFilter(int srcType, int _windowSize, int _partitions)
{
return makePtr<MedianFilter>(srcType, _windowSize,_partitions);
}
#endif

@ -53,6 +53,7 @@ namespace
IMPLEMENT_PARAM_CLASS(Deriv_X, int)
IMPLEMENT_PARAM_CLASS(Deriv_Y, int)
IMPLEMENT_PARAM_CLASS(Iterations, int)
IMPLEMENT_PARAM_CLASS(KernelSize, int)
cv::Mat getInnerROI(cv::InputArray m_, cv::Size ksize)
{
@ -647,4 +648,64 @@ INSTANTIATE_TEST_CASE_P(CUDA_Filters, MorphEx, testing::Combine(
testing::Values(Iterations(1), Iterations(2), Iterations(3)),
WHOLE_SUBMAT));
/////////////////////////////////////////////////////////////////////////////////////////////////
// Median
PARAM_TEST_CASE(Median, cv::cuda::DeviceInfo, cv::Size, MatDepth, KernelSize, UseRoi)
{
cv::cuda::DeviceInfo devInfo;
cv::Size size;
int type;
int kernel;
bool useRoi;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
size = GET_PARAM(1);
type = GET_PARAM(2);
kernel = GET_PARAM(3);
useRoi = GET_PARAM(4);
cv::cuda::setDevice(devInfo.deviceID());
}
};
CUDA_TEST_P(Median, Accuracy)
{
cv::Mat src = randomMat(size, type);
cv::Ptr<cv::cuda::Filter> median = cv::cuda::createMedianFilter(src.type(), kernel);
cv::cuda::GpuMat dst = createMat(size, type, useRoi);
median->apply(loadMat(src, useRoi), dst);
cv::Mat dst_gold;
cv::medianBlur(src,dst_gold,kernel);
cv::Rect rect(kernel+1,0,src.cols-(2*kernel+1),src.rows);
cv::Mat dst_gold_no_border = dst_gold(rect);
cv::cuda::GpuMat dst_no_border = cv::cuda::GpuMat(dst, rect);
EXPECT_MAT_NEAR(dst_gold_no_border, dst_no_border, 1);
}
INSTANTIATE_TEST_CASE_P(CUDA_Filters, Median, testing::Combine(
ALL_DEVICES,
DIFFERENT_SIZES,
testing::Values(MatDepth(CV_8U)),
testing::Values(KernelSize(3),
KernelSize(5),
KernelSize(7),
KernelSize(9),
KernelSize(11),
KernelSize(13),
KernelSize(15)),
WHOLE_SUBMAT)
);
#endif // HAVE_CUDA

@ -588,6 +588,7 @@ CV_EXPORTS Ptr<CornersDetector> createGoodFeaturesToTrackDetector(int srcType, i
//! @} cudaimgproc_feature
///////////////////////////// Mean Shift //////////////////////////////
/** @brief Performs mean-shift filtering for each point of the source image.

Loading…
Cancel
Save