From d62e486b697186a063f74e507a0664c0b26c5f20 Mon Sep 17 00:00:00 2001 From: Vadim Pisarevsky Date: Wed, 28 Mar 2012 14:32:23 +0000 Subject: [PATCH] avoid empty clusters in k-means in a more elegant way (relates to ticket #7698) --- modules/core/src/matrix.cpp | 58 ++++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/modules/core/src/matrix.cpp b/modules/core/src/matrix.cpp index ff9714ada4..6324df42ea 100644 --- a/modules/core/src/matrix.cpp +++ b/modules/core/src/matrix.cpp @@ -2434,6 +2434,7 @@ double cv::kmeans( InputArray _data, int K, attempts = std::max(attempts, 1); CV_Assert( data.dims <= 2 && type == CV_32F && K > 0 ); + CV_Assert( N >= K ); _bestLabels.create(N, 1, CV_32S, -1, true); @@ -2557,18 +2558,61 @@ double cv::kmeans( InputArray _data, int K, if( iter > 0 ) max_center_shift = 0; - + for( k = 0; k < K; k++ ) { - float* center = centers.ptr(k); if( counters[k] != 0 ) + continue; + + // if some cluster appeared to be empty then: + // 1. find the biggest cluster + // 2. find the farthest from the center point in the biggest cluster + // 3. exclude the farthest point from the biggest cluster and form a new 1-point cluster. + int max_k = 0; + for( int k1 = 1; k1 < K; k++ ) { - float scale = 1.f/counters[k]; - for( j = 0; j < dims; j++ ) - center[j] *= scale; + if( counters[max_k] < counters[k1] ) + max_k = k1; + } + + double max_dist = 0; + int farthest_i = -1; + float* new_center = centers.ptr(k); + float* old_center = centers.ptr(max_k); + + for( i = 0; i < N; i++ ) + { + if( labels[i] != max_k ) + continue; + sample = data.ptr(i); + double dist = normL2Sqr_(sample, old_center, dims); + + if( max_dist <= dist ) + { + max_dist = dist; + farthest_i = i; + } } - else - generateRandomCenter(_box, center, rng); + + counters[max_k]--; + counters[k]++; + sample = data.ptr(farthest_i); + + for( j = 0; j < dims; j++ ) + { + old_center[j] -= sample[j]; + new_center[j] += sample[j]; + } + } + + for( k = 0; k < K; k++ ) + { + float* center = centers.ptr(k); + CV_Assert( counters[k] != 0 ); + + float scale = 1.f/counters[k]; + for( j = 0; j < dims; j++ ) + center[j] *= scale; if( iter > 0 ) {