@ -349,6 +349,60 @@ public:
}
}
void getVotes ( InputArray input , OutputArray output , int flags ) const
{
CV_Assert ( ! roots . empty ( ) ) ;
int nclasses = ( int ) classLabels . size ( ) , ntrees = ( int ) roots . size ( ) ;
Mat samples = input . getMat ( ) , results ;
int i , j , nsamples = samples . rows ;
int predictType = flags & PREDICT_MASK ;
if ( predictType = = PREDICT_AUTO )
{
predictType = ! _isClassifier | | ( classLabels . size ( ) = = 2 & & ( flags & RAW_OUTPUT ) ! = 0 ) ?
PREDICT_SUM : PREDICT_MAX_VOTE ;
}
if ( predictType = = PREDICT_SUM )
{
output . create ( nsamples , ntrees , CV_32F ) ;
results = output . getMat ( ) ;
for ( i = 0 ; i < nsamples ; i + + )
{
for ( j = 0 ; j < ntrees ; j + + )
{
float val = predictTrees ( Range ( j , j + 1 ) , samples . row ( i ) , flags ) ;
results . at < float > ( i , j ) = val ;
}
}
} else
{
vector < int > votes ;
output . create ( nsamples + 1 , nclasses , CV_32S ) ;
results = output . getMat ( ) ;
for ( j = 0 ; j < nclasses ; j + + )
{
results . at < int > ( 0 , j ) = classLabels [ j ] ;
}
for ( i = 0 ; i < nsamples ; i + + )
{
votes . clear ( ) ;
for ( j = 0 ; j < ntrees ; j + + )
{
int val = ( int ) predictTrees ( Range ( j , j + 1 ) , samples . row ( i ) , flags ) ;
votes . push_back ( val ) ;
}
for ( j = 0 ; j < nclasses ; j + + )
{
results . at < int > ( i + 1 , j ) = ( int ) std : : count ( votes . begin ( ) , votes . end ( ) , classLabels [ j ] ) ;
}
}
}
}
RTreeParams rparams ;
double oobError ;
vector < float > varImportance ;
@ -401,6 +455,11 @@ public:
impl . read ( fn ) ;
}
void getVotes_ ( InputArray samples , OutputArray results , int flags ) const
{
impl . getVotes ( samples , results , flags ) ;
}
Mat getVarImportance ( ) const { return Mat_ < float > ( impl . varImportance , true ) ; }
int getVarCount ( ) const { return impl . getVarCount ( ) ; }
@ -427,6 +486,14 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
return Algorithm : : load < RTrees > ( filepath , nodeName ) ;
}
void RTrees : : getVotes ( InputArray input , OutputArray output , int flags ) const
{
const RTreesImpl * this_ = dynamic_cast < const RTreesImpl * > ( this ) ;
if ( ! this_ )
CV_Error ( Error : : StsNotImplemented , " the class is not RTreesImpl " ) ;
return this_ - > getVotes_ ( input , output , flags ) ;
}
} }
// End of file.