|
|
|
@ -69,128 +69,115 @@ int Stump::train(const Mat& data, const Mat& labels, const Mat& weights) |
|
|
|
|
/* Assert that data and labels have int type */ |
|
|
|
|
/* Assert that weights have float type */ |
|
|
|
|
|
|
|
|
|
Mat_<int> d = Mat_<int>::zeros(1, data.cols); |
|
|
|
|
const Mat_<int>& l = labels; |
|
|
|
|
const Mat_<float>& w = weights; |
|
|
|
|
|
|
|
|
|
/* Prepare labels for each feature rearranged according to sorted order */ |
|
|
|
|
Mat sorted_labels(data.rows, data.cols, labels.type()); |
|
|
|
|
Mat sorted_weights(data.rows, data.cols, weights.type()); |
|
|
|
|
Mat indices; |
|
|
|
|
sortIdx(data, indices, cv::SORT_EVERY_ROW | cv::SORT_ASCENDING); |
|
|
|
|
for( int row = 0; row < indices.rows; ++row ) |
|
|
|
|
{ |
|
|
|
|
for( int col = 0; col < indices.cols; ++col ) |
|
|
|
|
{ |
|
|
|
|
sorted_labels.at<int>(row, col) = |
|
|
|
|
labels.at<int>(0, indices.at<int>(row, col)); |
|
|
|
|
sorted_weights.at<float>(row, col) = |
|
|
|
|
weights.at<float>(0, indices.at<int>(row, col)); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
Mat_<int> indices(1, l.cols); |
|
|
|
|
|
|
|
|
|
/* Sort feature values */ |
|
|
|
|
Mat sorted_data(data.rows, data.cols, data.type()); |
|
|
|
|
sort(data, sorted_data, cv::SORT_EVERY_ROW | cv::SORT_ASCENDING); |
|
|
|
|
Mat_<int> sorted_d(1, data.cols); |
|
|
|
|
Mat_<int> sorted_l(1, l.cols); |
|
|
|
|
Mat_<float> sorted_w(1, w.cols); |
|
|
|
|
|
|
|
|
|
/* Split positive and negative weights */ |
|
|
|
|
Mat pos_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols, |
|
|
|
|
sorted_weights.type()); |
|
|
|
|
Mat neg_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols, |
|
|
|
|
sorted_weights.type()); |
|
|
|
|
|
|
|
|
|
Mat_<float> pos_c_w = Mat_<float>::zeros(1, w.cols); |
|
|
|
|
Mat_<float> neg_c_w = Mat_<float>::zeros(1, w.cols); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
float min_err = FLT_MAX; |
|
|
|
|
int min_row = -1; |
|
|
|
|
int min_thr = -1; |
|
|
|
|
int min_pol = -1; |
|
|
|
|
float min_pos = 1; |
|
|
|
|
float min_neg = -1; |
|
|
|
|
float eps = 1.0f / (4 * l.cols); |
|
|
|
|
|
|
|
|
|
/* For every feature */ |
|
|
|
|
for( int row = 0; row < data.rows; ++row ) |
|
|
|
|
{ |
|
|
|
|
for( int col = 0; col < data.cols; ++col ) |
|
|
|
|
{ |
|
|
|
|
if( sorted_labels.at<int>(row, col) == +1 ) |
|
|
|
|
{ |
|
|
|
|
pos_weights.at<float>(row, col) = |
|
|
|
|
sorted_weights.at<float>(row, col); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
neg_weights.at<float>(row, col) = |
|
|
|
|
sorted_weights.at<float>(row, col); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
d(0, col) = data.at<int>(row, col); |
|
|
|
|
|
|
|
|
|
/* Compute cumulative sums for fast stump error computation */ |
|
|
|
|
Mat pos_cum_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols, |
|
|
|
|
sorted_weights.type()); |
|
|
|
|
Mat neg_cum_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols, |
|
|
|
|
sorted_weights.type()); |
|
|
|
|
cumsum(pos_weights, pos_cum_weights); |
|
|
|
|
cumsum(neg_weights, neg_cum_weights); |
|
|
|
|
sortIdx(d, indices, cv::SORT_EVERY_ROW | cv::SORT_ASCENDING); |
|
|
|
|
|
|
|
|
|
/* Compute total weights of positive and negative samples */ |
|
|
|
|
float pos_total_weight = pos_cum_weights.at<float>(0, weights.cols - 1); |
|
|
|
|
float neg_total_weight = neg_cum_weights.at<float>(0, weights.cols - 1); |
|
|
|
|
for( int col = 0; col < indices.cols; ++col ) |
|
|
|
|
{ |
|
|
|
|
int ind = indices(0, col); |
|
|
|
|
sorted_d(0, col) = d(0, ind); |
|
|
|
|
sorted_l(0, col) = l(0, ind); |
|
|
|
|
sorted_w(0, col) = w(0, ind); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
Mat_<float> pos_w = Mat_<float>::zeros(1, w.cols); |
|
|
|
|
Mat_<float> neg_w = Mat_<float>::zeros(1, w.cols); |
|
|
|
|
for( int col = 0; col < d.cols; ++col ) |
|
|
|
|
{ |
|
|
|
|
float weight = sorted_w(0, col); |
|
|
|
|
if( sorted_l(0, col) == +1) |
|
|
|
|
pos_w(0, col) = weight; |
|
|
|
|
else |
|
|
|
|
neg_w(0, col) = weight; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
float eps = 1.0f / (4 * labels.cols); |
|
|
|
|
cumsum(pos_w, pos_c_w); |
|
|
|
|
cumsum(neg_w, neg_c_w); |
|
|
|
|
|
|
|
|
|
/* Compute minimal error */ |
|
|
|
|
float min_err = FLT_MAX; |
|
|
|
|
int min_row = -1; |
|
|
|
|
int min_col = -1; |
|
|
|
|
int min_polarity = 0; |
|
|
|
|
float min_pos_value = 1, min_neg_value = -1; |
|
|
|
|
float pos_total_w = pos_c_w(0, w.cols - 1); |
|
|
|
|
float neg_total_w = neg_c_w(0, w.cols - 1); |
|
|
|
|
|
|
|
|
|
for( int row = 0; row < sorted_weights.rows; ++row ) |
|
|
|
|
{ |
|
|
|
|
for( int col = 0; col < sorted_weights.cols - 1; ++col ) |
|
|
|
|
for( int col = 0; col < w.cols - 1; ++col ) |
|
|
|
|
{ |
|
|
|
|
float err, h_pos, h_neg; |
|
|
|
|
float pos_wrong, pos_right; |
|
|
|
|
float neg_wrong, neg_right; |
|
|
|
|
|
|
|
|
|
// Direct polarity
|
|
|
|
|
/* Direct polarity */ |
|
|
|
|
|
|
|
|
|
float pos_wrong = pos_cum_weights.at<float>(row, col); |
|
|
|
|
float pos_right = pos_total_weight - pos_wrong; |
|
|
|
|
pos_wrong = pos_c_w(0, col); |
|
|
|
|
pos_right = pos_total_w - pos_wrong; |
|
|
|
|
|
|
|
|
|
float neg_right = neg_cum_weights.at<float>(row, col); |
|
|
|
|
float neg_wrong = neg_total_weight - neg_right; |
|
|
|
|
|
|
|
|
|
h_pos = (float)(.5 * log((pos_right + eps) / (pos_wrong + eps))); |
|
|
|
|
h_neg = (float)(.5 * log((neg_wrong + eps) / (neg_right + eps))); |
|
|
|
|
neg_right = neg_c_w(0, col); |
|
|
|
|
neg_wrong = neg_total_w - neg_right; |
|
|
|
|
|
|
|
|
|
err = sqrt(pos_right * neg_wrong) + sqrt(pos_wrong * neg_right); |
|
|
|
|
|
|
|
|
|
h_pos = .5f * log((pos_right + eps) / (pos_wrong + eps)); |
|
|
|
|
h_neg = .5f * log((neg_wrong + eps) / (neg_right + eps)); |
|
|
|
|
|
|
|
|
|
if( err < min_err ) |
|
|
|
|
{ |
|
|
|
|
min_err = err; |
|
|
|
|
min_row = row; |
|
|
|
|
min_col = col; |
|
|
|
|
min_polarity = +1; |
|
|
|
|
min_pos_value = h_pos; |
|
|
|
|
min_neg_value = h_neg; |
|
|
|
|
min_thr = (sorted_d(0, col) + sorted_d(0, col + 1)) / 2; |
|
|
|
|
min_pol = +1; |
|
|
|
|
min_pos = h_pos; |
|
|
|
|
min_neg = h_neg; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Opposite polarity
|
|
|
|
|
/* Opposite polarity */ |
|
|
|
|
|
|
|
|
|
swap(pos_right, pos_wrong); |
|
|
|
|
swap(neg_right, neg_wrong); |
|
|
|
|
|
|
|
|
|
h_pos = -h_pos; |
|
|
|
|
h_neg = -h_neg; |
|
|
|
|
|
|
|
|
|
err = sqrt(pos_right * neg_wrong) + sqrt(pos_wrong * neg_right); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if( err < min_err ) |
|
|
|
|
{ |
|
|
|
|
min_err = err; |
|
|
|
|
min_row = row; |
|
|
|
|
min_col = col; |
|
|
|
|
min_polarity = -1; |
|
|
|
|
min_pos_value = h_pos; |
|
|
|
|
min_neg_value = h_neg; |
|
|
|
|
min_thr = (sorted_d(0, col) + sorted_d(0, col + 1)) / 2; |
|
|
|
|
min_pol = -1; |
|
|
|
|
min_pos = -h_pos; |
|
|
|
|
min_neg = -h_neg; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/* Compute threshold, store found values in fields */ |
|
|
|
|
threshold_ = ( sorted_data.at<int>(min_row, min_col) + |
|
|
|
|
sorted_data.at<int>(min_row, min_col + 1) ) / 2; |
|
|
|
|
polarity_ = min_polarity; |
|
|
|
|
pos_value_ = min_pos_value; |
|
|
|
|
neg_value_ = min_neg_value; |
|
|
|
|
threshold_ = min_thr; |
|
|
|
|
polarity_ = min_pol; |
|
|
|
|
pos_value_ = min_pos; |
|
|
|
|
neg_value_ = min_neg; |
|
|
|
|
|
|
|
|
|
return min_row; |
|
|
|
|
} |
|
|
|
|