Add Connected Components Labeling in CUDA

pull/3153/head
Stefano Allegretti 3 years ago
parent bb3048d524
commit 80eb045bf6
  1. 10
      modules/cudaimgproc/doc/cudaimgproc.bib
  2. 44
      modules/cudaimgproc/include/opencv2/cudaimgproc.hpp
  3. 58
      modules/cudaimgproc/samples/connected_components.cpp
  4. 49
      modules/cudaimgproc/src/connectedcomponents.cpp
  5. 345
      modules/cudaimgproc/src/cuda/connectedcomponents.cu
  6. 479
      modules/cudaimgproc/test/test_connectedcomponents.cpp

@ -0,0 +1,10 @@
@article{Allegretti2019,
title={Optimized block-based algorithms to label connected components on GPUs},
author={Allegretti, Stefano and Bolelli, Federico and Grana, Costantino},
journal={IEEE Transactions on Parallel and Distributed Systems},
volume={31},
number={2},
pages={423--438},
year={2019},
publisher={IEEE}
}

@ -731,6 +731,50 @@ type.
CV_EXPORTS_W void blendLinear(InputArray img1, InputArray img2, InputArray weights1, InputArray weights2,
OutputArray result, Stream& stream = Stream::Null());
/////////////////// Connected Components Labeling /////////////////////
//! Connected Components Algorithm
enum ConnectedComponentsAlgorithmsTypes {
CCL_DEFAULT = -1, //!< BKE @cite Allegretti2019 algorithm for 8-way connectivity.
CCL_BKE = 0, //!< BKE @cite Allegretti2019 algorithm for 8-way connectivity.
};
/** @brief Computes the Connected Components Labeled image of a binary image.
The function takes as input a binary image and performs Connected Components Labeling. The output
is an image where each Connected Component is assigned a unique label (integer value).
ltype specifies the output label image type, an important consideration based on the total
number of labels or alternatively the total number of pixels in the source image.
ccltype specifies the connected components labeling algorithm to use, currently
BKE @cite Allegretti2019 is supported, see the #ConnectedComponentsAlgorithmsTypes
for details. Note that labels in the output are not required to be sequential.
@param image The 8-bit single-channel image to be labeled.
@param labels Destination labeled image.
@param connectivity Connectivity to use for the labeling procedure. 8 for 8-way connectivity is supported.
@param ltype Output image label type. Currently CV_32S is supported.
@param ccltype Connected components algorithm type (see the #ConnectedComponentsAlgorithmsTypes).
@note A sample program demonstrating Connected Components Labeling in CUDA can be found at\n
opencv_contrib_source_code/modules/cudaimgproc/samples/connected_components.cpp
*/
CV_EXPORTS_AS(connectedComponentsWithAlgorithm) void connectedComponents(InputArray image, OutputArray labels,
int connectivity, int ltype, cv::cuda::ConnectedComponentsAlgorithmsTypes ccltype);
/** @overload
@param image The 8-bit single-channel image to be labeled.
@param labels Destination labeled image.
@param connectivity Connectivity to use for the labeling procedure. 8 for 8-way connectivity is supported.
@param ltype Output image label type. Currently CV_32S is supported.
*/
CV_EXPORTS_W void connectedComponents(InputArray image, OutputArray labels,
int connectivity = 8, int ltype = CV_32S);
//! @}
}} // namespace cv { namespace cuda {

@ -0,0 +1,58 @@
#include <iostream>
#include <opencv2/core/utility.hpp>
#include "opencv2/imgproc.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/cudaimgproc.hpp"
using namespace cv;
using namespace std;
using namespace cv::cuda;
void colorLabels(const Mat1i& labels, Mat3b& colors) {
colors.create(labels.size());
for (int r = 0; r < labels.rows; ++r) {
int const* labels_row = labels.ptr<int>(r);
Vec3b* colors_row = colors.ptr<Vec3b>(r);
for (int c = 0; c < labels.cols; ++c) {
colors_row[c] = Vec3b(labels_row[c] * 131 % 255, labels_row[c] * 241 % 255, labels_row[c] * 251 % 255);
}
}
}
int main(int argc, const char** argv)
{
CommandLineParser parser(argc, argv, "{@image|stuff.jpg|image for converting to a grayscale}");
parser.about("This program finds connected components in a binary image and assign each of them a different color.\n"
"The connected components labeling is performed in GPU.\n");
parser.printMessage();
String inputImage = parser.get<string>(0);
Mat1b img = imread(samples::findFile(inputImage), IMREAD_GRAYSCALE);
Mat1i labels;
if (img.empty())
{
cout << "Could not read input image file: " << inputImage << endl;
return EXIT_FAILURE;
}
GpuMat d_img, d_labels;
d_img.upload(img);
cuda::connectedComponents(d_img, d_labels, 8, CV_32S);
d_labels.download(labels);
Mat3b colors;
colorLabels(labels, colors);
imshow("Labels", colors);
waitKey(0);
return EXIT_SUCCESS;
}

@ -0,0 +1,49 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "precomp.hpp"
using namespace cv;
using namespace cv::cuda;
#if !defined (HAVE_CUDA) || defined (CUDA_DISABLER)
void cv::cuda::connectedComponents(InputArray img_, OutputArray labels_, int connectivity,
int ltype, ConnectedComponentsAlgorithmsTypes ccltype) { throw_no_cuda(); }
#else /* !defined (HAVE_CUDA) */
namespace cv { namespace cuda { namespace device { namespace imgproc {
void BlockBasedKomuraEquivalence(const cv::cuda::GpuMat& img, cv::cuda::GpuMat& labels);
}}}}
void cv::cuda::connectedComponents(InputArray img_, OutputArray labels_, int connectivity,
int ltype, ConnectedComponentsAlgorithmsTypes ccltype) {
const cv::cuda::GpuMat img = img_.getGpuMat();
cv::cuda::GpuMat& labels = labels_.getGpuMatRef();
CV_Assert(img.channels() == 1);
CV_Assert(connectivity == 8);
CV_Assert(ltype == CV_32S);
CV_Assert(ccltype == CCL_BKE || ccltype == CCL_DEFAULT);
int iDepth = img_.depth();
CV_Assert(iDepth == CV_8U || iDepth == CV_8S);
labels.create(img.size(), CV_MAT_DEPTH(ltype));
if ((ccltype == CCL_BKE || ccltype == CCL_DEFAULT) && connectivity == 8 && ltype == CV_32S) {
using cv::cuda::device::imgproc::BlockBasedKomuraEquivalence;
BlockBasedKomuraEquivalence(img, labels);
}
}
void cv::cuda::connectedComponents(InputArray img_, OutputArray labels_, int connectivity, int ltype) {
cv::cuda::connectedComponents(img_, labels_, connectivity, ltype, CCL_DEFAULT);
}
#endif /* !defined (HAVE_CUDA) */

@ -0,0 +1,345 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#if !defined CUDA_DISABLER
#include "opencv2/core/cuda/common.hpp"
#include "opencv2/core/cuda/emulation.hpp"
#include "opencv2/core/cuda/transform.hpp"
#include "opencv2/core/cuda/functional.hpp"
#include "opencv2/core/cuda/utility.hpp"
#include "opencv2/core/cuda.hpp"
using namespace cv::cuda;
using namespace cv::cuda::device;
namespace cv { namespace cuda { namespace device { namespace imgproc {
constexpr int kblock_rows = 16;
constexpr int kblock_cols = 16;
namespace {
enum class Info : unsigned char { a = 0, b = 1, c = 2, d = 3, P = 4, Q = 5, R = 6, S = 7 };
// Only use it with unsigned numeric types
template <typename T>
__device__ __forceinline__ unsigned char HasBit(T bitmap, Info pos) {
return (bitmap >> static_cast<unsigned char>(pos)) & 1;
}
template <typename T>
__device__ __forceinline__ unsigned char HasBit(T bitmap, unsigned char pos) {
return (bitmap >> pos) & 1;
}
// Only use it with unsigned numeric types
__device__ __forceinline__ void SetBit(unsigned char& bitmap, Info pos) {
bitmap |= (1 << static_cast<unsigned char>(pos));
}
// Returns the root index of the UFTree
__device__ unsigned Find(const int* s_buf, unsigned n) {
while (s_buf[n] != n) {
n = s_buf[n];
}
return n;
}
__device__ unsigned FindAndCompress(int* s_buf, unsigned n) {
unsigned id = n;
while (s_buf[n] != n) {
n = s_buf[n];
s_buf[id] = n;
}
return n;
}
// Merges the UFTrees of a and b, linking one root to the other
__device__ void Union(int* s_buf, unsigned a, unsigned b) {
bool done;
do {
a = Find(s_buf, a);
b = Find(s_buf, b);
if (a < b) {
int old = atomicMin(s_buf + b, a);
done = (old == b);
b = old;
}
else if (b < a) {
int old = atomicMin(s_buf + a, b);
done = (old == a);
a = old;
}
else {
done = true;
}
} while (!done);
}
__global__ void InitLabeling(const cuda::PtrStepSzb img, cuda::PtrStepSzi labels, unsigned char* last_pixel) {
unsigned row = (blockIdx.y * kblock_rows + threadIdx.y) * 2;
unsigned col = (blockIdx.x * kblock_cols + threadIdx.x) * 2;
unsigned img_index = row * img.step + col;
unsigned labels_index = row * (labels.step / labels.elem_size) + col;
if (row < labels.rows && col < labels.cols) {
unsigned P = 0;
// Bitmask representing two kinds of information
// Bits 0, 1, 2, 3 are set if pixel a, b, c, d are foreground, respectively
// Bits 4, 5, 6, 7 are set if block P, Q, R, S need to be merged to X in Merge phase
unsigned char info = 0;
char buffer alignas(int)[4];
*(reinterpret_cast<int*>(buffer)) = 0;
// Read pairs of consecutive values in memory at once
if (col + 1 < img.cols) {
// This does not depend on endianness
*(reinterpret_cast<int16_t*>(buffer)) = *(reinterpret_cast<int16_t*>(img.data + img_index));
if (row + 1 < img.rows) {
*(reinterpret_cast<int16_t*>(buffer + 2)) = *(reinterpret_cast<int16_t*>(img.data + img_index + img.step));
}
}
else {
buffer[0] = img.data[img_index];
if (row + 1 < img.rows) {
buffer[2] = img.data[img_index + img.step];
}
}
if (buffer[0]) {
P |= 0x777;
SetBit(info, Info::a);
}
if (buffer[1]) {
P |= (0x777 << 1);
SetBit(info, Info::b);
}
if (buffer[2]) {
P |= (0x777 << 4);
SetBit(info, Info::c);
}
if (buffer[3]) {
SetBit(info, Info::d);
}
if (col == 0) {
P &= 0xEEEE;
}
if (col + 1 >= img.cols) {
P &= 0x3333;
}
else if (col + 2 >= img.cols) {
P &= 0x7777;
}
if (row == 0) {
P &= 0xFFF0;
}
if (row + 1 >= img.rows) {
P &= 0x00FF;
}
else if (row + 2 >= img.rows) {
P &= 0x0FFF;
}
// P is now ready to be used to find neighbor blocks
// P value avoids range errors
int father_offset = 0;
// P square
if (HasBit(P, 0) && img.data[img_index - img.step - 1]) {
father_offset = -(2 * (labels.step / labels.elem_size) + 2);
}
// Q square
if ((HasBit(P, 1) && img.data[img_index - img.step]) || (HasBit(P, 2) && img.data[img_index + 1 - img.step])) {
if (!father_offset) {
father_offset = -(2 * (labels.step / labels.elem_size));
}
else {
SetBit(info, Info::Q);
}
}
// R square
if (HasBit(P, 3) && img.data[img_index + 2 - img.step]) {
if (!father_offset) {
father_offset = -(2 * (labels.step / labels.elem_size) - 2);
}
else {
SetBit(info, Info::R);
}
}
// S square
if ((HasBit(P, 4) && img.data[img_index - 1]) || (HasBit(P, 8) && img.data[img_index + img.step - 1])) {
if (!father_offset) {
father_offset = -2;
}
else {
SetBit(info, Info::S);
}
}
labels.data[labels_index] = labels_index + father_offset;
if (col + 1 < labels.cols) {
last_pixel = reinterpret_cast<unsigned char*>(labels.data + labels_index + 1);
}
else if (row + 1 < labels.rows) {
last_pixel = reinterpret_cast<unsigned char*>(labels.data + labels_index + labels.step / labels.elem_size);
}
*last_pixel = info;
}
}
__global__ void Merge(cuda::PtrStepSzi labels, unsigned char* last_pixel) {
unsigned row = (blockIdx.y * kblock_rows + threadIdx.y) * 2;
unsigned col = (blockIdx.x * kblock_cols + threadIdx.x) * 2;
unsigned labels_index = row * (labels.step / labels.elem_size) + col;
if (row < labels.rows && col < labels.cols) {
if (col + 1 < labels.cols) {
last_pixel = reinterpret_cast<unsigned char*>(labels.data + labels_index + 1);
}
else if (row + 1 < labels.rows) {
last_pixel = reinterpret_cast<unsigned char*>(labels.data + labels_index + labels.step / labels.elem_size);
}
unsigned char info = *last_pixel;
if (HasBit(info, Info::Q)) {
Union(labels.data, labels_index, labels_index - 2 * (labels.step / labels.elem_size));
}
if (HasBit(info, Info::R)) {
Union(labels.data, labels_index, labels_index - 2 * (labels.step / labels.elem_size) + 2);
}
if (HasBit(info, Info::S)) {
Union(labels.data, labels_index, labels_index - 2);
}
}
}
__global__ void Compression(cuda::PtrStepSzi labels) {
unsigned row = (blockIdx.y * kblock_rows + threadIdx.y) * 2;
unsigned col = (blockIdx.x * kblock_cols + threadIdx.x) * 2;
unsigned labels_index = row * (labels.step / labels.elem_size) + col;
if (row < labels.rows && col < labels.cols) {
FindAndCompress(labels.data, labels_index);
}
}
__global__ void FinalLabeling(const cuda::PtrStepSzb img, cuda::PtrStepSzi labels) {
unsigned row = (blockIdx.y * kblock_rows + threadIdx.y) * 2;
unsigned col = (blockIdx.x * kblock_cols + threadIdx.x) * 2;
unsigned labels_index = row * (labels.step / labels.elem_size) + col;
if (row < labels.rows && col < labels.cols) {
int label;
unsigned char info;
unsigned long long buffer;
if (col + 1 < labels.cols) {
buffer = *reinterpret_cast<unsigned long long*>(labels.data + labels_index);
label = (buffer & (0xFFFFFFFF)) + 1;
info = (buffer >> 32) & 0xFFFFFFFF;
}
else {
label = labels[labels_index] + 1;
if (row + 1 < labels.rows) {
info = labels[labels_index + labels.step / labels.elem_size];
}
else {
// Read from the input image
// "a" is already in position 0
info = img[row * img.step + col];
}
}
if (col + 1 < labels.cols) {
*reinterpret_cast<unsigned long long*>(labels.data + labels_index) =
(static_cast<unsigned long long>(HasBit(info, Info::b) * label) << 32) | (HasBit(info, Info::a) * label);
if (row + 1 < labels.rows) {
*reinterpret_cast<unsigned long long*>(labels.data + labels_index + labels.step / labels.elem_size) =
(static_cast<unsigned long long>(HasBit(info, Info::d) * label) << 32) | (HasBit(info, Info::c) * label);
}
}
else {
labels[labels_index] = HasBit(info, Info::a) * label;
if (row + 1 < labels.rows) {
labels[labels_index + (labels.step / labels.elem_size)] = HasBit(info, Info::c) * label;
}
}
}
}
}
void BlockBasedKomuraEquivalence(const cv::cuda::GpuMat& img, cv::cuda::GpuMat& labels) {
dim3 grid_size;
dim3 block_size;
unsigned char* last_pixel;
bool last_pixel_allocated;
last_pixel_allocated = false;
if ((img.rows == 1 || img.cols == 1) && !((img.rows + img.cols) % 2)) {
cudaSafeCall(cudaMalloc(&last_pixel, sizeof(unsigned char)));
last_pixel_allocated = true;
}
else {
last_pixel = labels.data + ((labels.rows - 2) * labels.step) + (labels.cols - 2) * labels.elemSize();
}
grid_size = dim3((((img.cols + 1) / 2) - 1) / kblock_cols + 1, (((img.rows + 1) / 2) - 1) / kblock_rows + 1, 1);
block_size = dim3(kblock_cols, kblock_rows, 1);
InitLabeling << <grid_size, block_size >> > (img, labels, last_pixel);
cudaSafeCall(cudaGetLastError());
Compression << <grid_size, block_size >> > (labels);
cudaSafeCall(cudaGetLastError());
Merge << <grid_size, block_size >> > (labels, last_pixel);
cudaSafeCall(cudaGetLastError());
Compression << <grid_size, block_size >> > (labels);
cudaSafeCall(cudaGetLastError());
FinalLabeling << <grid_size, block_size >> > (img, labels);
cudaSafeCall(cudaGetLastError());
if (last_pixel_allocated) {
cudaSafeCall(cudaFree(last_pixel));
}
cudaSafeCall(cudaDeviceSynchronize());
}
}}}}
#endif /* CUDA_DISABLER */

@ -0,0 +1,479 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "test_precomp.hpp"
#ifdef HAVE_CUDA
namespace opencv_test {
namespace {
// This function force a row major order for the labels
template <typename LabelT>
void normalize_labels_impl(Mat& labels) {
std::map<LabelT, LabelT> map_new_labels;
LabelT i_max_new_label = 0;
for (int r = 0; r < labels.rows; ++r) {
LabelT* const mat_row = labels.ptr<LabelT>(r);
for (int c = 0; c < labels.cols; ++c) {
LabelT iCurLabel = mat_row[c];
if (iCurLabel > 0) {
if (map_new_labels.find(iCurLabel) == map_new_labels.end()) {
map_new_labels[iCurLabel] = ++i_max_new_label;
}
mat_row[c] = map_new_labels.at(iCurLabel);
}
}
}
}
void normalize_labels(Mat& labels) {
int type = labels.type();
int depth = type & CV_MAT_DEPTH_MASK;
int chans = 1 + (type >> CV_CN_SHIFT);
CV_Assert(chans == 1);
CV_Assert(depth == CV_16U || depth == CV_16S || depth == CV_32S);
switch (depth) {
case CV_16U: normalize_labels_impl<ushort>(labels); break;
case CV_16S: normalize_labels_impl<short>(labels); break;
case CV_32S: normalize_labels_impl<int>(labels); break;
default: CV_Assert(0);
}
}
////////////////////////////////////////////////////////
// ConnectedComponents
PARAM_TEST_CASE(ConnectedComponents, cv::cuda::DeviceInfo, int, int, cv::cuda::ConnectedComponentsAlgorithmsTypes)
{
cv::cuda::DeviceInfo devInfo;
int connectivity;
int ltype;
cv::cuda::ConnectedComponentsAlgorithmsTypes algo;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
connectivity = GET_PARAM(1);
ltype = GET_PARAM(2);
algo = GET_PARAM(3);
cv::cuda::setDevice(devInfo.deviceID());
}
};
CUDA_TEST_P(ConnectedComponents, Chessboard_Even)
{
std::initializer_list<int> sizes{ 16, 16 };
cv::Mat1b input;
cv::Mat1i correct_output_int;
cv::Mat correct_output;
// Chessboard image with even number of rows and cols
// Note that this is the maximum number of labels for 4-way connectivity
{
input = cv::Mat1b(sizes, {
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
});
if (connectivity == 8) {
correct_output_int = cv::Mat1i(sizes, {
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
});
}
else {
correct_output_int = cv::Mat1i(sizes, {
1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0,
0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16,
17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0,
0, 25, 0, 26, 0, 27, 0, 28, 0, 29, 0, 30, 0, 31, 0, 32,
33, 0, 34, 0, 35, 0, 36, 0, 37, 0, 38, 0, 39, 0, 40, 0,
0, 41, 0, 42, 0, 43, 0, 44, 0, 45, 0, 46, 0, 47, 0, 48,
49, 0, 50, 0, 51, 0, 52, 0, 53, 0, 54, 0, 55, 0, 56, 0,
0, 57, 0, 58, 0, 59, 0, 60, 0, 61, 0, 62, 0, 63, 0, 64,
65, 0, 66, 0, 67, 0, 68, 0, 69, 0, 70, 0, 71, 0, 72, 0,
0, 73, 0, 74, 0, 75, 0, 76, 0, 77, 0, 78, 0, 79, 0, 80,
81, 0, 82, 0, 83, 0, 84, 0, 85, 0, 86, 0, 87, 0, 88, 0,
0, 89, 0, 90, 0, 91, 0, 92, 0, 93, 0, 94, 0, 95, 0, 96,
97, 0, 98, 0, 99, 0, 100, 0, 101, 0, 102, 0, 103, 0, 104, 0,
0, 105, 0, 106, 0, 107, 0, 108, 0, 109, 0, 110, 0, 111, 0, 112,
113, 0, 114, 0, 115, 0, 116, 0, 117, 0, 118, 0, 119, 0, 120, 0,
0, 121, 0, 122, 0, 123, 0, 124, 0, 125, 0, 126, 0, 127, 0, 128
});
}
}
correct_output_int.convertTo(correct_output, CV_MAT_DEPTH(ltype));
cv::Mat labels;
cv::Mat diff;
cv::cuda::GpuMat d_input;
cv::cuda::GpuMat d_labels;
d_input.upload(input);
EXPECT_NO_THROW(cv::cuda::connectedComponents(d_input, d_labels, connectivity, ltype, algo));
d_labels.download(labels);
normalize_labels(labels);
diff = labels != correct_output;
EXPECT_EQ(cv::countNonZero(diff), 0);
}
CUDA_TEST_P(ConnectedComponents, Chessboard_Odd)
{
std::initializer_list<int> sizes{ 15, 15 };
cv::Mat1b input;
cv::Mat1i correct_output_int;
cv::Mat correct_output;
// Chessboard image with even number of rows and cols
// Note that this is the maximum number of labels for 4-way connectivity
{
input = Mat1b(sizes, {
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
});
if (connectivity == 8) {
correct_output_int = Mat1i(sizes, {
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
});
}
else {
correct_output_int = Mat1i(sizes, {
1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8,
0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0,
16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23,
0, 24, 0, 25, 0, 26, 0, 27, 0, 28, 0, 29, 0, 30, 0,
31, 0, 32, 0, 33, 0, 34, 0, 35, 0, 36, 0, 37, 0, 38,
0, 39, 0, 40, 0, 41, 0, 42, 0, 43, 0, 44, 0, 45, 0,
46, 0, 47, 0, 48, 0, 49, 0, 50, 0, 51, 0, 52, 0, 53,
0, 54, 0, 55, 0, 56, 0, 57, 0, 58, 0, 59, 0, 60, 0,
61, 0, 62, 0, 63, 0, 64, 0, 65, 0, 66, 0, 67, 0, 68,
0, 69, 0, 70, 0, 71, 0, 72, 0, 73, 0, 74, 0, 75, 0,
76, 0, 77, 0, 78, 0, 79, 0, 80, 0, 81, 0, 82, 0, 83,
0, 84, 0, 85, 0, 86, 0, 87, 0, 88, 0, 89, 0, 90, 0,
91, 0, 92, 0, 93, 0, 94, 0, 95, 0, 96, 0, 97, 0, 98,
0, 99, 0, 100, 0, 101, 0, 102, 0, 103, 0, 104, 0, 105, 0,
106, 0, 107, 0, 108, 0, 109, 0, 110, 0, 111, 0, 112, 0, 113
});
}
}
correct_output_int.convertTo(correct_output, CV_MAT_DEPTH(ltype));
cv::Mat labels;
cv::Mat diff;
cv::cuda::GpuMat d_input;
cv::cuda::GpuMat d_labels;
d_input.upload(input);
EXPECT_NO_THROW(cv::cuda::connectedComponents(d_input, d_labels, connectivity, ltype, algo));
d_labels.download(labels);
normalize_labels(labels);
diff = labels != correct_output;
EXPECT_EQ(cv::countNonZero(diff), 0);
}
CUDA_TEST_P(ConnectedComponents, Maxlabels_8conn_Even)
{
std::initializer_list<int> sizes{ 16, 16 };
cv::Mat1b input;
cv::Mat1i correct_output_int;
cv::Mat correct_output;
// Chessboard image with even number of rows and cols
// Note that this is the maximum number of labels for 4-way connectivity
{
input = Mat1b(sizes, {
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
});
correct_output_int = Mat1i(sizes, {
1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
25, 0, 26, 0, 27, 0, 28, 0, 29, 0, 30, 0, 31, 0, 32, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
33, 0, 34, 0, 35, 0, 36, 0, 37, 0, 38, 0, 39, 0, 40, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
41, 0, 42, 0, 43, 0, 44, 0, 45, 0, 46, 0, 47, 0, 48, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
49, 0, 50, 0, 51, 0, 52, 0, 53, 0, 54, 0, 55, 0, 56, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
57, 0, 58, 0, 59, 0, 60, 0, 61, 0, 62, 0, 63, 0, 64, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
});
}
correct_output_int.convertTo(correct_output, CV_MAT_DEPTH(ltype));
cv::Mat labels;
cv::Mat diff;
cv::cuda::GpuMat d_input;
cv::cuda::GpuMat d_labels;
d_input.upload(input);
EXPECT_NO_THROW(cv::cuda::connectedComponents(d_input, d_labels, connectivity, ltype, algo));
d_labels.download(labels);
normalize_labels(labels);
diff = labels != correct_output;
EXPECT_EQ(cv::countNonZero(diff), 0);
}
CUDA_TEST_P(ConnectedComponents, Maxlabels_8conn_Odd)
{
std::initializer_list<int> sizes{ 15, 15 };
cv::Mat1b input;
cv::Mat1i correct_output_int;
cv::Mat correct_output;
// Chessboard image with even number of rows and cols
// Note that this is the maximum number of labels for 4-way connectivity
{
input = Mat1b(sizes, {
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
});
correct_output_int = Mat1i(sizes, {
1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
25, 0, 26, 0, 27, 0, 28, 0, 29, 0, 30, 0, 31, 0, 32,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
33, 0, 34, 0, 35, 0, 36, 0, 37, 0, 38, 0, 39, 0, 40,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
41, 0, 42, 0, 43, 0, 44, 0, 45, 0, 46, 0, 47, 0, 48,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
49, 0, 50, 0, 51, 0, 52, 0, 53, 0, 54, 0, 55, 0, 56,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
57, 0, 58, 0, 59, 0, 60, 0, 61, 0, 62, 0, 63, 0, 64
});
}
correct_output_int.convertTo(correct_output, CV_MAT_DEPTH(ltype));
cv::Mat labels;
cv::Mat diff;
cv::cuda::GpuMat d_input;
cv::cuda::GpuMat d_labels;
d_input.upload(input);
EXPECT_NO_THROW(cv::cuda::connectedComponents(d_input, d_labels, connectivity, ltype, algo));
d_labels.download(labels);
normalize_labels(labels);
diff = labels != correct_output;
EXPECT_EQ(cv::countNonZero(diff), 0);
}
CUDA_TEST_P(ConnectedComponents, Single_Row)
{
std::initializer_list<int> sizes{ 1, 15 };
cv::Mat1b input;
cv::Mat1i correct_output_int;
cv::Mat correct_output;
// Chessboard image with even number of rows and cols
// Note that this is the maximum number of labels for 4-way connectivity
{
input = Mat1b(sizes, { 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1 });
correct_output_int = Mat1i(sizes, { 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8 });
}
correct_output_int.convertTo(correct_output, CV_MAT_DEPTH(ltype));
cv::Mat labels;
cv::Mat diff;
cv::cuda::GpuMat d_input;
cv::cuda::GpuMat d_labels;
d_input.upload(input);
EXPECT_NO_THROW(cv::cuda::connectedComponents(d_input, d_labels, connectivity, ltype, algo));
d_labels.download(labels);
normalize_labels(labels);
diff = labels != correct_output;
EXPECT_EQ(cv::countNonZero(diff), 0);
}
CUDA_TEST_P(ConnectedComponents, Single_Column)
{
std::initializer_list<int> sizes{ 15, 1 };
cv::Mat1b input;
cv::Mat1i correct_output_int;
cv::Mat correct_output;
// Chessboard image with even number of rows and cols
// Note that this is the maximum number of labels for 4-way connectivity
{
input = Mat1b(sizes, { 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1 });
correct_output_int = Mat1i(sizes, { 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8 });
}
correct_output_int.convertTo(correct_output, CV_MAT_DEPTH(ltype));
cv::Mat labels;
cv::Mat diff;
cv::cuda::GpuMat d_input;
cv::cuda::GpuMat d_labels;
d_input.upload(input);
EXPECT_NO_THROW(cv::cuda::connectedComponents(d_input, d_labels, connectivity, ltype, algo));
d_labels.download(labels);
normalize_labels(labels);
diff = labels != correct_output;
EXPECT_EQ(cv::countNonZero(diff), 0);
}
CUDA_TEST_P(ConnectedComponents, Concentric_Circles)
{
string img_path = cvtest::TS::ptr()->get_data_path() + "connectedcomponents/concentric_circles.png";
string exp_path = cvtest::TS::ptr()->get_data_path() + "connectedcomponents/ccomp_exp.png";
Mat img = imread(img_path, 0);
EXPECT_FALSE(img.empty());
Mat exp = imread(exp_path, 0);
EXPECT_FALSE(exp.empty());
Mat labels;
exp.convertTo(exp, ltype);
GpuMat d_img;
GpuMat d_labels;
d_img.upload(img);
EXPECT_NO_THROW(cv::cuda::connectedComponents(d_img, d_labels, connectivity, ltype, algo));
d_labels.download(labels);
normalize_labels(labels);
Mat diff = labels != exp;
EXPECT_EQ(cv::countNonZero(diff), 0);
}
INSTANTIATE_TEST_CASE_P(CUDA_ImgProc, ConnectedComponents, testing::Combine(
ALL_DEVICES,
testing::Values(8),
testing::Values(CV_32S),
testing::Values(cv::cuda::CCL_DEFAULT, cv::cuda::CCL_BKE)
));
}
} // namespace
#endif // HAVE_CUDA
Loading…
Cancel
Save