mirror of https://github.com/opencv/opencv.git
Open Source Computer Vision Library
https://opencv.org/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1946 lines
54 KiB
1946 lines
54 KiB
/*M/////////////////////////////////////////////////////////////////////////////////////// |
|
// |
|
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. |
|
// |
|
// By downloading, copying, installing or using the software you agree to this license. |
|
// If you do not agree to this license, do not download, install, |
|
// copy or use the software. |
|
// |
|
// |
|
// License Agreement |
|
// For Open Source Computer Vision Library |
|
// |
|
// Copyright (C) 2000, Intel Corporation, all rights reserved. |
|
// Copyright (C) 2014, Itseez Inc, all rights reserved. |
|
// Third party copyrights are property of their respective owners. |
|
// |
|
// Redistribution and use in source and binary forms, with or without modification, |
|
// are permitted provided that the following conditions are met: |
|
// |
|
// * Redistribution's of source code must retain the above copyright notice, |
|
// this list of conditions and the following disclaimer. |
|
// |
|
// * Redistribution's in binary form must reproduce the above copyright notice, |
|
// this list of conditions and the following disclaimer in the documentation |
|
// and/or other materials provided with the distribution. |
|
// |
|
// * The name of the copyright holders may not be used to endorse or promote products |
|
// derived from this software without specific prior written permission. |
|
// |
|
// This software is provided by the copyright holders and contributors "as is" and |
|
// any express or implied warranties, including, but not limited to, the implied |
|
// warranties of merchantability and fitness for a particular purpose are disclaimed. |
|
// In no event shall the Intel Corporation or contributors be liable for any direct, |
|
// indirect, incidental, special, exemplary, or consequential damages |
|
// (including, but not limited to, procurement of substitute goods or services; |
|
// loss of use, data, or profits; or business interruption) however caused |
|
// and on any theory of liability, whether in contract, strict liability, |
|
// or tort (including negligence or otherwise) arising in any way out of |
|
// the use of this software, even if advised of the possibility of such damage. |
|
// |
|
//M*/ |
|
|
|
#include "precomp.hpp" |
|
#include <ctype.h> |
|
|
|
namespace cv { |
|
namespace ml { |
|
|
|
using std::vector; |
|
|
|
TreeParams::TreeParams() |
|
{ |
|
maxDepth = INT_MAX; |
|
minSampleCount = 10; |
|
regressionAccuracy = 0.01f; |
|
useSurrogates = false; |
|
maxCategories = 10; |
|
CVFolds = 10; |
|
use1SERule = true; |
|
truncatePrunedTree = true; |
|
priors = Mat(); |
|
} |
|
|
|
TreeParams::TreeParams(int _maxDepth, int _minSampleCount, |
|
double _regressionAccuracy, bool _useSurrogates, |
|
int _maxCategories, int _CVFolds, |
|
bool _use1SERule, bool _truncatePrunedTree, |
|
const Mat& _priors) |
|
{ |
|
maxDepth = _maxDepth; |
|
minSampleCount = _minSampleCount; |
|
regressionAccuracy = (float)_regressionAccuracy; |
|
useSurrogates = _useSurrogates; |
|
maxCategories = _maxCategories; |
|
CVFolds = _CVFolds; |
|
use1SERule = _use1SERule; |
|
truncatePrunedTree = _truncatePrunedTree; |
|
priors = _priors; |
|
} |
|
|
|
DTrees::Node::Node() |
|
{ |
|
classIdx = 0; |
|
value = 0; |
|
parent = left = right = split = defaultDir = -1; |
|
} |
|
|
|
DTrees::Split::Split() |
|
{ |
|
varIdx = 0; |
|
inversed = false; |
|
quality = 0.f; |
|
next = -1; |
|
c = 0.f; |
|
subsetOfs = 0; |
|
} |
|
|
|
|
|
DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data) |
|
{ |
|
data = _data; |
|
vector<int> subsampleIdx; |
|
Mat sidx0 = _data->getTrainSampleIdx(); |
|
if( !sidx0.empty() ) |
|
{ |
|
sidx0.copyTo(sidx); |
|
std::sort(sidx.begin(), sidx.end()); |
|
} |
|
else |
|
{ |
|
int n = _data->getNSamples(); |
|
setRangeVector(sidx, n); |
|
} |
|
|
|
maxSubsetSize = 0; |
|
} |
|
|
|
DTreesImpl::DTreesImpl() {} |
|
DTreesImpl::~DTreesImpl() {} |
|
void DTreesImpl::clear() |
|
{ |
|
varIdx.clear(); |
|
compVarIdx.clear(); |
|
varType.clear(); |
|
catOfs.clear(); |
|
catMap.clear(); |
|
roots.clear(); |
|
nodes.clear(); |
|
splits.clear(); |
|
subsets.clear(); |
|
classLabels.clear(); |
|
|
|
w.release(); |
|
_isClassifier = false; |
|
} |
|
|
|
void DTreesImpl::startTraining( const Ptr<TrainData>& data, int ) |
|
{ |
|
clear(); |
|
w = makePtr<WorkData>(data); |
|
|
|
Mat vtype = data->getVarType(); |
|
vtype.copyTo(varType); |
|
|
|
data->getCatOfs().copyTo(catOfs); |
|
data->getCatMap().copyTo(catMap); |
|
data->getDefaultSubstValues().copyTo(missingSubst); |
|
|
|
int nallvars = data->getNAllVars(); |
|
|
|
Mat vidx0 = data->getVarIdx(); |
|
if( !vidx0.empty() ) |
|
vidx0.copyTo(varIdx); |
|
else |
|
setRangeVector(varIdx, nallvars); |
|
|
|
initCompVarIdx(); |
|
|
|
w->maxSubsetSize = 0; |
|
|
|
int i, nvars = (int)varIdx.size(); |
|
for( i = 0; i < nvars; i++ ) |
|
w->maxSubsetSize = std::max(w->maxSubsetSize, getCatCount(varIdx[i])); |
|
|
|
w->maxSubsetSize = std::max((w->maxSubsetSize + 31)/32, 1); |
|
|
|
data->getSampleWeights().copyTo(w->sample_weights); |
|
|
|
_isClassifier = data->getResponseType() == VAR_CATEGORICAL; |
|
|
|
if( _isClassifier ) |
|
{ |
|
data->getNormCatResponses().copyTo(w->cat_responses); |
|
data->getClassLabels().copyTo(classLabels); |
|
int nclasses = (int)classLabels.size(); |
|
|
|
Mat class_weights = params.priors; |
|
if( !class_weights.empty() ) |
|
{ |
|
if( class_weights.type() != CV_64F || !class_weights.isContinuous() ) |
|
{ |
|
Mat temp; |
|
class_weights.convertTo(temp, CV_64F); |
|
class_weights = temp; |
|
} |
|
CV_Assert( class_weights.checkVector(1, CV_64F) == nclasses ); |
|
|
|
int nsamples = (int)w->cat_responses.size(); |
|
const double* cw = class_weights.ptr<double>(); |
|
CV_Assert( (int)w->sample_weights.size() == nsamples ); |
|
|
|
for( i = 0; i < nsamples; i++ ) |
|
{ |
|
int ci = w->cat_responses[i]; |
|
CV_Assert( 0 <= ci && ci < nclasses ); |
|
w->sample_weights[i] *= cw[ci]; |
|
} |
|
} |
|
} |
|
else |
|
data->getResponses().copyTo(w->ord_responses); |
|
} |
|
|
|
|
|
void DTreesImpl::initCompVarIdx() |
|
{ |
|
int nallvars = (int)varType.size(); |
|
compVarIdx.assign(nallvars, -1); |
|
int i, nvars = (int)varIdx.size(), prevIdx = -1; |
|
for( i = 0; i < nvars; i++ ) |
|
{ |
|
int vi = varIdx[i]; |
|
CV_Assert( 0 <= vi && vi < nallvars && vi > prevIdx ); |
|
prevIdx = vi; |
|
compVarIdx[vi] = i; |
|
} |
|
} |
|
|
|
void DTreesImpl::endTraining() |
|
{ |
|
w.release(); |
|
} |
|
|
|
bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags ) |
|
{ |
|
startTraining(trainData, flags); |
|
bool ok = addTree( w->sidx ) >= 0; |
|
w.release(); |
|
endTraining(); |
|
return ok; |
|
} |
|
|
|
const vector<int>& DTreesImpl::getActiveVars() |
|
{ |
|
return varIdx; |
|
} |
|
|
|
int DTreesImpl::addTree(const vector<int>& sidx ) |
|
{ |
|
size_t n = (params.getMaxDepth() > 0 ? (1 << params.getMaxDepth()) : 1024) + w->wnodes.size(); |
|
|
|
w->wnodes.reserve(n); |
|
w->wsplits.reserve(n); |
|
w->wsubsets.reserve(n*w->maxSubsetSize); |
|
w->wnodes.clear(); |
|
w->wsplits.clear(); |
|
w->wsubsets.clear(); |
|
|
|
int cv_n = params.getCVFolds(); |
|
|
|
if( cv_n > 0 ) |
|
{ |
|
w->cv_Tn.resize(n*cv_n); |
|
w->cv_node_error.resize(n*cv_n); |
|
w->cv_node_risk.resize(n*cv_n); |
|
} |
|
|
|
// build the tree recursively |
|
int w_root = addNodeAndTrySplit(-1, sidx); |
|
int maxdepth = INT_MAX;//pruneCV(root); |
|
|
|
int w_nidx = w_root, pidx = -1, depth = 0; |
|
int root = (int)nodes.size(); |
|
|
|
for(;;) |
|
{ |
|
const WNode& wnode = w->wnodes[w_nidx]; |
|
Node node; |
|
node.parent = pidx; |
|
node.classIdx = wnode.class_idx; |
|
node.value = wnode.value; |
|
node.defaultDir = wnode.defaultDir; |
|
|
|
int wsplit_idx = wnode.split; |
|
if( wsplit_idx >= 0 ) |
|
{ |
|
const WSplit& wsplit = w->wsplits[wsplit_idx]; |
|
Split split; |
|
split.c = wsplit.c; |
|
split.quality = wsplit.quality; |
|
split.inversed = wsplit.inversed; |
|
split.varIdx = wsplit.varIdx; |
|
split.subsetOfs = -1; |
|
if( wsplit.subsetOfs >= 0 ) |
|
{ |
|
int ssize = getSubsetSize(split.varIdx); |
|
split.subsetOfs = (int)subsets.size(); |
|
subsets.resize(split.subsetOfs + ssize); |
|
// This check verifies that subsets index is in the correct range |
|
// as in case ssize == 0 no real resize performed. |
|
// Thus memory kept safe. |
|
// Also this skips useless memcpy call when size parameter is zero |
|
if(ssize > 0) |
|
{ |
|
memcpy(&subsets[split.subsetOfs], &w->wsubsets[wsplit.subsetOfs], ssize*sizeof(int)); |
|
} |
|
} |
|
node.split = (int)splits.size(); |
|
splits.push_back(split); |
|
} |
|
int nidx = (int)nodes.size(); |
|
nodes.push_back(node); |
|
if( pidx >= 0 ) |
|
{ |
|
int w_pidx = w->wnodes[w_nidx].parent; |
|
if( w->wnodes[w_pidx].left == w_nidx ) |
|
{ |
|
nodes[pidx].left = nidx; |
|
} |
|
else |
|
{ |
|
CV_Assert(w->wnodes[w_pidx].right == w_nidx); |
|
nodes[pidx].right = nidx; |
|
} |
|
} |
|
|
|
if( wnode.left >= 0 && depth+1 < maxdepth ) |
|
{ |
|
w_nidx = wnode.left; |
|
pidx = nidx; |
|
depth++; |
|
} |
|
else |
|
{ |
|
int w_pidx = wnode.parent; |
|
while( w_pidx >= 0 && w->wnodes[w_pidx].right == w_nidx ) |
|
{ |
|
w_nidx = w_pidx; |
|
w_pidx = w->wnodes[w_pidx].parent; |
|
nidx = pidx; |
|
pidx = nodes[pidx].parent; |
|
depth--; |
|
} |
|
|
|
if( w_pidx < 0 ) |
|
break; |
|
|
|
w_nidx = w->wnodes[w_pidx].right; |
|
CV_Assert( w_nidx >= 0 ); |
|
} |
|
} |
|
roots.push_back(root); |
|
return root; |
|
} |
|
|
|
void DTreesImpl::setDParams(const TreeParams& _params) |
|
{ |
|
params = _params; |
|
} |
|
|
|
int DTreesImpl::addNodeAndTrySplit( int parent, const vector<int>& sidx ) |
|
{ |
|
w->wnodes.push_back(WNode()); |
|
int nidx = (int)(w->wnodes.size() - 1); |
|
WNode& node = w->wnodes.back(); |
|
|
|
node.parent = parent; |
|
node.depth = parent >= 0 ? w->wnodes[parent].depth + 1 : 0; |
|
int nfolds = params.getCVFolds(); |
|
|
|
if( nfolds > 0 ) |
|
{ |
|
w->cv_Tn.resize((nidx+1)*nfolds); |
|
w->cv_node_error.resize((nidx+1)*nfolds); |
|
w->cv_node_risk.resize((nidx+1)*nfolds); |
|
} |
|
|
|
int i, n = node.sample_count = (int)sidx.size(); |
|
bool can_split = true; |
|
vector<int> sleft, sright; |
|
|
|
calcValue( nidx, sidx ); |
|
|
|
if( n <= params.getMinSampleCount() || node.depth >= params.getMaxDepth() ) |
|
can_split = false; |
|
else if( _isClassifier ) |
|
{ |
|
const int* responses = &w->cat_responses[0]; |
|
const int* s = &sidx[0]; |
|
int first = responses[s[0]]; |
|
for( i = 1; i < n; i++ ) |
|
if( responses[s[i]] != first ) |
|
break; |
|
if( i == n ) |
|
can_split = false; |
|
} |
|
else |
|
{ |
|
if( sqrt(node.node_risk) < params.getRegressionAccuracy() ) |
|
can_split = false; |
|
} |
|
|
|
if( can_split ) |
|
node.split = findBestSplit( sidx ); |
|
|
|
//printf("depth=%d, nidx=%d, parent=%d, n=%d, %s, value=%.1f, risk=%.1f\n", node.depth, nidx, node.parent, n, (node.split < 0 ? "leaf" : varType[w->wsplits[node.split].varIdx] == VAR_CATEGORICAL ? "cat" : "ord"), node.value, node.node_risk); |
|
|
|
if( node.split >= 0 ) |
|
{ |
|
node.defaultDir = calcDir( node.split, sidx, sleft, sright ); |
|
if( params.useSurrogates ) |
|
CV_Error( CV_StsNotImplemented, "surrogate splits are not implemented yet"); |
|
|
|
int left = addNodeAndTrySplit( nidx, sleft ); |
|
int right = addNodeAndTrySplit( nidx, sright ); |
|
w->wnodes[nidx].left = left; |
|
w->wnodes[nidx].right = right; |
|
CV_Assert( w->wnodes[nidx].left > 0 && w->wnodes[nidx].right > 0 ); |
|
} |
|
|
|
return nidx; |
|
} |
|
|
|
int DTreesImpl::findBestSplit( const vector<int>& _sidx ) |
|
{ |
|
const vector<int>& activeVars = getActiveVars(); |
|
int splitidx = -1; |
|
int vi_, nv = (int)activeVars.size(); |
|
AutoBuffer<int> buf(w->maxSubsetSize*2); |
|
int *subset = buf, *best_subset = subset + w->maxSubsetSize; |
|
WSplit split, best_split; |
|
best_split.quality = 0.; |
|
|
|
for( vi_ = 0; vi_ < nv; vi_++ ) |
|
{ |
|
int vi = activeVars[vi_]; |
|
if( varType[vi] == VAR_CATEGORICAL ) |
|
{ |
|
if( _isClassifier ) |
|
split = findSplitCatClass(vi, _sidx, 0, subset); |
|
else |
|
split = findSplitCatReg(vi, _sidx, 0, subset); |
|
} |
|
else |
|
{ |
|
if( _isClassifier ) |
|
split = findSplitOrdClass(vi, _sidx, 0); |
|
else |
|
split = findSplitOrdReg(vi, _sidx, 0); |
|
} |
|
if( split.quality > best_split.quality ) |
|
{ |
|
best_split = split; |
|
std::swap(subset, best_subset); |
|
} |
|
} |
|
|
|
if( best_split.quality > 0 ) |
|
{ |
|
int best_vi = best_split.varIdx; |
|
CV_Assert( compVarIdx[best_split.varIdx] >= 0 && best_vi >= 0 ); |
|
int i, prevsz = (int)w->wsubsets.size(), ssize = getSubsetSize(best_vi); |
|
w->wsubsets.resize(prevsz + ssize); |
|
for( i = 0; i < ssize; i++ ) |
|
w->wsubsets[prevsz + i] = best_subset[i]; |
|
best_split.subsetOfs = prevsz; |
|
w->wsplits.push_back(best_split); |
|
splitidx = (int)(w->wsplits.size()-1); |
|
} |
|
|
|
return splitidx; |
|
} |
|
|
|
void DTreesImpl::calcValue( int nidx, const vector<int>& _sidx ) |
|
{ |
|
WNode* node = &w->wnodes[nidx]; |
|
int i, j, k, n = (int)_sidx.size(), cv_n = params.getCVFolds(); |
|
int m = (int)classLabels.size(); |
|
|
|
cv::AutoBuffer<double> buf(std::max(m, 3)*(cv_n+1)); |
|
|
|
if( cv_n > 0 ) |
|
{ |
|
size_t sz = w->cv_Tn.size(); |
|
w->cv_Tn.resize(sz + cv_n); |
|
w->cv_node_risk.resize(sz + cv_n); |
|
w->cv_node_error.resize(sz + cv_n); |
|
} |
|
|
|
if( _isClassifier ) |
|
{ |
|
// in case of classification tree: |
|
// * node value is the label of the class that has the largest weight in the node. |
|
// * node risk is the weighted number of misclassified samples, |
|
// * j-th cross-validation fold value and risk are calculated as above, |
|
// but using the samples with cv_labels(*)!=j. |
|
// * j-th cross-validation fold error is calculated as the weighted number of |
|
// misclassified samples with cv_labels(*)==j. |
|
|
|
// compute the number of instances of each class |
|
double* cls_count = buf; |
|
double* cv_cls_count = cls_count + m; |
|
|
|
double max_val = -1, total_weight = 0; |
|
int max_k = -1; |
|
|
|
for( k = 0; k < m; k++ ) |
|
cls_count[k] = 0; |
|
|
|
if( cv_n == 0 ) |
|
{ |
|
for( i = 0; i < n; i++ ) |
|
{ |
|
int si = _sidx[i]; |
|
cls_count[w->cat_responses[si]] += w->sample_weights[si]; |
|
} |
|
} |
|
else |
|
{ |
|
for( j = 0; j < cv_n; j++ ) |
|
for( k = 0; k < m; k++ ) |
|
cv_cls_count[j*m + k] = 0; |
|
|
|
for( i = 0; i < n; i++ ) |
|
{ |
|
int si = _sidx[i]; |
|
j = w->cv_labels[si]; k = w->cat_responses[si]; |
|
cv_cls_count[j*m + k] += w->sample_weights[si]; |
|
} |
|
|
|
for( j = 0; j < cv_n; j++ ) |
|
for( k = 0; k < m; k++ ) |
|
cls_count[k] += cv_cls_count[j*m + k]; |
|
} |
|
|
|
for( k = 0; k < m; k++ ) |
|
{ |
|
double val = cls_count[k]; |
|
total_weight += val; |
|
if( max_val < val ) |
|
{ |
|
max_val = val; |
|
max_k = k; |
|
} |
|
} |
|
|
|
node->class_idx = max_k; |
|
node->value = classLabels[max_k]; |
|
node->node_risk = total_weight - max_val; |
|
|
|
for( j = 0; j < cv_n; j++ ) |
|
{ |
|
double sum_k = 0, sum = 0, max_val_k = 0; |
|
max_val = -1; max_k = -1; |
|
|
|
for( k = 0; k < m; k++ ) |
|
{ |
|
double val_k = cv_cls_count[j*m + k]; |
|
double val = cls_count[k] - val_k; |
|
sum_k += val_k; |
|
sum += val; |
|
if( max_val < val ) |
|
{ |
|
max_val = val; |
|
max_val_k = val_k; |
|
max_k = k; |
|
} |
|
} |
|
|
|
w->cv_Tn[nidx*cv_n + j] = INT_MAX; |
|
w->cv_node_risk[nidx*cv_n + j] = sum - max_val; |
|
w->cv_node_error[nidx*cv_n + j] = sum_k - max_val_k; |
|
} |
|
} |
|
else |
|
{ |
|
// in case of regression tree: |
|
// * node value is 1/n*sum_i(Y_i), where Y_i is i-th response, |
|
// n is the number of samples in the node. |
|
// * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2) |
|
// * j-th cross-validation fold value and risk are calculated as above, |
|
// but using the samples with cv_labels(*)!=j. |
|
// * j-th cross-validation fold error is calculated |
|
// using samples with cv_labels(*)==j as the test subset: |
|
// error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2), |
|
// where node_value_j is the node value calculated |
|
// as described in the previous bullet, and summation is done |
|
// over the samples with cv_labels(*)==j. |
|
double sum = 0, sum2 = 0, sumw = 0; |
|
|
|
if( cv_n == 0 ) |
|
{ |
|
for( i = 0; i < n; i++ ) |
|
{ |
|
int si = _sidx[i]; |
|
double wval = w->sample_weights[si]; |
|
double t = w->ord_responses[si]; |
|
sum += t*wval; |
|
sum2 += t*t*wval; |
|
sumw += wval; |
|
} |
|
} |
|
else |
|
{ |
|
double *cv_sum = buf, *cv_sum2 = cv_sum + cv_n; |
|
double* cv_count = (double*)(cv_sum2 + cv_n); |
|
|
|
for( j = 0; j < cv_n; j++ ) |
|
{ |
|
cv_sum[j] = cv_sum2[j] = 0.; |
|
cv_count[j] = 0; |
|
} |
|
|
|
for( i = 0; i < n; i++ ) |
|
{ |
|
int si = _sidx[i]; |
|
j = w->cv_labels[si]; |
|
double wval = w->sample_weights[si]; |
|
double t = w->ord_responses[si]; |
|
cv_sum[j] += t*wval; |
|
cv_sum2[j] += t*t*wval; |
|
cv_count[j] += wval; |
|
} |
|
|
|
for( j = 0; j < cv_n; j++ ) |
|
{ |
|
sum += cv_sum[j]; |
|
sum2 += cv_sum2[j]; |
|
sumw += cv_count[j]; |
|
} |
|
|
|
for( j = 0; j < cv_n; j++ ) |
|
{ |
|
double s = sum - cv_sum[j], si = sum - s; |
|
double s2 = sum2 - cv_sum2[j], s2i = sum2 - s2; |
|
double c = cv_count[j], ci = sumw - c; |
|
double r = si/std::max(ci, DBL_EPSILON); |
|
w->cv_node_risk[nidx*cv_n + j] = s2i - r*r*ci; |
|
w->cv_node_error[nidx*cv_n + j] = s2 - 2*r*s + c*r*r; |
|
w->cv_Tn[nidx*cv_n + j] = INT_MAX; |
|
} |
|
} |
|
|
|
node->node_risk = sum2 - (sum/sumw)*sum; |
|
node->value = sum/sumw; |
|
} |
|
} |
|
|
|
DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality ) |
|
{ |
|
const double epsilon = FLT_EPSILON*2; |
|
int n = (int)_sidx.size(); |
|
int m = (int)classLabels.size(); |
|
|
|
cv::AutoBuffer<uchar> buf(n*(sizeof(float) + sizeof(int)) + m*2*sizeof(double)); |
|
const int* sidx = &_sidx[0]; |
|
const int* responses = &w->cat_responses[0]; |
|
const double* weights = &w->sample_weights[0]; |
|
double* lcw = (double*)(uchar*)buf; |
|
double* rcw = lcw + m; |
|
float* values = (float*)(rcw + m); |
|
int* sorted_idx = (int*)(values + n); |
|
int i, best_i = -1; |
|
double best_val = initQuality; |
|
|
|
for( i = 0; i < m; i++ ) |
|
lcw[i] = rcw[i] = 0.; |
|
|
|
w->data->getValues( vi, _sidx, values ); |
|
|
|
for( i = 0; i < n; i++ ) |
|
{ |
|
sorted_idx[i] = i; |
|
int si = sidx[i]; |
|
rcw[responses[si]] += weights[si]; |
|
} |
|
|
|
std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values)); |
|
|
|
double L = 0, R = 0, lsum2 = 0, rsum2 = 0; |
|
for( i = 0; i < m; i++ ) |
|
{ |
|
double wval = rcw[i]; |
|
R += wval; |
|
rsum2 += wval*wval; |
|
} |
|
|
|
for( i = 0; i < n - 1; i++ ) |
|
{ |
|
int curr = sorted_idx[i]; |
|
int next = sorted_idx[i+1]; |
|
int si = sidx[curr]; |
|
double wval = weights[si], w2 = wval*wval; |
|
L += wval; R -= wval; |
|
int idx = responses[si]; |
|
double lv = lcw[idx], rv = rcw[idx]; |
|
lsum2 += 2*lv*wval + w2; |
|
rsum2 -= 2*rv*wval - w2; |
|
lcw[idx] = lv + wval; rcw[idx] = rv - wval; |
|
|
|
if( values[curr] + epsilon < values[next] ) |
|
{ |
|
double val = (lsum2*R + rsum2*L)/(L*R); |
|
if( best_val < val ) |
|
{ |
|
best_val = val; |
|
best_i = i; |
|
} |
|
} |
|
} |
|
|
|
WSplit split; |
|
if( best_i >= 0 ) |
|
{ |
|
split.varIdx = vi; |
|
split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f; |
|
split.inversed = false; |
|
split.quality = (float)best_val; |
|
} |
|
return split; |
|
} |
|
|
|
// simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector. |
|
void DTreesImpl::clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels ) |
|
{ |
|
int iters = 0, max_iters = 100; |
|
int i, j, idx; |
|
cv::AutoBuffer<double> buf(n + k); |
|
double *v_weights = buf, *c_weights = buf + n; |
|
bool modified = true; |
|
RNG r((uint64)-1); |
|
|
|
// assign labels randomly |
|
for( i = 0; i < n; i++ ) |
|
{ |
|
double sum = 0; |
|
const double* v = vectors + i*m; |
|
labels[i] = i < k ? i : r.uniform(0, k); |
|
|
|
// compute weight of each vector |
|
for( j = 0; j < m; j++ ) |
|
sum += v[j]; |
|
v_weights[i] = sum ? 1./sum : 0.; |
|
} |
|
|
|
for( i = 0; i < n; i++ ) |
|
{ |
|
int i1 = r.uniform(0, n); |
|
int i2 = r.uniform(0, n); |
|
std::swap( labels[i1], labels[i2] ); |
|
} |
|
|
|
for( iters = 0; iters <= max_iters; iters++ ) |
|
{ |
|
// calculate csums |
|
for( i = 0; i < k; i++ ) |
|
{ |
|
for( j = 0; j < m; j++ ) |
|
csums[i*m + j] = 0; |
|
} |
|
|
|
for( i = 0; i < n; i++ ) |
|
{ |
|
const double* v = vectors + i*m; |
|
double* s = csums + labels[i]*m; |
|
for( j = 0; j < m; j++ ) |
|
s[j] += v[j]; |
|
} |
|
|
|
// exit the loop here, when we have up-to-date csums |
|
if( iters == max_iters || !modified ) |
|
break; |
|
|
|
modified = false; |
|
|
|
// calculate weight of each cluster |
|
for( i = 0; i < k; i++ ) |
|
{ |
|
const double* s = csums + i*m; |
|
double sum = 0; |
|
for( j = 0; j < m; j++ ) |
|
sum += s[j]; |
|
c_weights[i] = sum ? 1./sum : 0; |
|
} |
|
|
|
// now for each vector determine the closest cluster |
|
for( i = 0; i < n; i++ ) |
|
{ |
|
const double* v = vectors + i*m; |
|
double alpha = v_weights[i]; |
|
double min_dist2 = DBL_MAX; |
|
int min_idx = -1; |
|
|
|
for( idx = 0; idx < k; idx++ ) |
|
{ |
|
const double* s = csums + idx*m; |
|
double dist2 = 0., beta = c_weights[idx]; |
|
for( j = 0; j < m; j++ ) |
|
{ |
|
double t = v[j]*alpha - s[j]*beta; |
|
dist2 += t*t; |
|
} |
|
if( min_dist2 > dist2 ) |
|
{ |
|
min_dist2 = dist2; |
|
min_idx = idx; |
|
} |
|
} |
|
|
|
if( min_idx != labels[i] ) |
|
modified = true; |
|
labels[i] = min_idx; |
|
} |
|
} |
|
} |
|
|
|
DTreesImpl::WSplit DTreesImpl::findSplitCatClass( int vi, const vector<int>& _sidx, |
|
double initQuality, int* subset ) |
|
{ |
|
int _mi = getCatCount(vi), mi = _mi; |
|
int n = (int)_sidx.size(); |
|
int m = (int)classLabels.size(); |
|
|
|
int base_size = m*(3 + mi) + mi + 1; |
|
if( m > 2 && mi > params.getMaxCategories() ) |
|
base_size += m*std::min(params.getMaxCategories(), n) + mi; |
|
else |
|
base_size += mi; |
|
AutoBuffer<double> buf(base_size + n); |
|
|
|
double* lc = (double*)buf; |
|
double* rc = lc + m; |
|
double* _cjk = rc + m*2, *cjk = _cjk; |
|
double* c_weights = cjk + m*mi; |
|
|
|
int* labels = (int*)(buf + base_size); |
|
w->data->getNormCatValues(vi, _sidx, labels); |
|
const int* responses = &w->cat_responses[0]; |
|
const double* weights = &w->sample_weights[0]; |
|
|
|
int* cluster_labels = 0; |
|
double** dbl_ptr = 0; |
|
int i, j, k, si, idx; |
|
double L = 0, R = 0; |
|
double best_val = initQuality; |
|
int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0; |
|
|
|
// init array of counters: |
|
// c_{jk} - number of samples that have vi-th input variable = j and response = k. |
|
for( j = -1; j < mi; j++ ) |
|
for( k = 0; k < m; k++ ) |
|
cjk[j*m + k] = 0; |
|
|
|
for( i = 0; i < n; i++ ) |
|
{ |
|
si = _sidx[i]; |
|
j = labels[i]; |
|
k = responses[si]; |
|
cjk[j*m + k] += weights[si]; |
|
} |
|
|
|
if( m > 2 ) |
|
{ |
|
if( mi > params.getMaxCategories() ) |
|
{ |
|
mi = std::min(params.getMaxCategories(), n); |
|
cjk = c_weights + _mi; |
|
cluster_labels = (int*)(cjk + m*mi); |
|
clusterCategories( _cjk, _mi, m, cjk, mi, cluster_labels ); |
|
} |
|
subset_i = 1; |
|
subset_n = 1 << mi; |
|
} |
|
else |
|
{ |
|
assert( m == 2 ); |
|
dbl_ptr = (double**)(c_weights + _mi); |
|
for( j = 0; j < mi; j++ ) |
|
dbl_ptr[j] = cjk + j*2 + 1; |
|
std::sort(dbl_ptr, dbl_ptr + mi, cmp_lt_ptr<double>()); |
|
subset_i = 0; |
|
subset_n = mi; |
|
} |
|
|
|
for( k = 0; k < m; k++ ) |
|
{ |
|
double sum = 0; |
|
for( j = 0; j < mi; j++ ) |
|
sum += cjk[j*m + k]; |
|
CV_Assert(sum > 0); |
|
rc[k] = sum; |
|
lc[k] = 0; |
|
} |
|
|
|
for( j = 0; j < mi; j++ ) |
|
{ |
|
double sum = 0; |
|
for( k = 0; k < m; k++ ) |
|
sum += cjk[j*m + k]; |
|
c_weights[j] = sum; |
|
R += c_weights[j]; |
|
} |
|
|
|
for( ; subset_i < subset_n; subset_i++ ) |
|
{ |
|
double lsum2 = 0, rsum2 = 0; |
|
|
|
if( m == 2 ) |
|
idx = (int)(dbl_ptr[subset_i] - cjk)/2; |
|
else |
|
{ |
|
int graycode = (subset_i>>1)^subset_i; |
|
int diff = graycode ^ prevcode; |
|
|
|
// determine index of the changed bit. |
|
Cv32suf u; |
|
idx = diff >= (1 << 16) ? 16 : 0; |
|
u.f = (float)(((diff >> 16) | diff) & 65535); |
|
idx += (u.i >> 23) - 127; |
|
subtract = graycode < prevcode; |
|
prevcode = graycode; |
|
} |
|
|
|
double* crow = cjk + idx*m; |
|
double weight = c_weights[idx]; |
|
if( weight < FLT_EPSILON ) |
|
continue; |
|
|
|
if( !subtract ) |
|
{ |
|
for( k = 0; k < m; k++ ) |
|
{ |
|
double t = crow[k]; |
|
double lval = lc[k] + t; |
|
double rval = rc[k] - t; |
|
lsum2 += lval*lval; |
|
rsum2 += rval*rval; |
|
lc[k] = lval; rc[k] = rval; |
|
} |
|
L += weight; |
|
R -= weight; |
|
} |
|
else |
|
{ |
|
for( k = 0; k < m; k++ ) |
|
{ |
|
double t = crow[k]; |
|
double lval = lc[k] - t; |
|
double rval = rc[k] + t; |
|
lsum2 += lval*lval; |
|
rsum2 += rval*rval; |
|
lc[k] = lval; rc[k] = rval; |
|
} |
|
L -= weight; |
|
R += weight; |
|
} |
|
|
|
if( L > FLT_EPSILON && R > FLT_EPSILON ) |
|
{ |
|
double val = (lsum2*R + rsum2*L)/(L*R); |
|
if( best_val < val ) |
|
{ |
|
best_val = val; |
|
best_subset = subset_i; |
|
} |
|
} |
|
} |
|
|
|
WSplit split; |
|
if( best_subset >= 0 ) |
|
{ |
|
split.varIdx = vi; |
|
split.quality = (float)best_val; |
|
memset( subset, 0, getSubsetSize(vi) * sizeof(int) ); |
|
if( m == 2 ) |
|
{ |
|
for( i = 0; i <= best_subset; i++ ) |
|
{ |
|
idx = (int)(dbl_ptr[i] - cjk) >> 1; |
|
subset[idx >> 5] |= 1 << (idx & 31); |
|
} |
|
} |
|
else |
|
{ |
|
for( i = 0; i < _mi; i++ ) |
|
{ |
|
idx = cluster_labels ? cluster_labels[i] : i; |
|
if( best_subset & (1 << idx) ) |
|
subset[i >> 5] |= 1 << (i & 31); |
|
} |
|
} |
|
} |
|
return split; |
|
} |
|
|
|
DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality ) |
|
{ |
|
const float epsilon = FLT_EPSILON*2; |
|
const double* weights = &w->sample_weights[0]; |
|
int n = (int)_sidx.size(); |
|
|
|
AutoBuffer<uchar> buf(n*(sizeof(int) + sizeof(float))); |
|
|
|
float* values = (float*)(uchar*)buf; |
|
int* sorted_idx = (int*)(values + n); |
|
w->data->getValues(vi, _sidx, values); |
|
const double* responses = &w->ord_responses[0]; |
|
|
|
int i, si, best_i = -1; |
|
double L = 0, R = 0; |
|
double best_val = initQuality, lsum = 0, rsum = 0; |
|
|
|
for( i = 0; i < n; i++ ) |
|
{ |
|
sorted_idx[i] = i; |
|
si = _sidx[i]; |
|
R += weights[si]; |
|
rsum += weights[si]*responses[si]; |
|
} |
|
|
|
std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values)); |
|
|
|
// find the optimal split |
|
for( i = 0; i < n - 1; i++ ) |
|
{ |
|
int curr = sorted_idx[i]; |
|
int next = sorted_idx[i+1]; |
|
si = _sidx[curr]; |
|
double wval = weights[si]; |
|
double t = responses[si]*wval; |
|
L += wval; R -= wval; |
|
lsum += t; rsum -= t; |
|
|
|
if( values[curr] + epsilon < values[next] ) |
|
{ |
|
double val = (lsum*lsum*R + rsum*rsum*L)/(L*R); |
|
if( best_val < val ) |
|
{ |
|
best_val = val; |
|
best_i = i; |
|
} |
|
} |
|
} |
|
|
|
WSplit split; |
|
if( best_i >= 0 ) |
|
{ |
|
split.varIdx = vi; |
|
split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f; |
|
split.inversed = false; |
|
split.quality = (float)best_val; |
|
} |
|
return split; |
|
} |
|
|
|
DTreesImpl::WSplit DTreesImpl::findSplitCatReg( int vi, const vector<int>& _sidx, |
|
double initQuality, int* subset ) |
|
{ |
|
const double* weights = &w->sample_weights[0]; |
|
const double* responses = &w->ord_responses[0]; |
|
int n = (int)_sidx.size(); |
|
int mi = getCatCount(vi); |
|
|
|
AutoBuffer<double> buf(3*mi + 3 + n); |
|
double* sum = (double*)buf + 1; |
|
double* counts = sum + mi + 1; |
|
double** sum_ptr = (double**)(counts + mi); |
|
int* cat_labels = (int*)(sum_ptr + mi); |
|
|
|
w->data->getNormCatValues(vi, _sidx, cat_labels); |
|
|
|
double L = 0, R = 0, best_val = initQuality, lsum = 0, rsum = 0; |
|
int i, si, best_subset = -1, subset_i; |
|
|
|
for( i = -1; i < mi; i++ ) |
|
sum[i] = counts[i] = 0; |
|
|
|
// calculate sum response and weight of each category of the input var |
|
for( i = 0; i < n; i++ ) |
|
{ |
|
int idx = cat_labels[i]; |
|
si = _sidx[i]; |
|
double wval = weights[si]; |
|
sum[idx] += responses[si]*wval; |
|
counts[idx] += wval; |
|
} |
|
|
|
// calculate average response in each category |
|
for( i = 0; i < mi; i++ ) |
|
{ |
|
R += counts[i]; |
|
rsum += sum[i]; |
|
sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0; |
|
sum_ptr[i] = sum + i; |
|
} |
|
|
|
std::sort(sum_ptr, sum_ptr + mi, cmp_lt_ptr<double>()); |
|
|
|
// revert back to unnormalized sums |
|
// (there should be a very little loss in accuracy) |
|
for( i = 0; i < mi; i++ ) |
|
sum[i] *= counts[i]; |
|
|
|
for( subset_i = 0; subset_i < mi-1; subset_i++ ) |
|
{ |
|
int idx = (int)(sum_ptr[subset_i] - sum); |
|
double ni = counts[idx]; |
|
|
|
if( ni > FLT_EPSILON ) |
|
{ |
|
double s = sum[idx]; |
|
lsum += s; L += ni; |
|
rsum -= s; R -= ni; |
|
|
|
if( L > FLT_EPSILON && R > FLT_EPSILON ) |
|
{ |
|
double val = (lsum*lsum*R + rsum*rsum*L)/(L*R); |
|
if( best_val < val ) |
|
{ |
|
best_val = val; |
|
best_subset = subset_i; |
|
} |
|
} |
|
} |
|
} |
|
|
|
WSplit split; |
|
if( best_subset >= 0 ) |
|
{ |
|
split.varIdx = vi; |
|
split.quality = (float)best_val; |
|
memset( subset, 0, getSubsetSize(vi) * sizeof(int)); |
|
for( i = 0; i <= best_subset; i++ ) |
|
{ |
|
int idx = (int)(sum_ptr[i] - sum); |
|
subset[idx >> 5] |= 1 << (idx & 31); |
|
} |
|
} |
|
return split; |
|
} |
|
|
|
int DTreesImpl::calcDir( int splitidx, const vector<int>& _sidx, |
|
vector<int>& _sleft, vector<int>& _sright ) |
|
{ |
|
WSplit split = w->wsplits[splitidx]; |
|
int i, si, n = (int)_sidx.size(), vi = split.varIdx; |
|
_sleft.reserve(n); |
|
_sright.reserve(n); |
|
_sleft.clear(); |
|
_sright.clear(); |
|
|
|
AutoBuffer<float> buf(n); |
|
int mi = getCatCount(vi); |
|
double wleft = 0, wright = 0; |
|
const double* weights = &w->sample_weights[0]; |
|
|
|
if( mi <= 0 ) // split on an ordered variable |
|
{ |
|
float c = split.c; |
|
float* values = buf; |
|
w->data->getValues(vi, _sidx, values); |
|
|
|
for( i = 0; i < n; i++ ) |
|
{ |
|
si = _sidx[i]; |
|
if( values[i] <= c ) |
|
{ |
|
_sleft.push_back(si); |
|
wleft += weights[si]; |
|
} |
|
else |
|
{ |
|
_sright.push_back(si); |
|
wright += weights[si]; |
|
} |
|
} |
|
} |
|
else |
|
{ |
|
const int* subset = &w->wsubsets[split.subsetOfs]; |
|
int* cat_labels = (int*)(float*)buf; |
|
w->data->getNormCatValues(vi, _sidx, cat_labels); |
|
|
|
for( i = 0; i < n; i++ ) |
|
{ |
|
si = _sidx[i]; |
|
unsigned u = cat_labels[i]; |
|
if( CV_DTREE_CAT_DIR(u, subset) < 0 ) |
|
{ |
|
_sleft.push_back(si); |
|
wleft += weights[si]; |
|
} |
|
else |
|
{ |
|
_sright.push_back(si); |
|
wright += weights[si]; |
|
} |
|
} |
|
} |
|
CV_Assert( (int)_sleft.size() < n && (int)_sright.size() < n ); |
|
return wleft > wright ? -1 : 1; |
|
} |
|
|
|
int DTreesImpl::pruneCV( int root ) |
|
{ |
|
vector<double> ab; |
|
|
|
// 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}. |
|
// 2. choose the best tree index (if need, apply 1SE rule). |
|
// 3. store the best index and cut the branches. |
|
|
|
int ti, tree_count = 0, j, cv_n = params.getCVFolds(), n = w->wnodes[root].sample_count; |
|
// currently, 1SE for regression is not implemented |
|
bool use_1se = params.use1SERule != 0 && _isClassifier; |
|
double min_err = 0, min_err_se = 0; |
|
int min_idx = -1; |
|
|
|
// build the main tree sequence, calculate alpha's |
|
for(;;tree_count++) |
|
{ |
|
double min_alpha = updateTreeRNC(root, tree_count, -1); |
|
if( cutTree(root, tree_count, -1, min_alpha) ) |
|
break; |
|
|
|
ab.push_back(min_alpha); |
|
} |
|
|
|
if( tree_count > 0 ) |
|
{ |
|
ab[0] = 0.; |
|
|
|
for( ti = 1; ti < tree_count-1; ti++ ) |
|
ab[ti] = std::sqrt(ab[ti]*ab[ti+1]); |
|
ab[tree_count-1] = DBL_MAX*0.5; |
|
|
|
Mat err_jk(cv_n, tree_count, CV_64F); |
|
|
|
for( j = 0; j < cv_n; j++ ) |
|
{ |
|
int tj = 0, tk = 0; |
|
for( ; tj < tree_count; tj++ ) |
|
{ |
|
double min_alpha = updateTreeRNC(root, tj, j); |
|
if( cutTree(root, tj, j, min_alpha) ) |
|
min_alpha = DBL_MAX; |
|
|
|
for( ; tk < tree_count; tk++ ) |
|
{ |
|
if( ab[tk] > min_alpha ) |
|
break; |
|
err_jk.at<double>(j, tk) = w->wnodes[root].tree_error; |
|
} |
|
} |
|
} |
|
|
|
for( ti = 0; ti < tree_count; ti++ ) |
|
{ |
|
double sum_err = 0; |
|
for( j = 0; j < cv_n; j++ ) |
|
sum_err += err_jk.at<double>(j, ti); |
|
if( ti == 0 || sum_err < min_err ) |
|
{ |
|
min_err = sum_err; |
|
min_idx = ti; |
|
if( use_1se ) |
|
min_err_se = sqrt( sum_err*(n - sum_err) ); |
|
} |
|
else if( sum_err < min_err + min_err_se ) |
|
min_idx = ti; |
|
} |
|
} |
|
|
|
return min_idx; |
|
} |
|
|
|
double DTreesImpl::updateTreeRNC( int root, double T, int fold ) |
|
{ |
|
int nidx = root, pidx = -1, cv_n = params.getCVFolds(); |
|
double min_alpha = DBL_MAX; |
|
|
|
for(;;) |
|
{ |
|
WNode *node = 0, *parent = 0; |
|
|
|
for(;;) |
|
{ |
|
node = &w->wnodes[nidx]; |
|
double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn; |
|
if( t <= T || node->left < 0 ) |
|
{ |
|
node->complexity = 1; |
|
node->tree_risk = node->node_risk; |
|
node->tree_error = 0.; |
|
if( fold >= 0 ) |
|
{ |
|
node->tree_risk = w->cv_node_risk[nidx*cv_n + fold]; |
|
node->tree_error = w->cv_node_error[nidx*cv_n + fold]; |
|
} |
|
break; |
|
} |
|
nidx = node->left; |
|
} |
|
|
|
for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx; |
|
nidx = pidx, pidx = w->wnodes[pidx].parent ) |
|
{ |
|
node = &w->wnodes[nidx]; |
|
parent = &w->wnodes[pidx]; |
|
parent->complexity += node->complexity; |
|
parent->tree_risk += node->tree_risk; |
|
parent->tree_error += node->tree_error; |
|
|
|
parent->alpha = ((fold >= 0 ? w->cv_node_risk[pidx*cv_n + fold] : parent->node_risk) |
|
- parent->tree_risk)/(parent->complexity - 1); |
|
min_alpha = std::min( min_alpha, parent->alpha ); |
|
} |
|
|
|
if( pidx < 0 ) |
|
break; |
|
|
|
node = &w->wnodes[nidx]; |
|
parent = &w->wnodes[pidx]; |
|
parent->complexity = node->complexity; |
|
parent->tree_risk = node->tree_risk; |
|
parent->tree_error = node->tree_error; |
|
nidx = parent->right; |
|
} |
|
|
|
return min_alpha; |
|
} |
|
|
|
bool DTreesImpl::cutTree( int root, double T, int fold, double min_alpha ) |
|
{ |
|
int cv_n = params.getCVFolds(), nidx = root, pidx = -1; |
|
WNode* node = &w->wnodes[root]; |
|
if( node->left < 0 ) |
|
return true; |
|
|
|
for(;;) |
|
{ |
|
for(;;) |
|
{ |
|
node = &w->wnodes[nidx]; |
|
double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn; |
|
if( t <= T || node->left < 0 ) |
|
break; |
|
if( node->alpha <= min_alpha + FLT_EPSILON ) |
|
{ |
|
if( fold >= 0 ) |
|
w->cv_Tn[nidx*cv_n + fold] = T; |
|
else |
|
node->Tn = T; |
|
if( nidx == root ) |
|
return true; |
|
break; |
|
} |
|
nidx = node->left; |
|
} |
|
|
|
for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx; |
|
nidx = pidx, pidx = w->wnodes[pidx].parent ) |
|
; |
|
|
|
if( pidx < 0 ) |
|
break; |
|
|
|
nidx = w->wnodes[pidx].right; |
|
} |
|
|
|
return false; |
|
} |
|
|
|
float DTreesImpl::predictTrees( const Range& range, const Mat& sample, int flags ) const |
|
{ |
|
CV_Assert( sample.type() == CV_32F ); |
|
|
|
int predictType = flags & PREDICT_MASK; |
|
int nvars = (int)varIdx.size(); |
|
if( nvars == 0 ) |
|
nvars = (int)varType.size(); |
|
int i, ncats = (int)catOfs.size(), nclasses = (int)classLabels.size(); |
|
int catbufsize = ncats > 0 ? nvars : 0; |
|
AutoBuffer<int> buf(nclasses + catbufsize + 1); |
|
int* votes = buf; |
|
int* catbuf = votes + nclasses; |
|
const int* cvidx = (flags & (COMPRESSED_INPUT|PREPROCESSED_INPUT)) == 0 && !varIdx.empty() ? &compVarIdx[0] : 0; |
|
const uchar* vtype = &varType[0]; |
|
const Vec2i* cofs = !catOfs.empty() ? &catOfs[0] : 0; |
|
const int* cmap = !catMap.empty() ? &catMap[0] : 0; |
|
const float* psample = sample.ptr<float>(); |
|
const float* missingSubstPtr = !missingSubst.empty() ? &missingSubst[0] : 0; |
|
size_t sstep = sample.isContinuous() ? 1 : sample.step/sizeof(float); |
|
double sum = 0.; |
|
int lastClassIdx = -1; |
|
const float MISSED_VAL = TrainData::missingValue(); |
|
|
|
for( i = 0; i < catbufsize; i++ ) |
|
catbuf[i] = -1; |
|
|
|
if( predictType == PREDICT_AUTO ) |
|
{ |
|
predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ? |
|
PREDICT_SUM : PREDICT_MAX_VOTE; |
|
} |
|
|
|
if( predictType == PREDICT_MAX_VOTE ) |
|
{ |
|
for( i = 0; i < nclasses; i++ ) |
|
votes[i] = 0; |
|
} |
|
|
|
for( int ridx = range.start; ridx < range.end; ridx++ ) |
|
{ |
|
int nidx = roots[ridx], prev = nidx, c = 0; |
|
|
|
for(;;) |
|
{ |
|
prev = nidx; |
|
const Node& node = nodes[nidx]; |
|
if( node.split < 0 ) |
|
break; |
|
const Split& split = splits[node.split]; |
|
int vi = split.varIdx; |
|
int ci = cvidx ? cvidx[vi] : vi; |
|
float val = psample[ci*sstep]; |
|
if( val == MISSED_VAL ) |
|
{ |
|
if( !missingSubstPtr ) |
|
{ |
|
nidx = node.defaultDir < 0 ? node.left : node.right; |
|
continue; |
|
} |
|
val = missingSubstPtr[vi]; |
|
} |
|
|
|
if( vtype[vi] == VAR_ORDERED ) |
|
nidx = val <= split.c ? node.left : node.right; |
|
else |
|
{ |
|
if( flags & PREPROCESSED_INPUT ) |
|
c = cvRound(val); |
|
else |
|
{ |
|
c = catbuf[ci]; |
|
if( c < 0 ) |
|
{ |
|
int a = c = cofs[vi][0]; |
|
int b = cofs[vi][1]; |
|
|
|
int ival = cvRound(val); |
|
if( ival != val ) |
|
CV_Error( CV_StsBadArg, |
|
"one of input categorical variable is not an integer" ); |
|
|
|
while( a < b ) |
|
{ |
|
c = (a + b) >> 1; |
|
if( ival < cmap[c] ) |
|
b = c; |
|
else if( ival > cmap[c] ) |
|
a = c+1; |
|
else |
|
break; |
|
} |
|
|
|
CV_Assert( c >= 0 && ival == cmap[c] ); |
|
|
|
c -= cofs[vi][0]; |
|
catbuf[ci] = c; |
|
} |
|
const int* subset = &subsets[split.subsetOfs]; |
|
unsigned u = c; |
|
nidx = CV_DTREE_CAT_DIR(u, subset) < 0 ? node.left : node.right; |
|
} |
|
} |
|
} |
|
|
|
if( predictType == PREDICT_SUM ) |
|
sum += nodes[prev].value; |
|
else |
|
{ |
|
lastClassIdx = nodes[prev].classIdx; |
|
votes[lastClassIdx]++; |
|
} |
|
} |
|
|
|
if( predictType == PREDICT_MAX_VOTE ) |
|
{ |
|
int best_idx = lastClassIdx; |
|
if( range.end - range.start > 1 ) |
|
{ |
|
best_idx = 0; |
|
for( i = 1; i < nclasses; i++ ) |
|
if( votes[best_idx] < votes[i] ) |
|
best_idx = i; |
|
} |
|
sum = (flags & RAW_OUTPUT) ? (float)best_idx : classLabels[best_idx]; |
|
} |
|
|
|
return (float)sum; |
|
} |
|
|
|
|
|
float DTreesImpl::predict( InputArray _samples, OutputArray _results, int flags ) const |
|
{ |
|
CV_Assert( !roots.empty() ); |
|
Mat samples = _samples.getMat(), results; |
|
int i, nsamples = samples.rows; |
|
int rtype = CV_32F; |
|
bool needresults = _results.needed(); |
|
float retval = 0.f; |
|
bool iscls = isClassifier(); |
|
float scale = !iscls ? 1.f/(int)roots.size() : 1.f; |
|
|
|
if( iscls && (flags & PREDICT_MASK) == PREDICT_MAX_VOTE ) |
|
rtype = CV_32S; |
|
|
|
if( needresults ) |
|
{ |
|
_results.create(nsamples, 1, rtype); |
|
results = _results.getMat(); |
|
} |
|
else |
|
nsamples = std::min(nsamples, 1); |
|
|
|
for( i = 0; i < nsamples; i++ ) |
|
{ |
|
float val = predictTrees( Range(0, (int)roots.size()), samples.row(i), flags )*scale; |
|
if( needresults ) |
|
{ |
|
if( rtype == CV_32F ) |
|
results.at<float>(i) = val; |
|
else |
|
results.at<int>(i) = cvRound(val); |
|
} |
|
if( i == 0 ) |
|
retval = val; |
|
} |
|
return retval; |
|
} |
|
|
|
void DTreesImpl::writeTrainingParams(FileStorage& fs) const |
|
{ |
|
fs << "use_surrogates" << (params.useSurrogates ? 1 : 0); |
|
fs << "max_categories" << params.getMaxCategories(); |
|
fs << "regression_accuracy" << params.getRegressionAccuracy(); |
|
|
|
fs << "max_depth" << params.getMaxDepth(); |
|
fs << "min_sample_count" << params.getMinSampleCount(); |
|
fs << "cross_validation_folds" << params.getCVFolds(); |
|
|
|
if( params.getCVFolds() > 1 ) |
|
fs << "use_1se_rule" << (params.use1SERule ? 1 : 0); |
|
|
|
if( !params.priors.empty() ) |
|
fs << "priors" << params.priors; |
|
} |
|
|
|
void DTreesImpl::writeParams(FileStorage& fs) const |
|
{ |
|
fs << "is_classifier" << isClassifier(); |
|
fs << "var_all" << (int)varType.size(); |
|
fs << "var_count" << getVarCount(); |
|
|
|
int ord_var_count = 0, cat_var_count = 0; |
|
int i, n = (int)varType.size(); |
|
for( i = 0; i < n; i++ ) |
|
if( varType[i] == VAR_ORDERED ) |
|
ord_var_count++; |
|
else |
|
cat_var_count++; |
|
fs << "ord_var_count" << ord_var_count; |
|
fs << "cat_var_count" << cat_var_count; |
|
|
|
fs << "training_params" << "{"; |
|
writeTrainingParams(fs); |
|
|
|
fs << "}"; |
|
|
|
if( !varIdx.empty() ) |
|
{ |
|
fs << "global_var_idx" << 1; |
|
fs << "var_idx" << varIdx; |
|
} |
|
|
|
fs << "var_type" << varType; |
|
|
|
if( !catOfs.empty() ) |
|
fs << "cat_ofs" << catOfs; |
|
if( !catMap.empty() ) |
|
fs << "cat_map" << catMap; |
|
if( !classLabels.empty() ) |
|
fs << "class_labels" << classLabels; |
|
if( !missingSubst.empty() ) |
|
fs << "missing_subst" << missingSubst; |
|
} |
|
|
|
void DTreesImpl::writeSplit( FileStorage& fs, int splitidx ) const |
|
{ |
|
const Split& split = splits[splitidx]; |
|
|
|
fs << "{:"; |
|
|
|
int vi = split.varIdx; |
|
fs << "var" << vi; |
|
fs << "quality" << split.quality; |
|
|
|
if( varType[vi] == VAR_CATEGORICAL ) // split on a categorical var |
|
{ |
|
int i, n = getCatCount(vi), to_right = 0; |
|
const int* subset = &subsets[split.subsetOfs]; |
|
for( i = 0; i < n; i++ ) |
|
to_right += CV_DTREE_CAT_DIR(i, subset) > 0; |
|
|
|
// ad-hoc rule when to use inverse categorical split notation |
|
// to achieve more compact and clear representation |
|
int default_dir = to_right <= 1 || to_right <= std::min(3, n/2) || to_right <= n/3 ? -1 : 1; |
|
|
|
fs << (default_dir*(split.inversed ? -1 : 1) > 0 ? "in" : "not_in") << "[:"; |
|
|
|
for( i = 0; i < n; i++ ) |
|
{ |
|
int dir = CV_DTREE_CAT_DIR(i, subset); |
|
if( dir*default_dir < 0 ) |
|
fs << i; |
|
} |
|
|
|
fs << "]"; |
|
} |
|
else |
|
fs << (!split.inversed ? "le" : "gt") << split.c; |
|
|
|
fs << "}"; |
|
} |
|
|
|
void DTreesImpl::writeNode( FileStorage& fs, int nidx, int depth ) const |
|
{ |
|
const Node& node = nodes[nidx]; |
|
fs << "{"; |
|
fs << "depth" << depth; |
|
fs << "value" << node.value; |
|
|
|
if( _isClassifier ) |
|
fs << "norm_class_idx" << node.classIdx; |
|
|
|
if( node.split >= 0 ) |
|
{ |
|
fs << "splits" << "["; |
|
|
|
for( int splitidx = node.split; splitidx >= 0; splitidx = splits[splitidx].next ) |
|
writeSplit( fs, splitidx ); |
|
|
|
fs << "]"; |
|
} |
|
|
|
fs << "}"; |
|
} |
|
|
|
void DTreesImpl::writeTree( FileStorage& fs, int root ) const |
|
{ |
|
fs << "nodes" << "["; |
|
|
|
int nidx = root, pidx = 0, depth = 0; |
|
const Node *node = 0; |
|
|
|
// traverse the tree and save all the nodes in depth-first order |
|
for(;;) |
|
{ |
|
for(;;) |
|
{ |
|
writeNode( fs, nidx, depth ); |
|
node = &nodes[nidx]; |
|
if( node->left < 0 ) |
|
break; |
|
nidx = node->left; |
|
depth++; |
|
} |
|
|
|
for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx; |
|
nidx = pidx, pidx = nodes[pidx].parent ) |
|
depth--; |
|
|
|
if( pidx < 0 ) |
|
break; |
|
|
|
nidx = nodes[pidx].right; |
|
} |
|
|
|
fs << "]"; |
|
} |
|
|
|
void DTreesImpl::write( FileStorage& fs ) const |
|
{ |
|
writeParams(fs); |
|
writeTree(fs, roots[0]); |
|
} |
|
|
|
void DTreesImpl::readParams( const FileNode& fn ) |
|
{ |
|
_isClassifier = (int)fn["is_classifier"] != 0; |
|
/*int var_all = (int)fn["var_all"]; |
|
int var_count = (int)fn["var_count"]; |
|
int cat_var_count = (int)fn["cat_var_count"]; |
|
int ord_var_count = (int)fn["ord_var_count"];*/ |
|
|
|
FileNode tparams_node = fn["training_params"]; |
|
|
|
TreeParams params0 = TreeParams(); |
|
|
|
if( !tparams_node.empty() ) // training parameters are not necessary |
|
{ |
|
params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0; |
|
params0.setMaxCategories((int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"])); |
|
params0.setRegressionAccuracy((float)tparams_node["regression_accuracy"]); |
|
params0.setMaxDepth((int)tparams_node["max_depth"]); |
|
params0.setMinSampleCount((int)tparams_node["min_sample_count"]); |
|
params0.setCVFolds((int)tparams_node["cross_validation_folds"]); |
|
|
|
if( params0.getCVFolds() > 1 ) |
|
{ |
|
params.use1SERule = (int)tparams_node["use_1se_rule"] != 0; |
|
} |
|
|
|
tparams_node["priors"] >> params0.priors; |
|
} |
|
|
|
readVectorOrMat(fn["var_idx"], varIdx); |
|
fn["var_type"] >> varType; |
|
|
|
int format = 0; |
|
fn["format"] >> format; |
|
bool isLegacy = format < 3; |
|
|
|
int varAll = (int)fn["var_all"]; |
|
if (isLegacy && (int)varType.size() <= varAll) |
|
{ |
|
std::vector<uchar> extendedTypes(varAll + 1, 0); |
|
|
|
int i = 0, n; |
|
if (!varIdx.empty()) |
|
{ |
|
n = (int)varIdx.size(); |
|
for (; i < n; ++i) |
|
{ |
|
int var = varIdx[i]; |
|
extendedTypes[var] = varType[i]; |
|
} |
|
} |
|
else |
|
{ |
|
n = (int)varType.size(); |
|
for (; i < n; ++i) |
|
{ |
|
extendedTypes[i] = varType[i]; |
|
} |
|
} |
|
extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED); |
|
extendedTypes.swap(varType); |
|
} |
|
|
|
readVectorOrMat(fn["cat_map"], catMap); |
|
|
|
if (isLegacy) |
|
{ |
|
// generating "catOfs" from "cat_count" |
|
catOfs.clear(); |
|
classLabels.clear(); |
|
std::vector<int> counts; |
|
readVectorOrMat(fn["cat_count"], counts); |
|
unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1; |
|
for (; i < size; ++i) |
|
{ |
|
Vec2i newOffsets(0, 0); |
|
if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap |
|
{ |
|
newOffsets[0] = curShift; |
|
curShift += counts[j]; |
|
newOffsets[1] = curShift; |
|
++j; |
|
} |
|
catOfs.push_back(newOffsets); |
|
} |
|
// other elements in "catMap" are "classLabels" |
|
if (curShift < catMap.size()) |
|
{ |
|
classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end()); |
|
catMap.erase(catMap.begin() + curShift, catMap.end()); |
|
} |
|
} |
|
else |
|
{ |
|
fn["cat_ofs"] >> catOfs; |
|
fn["missing_subst"] >> missingSubst; |
|
fn["class_labels"] >> classLabels; |
|
} |
|
|
|
// init var mapping for node reading (var indexes or varIdx indexes) |
|
bool globalVarIdx = false; |
|
fn["global_var_idx"] >> globalVarIdx; |
|
if (globalVarIdx || varIdx.empty()) |
|
setRangeVector(varMapping, (int)varType.size()); |
|
else |
|
varMapping = varIdx; |
|
|
|
initCompVarIdx(); |
|
setDParams(params0); |
|
} |
|
|
|
int DTreesImpl::readSplit( const FileNode& fn ) |
|
{ |
|
Split split; |
|
|
|
int vi = (int)fn["var"]; |
|
CV_Assert( 0 <= vi && vi <= (int)varType.size() ); |
|
vi = varMapping[vi]; // convert to varIdx if needed |
|
split.varIdx = vi; |
|
|
|
if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var |
|
{ |
|
int i, val, ssize = getSubsetSize(vi); |
|
split.subsetOfs = (int)subsets.size(); |
|
for( i = 0; i < ssize; i++ ) |
|
subsets.push_back(0); |
|
int* subset = &subsets[split.subsetOfs]; |
|
FileNode fns = fn["in"]; |
|
if( fns.empty() ) |
|
{ |
|
fns = fn["not_in"]; |
|
split.inversed = true; |
|
} |
|
|
|
if( fns.isInt() ) |
|
{ |
|
val = (int)fns; |
|
subset[val >> 5] |= 1 << (val & 31); |
|
} |
|
else |
|
{ |
|
FileNodeIterator it = fns.begin(); |
|
int n = (int)fns.size(); |
|
for( i = 0; i < n; i++, ++it ) |
|
{ |
|
val = (int)*it; |
|
subset[val >> 5] |= 1 << (val & 31); |
|
} |
|
} |
|
|
|
// for categorical splits we do not use inversed splits, |
|
// instead we inverse the variable set in the split |
|
if( split.inversed ) |
|
{ |
|
for( i = 0; i < ssize; i++ ) |
|
subset[i] ^= -1; |
|
split.inversed = false; |
|
} |
|
} |
|
else |
|
{ |
|
FileNode cmpNode = fn["le"]; |
|
if( cmpNode.empty() ) |
|
{ |
|
cmpNode = fn["gt"]; |
|
split.inversed = true; |
|
} |
|
split.c = (float)cmpNode; |
|
} |
|
|
|
split.quality = (float)fn["quality"]; |
|
splits.push_back(split); |
|
|
|
return (int)(splits.size() - 1); |
|
} |
|
|
|
int DTreesImpl::readNode( const FileNode& fn ) |
|
{ |
|
Node node; |
|
node.value = (double)fn["value"]; |
|
|
|
if( _isClassifier ) |
|
node.classIdx = (int)fn["norm_class_idx"]; |
|
|
|
FileNode sfn = fn["splits"]; |
|
if( !sfn.empty() ) |
|
{ |
|
int i, n = (int)sfn.size(), prevsplit = -1; |
|
FileNodeIterator it = sfn.begin(); |
|
|
|
for( i = 0; i < n; i++, ++it ) |
|
{ |
|
int splitidx = readSplit(*it); |
|
if( splitidx < 0 ) |
|
break; |
|
if( prevsplit < 0 ) |
|
node.split = splitidx; |
|
else |
|
splits[prevsplit].next = splitidx; |
|
prevsplit = splitidx; |
|
} |
|
} |
|
nodes.push_back(node); |
|
return (int)(nodes.size() - 1); |
|
} |
|
|
|
int DTreesImpl::readTree( const FileNode& fn ) |
|
{ |
|
int i, n = (int)fn.size(), root = -1, pidx = -1; |
|
FileNodeIterator it = fn.begin(); |
|
|
|
for( i = 0; i < n; i++, ++it ) |
|
{ |
|
int nidx = readNode(*it); |
|
if( nidx < 0 ) |
|
break; |
|
Node& node = nodes[nidx]; |
|
node.parent = pidx; |
|
if( pidx < 0 ) |
|
root = nidx; |
|
else |
|
{ |
|
Node& parent = nodes[pidx]; |
|
if( parent.left < 0 ) |
|
parent.left = nidx; |
|
else |
|
parent.right = nidx; |
|
} |
|
if( node.split >= 0 ) |
|
pidx = nidx; |
|
else |
|
{ |
|
while( pidx >= 0 && nodes[pidx].right >= 0 ) |
|
pidx = nodes[pidx].parent; |
|
} |
|
} |
|
roots.push_back(root); |
|
return root; |
|
} |
|
|
|
void DTreesImpl::read( const FileNode& fn ) |
|
{ |
|
clear(); |
|
readParams(fn); |
|
|
|
FileNode fnodes = fn["nodes"]; |
|
CV_Assert( !fnodes.empty() ); |
|
readTree(fnodes); |
|
} |
|
|
|
Ptr<DTrees> DTrees::create() |
|
{ |
|
return makePtr<DTreesImpl>(); |
|
} |
|
|
|
} |
|
} |
|
|
|
/* End of file. */
|
|
|