added stream parameter to all cudaimgproc routines

pull/3566/head
Vladislav Vinogradov 10 years ago
parent 220d937d9a
commit f50a061225
  1. 36
      modules/cudaimgproc/include/opencv2/cudaimgproc.hpp
  2. 50
      modules/cudaimgproc/src/canny.cpp
  3. 50
      modules/cudaimgproc/src/cuda/canny.cu
  4. 7
      modules/cudaimgproc/src/gftt.cpp
  5. 6
      modules/cudaimgproc/src/histogram.cpp
  6. 7
      modules/cudaimgproc/src/hough_circles.cpp
  7. 21
      modules/cudaimgproc/src/hough_lines.cpp
  8. 7
      modules/cudaimgproc/src/hough_segments.cpp
  9. 9
      modules/cudaimgproc/src/mssegmentation.cpp

@ -240,8 +240,9 @@ CV_EXPORTS Ptr<cuda::CLAHE> createCLAHE(double clipLimit = 40.0, Size tileGridSi
@param nLevels Number of computed levels. nLevels must be at least 2.
@param lowerLevel Lower boundary value of the lowest level.
@param upperLevel Upper boundary value of the greatest level.
@param stream Stream for the asynchronous version.
*/
CV_EXPORTS void evenLevels(OutputArray levels, int nLevels, int lowerLevel, int upperLevel);
CV_EXPORTS void evenLevels(OutputArray levels, int nLevels, int lowerLevel, int upperLevel, Stream& stream = Stream::Null());
/** @brief Calculates a histogram with evenly distributed bins.
@ -281,15 +282,17 @@ public:
/** @brief Finds edges in an image using the @cite Canny86 algorithm.
@param image Single-channel 8-bit input image.
@param edges Output edge map. It has the same size and type as image .
@param edges Output edge map. It has the same size and type as image.
@param stream Stream for the asynchronous version.
*/
virtual void detect(InputArray image, OutputArray edges) = 0;
virtual void detect(InputArray image, OutputArray edges, Stream& stream = Stream::Null()) = 0;
/** @overload
@param dx First derivative of image in the vertical direction. Support only CV_32S type.
@param dy First derivative of image in the horizontal direction. Support only CV_32S type.
@param edges Output edge map. It has the same size and type as image .
@param edges Output edge map. It has the same size and type as image.
@param stream Stream for the asynchronous version.
*/
virtual void detect(InputArray dx, InputArray dy, OutputArray edges) = 0;
virtual void detect(InputArray dx, InputArray dy, OutputArray edges, Stream& stream = Stream::Null()) = 0;
virtual void setLowThreshold(double low_thresh) = 0;
virtual double getLowThreshold() const = 0;
@ -336,18 +339,20 @@ public:
\f$(\rho, \theta)\f$ . \f$\rho\f$ is the distance from the coordinate origin \f$(0,0)\f$ (top-left corner of
the image). \f$\theta\f$ is the line rotation angle in radians (
\f$0 \sim \textrm{vertical line}, \pi/2 \sim \textrm{horizontal line}\f$ ).
@param stream Stream for the asynchronous version.
@sa HoughLines
*/
virtual void detect(InputArray src, OutputArray lines) = 0;
virtual void detect(InputArray src, OutputArray lines, Stream& stream = Stream::Null()) = 0;
/** @brief Downloads results from cuda::HoughLinesDetector::detect to host memory.
@param d_lines Result of cuda::HoughLinesDetector::detect .
@param h_lines Output host array.
@param h_votes Optional output array for line's votes.
@param stream Stream for the asynchronous version.
*/
virtual void downloadResults(InputArray d_lines, OutputArray h_lines, OutputArray h_votes = noArray()) = 0;
virtual void downloadResults(InputArray d_lines, OutputArray h_lines, OutputArray h_votes = noArray(), Stream& stream = Stream::Null()) = 0;
virtual void setRho(float rho) = 0;
virtual float getRho() const = 0;
@ -391,10 +396,11 @@ public:
@param lines Output vector of lines. Each line is represented by a 4-element vector
\f$(x_1, y_1, x_2, y_2)\f$ , where \f$(x_1,y_1)\f$ and \f$(x_2, y_2)\f$ are the ending points of each detected
line segment.
@param stream Stream for the asynchronous version.
@sa HoughLinesP
*/
virtual void detect(InputArray src, OutputArray lines) = 0;
virtual void detect(InputArray src, OutputArray lines, Stream& stream = Stream::Null()) = 0;
virtual void setRho(float rho) = 0;
virtual float getRho() const = 0;
@ -435,10 +441,11 @@ public:
@param src 8-bit, single-channel grayscale input image.
@param circles Output vector of found circles. Each vector is encoded as a 3-element
floating-point vector \f$(x, y, radius)\f$ .
@param stream Stream for the asynchronous version.
@sa HoughCircles
*/
virtual void detect(InputArray src, OutputArray circles) = 0;
virtual void detect(InputArray src, OutputArray circles, Stream& stream = Stream::Null()) = 0;
virtual void setDp(float dp) = 0;
virtual float getDp() const = 0;
@ -553,8 +560,9 @@ public:
positions).
@param mask Optional region of interest. If the image is not empty (it needs to have the type
CV_8UC1 and the same size as image ), it specifies the region in which the corners are detected.
@param stream Stream for the asynchronous version.
*/
virtual void detect(InputArray image, OutputArray corners, InputArray mask = noArray()) = 0;
virtual void detect(InputArray image, OutputArray corners, InputArray mask = noArray(), Stream& stream = Stream::Null()) = 0;
};
/** @brief Creates implementation for cuda::CornersDetector .
@ -590,7 +598,7 @@ as src .
@param sp Spatial window radius.
@param sr Color window radius.
@param criteria Termination criteria. See TermCriteria.
@param stream
@param stream Stream for the asynchronous version.
It maps each point of the source image into another point. As a result, you have a new color and new
position of each point.
@ -610,7 +618,7 @@ src size. The type is CV_16SC2 .
@param sp Spatial window radius.
@param sr Color window radius.
@param criteria Termination criteria. See TermCriteria.
@param stream
@param stream Stream for the asynchronous version.
@sa cuda::meanShiftFiltering
*/
@ -626,9 +634,11 @@ CV_EXPORTS void meanShiftProc(InputArray src, OutputArray dstr, OutputArray dsts
@param sr Color window radius.
@param minsize Minimum segment size. Smaller segments are merged.
@param criteria Termination criteria. See TermCriteria.
@param stream Stream for the asynchronous version.
*/
CV_EXPORTS void meanShiftSegmentation(InputArray src, OutputArray dst, int sp, int sr, int minsize,
TermCriteria criteria = TermCriteria(TermCriteria::MAX_ITER + TermCriteria::EPS, 5, 1));
TermCriteria criteria = TermCriteria(TermCriteria::MAX_ITER + TermCriteria::EPS, 5, 1),
Stream& stream = Stream::Null());
/////////////////////////// Match Template ////////////////////////////

@ -53,16 +53,16 @@ Ptr<CannyEdgeDetector> cv::cuda::createCannyEdgeDetector(double, double, int, bo
namespace canny
{
void calcMagnitude(PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzi dx, PtrStepSzi dy, PtrStepSzf mag, bool L2Grad);
void calcMagnitude(PtrStepSzi dx, PtrStepSzi dy, PtrStepSzf mag, bool L2Grad);
void calcMagnitude(PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzi dx, PtrStepSzi dy, PtrStepSzf mag, bool L2Grad, cudaStream_t stream);
void calcMagnitude(PtrStepSzi dx, PtrStepSzi dy, PtrStepSzf mag, bool L2Grad, cudaStream_t stream);
void calcMap(PtrStepSzi dx, PtrStepSzi dy, PtrStepSzf mag, PtrStepSzi map, float low_thresh, float high_thresh);
void calcMap(PtrStepSzi dx, PtrStepSzi dy, PtrStepSzf mag, PtrStepSzi map, float low_thresh, float high_thresh, cudaStream_t stream);
void edgesHysteresisLocal(PtrStepSzi map, short2* st1);
void edgesHysteresisLocal(PtrStepSzi map, short2* st1, cudaStream_t stream);
void edgesHysteresisGlobal(PtrStepSzi map, short2* st1, short2* st2);
void edgesHysteresisGlobal(PtrStepSzi map, short2* st1, short2* st2, cudaStream_t stream);
void getEdges(PtrStepSzi map, PtrStepSzb dst);
void getEdges(PtrStepSzi map, PtrStepSzb dst, cudaStream_t stream);
}
namespace
@ -76,8 +76,8 @@ namespace
old_apperture_size_ = -1;
}
void detect(InputArray image, OutputArray edges);
void detect(InputArray dx, InputArray dy, OutputArray edges);
void detect(InputArray image, OutputArray edges, Stream& stream);
void detect(InputArray dx, InputArray dy, OutputArray edges, Stream& stream);
void setLowThreshold(double low_thresh) { low_thresh_ = low_thresh; }
double getLowThreshold() const { return low_thresh_; }
@ -111,7 +111,7 @@ namespace
private:
void createBuf(Size image_size);
void CannyCaller(GpuMat& edges);
void CannyCaller(GpuMat& edges, Stream& stream);
double low_thresh_;
double high_thresh_;
@ -128,7 +128,7 @@ namespace
int old_apperture_size_;
};
void CannyImpl::detect(InputArray _image, OutputArray _edges)
void CannyImpl::detect(InputArray _image, OutputArray _edges, Stream& stream)
{
GpuMat image = _image.getGpuMat();
@ -150,24 +150,24 @@ namespace
image.locateROI(wholeSize, ofs);
GpuMat srcWhole(wholeSize, image.type(), image.datastart, image.step);
canny::calcMagnitude(srcWhole, ofs.x, ofs.y, dx_, dy_, mag_, L2gradient_);
canny::calcMagnitude(srcWhole, ofs.x, ofs.y, dx_, dy_, mag_, L2gradient_, StreamAccessor::getStream(stream));
}
else
{
#ifndef HAVE_OPENCV_CUDAFILTERS
throw_no_cuda();
#else
filterDX_->apply(image, dx_);
filterDY_->apply(image, dy_);
filterDX_->apply(image, dx_, stream);
filterDY_->apply(image, dy_, stream);
canny::calcMagnitude(dx_, dy_, mag_, L2gradient_);
canny::calcMagnitude(dx_, dy_, mag_, L2gradient_, StreamAccessor::getStream(stream));
#endif
}
CannyCaller(edges);
CannyCaller(edges, stream);
}
void CannyImpl::detect(InputArray _dx, InputArray _dy, OutputArray _edges)
void CannyImpl::detect(InputArray _dx, InputArray _dy, OutputArray _edges, Stream& stream)
{
GpuMat dx = _dx.getGpuMat();
GpuMat dy = _dy.getGpuMat();
@ -176,8 +176,8 @@ namespace
CV_Assert( dy.type() == dx.type() && dy.size() == dx.size() );
CV_Assert( deviceSupports(SHARED_ATOMICS) );
dx.copyTo(dx_);
dy.copyTo(dy_);
dx.copyTo(dx_, stream);
dy.copyTo(dy_, stream);
if (low_thresh_ > high_thresh_)
std::swap(low_thresh_, high_thresh_);
@ -187,9 +187,9 @@ namespace
_edges.create(dx.size(), CV_8UC1);
GpuMat edges = _edges.getGpuMat();
canny::calcMagnitude(dx_, dy_, mag_, L2gradient_);
canny::calcMagnitude(dx_, dy_, mag_, L2gradient_, StreamAccessor::getStream(stream));
CannyCaller(edges);
CannyCaller(edges, stream);
}
void CannyImpl::createBuf(Size image_size)
@ -215,16 +215,16 @@ namespace
ensureSizeIsEnough(1, image_size.area(), CV_16SC2, st2_);
}
void CannyImpl::CannyCaller(GpuMat& edges)
void CannyImpl::CannyCaller(GpuMat& edges, Stream& stream)
{
map_.setTo(Scalar::all(0));
canny::calcMap(dx_, dy_, mag_, map_, static_cast<float>(low_thresh_), static_cast<float>(high_thresh_));
canny::calcMap(dx_, dy_, mag_, map_, static_cast<float>(low_thresh_), static_cast<float>(high_thresh_), StreamAccessor::getStream(stream));
canny::edgesHysteresisLocal(map_, st1_.ptr<short2>());
canny::edgesHysteresisLocal(map_, st1_.ptr<short2>(), StreamAccessor::getStream(stream));
canny::edgesHysteresisGlobal(map_, st1_.ptr<short2>(), st2_.ptr<short2>());
canny::edgesHysteresisGlobal(map_, st1_.ptr<short2>(), st2_.ptr<short2>(), StreamAccessor::getStream(stream));
canny::getEdges(map_, edges);
canny::getEdges(map_, edges, StreamAccessor::getStream(stream));
}
}

@ -120,7 +120,7 @@ namespace canny
mag(y, x) = norm(dxVal, dyVal);
}
void calcMagnitude(PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzi dx, PtrStepSzi dy, PtrStepSzf mag, bool L2Grad)
void calcMagnitude(PtrStepSzb srcWhole, int xoff, int yoff, PtrStepSzi dx, PtrStepSzi dy, PtrStepSzf mag, bool L2Grad, cudaStream_t stream)
{
const dim3 block(16, 16);
const dim3 grid(divUp(mag.cols, block.x), divUp(mag.rows, block.y));
@ -131,30 +131,31 @@ namespace canny
if (L2Grad)
{
L2 norm;
calcMagnitudeKernel<<<grid, block>>>(src, dx, dy, mag, norm);
calcMagnitudeKernel<<<grid, block, 0, stream>>>(src, dx, dy, mag, norm);
}
else
{
L1 norm;
calcMagnitudeKernel<<<grid, block>>>(src, dx, dy, mag, norm);
calcMagnitudeKernel<<<grid, block, 0, stream>>>(src, dx, dy, mag, norm);
}
cudaSafeCall( cudaGetLastError() );
cudaSafeCall(cudaThreadSynchronize());
if (stream == NULL)
cudaSafeCall( cudaDeviceSynchronize() );
}
void calcMagnitude(PtrStepSzi dx, PtrStepSzi dy, PtrStepSzf mag, bool L2Grad)
void calcMagnitude(PtrStepSzi dx, PtrStepSzi dy, PtrStepSzf mag, bool L2Grad, cudaStream_t stream)
{
if (L2Grad)
{
L2 norm;
transform(dx, dy, mag, norm, WithOutMask(), 0);
transform(dx, dy, mag, norm, WithOutMask(), stream);
}
else
{
L1 norm;
transform(dx, dy, mag, norm, WithOutMask(), 0);
transform(dx, dy, mag, norm, WithOutMask(), stream);
}
}
}
@ -217,17 +218,18 @@ namespace canny
map(y, x) = edge_type;
}
void calcMap(PtrStepSzi dx, PtrStepSzi dy, PtrStepSzf mag, PtrStepSzi map, float low_thresh, float high_thresh)
void calcMap(PtrStepSzi dx, PtrStepSzi dy, PtrStepSzf mag, PtrStepSzi map, float low_thresh, float high_thresh, cudaStream_t stream)
{
const dim3 block(16, 16);
const dim3 grid(divUp(dx.cols, block.x), divUp(dx.rows, block.y));
bindTexture(&tex_mag, mag);
calcMapKernel<<<grid, block>>>(dx, dy, map, low_thresh, high_thresh);
calcMapKernel<<<grid, block, 0, stream>>>(dx, dy, map, low_thresh, high_thresh);
cudaSafeCall( cudaGetLastError() );
cudaSafeCall( cudaDeviceSynchronize() );
if (stream == NULL)
cudaSafeCall( cudaDeviceSynchronize() );
}
}
@ -328,20 +330,21 @@ namespace canny
}
}
void edgesHysteresisLocal(PtrStepSzi map, short2* st1)
void edgesHysteresisLocal(PtrStepSzi map, short2* st1, cudaStream_t stream)
{
void* counter_ptr;
cudaSafeCall( cudaGetSymbolAddress(&counter_ptr, counter) );
cudaSafeCall( cudaMemset(counter_ptr, 0, sizeof(int)) );
cudaSafeCall( cudaMemsetAsync(counter_ptr, 0, sizeof(int), stream) );
const dim3 block(16, 16);
const dim3 grid(divUp(map.cols, block.x), divUp(map.rows, block.y));
edgesHysteresisLocalKernel<<<grid, block>>>(map, st1);
edgesHysteresisLocalKernel<<<grid, block, 0, stream>>>(map, st1);
cudaSafeCall( cudaGetLastError() );
cudaSafeCall( cudaDeviceSynchronize() );
if (stream == NULL)
cudaSafeCall( cudaDeviceSynchronize() );
}
}
@ -441,27 +444,30 @@ namespace canny
}
}
void edgesHysteresisGlobal(PtrStepSzi map, short2* st1, short2* st2)
void edgesHysteresisGlobal(PtrStepSzi map, short2* st1, short2* st2, cudaStream_t stream)
{
void* counter_ptr;
cudaSafeCall( cudaGetSymbolAddress(&counter_ptr, canny::counter) );
int count;
cudaSafeCall( cudaMemcpy(&count, counter_ptr, sizeof(int), cudaMemcpyDeviceToHost) );
cudaSafeCall( cudaMemcpyAsync(&count, counter_ptr, sizeof(int), cudaMemcpyDeviceToHost, stream) );
cudaSafeCall( cudaStreamSynchronize(stream) );
while (count > 0)
{
cudaSafeCall( cudaMemset(counter_ptr, 0, sizeof(int)) );
cudaSafeCall( cudaMemsetAsync(counter_ptr, 0, sizeof(int), stream) );
const dim3 block(128);
const dim3 grid(::min(count, 65535u), divUp(count, 65535), 1);
edgesHysteresisGlobalKernel<<<grid, block>>>(map, st1, st2, count);
edgesHysteresisGlobalKernel<<<grid, block, 0, stream>>>(map, st1, st2, count);
cudaSafeCall( cudaGetLastError() );
cudaSafeCall( cudaDeviceSynchronize() );
if (stream == NULL)
cudaSafeCall( cudaDeviceSynchronize() );
cudaSafeCall( cudaMemcpy(&count, counter_ptr, sizeof(int), cudaMemcpyDeviceToHost) );
cudaSafeCall( cudaMemcpyAsync(&count, counter_ptr, sizeof(int), cudaMemcpyDeviceToHost, stream) );
cudaSafeCall( cudaStreamSynchronize(stream) );
count = min(count, map.cols * map.rows);
@ -499,9 +505,9 @@ namespace cv { namespace cuda { namespace device
namespace canny
{
void getEdges(PtrStepSzi map, PtrStepSzb dst)
void getEdges(PtrStepSzi map, PtrStepSzb dst, cudaStream_t stream)
{
transform(map, dst, GetEdges(), WithOutMask(), 0);
transform(map, dst, GetEdges(), WithOutMask(), stream);
}
}

@ -68,7 +68,7 @@ namespace
GoodFeaturesToTrackDetector(int srcType, int maxCorners, double qualityLevel, double minDistance,
int blockSize, bool useHarrisDetector, double harrisK);
void detect(InputArray image, OutputArray corners, InputArray mask = noArray());
void detect(InputArray image, OutputArray corners, InputArray mask, Stream& stream);
private:
int maxCorners_;
@ -96,8 +96,11 @@ namespace
cuda::createMinEigenValCorner(srcType, blockSize, 3);
}
void GoodFeaturesToTrackDetector::detect(InputArray _image, OutputArray _corners, InputArray _mask)
void GoodFeaturesToTrackDetector::detect(InputArray _image, OutputArray _corners, InputArray _mask, Stream& stream)
{
// TODO : implement async version
(void) stream;
using namespace cv::cuda::device::gfft;
GpuMat image = _image.getGpuMat();

@ -53,7 +53,7 @@ void cv::cuda::equalizeHist(InputArray, OutputArray, Stream&) { throw_no_cuda();
cv::Ptr<cv::cuda::CLAHE> cv::cuda::createCLAHE(double, cv::Size) { throw_no_cuda(); return cv::Ptr<cv::cuda::CLAHE>(); }
void cv::cuda::evenLevels(OutputArray, int, int, int) { throw_no_cuda(); }
void cv::cuda::evenLevels(OutputArray, int, int, int, Stream&) { throw_no_cuda(); }
void cv::cuda::histEven(InputArray, OutputArray, InputOutputArray, int, int, int, Stream&) { throw_no_cuda(); }
void cv::cuda::histEven(InputArray, GpuMat*, InputOutputArray, int*, int*, int*, Stream&) { throw_no_cuda(); }
@ -460,7 +460,7 @@ namespace
};
}
void cv::cuda::evenLevels(OutputArray _levels, int nLevels, int lowerLevel, int upperLevel)
void cv::cuda::evenLevels(OutputArray _levels, int nLevels, int lowerLevel, int upperLevel, Stream& stream)
{
const int kind = _levels.kind();
@ -475,7 +475,7 @@ void cv::cuda::evenLevels(OutputArray _levels, int nLevels, int lowerLevel, int
nppSafeCall( nppiEvenLevelsHost_32s(host_levels.ptr<Npp32s>(), nLevels, lowerLevel, upperLevel) );
if (kind == _InputArray::CUDA_GPU_MAT)
_levels.getGpuMatRef().upload(host_levels);
_levels.getGpuMatRef().upload(host_levels, stream);
}
namespace hist

@ -74,7 +74,7 @@ namespace
public:
HoughCirclesDetectorImpl(float dp, float minDist, int cannyThreshold, int votesThreshold, int minRadius, int maxRadius, int maxCircles);
void detect(InputArray src, OutputArray circles);
void detect(InputArray src, OutputArray circles, Stream& stream);
void setDp(float dp) { dp_ = dp; }
float getDp() const { return dp_; }
@ -154,8 +154,11 @@ namespace
filterDy_ = cuda::createSobelFilter(CV_8UC1, CV_32S, 0, 1);
}
void HoughCirclesDetectorImpl::detect(InputArray _src, OutputArray circles)
void HoughCirclesDetectorImpl::detect(InputArray _src, OutputArray circles, Stream& stream)
{
// TODO : implement async version
(void) stream;
using namespace cv::cuda::device::hough;
using namespace cv::cuda::device::hough_circles;

@ -75,8 +75,8 @@ namespace
{
}
void detect(InputArray src, OutputArray lines);
void downloadResults(InputArray d_lines, OutputArray h_lines, OutputArray h_votes = noArray());
void detect(InputArray src, OutputArray lines, Stream& stream);
void downloadResults(InputArray d_lines, OutputArray h_lines, OutputArray h_votes, Stream& stream);
void setRho(float rho) { rho_ = rho; }
float getRho() const { return rho_; }
@ -125,8 +125,11 @@ namespace
GpuMat result_;
};
void HoughLinesDetectorImpl::detect(InputArray _src, OutputArray lines)
void HoughLinesDetectorImpl::detect(InputArray _src, OutputArray lines, Stream& stream)
{
// TODO : implement async version
(void) stream;
using namespace cv::cuda::device::hough;
using namespace cv::cuda::device::hough_lines;
@ -170,7 +173,7 @@ namespace
result_.copyTo(lines);
}
void HoughLinesDetectorImpl::downloadResults(InputArray _d_lines, OutputArray h_lines, OutputArray h_votes)
void HoughLinesDetectorImpl::downloadResults(InputArray _d_lines, OutputArray h_lines, OutputArray h_votes, Stream& stream)
{
GpuMat d_lines = _d_lines.getGpuMat();
@ -184,12 +187,18 @@ namespace
CV_Assert( d_lines.rows == 2 && d_lines.type() == CV_32FC2 );
d_lines.row(0).download(h_lines);
if (stream)
d_lines.row(0).download(h_lines, stream);
else
d_lines.row(0).download(h_lines);
if (h_votes.needed())
{
GpuMat d_votes(1, d_lines.cols, CV_32SC1, d_lines.ptr<int>(1));
d_votes.download(h_votes);
if (stream)
d_votes.download(h_votes, stream);
else
d_votes.download(h_votes);
}
}
}

@ -79,7 +79,7 @@ namespace
{
}
void detect(InputArray src, OutputArray lines);
void detect(InputArray src, OutputArray lines, Stream& stream);
void setRho(float rho) { rho_ = rho; }
float getRho() const { return rho_; }
@ -128,8 +128,11 @@ namespace
GpuMat result_;
};
void HoughSegmentDetectorImpl::detect(InputArray _src, OutputArray lines)
void HoughSegmentDetectorImpl::detect(InputArray _src, OutputArray lines, Stream& stream)
{
// TODO : implement async version
(void) stream;
using namespace cv::cuda::device::hough;
using namespace cv::cuda::device::hough_lines;
using namespace cv::cuda::device::hough_segments;

@ -43,7 +43,7 @@
#if !defined HAVE_CUDA || defined(CUDA_DISABLER)
void cv::cuda::meanShiftSegmentation(InputArray, OutputArray, int, int, int, TermCriteria) { throw_no_cuda(); }
void cv::cuda::meanShiftSegmentation(InputArray, OutputArray, int, int, int, TermCriteria, Stream&) { throw_no_cuda(); }
#else
@ -222,7 +222,7 @@ inline int dist2(const cv::Vec2s& lhs, const cv::Vec2s& rhs)
} // anonymous namespace
void cv::cuda::meanShiftSegmentation(InputArray _src, OutputArray _dst, int sp, int sr, int minsize, TermCriteria criteria)
void cv::cuda::meanShiftSegmentation(InputArray _src, OutputArray _dst, int sp, int sr, int minsize, TermCriteria criteria, Stream& stream)
{
GpuMat src = _src.getGpuMat();
@ -235,7 +235,10 @@ void cv::cuda::meanShiftSegmentation(InputArray _src, OutputArray _dst, int sp,
// Perform mean shift procedure and obtain region and spatial maps
GpuMat d_rmap, d_spmap;
cuda::meanShiftProc(src, d_rmap, d_spmap, sp, sr, criteria);
cuda::meanShiftProc(src, d_rmap, d_spmap, sp, sr, criteria, stream);
stream.waitForCompletion();
Mat rmap(d_rmap);
Mat spmap(d_spmap);

Loading…
Cancel
Save