added gpu::graphcut for float sources (CUDA 5.0)

pull/5/merge
Vladislav Vinogradov 12 years ago
parent 4f99f69a29
commit b43cec3301
  1. 72
      modules/gpu/src/graphcuts.cpp

@ -78,17 +78,25 @@ namespace
void cv::gpu::graphcut(GpuMat& terminals, GpuMat& leftTransp, GpuMat& rightTransp, GpuMat& top, GpuMat& bottom, GpuMat& labels, GpuMat& buf, Stream& s) void cv::gpu::graphcut(GpuMat& terminals, GpuMat& leftTransp, GpuMat& rightTransp, GpuMat& top, GpuMat& bottom, GpuMat& labels, GpuMat& buf, Stream& s)
{ {
#if (CUDA_VERSION < 5000)
CV_Assert(terminals.type() == CV_32S);
#else
CV_Assert(terminals.type() == CV_32S || terminals.type() == CV_32F);
#endif
Size src_size = terminals.size(); Size src_size = terminals.size();
CV_Assert(terminals.type() == CV_32S);
CV_Assert(leftTransp.size() == Size(src_size.height, src_size.width)); CV_Assert(leftTransp.size() == Size(src_size.height, src_size.width));
CV_Assert(leftTransp.type() == CV_32S); CV_Assert(leftTransp.type() == terminals.type());
CV_Assert(rightTransp.size() == Size(src_size.height, src_size.width)); CV_Assert(rightTransp.size() == Size(src_size.height, src_size.width));
CV_Assert(rightTransp.type() == CV_32S); CV_Assert(rightTransp.type() == terminals.type());
CV_Assert(top.size() == src_size); CV_Assert(top.size() == src_size);
CV_Assert(top.type() == CV_32S); CV_Assert(top.type() == terminals.type());
CV_Assert(bottom.size() == src_size); CV_Assert(bottom.size() == src_size);
CV_Assert(bottom.type() == CV_32S); CV_Assert(bottom.type() == terminals.type());
labels.create(src_size, CV_8U); labels.create(src_size, CV_8U);
@ -107,8 +115,21 @@ void cv::gpu::graphcut(GpuMat& terminals, GpuMat& leftTransp, GpuMat& rightTrans
NppiGraphcutStateHandler state(sznpp, buf.ptr<Npp8u>(), nppiGraphcutInitAlloc); NppiGraphcutStateHandler state(sznpp, buf.ptr<Npp8u>(), nppiGraphcutInitAlloc);
#if (CUDA_VERSION < 5000)
nppSafeCall( nppiGraphcut_32s8u(terminals.ptr<Npp32s>(), leftTransp.ptr<Npp32s>(), rightTransp.ptr<Npp32s>(), top.ptr<Npp32s>(), bottom.ptr<Npp32s>(), nppSafeCall( nppiGraphcut_32s8u(terminals.ptr<Npp32s>(), leftTransp.ptr<Npp32s>(), rightTransp.ptr<Npp32s>(), top.ptr<Npp32s>(), bottom.ptr<Npp32s>(),
static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) ); static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) );
#else
if (terminals.type() == CV_32S)
{
nppSafeCall( nppiGraphcut_32s8u(terminals.ptr<Npp32s>(), leftTransp.ptr<Npp32s>(), rightTransp.ptr<Npp32s>(), top.ptr<Npp32s>(), bottom.ptr<Npp32s>(),
static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) );
}
else
{
nppSafeCall( nppiGraphcut_32f8u(terminals.ptr<Npp32f>(), leftTransp.ptr<Npp32f>(), rightTransp.ptr<Npp32f>(), top.ptr<Npp32f>(), bottom.ptr<Npp32f>(),
static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) );
}
#endif
if (stream == 0) if (stream == 0)
cudaSafeCall( cudaDeviceSynchronize() ); cudaSafeCall( cudaDeviceSynchronize() );
@ -117,33 +138,37 @@ void cv::gpu::graphcut(GpuMat& terminals, GpuMat& leftTransp, GpuMat& rightTrans
void cv::gpu::graphcut(GpuMat& terminals, GpuMat& leftTransp, GpuMat& rightTransp, GpuMat& top, GpuMat& topLeft, GpuMat& topRight, void cv::gpu::graphcut(GpuMat& terminals, GpuMat& leftTransp, GpuMat& rightTransp, GpuMat& top, GpuMat& topLeft, GpuMat& topRight,
GpuMat& bottom, GpuMat& bottomLeft, GpuMat& bottomRight, GpuMat& labels, GpuMat& buf, Stream& s) GpuMat& bottom, GpuMat& bottomLeft, GpuMat& bottomRight, GpuMat& labels, GpuMat& buf, Stream& s)
{ {
Size src_size = terminals.size(); #if (CUDA_VERSION < 5000)
CV_Assert(terminals.type() == CV_32S); CV_Assert(terminals.type() == CV_32S);
#else
CV_Assert(terminals.type() == CV_32S || terminals.type() == CV_32F);
#endif
Size src_size = terminals.size();
CV_Assert(leftTransp.size() == Size(src_size.height, src_size.width)); CV_Assert(leftTransp.size() == Size(src_size.height, src_size.width));
CV_Assert(leftTransp.type() == CV_32S); CV_Assert(leftTransp.type() == terminals.type());
CV_Assert(rightTransp.size() == Size(src_size.height, src_size.width)); CV_Assert(rightTransp.size() == Size(src_size.height, src_size.width));
CV_Assert(rightTransp.type() == CV_32S); CV_Assert(rightTransp.type() == terminals.type());
CV_Assert(top.size() == src_size); CV_Assert(top.size() == src_size);
CV_Assert(top.type() == CV_32S); CV_Assert(top.type() == terminals.type());
CV_Assert(topLeft.size() == src_size); CV_Assert(topLeft.size() == src_size);
CV_Assert(topLeft.type() == CV_32S); CV_Assert(topLeft.type() == terminals.type());
CV_Assert(topRight.size() == src_size); CV_Assert(topRight.size() == src_size);
CV_Assert(topRight.type() == CV_32S); CV_Assert(topRight.type() == terminals.type());
CV_Assert(bottom.size() == src_size); CV_Assert(bottom.size() == src_size);
CV_Assert(bottom.type() == CV_32S); CV_Assert(bottom.type() == terminals.type());
CV_Assert(bottomLeft.size() == src_size); CV_Assert(bottomLeft.size() == src_size);
CV_Assert(bottomLeft.type() == CV_32S); CV_Assert(bottomLeft.type() == terminals.type());
CV_Assert(bottomRight.size() == src_size); CV_Assert(bottomRight.size() == src_size);
CV_Assert(bottomRight.type() == CV_32S); CV_Assert(bottomRight.type() == terminals.type());
labels.create(src_size, CV_8U); labels.create(src_size, CV_8U);
@ -162,10 +187,27 @@ void cv::gpu::graphcut(GpuMat& terminals, GpuMat& leftTransp, GpuMat& rightTrans
NppiGraphcutStateHandler state(sznpp, buf.ptr<Npp8u>(), nppiGraphcut8InitAlloc); NppiGraphcutStateHandler state(sznpp, buf.ptr<Npp8u>(), nppiGraphcut8InitAlloc);
#if (CUDA_VERSION < 5000)
nppSafeCall( nppiGraphcut8_32s8u(terminals.ptr<Npp32s>(), leftTransp.ptr<Npp32s>(), rightTransp.ptr<Npp32s>(), nppSafeCall( nppiGraphcut8_32s8u(terminals.ptr<Npp32s>(), leftTransp.ptr<Npp32s>(), rightTransp.ptr<Npp32s>(),
top.ptr<Npp32s>(), topLeft.ptr<Npp32s>(), topRight.ptr<Npp32s>(), top.ptr<Npp32s>(), topLeft.ptr<Npp32s>(), topRight.ptr<Npp32s>(),
bottom.ptr<Npp32s>(), bottomLeft.ptr<Npp32s>(), bottomRight.ptr<Npp32s>(), bottom.ptr<Npp32s>(), bottomLeft.ptr<Npp32s>(), bottomRight.ptr<Npp32s>(),
static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) ); static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) );
#else
if (terminals.type() == CV_32S)
{
nppSafeCall( nppiGraphcut8_32s8u(terminals.ptr<Npp32s>(), leftTransp.ptr<Npp32s>(), rightTransp.ptr<Npp32s>(),
top.ptr<Npp32s>(), topLeft.ptr<Npp32s>(), topRight.ptr<Npp32s>(),
bottom.ptr<Npp32s>(), bottomLeft.ptr<Npp32s>(), bottomRight.ptr<Npp32s>(),
static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) );
}
else
{
nppSafeCall( nppiGraphcut8_32f8u(terminals.ptr<Npp32f>(), leftTransp.ptr<Npp32f>(), rightTransp.ptr<Npp32f>(),
top.ptr<Npp32f>(), topLeft.ptr<Npp32f>(), topRight.ptr<Npp32f>(),
bottom.ptr<Npp32f>(), bottomLeft.ptr<Npp32f>(), bottomRight.ptr<Npp32f>(),
static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) );
}
#endif
if (stream == 0) if (stream == 0)
cudaSafeCall( cudaDeviceSynchronize() ); cudaSafeCall( cudaDeviceSynchronize() );

Loading…
Cancel
Save