Merge pull request #3458 from thorikawa:kmeans-index-parallel

pull/3465/merge
Vadim Pisarevsky 10 years ago
commit 7106d3e8c2
  1. 108
      modules/flann/include/opencv2/flann/kmeans_index.h

@ -271,6 +271,71 @@ public:
return FLANN_INDEX_KMEANS;
}
class KMeansDistanceComputer : public cv::ParallelLoopBody
{
public:
KMeansDistanceComputer(Distance _distance, const Matrix<ElementType>& _dataset,
const int _branching, const int* _indices, const Matrix<double>& _dcenters, const size_t _veclen,
int* _count, int* _belongs_to, std::vector<DistanceType>& _radiuses, bool& _converged, cv::Mutex& _mtx)
: distance(_distance)
, dataset(_dataset)
, branching(_branching)
, indices(_indices)
, dcenters(_dcenters)
, veclen(_veclen)
, count(_count)
, belongs_to(_belongs_to)
, radiuses(_radiuses)
, converged(_converged)
, mtx(_mtx)
{
}
void operator()(const cv::Range& range) const
{
const int begin = range.start;
const int end = range.end;
for( int i = begin; i<end; ++i)
{
DistanceType sq_dist = distance(dataset[indices[i]], dcenters[0], veclen);
int new_centroid = 0;
for (int j=1; j<branching; ++j) {
DistanceType new_sq_dist = distance(dataset[indices[i]], dcenters[j], veclen);
if (sq_dist>new_sq_dist) {
new_centroid = j;
sq_dist = new_sq_dist;
}
}
if (sq_dist > radiuses[new_centroid]) {
radiuses[new_centroid] = sq_dist;
}
if (new_centroid != belongs_to[i]) {
count[belongs_to[i]]--;
count[new_centroid]++;
belongs_to[i] = new_centroid;
mtx.lock();
converged = false;
mtx.unlock();
}
}
}
private:
Distance distance;
const Matrix<ElementType>& dataset;
const int branching;
const int* indices;
const Matrix<double>& dcenters;
const size_t veclen;
int* count;
int* belongs_to;
std::vector<DistanceType>& radiuses;
bool& converged;
cv::Mutex& mtx;
KMeansDistanceComputer& operator=( const KMeansDistanceComputer & ) { return *this; }
};
/**
* Index constructor
*
@ -658,7 +723,8 @@ private:
return;
}
int* centers_idx = new int[branching];
cv::AutoBuffer<int> centers_idx_buf(branching);
int* centers_idx = (int*)centers_idx_buf;
int centers_length;
(this->*chooseCenters)(branching, indices, indices_length, centers_idx, centers_length);
@ -666,29 +732,30 @@ private:
node->indices = indices;
std::sort(node->indices,node->indices+indices_length);
node->childs = NULL;
delete [] centers_idx;
return;
}
Matrix<double> dcenters(new double[branching*veclen_],branching,veclen_);
cv::AutoBuffer<double> dcenters_buf(branching*veclen_);
Matrix<double> dcenters((double*)dcenters_buf,branching,veclen_);
for (int i=0; i<centers_length; ++i) {
ElementType* vec = dataset_[centers_idx[i]];
for (size_t k=0; k<veclen_; ++k) {
dcenters[i][k] = double(vec[k]);
}
}
delete[] centers_idx;
std::vector<DistanceType> radiuses(branching);
int* count = new int[branching];
cv::AutoBuffer<int> count_buf(branching);
int* count = (int*)count_buf;
for (int i=0; i<branching; ++i) {
radiuses[i] = 0;
count[i] = 0;
}
// assign points to clusters
int* belongs_to = new int[indices_length];
cv::AutoBuffer<int> belongs_to_buf(indices_length);
int* belongs_to = (int*)belongs_to_buf;
for (int i=0; i<indices_length; ++i) {
DistanceType sq_dist = distance_(dataset_[indices[i]], dcenters[0], veclen_);
@ -732,27 +799,9 @@ private:
}
// reassign points to clusters
for (int i=0; i<indices_length; ++i) {
DistanceType sq_dist = distance_(dataset_[indices[i]], dcenters[0], veclen_);
int new_centroid = 0;
for (int j=1; j<branching; ++j) {
DistanceType new_sq_dist = distance_(dataset_[indices[i]], dcenters[j], veclen_);
if (sq_dist>new_sq_dist) {
new_centroid = j;
sq_dist = new_sq_dist;
}
}
if (sq_dist>radiuses[new_centroid]) {
radiuses[new_centroid] = sq_dist;
}
if (new_centroid != belongs_to[i]) {
count[belongs_to[i]]--;
count[new_centroid]++;
belongs_to[i] = new_centroid;
converged = false;
}
}
cv::Mutex mtx;
KMeansDistanceComputer invoker(distance_, dataset_, branching, indices, dcenters, veclen_, count, belongs_to, radiuses, converged, mtx);
parallel_for_(cv::Range(0, (int)indices_length), invoker);
for (int i=0; i<branching; ++i) {
// if one cluster converges to an empty cluster,
@ -823,11 +872,6 @@ private:
computeClustering(node->childs[c],indices+start, end-start, branching, level+1);
start=end;
}
delete[] dcenters.data;
delete[] centers;
delete[] count;
delete[] belongs_to;
}

Loading…
Cancel
Save