Check sure that we're not already below required leaf false alarm rate before continuing to get negative samples.

pull/3236/head
Stephen Mell 11 years ago committed by Vadim Pisarevsky
parent a67484668e
commit 5947519ff4
  1. 13
      apps/traincascade/cascadeclassifier.cpp
  2. 4
      apps/traincascade/cascadeclassifier.h

@ -198,7 +198,7 @@ bool CvCascadeClassifier::train( const string _cascadeDirName,
cout << endl << "===== TRAINING " << i << "-stage =====" << endl; cout << endl << "===== TRAINING " << i << "-stage =====" << endl;
cout << "<BEGIN" << endl; cout << "<BEGIN" << endl;
if ( !updateTrainingSet( tempLeafFARate ) ) if ( !updateTrainingSet( requiredLeafFARate, tempLeafFARate ) )
{ {
cout << "Train dataset for temp stage can not be filled. " cout << "Train dataset for temp stage can not be filled. "
"Branch training terminated." << endl; "Branch training terminated." << endl;
@ -284,17 +284,17 @@ int CvCascadeClassifier::predict( int sampleIdx )
return 1; return 1;
} }
bool CvCascadeClassifier::updateTrainingSet( double& acceptanceRatio) bool CvCascadeClassifier::updateTrainingSet( double minimumAcceptanceRatio, double& acceptanceRatio)
{ {
int64 posConsumed = 0, negConsumed = 0; int64 posConsumed = 0, negConsumed = 0;
imgReader.restart(); imgReader.restart();
int posCount = fillPassedSamples( 0, numPos, true, posConsumed ); int posCount = fillPassedSamples( 0, numPos, true, 0, posConsumed );
if( !posCount ) if( !posCount )
return false; return false;
cout << "POS count : consumed " << posCount << " : " << (int)posConsumed << endl; cout << "POS count : consumed " << posCount << " : " << (int)posConsumed << endl;
int proNumNeg = cvRound( ( ((double)numNeg) * ((double)posCount) ) / numPos ); // apply only a fraction of negative samples. double is required since overflow is possible int proNumNeg = cvRound( ( ((double)numNeg) * ((double)posCount) ) / numPos ); // apply only a fraction of negative samples. double is required since overflow is possible
int negCount = fillPassedSamples( posCount, proNumNeg, false, negConsumed ); int negCount = fillPassedSamples( posCount, proNumNeg, false, minimumAcceptanceRatio, negConsumed );
if ( !negCount ) if ( !negCount )
return false; return false;
@ -304,7 +304,7 @@ bool CvCascadeClassifier::updateTrainingSet( double& acceptanceRatio)
return true; return true;
} }
int CvCascadeClassifier::fillPassedSamples( int first, int count, bool isPositive, int64& consumed ) int CvCascadeClassifier::fillPassedSamples( int first, int count, bool isPositive, double minimumAcceptanceRatio, int64& consumed )
{ {
int getcount = 0; int getcount = 0;
Mat img(cascadeParams.winSize, CV_8UC1); Mat img(cascadeParams.winSize, CV_8UC1);
@ -312,6 +312,9 @@ int CvCascadeClassifier::fillPassedSamples( int first, int count, bool isPositiv
{ {
for( ; ; ) for( ; ; )
{ {
if( consumed != 0 && ((double)getcount+1)/(double)(int64)consumed <= minimumAcceptanceRatio )
return getcount;
bool isGetImg = isPositive ? imgReader.getPos( img ) : bool isGetImg = isPositive ? imgReader.getPos( img ) :
imgReader.getNeg( img ); imgReader.getNeg( img );
if( !isGetImg ) if( !isGetImg )

@ -101,8 +101,8 @@ private:
int predict( int sampleIdx ); int predict( int sampleIdx );
void save( const std::string cascadeDirName, bool baseFormat = false ); void save( const std::string cascadeDirName, bool baseFormat = false );
bool load( const std::string cascadeDirName ); bool load( const std::string cascadeDirName );
bool updateTrainingSet( double& acceptanceRatio ); bool updateTrainingSet( double minimumAcceptanceRatio, double& acceptanceRatio );
int fillPassedSamples( int first, int count, bool isPositive, int64& consumed ); int fillPassedSamples( int first, int count, bool isPositive, double requiredAcceptanceRatio, int64& consumed );
void writeParams( cv::FileStorage &fs ) const; void writeParams( cv::FileStorage &fs ) const;
void writeStages( cv::FileStorage &fs, const cv::Mat& featureMap ) const; void writeStages( cv::FileStorage &fs, const cv::Mat& featureMap ) const;

Loading…
Cancel
Save