|
|
|
@ -1,3 +1,8 @@ |
|
|
|
|
#include <cmath> |
|
|
|
|
|
|
|
|
|
#include <algorithm> |
|
|
|
|
using std::swap; |
|
|
|
|
|
|
|
|
|
#include "waldboost.hpp" |
|
|
|
|
|
|
|
|
|
using cv::Mat; |
|
|
|
@ -6,6 +11,7 @@ using cv::sort; |
|
|
|
|
using cv::sortIdx; |
|
|
|
|
using cv::adas::Stump; |
|
|
|
|
using cv::adas::WaldBoost; |
|
|
|
|
using cv::Ptr; |
|
|
|
|
|
|
|
|
|
/* Cumulative sum by rows */ |
|
|
|
|
static void cumsum(const Mat_<float>& src, Mat_<float> dst) |
|
|
|
@ -81,28 +87,37 @@ int Stump::train(const Mat& data, const Mat& labels, const Mat& weights) |
|
|
|
|
cumsum(neg_weights, neg_cum_weights); |
|
|
|
|
|
|
|
|
|
/* Compute total weights of positive and negative samples */ |
|
|
|
|
int pos_total_weight = 0, neg_total_weight = 0; |
|
|
|
|
for( int col = 0; col < labels.cols; ++col ) |
|
|
|
|
{ |
|
|
|
|
if( labels.at<int>(0, col) == +1) |
|
|
|
|
pos_total_weight += weights.at<float>(0, col); |
|
|
|
|
else |
|
|
|
|
neg_total_weight += weights.at<float>(0, col); |
|
|
|
|
} |
|
|
|
|
float pos_total_weight = pos_cum_weights.at<float>(0, weights.cols - 1); |
|
|
|
|
float neg_total_weight = neg_cum_weights.at<float>(0, weights.cols - 1); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
float eps = 1. / 4 * labels.cols; |
|
|
|
|
|
|
|
|
|
/* Compute minimal error */ |
|
|
|
|
float min_err = FLT_MAX; |
|
|
|
|
int min_row = -1; |
|
|
|
|
int min_col = -1; |
|
|
|
|
int min_polarity = 0; |
|
|
|
|
float min_pos_value = 1, min_neg_value = -1; |
|
|
|
|
|
|
|
|
|
for( int row = 0; row < sorted_weights.rows; ++row ) |
|
|
|
|
{ |
|
|
|
|
for( int col = 0; col < sorted_weights.cols - 1; ++col ) |
|
|
|
|
{ |
|
|
|
|
float err; |
|
|
|
|
float err, h_pos, h_neg; |
|
|
|
|
|
|
|
|
|
// Direct polarity
|
|
|
|
|
|
|
|
|
|
err = pos_cum_weights.at<float>(row, col) + |
|
|
|
|
(neg_total_weight - neg_cum_weights.at<float>(row, col)); |
|
|
|
|
float pos_wrong = pos_cum_weights.at<float>(row, col); |
|
|
|
|
float pos_right = pos_total_weight - pos_wrong; |
|
|
|
|
|
|
|
|
|
float neg_right = neg_cum_weights.at<float>(row, col); |
|
|
|
|
float neg_wrong = neg_total_weight - neg_right; |
|
|
|
|
|
|
|
|
|
h_pos = .5 * log((pos_right + eps) / (pos_wrong + eps)); |
|
|
|
|
h_neg = .5 * log((neg_wrong + eps) / (neg_right + eps)); |
|
|
|
|
|
|
|
|
|
err = sqrt(pos_right * neg_wrong) + sqrt(pos_wrong * neg_right); |
|
|
|
|
|
|
|
|
|
if( err < min_err ) |
|
|
|
|
{ |
|
|
|
@ -110,11 +125,19 @@ int Stump::train(const Mat& data, const Mat& labels, const Mat& weights) |
|
|
|
|
min_row = row; |
|
|
|
|
min_col = col; |
|
|
|
|
min_polarity = +1; |
|
|
|
|
min_pos_value = h_pos; |
|
|
|
|
min_neg_value = h_neg; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Opposite polarity
|
|
|
|
|
swap(pos_right, pos_wrong); |
|
|
|
|
swap(neg_right, neg_wrong); |
|
|
|
|
|
|
|
|
|
h_pos = -h_pos; |
|
|
|
|
h_neg = -h_neg; |
|
|
|
|
|
|
|
|
|
err = sqrt(pos_right * neg_wrong) + sqrt(pos_wrong * neg_right); |
|
|
|
|
|
|
|
|
|
err = (pos_total_weight - pos_cum_weights.at<float>(row, col)) + |
|
|
|
|
neg_cum_weights.at<float>(row, col); |
|
|
|
|
|
|
|
|
|
if( err < min_err ) |
|
|
|
|
{ |
|
|
|
@ -122,6 +145,8 @@ int Stump::train(const Mat& data, const Mat& labels, const Mat& weights) |
|
|
|
|
min_row = row; |
|
|
|
|
min_col = col; |
|
|
|
|
min_polarity = -1; |
|
|
|
|
min_pos_value = h_pos; |
|
|
|
|
min_neg_value = h_neg; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
@ -130,18 +155,13 @@ int Stump::train(const Mat& data, const Mat& labels, const Mat& weights) |
|
|
|
|
threshold_ = ( sorted_data.at<int>(min_row, min_col) + |
|
|
|
|
sorted_data.at<int>(min_row, min_col + 1) ) / 2; |
|
|
|
|
polarity_ = min_polarity; |
|
|
|
|
pos_value_ = min_pos_value; |
|
|
|
|
neg_value_ = min_neg_value; |
|
|
|
|
|
|
|
|
|
return min_row; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static inline int sign(int value) |
|
|
|
|
{ |
|
|
|
|
if (value > 0) |
|
|
|
|
return +1; |
|
|
|
|
return -1; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
int Stump::predict(int value) |
|
|
|
|
float Stump::predict(int value) |
|
|
|
|
{ |
|
|
|
|
return polarity_ * sign(value - threshold_); |
|
|
|
|
return polarity_ * (value - threshold_) > 0 ? pos_value_ : neg_value_; |
|
|
|
|
} |
|
|
|
|