parent
4fae9410e2
commit
0ecbe87c74
2 changed files with 310 additions and 0 deletions
@ -0,0 +1,166 @@ |
||||
#include "waldboost.hpp" |
||||
|
||||
using cv::Mat; |
||||
using cv::Mat_; |
||||
using cv::sort; |
||||
using cv::sortIdx; |
||||
using cv::adas::Stump; |
||||
using cv::adas::WaldBoost; |
||||
|
||||
/* Cumulative sum by rows */ |
||||
static void cumsum(const Mat_<float>& src, Mat_<float> dst) |
||||
{ |
||||
CV_Assert(src.cols > 0); |
||||
|
||||
for( int row = 0; row < src.rows; ++row ) |
||||
{ |
||||
dst(row, 0) = src(row, 0); |
||||
for( int col = 1; col < src.cols; ++col ) |
||||
{ |
||||
dst(row, col) = dst(row, col - 1) + src(row, col); |
||||
} |
||||
} |
||||
} |
||||
|
||||
#include <iostream> |
||||
using std::cout; |
||||
using std::endl; |
||||
|
||||
int Stump::train(const Mat& data, const Mat& labels, const Mat& weights) |
||||
{ |
||||
CV_Assert(labels.rows == 1 && labels.cols == data.cols); |
||||
CV_Assert(weights.rows == 1 && weights.cols == data.cols); |
||||
/* Assert that data and labels have int type */ |
||||
/* Assert that weights have float type */ |
||||
|
||||
|
||||
/* Prepare labels for each feature rearranged according to sorted order */ |
||||
Mat sorted_labels(data.rows, data.cols, labels.type()); |
||||
Mat sorted_weights(data.rows, data.cols, weights.type()); |
||||
Mat indices; |
||||
sortIdx(data, indices, cv::SORT_EVERY_ROW | cv::SORT_ASCENDING); |
||||
for( int row = 0; row < indices.rows; ++row ) |
||||
{ |
||||
for( int col = 0; col < indices.cols; ++col ) |
||||
{ |
||||
sorted_labels.at<int>(row, col) = |
||||
labels.at<int>(0, indices.at<int>(row, col)); |
||||
sorted_weights.at<float>(row, col) = |
||||
weights.at<float>(0, indices.at<float>(row, col)); |
||||
} |
||||
} |
||||
|
||||
/* Sort feature values */ |
||||
Mat sorted_data(data.rows, data.cols, data.type()); |
||||
sort(data, sorted_data, cv::SORT_EVERY_ROW | cv::SORT_ASCENDING); |
||||
|
||||
/* Split positive and negative weights */ |
||||
Mat pos_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols, |
||||
sorted_weights.type()); |
||||
Mat neg_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols, |
||||
sorted_weights.type()); |
||||
for( int row = 0; row < data.rows; ++row ) |
||||
{ |
||||
for( int col = 0; col < data.cols; ++col ) |
||||
{ |
||||
if( sorted_labels.at<int>(row, col) == +1 ) |
||||
{ |
||||
pos_weights.at<float>(row, col) = |
||||
sorted_weights.at<float>(row, col); |
||||
} |
||||
else |
||||
{ |
||||
neg_weights.at<float>(row, col) = |
||||
sorted_weights.at<float>(row, col); |
||||
} |
||||
} |
||||
} |
||||
|
||||
/* Compute cumulative sums for fast stump error computation */ |
||||
Mat pos_cum_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols, |
||||
sorted_weights.type()); |
||||
Mat neg_cum_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols, |
||||
sorted_weights.type()); |
||||
cumsum(pos_weights, pos_cum_weights); |
||||
cumsum(neg_weights, neg_cum_weights); |
||||
|
||||
/* Compute total weights of positive and negative samples */ |
||||
int pos_total_weight = 0, neg_total_weight = 0; |
||||
for( int col = 0; col < labels.cols; ++col ) |
||||
{ |
||||
if( labels.at<int>(0, col) == +1) |
||||
pos_total_weight += weights.at<float>(0, col); |
||||
else |
||||
neg_total_weight += weights.at<float>(0, col); |
||||
} |
||||
|
||||
cout << pos_total_weight << endl; |
||||
cout << neg_total_weight << endl; |
||||
|
||||
cout << pos_weights << endl; |
||||
cout << neg_weights << endl; |
||||
|
||||
cout << pos_cum_weights << endl; |
||||
cout << neg_cum_weights << endl; |
||||
|
||||
/* Compute minimal error */ |
||||
float min_err = FLT_MAX; |
||||
int min_row = -1; |
||||
int min_col = -1; |
||||
int min_polarity = 0; |
||||
for( int row = 0; row < sorted_weights.rows; ++row ) |
||||
{ |
||||
for( int col = 0; col < sorted_weights.cols - 1; ++col ) |
||||
{ |
||||
float err; |
||||
|
||||
err = pos_cum_weights.at<float>(row, col) + |
||||
(neg_total_weight - neg_cum_weights.at<float>(row, col)); |
||||
|
||||
cout << "row " << row << "err " << err << endl; |
||||
|
||||
if( err < min_err ) |
||||
{ |
||||
min_err = err; |
||||
min_row = row; |
||||
min_col = col; |
||||
min_polarity = +1; |
||||
} |
||||
|
||||
|
||||
err = (pos_total_weight - pos_cum_weights.at<float>(row, col)) + |
||||
neg_cum_weights.at<float>(row, col); |
||||
|
||||
cout << "row " << row << "err " << err << endl; |
||||
|
||||
if( err < min_err ) |
||||
{ |
||||
min_err = err; |
||||
min_row = row; |
||||
min_col = col; |
||||
min_polarity = -1; |
||||
} |
||||
} |
||||
} |
||||
|
||||
cout << "min_err: " << min_err << endl; |
||||
/* Compute threshold, store found values in fields */ |
||||
threshold_ = ( sorted_data.at<int>(min_row, min_col) + |
||||
sorted_data.at<int>(min_row, min_col + 1) ) / 2; |
||||
cout << "threshold: " << threshold_ << endl; |
||||
polarity_ = min_polarity; |
||||
|
||||
return min_row; |
||||
} |
||||
|
||||
static inline int sign(int value) |
||||
{ |
||||
if (value > 0) |
||||
return +1; |
||||
return -1; |
||||
} |
||||
|
||||
int Stump::predict(int value) |
||||
{ |
||||
return polarity_ * sign(value - threshold_); |
||||
} |
@ -0,0 +1,144 @@ |
||||
/*
|
||||
|
||||
By downloading, copying, installing or using the software you agree to this |
||||
license. If you do not agree to this license, do not download, install, |
||||
copy or use the software. |
||||
|
||||
|
||||
License Agreement |
||||
For Open Source Computer Vision Library |
||||
(3-clause BSD License) |
||||
|
||||
Copyright (C) 2013, OpenCV Foundation, all rights reserved. |
||||
Third party copyrights are property of their respective owners. |
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, |
||||
are permitted provided that the following conditions are met: |
||||
|
||||
* Redistributions of source code must retain the above copyright notice, |
||||
this list of conditions and the following disclaimer. |
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice, |
||||
this list of conditions and the following disclaimer in the documentation |
||||
and/or other materials provided with the distribution. |
||||
|
||||
* Neither the names of the copyright holders nor the names of the contributors |
||||
may be used to endorse or promote products derived from this software |
||||
without specific prior written permission. |
||||
|
||||
This software is provided by the copyright holders and contributors "as is" and |
||||
any express or implied warranties, including, but not limited to, the implied |
||||
warranties of merchantability and fitness for a particular purpose are |
||||
disclaimed. In no event shall copyright holders or contributors be liable for |
||||
any direct, indirect, incidental, special, exemplary, or consequential damages |
||||
(including, but not limited to, procurement of substitute goods or services; |
||||
loss of use, data, or profits; or business interruption) however caused |
||||
and on any theory of liability, whether in contract, strict liability, |
||||
or tort (including negligence or otherwise) arising in any way out of |
||||
the use of this software, even if advised of the possibility of such damage. |
||||
|
||||
*/ |
||||
|
||||
#ifndef __OPENCV_ADAS_WALDBOOST_HPP__ |
||||
#define __OPENCV_ADAS_WALDBOOST_HPP__ |
||||
|
||||
#include <opencv2/core.hpp> |
||||
|
||||
#include "acffeature.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace adas |
||||
{ |
||||
|
||||
class Stump |
||||
{ |
||||
public: |
||||
|
||||
/* Initialize zero stump */ |
||||
Stump(): threshold_(0), polarity_(1) {}; |
||||
|
||||
/* Initialize stump with given threshold and polarity */ |
||||
Stump(int threshold, int polarity): threshold_(threshold), |
||||
polarity_(polarity) {}; |
||||
|
||||
/* Train stump for given data
|
||||
|
||||
data — matrix of feature values, size M x N, one feature per row |
||||
|
||||
labels — matrix of sample class labels, size 1 x N. Labels can be from |
||||
{-1, +1} |
||||
|
||||
weights — matrix of sample weights, size 1 x N |
||||
|
||||
Returns chosen feature index. Feature enumeration starts from 0 |
||||
*/ |
||||
int train(const Mat& data, const Mat& labels, const Mat& weights); |
||||
|
||||
/* Predict object class given
|
||||
|
||||
value — feature value. Feature must be the same as chose during training |
||||
stump |
||||
|
||||
Returns object class from {-1, +1} |
||||
*/ |
||||
int predict(int value); |
||||
|
||||
private: |
||||
/* Stump decision threshold */ |
||||
int threshold_; |
||||
/* Stump polarity, can be from {-1, +1} */ |
||||
int polarity_; |
||||
/* Stump decision rule:
|
||||
h(value) = polarity * sign(value - threshold) |
||||
*/ |
||||
}; |
||||
|
||||
struct WaldBoostParams |
||||
{ |
||||
int weak_count; |
||||
}; |
||||
|
||||
class WaldBoost |
||||
{ |
||||
public: |
||||
/* Initialize WaldBoost cascade with default of specified parameters */ |
||||
WaldBoost(const WaldBoostParams& params); |
||||
|
||||
/* Train WaldBoost cascade for given data
|
||||
|
||||
data — matrix of feature values, size M x N, one feature per row |
||||
|
||||
labels — matrix of sample class labels, size 1 x N. Labels can be from |
||||
{-1, +1} |
||||
|
||||
Returns feature indices chosen for cascade. |
||||
Feature enumeration starts from 0 |
||||
*/ |
||||
std::vector<int> train(const Mat& data, |
||||
const Mat& labels); |
||||
|
||||
/* Predict object class given object that can compute object features
|
||||
|
||||
feature_evaluator — object that can compute features by demand |
||||
|
||||
Returns confidence_value — measure of confidense that object |
||||
is from class +1 |
||||
*/ |
||||
float predict(const Ptr<ACFFeatureEvaluator>& feature_evaluator); |
||||
|
||||
private: |
||||
/* Parameters for cascade training */ |
||||
WaldBoostParams params_; |
||||
/* Stumps in cascade */ |
||||
std::vector<Stump> stumps_; |
||||
/* Weight for stumps in cascade linear combination */ |
||||
std::vector<float> stump_weights_; |
||||
/* Rejection thresholds for linear combination at every stump evaluation */ |
||||
std::vector<float> thresholds_; |
||||
}; |
||||
|
||||
} /* namespace adas */ |
||||
} /* namespace cv */ |
||||
|
||||
#endif /* __OPENCV_ADAS_WALDBOOST_HPP__ */ |
Loading…
Reference in new issue