parent
3e0c58c581
commit
56ea364cb2
4 changed files with 226 additions and 204 deletions
@ -0,0 +1,162 @@ |
|||||||
|
#include "stump.hpp" |
||||||
|
|
||||||
|
namespace cv |
||||||
|
{ |
||||||
|
namespace adas |
||||||
|
{ |
||||||
|
|
||||||
|
/* Cumulative sum by rows */ |
||||||
|
static void cumsum(const Mat_<float>& src, Mat_<float> dst) |
||||||
|
{ |
||||||
|
CV_Assert(src.cols > 0); |
||||||
|
|
||||||
|
for( int row = 0; row < src.rows; ++row ) |
||||||
|
{ |
||||||
|
dst(row, 0) = src(row, 0); |
||||||
|
for( int col = 1; col < src.cols; ++col ) |
||||||
|
{ |
||||||
|
dst(row, col) = dst(row, col - 1) + src(row, col); |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
int Stump::train(const Mat& data, const Mat& labels, const Mat& weights) |
||||||
|
{ |
||||||
|
CV_Assert(labels.rows == 1 && labels.cols == data.cols); |
||||||
|
CV_Assert(weights.rows == 1 && weights.cols == data.cols); |
||||||
|
/* Assert that data and labels have int type */ |
||||||
|
/* Assert that weights have float type */ |
||||||
|
|
||||||
|
|
||||||
|
/* Prepare labels for each feature rearranged according to sorted order */ |
||||||
|
Mat sorted_labels(data.rows, data.cols, labels.type()); |
||||||
|
Mat sorted_weights(data.rows, data.cols, weights.type()); |
||||||
|
Mat indices; |
||||||
|
sortIdx(data, indices, cv::SORT_EVERY_ROW | cv::SORT_ASCENDING); |
||||||
|
for( int row = 0; row < indices.rows; ++row ) |
||||||
|
{ |
||||||
|
for( int col = 0; col < indices.cols; ++col ) |
||||||
|
{ |
||||||
|
sorted_labels.at<int>(row, col) = |
||||||
|
labels.at<int>(0, indices.at<int>(row, col)); |
||||||
|
sorted_weights.at<float>(row, col) = |
||||||
|
weights.at<float>(0, indices.at<float>(row, col)); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
/* Sort feature values */ |
||||||
|
Mat sorted_data(data.rows, data.cols, data.type()); |
||||||
|
sort(data, sorted_data, cv::SORT_EVERY_ROW | cv::SORT_ASCENDING); |
||||||
|
|
||||||
|
/* Split positive and negative weights */ |
||||||
|
Mat pos_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols, |
||||||
|
sorted_weights.type()); |
||||||
|
Mat neg_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols, |
||||||
|
sorted_weights.type()); |
||||||
|
for( int row = 0; row < data.rows; ++row ) |
||||||
|
{ |
||||||
|
for( int col = 0; col < data.cols; ++col ) |
||||||
|
{ |
||||||
|
if( sorted_labels.at<int>(row, col) == +1 ) |
||||||
|
{ |
||||||
|
pos_weights.at<float>(row, col) = |
||||||
|
sorted_weights.at<float>(row, col); |
||||||
|
} |
||||||
|
else |
||||||
|
{ |
||||||
|
neg_weights.at<float>(row, col) = |
||||||
|
sorted_weights.at<float>(row, col); |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
/* Compute cumulative sums for fast stump error computation */ |
||||||
|
Mat pos_cum_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols, |
||||||
|
sorted_weights.type()); |
||||||
|
Mat neg_cum_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols, |
||||||
|
sorted_weights.type()); |
||||||
|
cumsum(pos_weights, pos_cum_weights); |
||||||
|
cumsum(neg_weights, neg_cum_weights); |
||||||
|
|
||||||
|
/* Compute total weights of positive and negative samples */ |
||||||
|
float pos_total_weight = pos_cum_weights.at<float>(0, weights.cols - 1); |
||||||
|
float neg_total_weight = neg_cum_weights.at<float>(0, weights.cols - 1); |
||||||
|
|
||||||
|
|
||||||
|
float eps = 1. / 4 * labels.cols; |
||||||
|
|
||||||
|
/* Compute minimal error */ |
||||||
|
float min_err = FLT_MAX; |
||||||
|
int min_row = -1; |
||||||
|
int min_col = -1; |
||||||
|
int min_polarity = 0; |
||||||
|
float min_pos_value = 1, min_neg_value = -1; |
||||||
|
|
||||||
|
for( int row = 0; row < sorted_weights.rows; ++row ) |
||||||
|
{ |
||||||
|
for( int col = 0; col < sorted_weights.cols - 1; ++col ) |
||||||
|
{ |
||||||
|
float err, h_pos, h_neg; |
||||||
|
|
||||||
|
// Direct polarity
|
||||||
|
|
||||||
|
float pos_wrong = pos_cum_weights.at<float>(row, col); |
||||||
|
float pos_right = pos_total_weight - pos_wrong; |
||||||
|
|
||||||
|
float neg_right = neg_cum_weights.at<float>(row, col); |
||||||
|
float neg_wrong = neg_total_weight - neg_right; |
||||||
|
|
||||||
|
h_pos = .5 * log((pos_right + eps) / (pos_wrong + eps)); |
||||||
|
h_neg = .5 * log((neg_wrong + eps) / (neg_right + eps)); |
||||||
|
|
||||||
|
err = sqrt(pos_right * neg_wrong) + sqrt(pos_wrong * neg_right); |
||||||
|
|
||||||
|
if( err < min_err ) |
||||||
|
{ |
||||||
|
min_err = err; |
||||||
|
min_row = row; |
||||||
|
min_col = col; |
||||||
|
min_polarity = +1; |
||||||
|
min_pos_value = h_pos; |
||||||
|
min_neg_value = h_neg; |
||||||
|
} |
||||||
|
|
||||||
|
// Opposite polarity
|
||||||
|
swap(pos_right, pos_wrong); |
||||||
|
swap(neg_right, neg_wrong); |
||||||
|
|
||||||
|
h_pos = -h_pos; |
||||||
|
h_neg = -h_neg; |
||||||
|
|
||||||
|
err = sqrt(pos_right * neg_wrong) + sqrt(pos_wrong * neg_right); |
||||||
|
|
||||||
|
|
||||||
|
if( err < min_err ) |
||||||
|
{ |
||||||
|
min_err = err; |
||||||
|
min_row = row; |
||||||
|
min_col = col; |
||||||
|
min_polarity = -1; |
||||||
|
min_pos_value = h_pos; |
||||||
|
min_neg_value = h_neg; |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
/* Compute threshold, store found values in fields */ |
||||||
|
threshold_ = ( sorted_data.at<int>(min_row, min_col) + |
||||||
|
sorted_data.at<int>(min_row, min_col + 1) ) / 2; |
||||||
|
polarity_ = min_polarity; |
||||||
|
pos_value_ = min_pos_value; |
||||||
|
neg_value_ = min_neg_value; |
||||||
|
|
||||||
|
return min_row; |
||||||
|
} |
||||||
|
|
||||||
|
float Stump::predict(int value) |
||||||
|
{ |
||||||
|
return polarity_ * (value - threshold_) > 0 ? pos_value_ : neg_value_; |
||||||
|
} |
||||||
|
|
||||||
|
} /* namespace adas */ |
||||||
|
} /* namespace cv */ |
@ -0,0 +1,58 @@ |
|||||||
|
#ifndef __OPENCV_ADAS_STUMP_HPP__ |
||||||
|
#define __OPENCV_ADAS_STUMP_HPP__ |
||||||
|
|
||||||
|
#include <opencv2/core.hpp> |
||||||
|
|
||||||
|
namespace cv |
||||||
|
{ |
||||||
|
namespace adas |
||||||
|
{ |
||||||
|
|
||||||
|
class Stump |
||||||
|
{ |
||||||
|
public: |
||||||
|
|
||||||
|
/* Initialize zero stump */ |
||||||
|
Stump(): threshold_(0), polarity_(1), pos_value_(1), neg_value_(-1) {} |
||||||
|
|
||||||
|
/* Initialize stump with given threshold, polarity
|
||||||
|
and classification values */ |
||||||
|
Stump(int threshold, int polarity, float pos_value, float neg_value): |
||||||
|
threshold_(threshold), polarity_(polarity), |
||||||
|
pos_value_(pos_value), neg_value_(neg_value) {} |
||||||
|
|
||||||
|
/* Train stump for given data
|
||||||
|
|
||||||
|
data — matrix of feature values, size M x N, one feature per row |
||||||
|
|
||||||
|
labels — matrix of sample class labels, size 1 x N. Labels can be from |
||||||
|
{-1, +1} |
||||||
|
|
||||||
|
weights — matrix of sample weights, size 1 x N |
||||||
|
|
||||||
|
Returns chosen feature index. Feature enumeration starts from 0 |
||||||
|
*/ |
||||||
|
int train(const Mat& data, const Mat& labels, const Mat& weights); |
||||||
|
|
||||||
|
/* Predict object class given
|
||||||
|
|
||||||
|
value — feature value. Feature must be the same as was chosen |
||||||
|
during training stump |
||||||
|
|
||||||
|
Returns real value, sign(value) means class
|
||||||
|
*/ |
||||||
|
float predict(int value); |
||||||
|
|
||||||
|
private: |
||||||
|
/* Stump decision threshold */ |
||||||
|
int threshold_; |
||||||
|
/* Stump polarity, can be from {-1, +1} */ |
||||||
|
int polarity_; |
||||||
|
/* Classification values for positive and negative classes */ |
||||||
|
float pos_value_, neg_value_; |
||||||
|
}; |
||||||
|
|
||||||
|
} /* namespace adas */ |
||||||
|
} /* namespace cv */ |
||||||
|
|
||||||
|
#endif /* __OPENCV_ADAS_STUMP_HPP__ */ |
Loading…
Reference in new issue