You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
176 lines
5.3 KiB
176 lines
5.3 KiB
// This file is part of OpenCV project. |
|
// It is subject to the license terms in the LICENSE file found in the top-level directory |
|
// of this distribution and at http://opencv.org/license.html. |
|
|
|
#include "precomp.hpp" |
|
#include "kuhn_munkres.hpp" |
|
|
|
#include <algorithm> |
|
#include <limits> |
|
#include <vector> |
|
|
|
namespace cv { |
|
namespace detail { |
|
inline namespace tracking { |
|
|
|
KuhnMunkres::KuhnMunkres() : n_() {} |
|
|
|
std::vector<size_t> KuhnMunkres::Solve(const cv::Mat& dissimilarity_matrix) { |
|
CV_Assert(dissimilarity_matrix.type() == CV_32F); |
|
double min_val; |
|
cv::minMaxLoc(dissimilarity_matrix, &min_val); |
|
CV_Assert(min_val >= 0); |
|
|
|
n_ = std::max(dissimilarity_matrix.rows, dissimilarity_matrix.cols); |
|
dm_ = cv::Mat(n_, n_, CV_32F, cv::Scalar(0)); |
|
marked_ = cv::Mat(n_, n_, CV_8S, cv::Scalar(0)); |
|
points_ = std::vector<cv::Point>(n_ * 2); |
|
|
|
dissimilarity_matrix.copyTo(dm_( |
|
cv::Rect(0, 0, dissimilarity_matrix.cols, dissimilarity_matrix.rows))); |
|
|
|
is_row_visited_ = std::vector<int>(n_, 0); |
|
is_col_visited_ = std::vector<int>(n_, 0); |
|
|
|
Run(); |
|
|
|
std::vector<size_t> results(static_cast<size_t>(marked_.rows), static_cast<size_t>(-1)); |
|
for (int i = 0; i < marked_.rows; i++) { |
|
const auto ptr = marked_.ptr<char>(i); |
|
for (int j = 0; j < marked_.cols; j++) { |
|
if (ptr[j] == kStar) { |
|
results[i] = j; |
|
} |
|
} |
|
} |
|
return results; |
|
} |
|
|
|
void KuhnMunkres::TrySimpleCase() { |
|
auto is_row_visited = std::vector<int>(n_, 0); |
|
auto is_col_visited = std::vector<int>(n_, 0); |
|
|
|
for (int row = 0; row < n_; row++) { |
|
auto ptr = dm_.ptr<float>(row); |
|
auto marked_ptr = marked_.ptr<char>(row); |
|
auto min_val = *std::min_element(ptr, ptr + n_); |
|
for (int col = 0; col < n_; col++) { |
|
ptr[col] -= min_val; |
|
if (ptr[col] == 0 && !is_col_visited[col] && !is_row_visited[row]) { |
|
marked_ptr[col] = kStar; |
|
is_col_visited[col] = 1; |
|
is_row_visited[row] = 1; |
|
} |
|
} |
|
} |
|
} |
|
|
|
bool KuhnMunkres::CheckIfOptimumIsFound() { |
|
int count = 0; |
|
for (int i = 0; i < n_; i++) { |
|
const auto marked_ptr = marked_.ptr<char>(i); |
|
for (int j = 0; j < n_; j++) { |
|
if (marked_ptr[j] == kStar) { |
|
is_col_visited_[j] = 1; |
|
count++; |
|
} |
|
} |
|
} |
|
|
|
return count >= n_; |
|
} |
|
|
|
cv::Point KuhnMunkres::FindUncoveredMinValPos() { |
|
auto min_val = std::numeric_limits<float>::max(); |
|
cv::Point min_val_pos(-1, -1); |
|
for (int i = 0; i < n_; i++) { |
|
if (!is_row_visited_[i]) { |
|
auto dm_ptr = dm_.ptr<float>(i); |
|
for (int j = 0; j < n_; j++) { |
|
if (!is_col_visited_[j] && dm_ptr[j] < min_val) { |
|
min_val = dm_ptr[j]; |
|
min_val_pos = cv::Point(j, i); |
|
} |
|
} |
|
} |
|
} |
|
return min_val_pos; |
|
} |
|
|
|
void KuhnMunkres::UpdateDissimilarityMatrix(float val) { |
|
for (int i = 0; i < n_; i++) { |
|
auto dm_ptr = dm_.ptr<float>(i); |
|
for (int j = 0; j < n_; j++) { |
|
if (is_row_visited_[i]) dm_ptr[j] += val; |
|
if (!is_col_visited_[j]) dm_ptr[j] -= val; |
|
} |
|
} |
|
} |
|
|
|
int KuhnMunkres::FindInRow(int row, int what) { |
|
for (int j = 0; j < n_; j++) { |
|
if (marked_.at<char>(row, j) == what) { |
|
return j; |
|
} |
|
} |
|
return -1; |
|
} |
|
|
|
int KuhnMunkres::FindInCol(int col, int what) { |
|
for (int i = 0; i < n_; i++) { |
|
if (marked_.at<char>(i, col) == what) { |
|
return i; |
|
} |
|
} |
|
return -1; |
|
} |
|
|
|
void KuhnMunkres::Run() { |
|
TrySimpleCase(); |
|
while (!CheckIfOptimumIsFound()) { |
|
while (true) { |
|
auto point = FindUncoveredMinValPos(); |
|
auto min_val = dm_.at<float>(point.y, point.x); |
|
if (min_val > 0) { |
|
UpdateDissimilarityMatrix(min_val); |
|
} else { |
|
marked_.at<char>(point.y, point.x) = kPrime; |
|
int col = FindInRow(point.y, kStar); |
|
if (col >= 0) { |
|
is_row_visited_[point.y] = 1; |
|
is_col_visited_[col] = 0; |
|
} else { |
|
int count = 0; |
|
points_[count] = point; |
|
|
|
while (true) { |
|
int row = FindInCol(points_[count].x, kStar); |
|
if (row >= 0) { |
|
count++; |
|
points_[count] = cv::Point(points_[count - 1].x, row); |
|
int col1 = FindInRow(points_[count].y, kPrime); |
|
count++; |
|
points_[count] = cv::Point(col1, points_[count - 1].y); |
|
} else { |
|
break; |
|
} |
|
} |
|
|
|
for (int i = 0; i < count + 1; i++) { |
|
auto& mark = marked_.at<char>(points_[i].y, points_[i].x); |
|
mark = mark == kStar ? 0 : kStar; |
|
} |
|
|
|
is_row_visited_ = std::vector<int>(n_, 0); |
|
is_col_visited_ = std::vector<int>(n_, 0); |
|
|
|
marked_.setTo(0, marked_ == kPrime); |
|
break; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
}}} // namespace
|