export SVM::trainAuto to python #7224 (#8373)

* export SVM::trainAuto to python #7224

* workaround for ABI compatibility of SVM::trainAuto

* add parameter comments to new SVM::trainAuto function

* Export ParamGrid member variables
pull/8443/head
Julian Tanke 8 years ago committed by Vadim Pisarevsky
parent 1857aa22b3
commit f70cc29edb
  1. 66
      modules/ml/include/opencv2/ml.hpp
  2. 4
      modules/ml/src/inner_functions.cpp
  3. 37
      modules/ml/src/svm.cpp

@ -104,7 +104,7 @@ enum SampleTypes
It is used for optimizing statmodel accuracy by varying model parameters, the accuracy estimate It is used for optimizing statmodel accuracy by varying model parameters, the accuracy estimate
being computed by cross-validation. being computed by cross-validation.
*/ */
class CV_EXPORTS ParamGrid class CV_EXPORTS_W ParamGrid
{ {
public: public:
/** @brief Default constructor */ /** @brief Default constructor */
@ -112,8 +112,8 @@ public:
/** @brief Constructor with parameters */ /** @brief Constructor with parameters */
ParamGrid(double _minVal, double _maxVal, double _logStep); ParamGrid(double _minVal, double _maxVal, double _logStep);
double minVal; //!< Minimum value of the statmodel parameter. Default value is 0. CV_PROP_RW double minVal; //!< Minimum value of the statmodel parameter. Default value is 0.
double maxVal; //!< Maximum value of the statmodel parameter. Default value is 0. CV_PROP_RW double maxVal; //!< Maximum value of the statmodel parameter. Default value is 0.
/** @brief Logarithmic step for iterating the statmodel parameter. /** @brief Logarithmic step for iterating the statmodel parameter.
The grid determines the following iteration sequence of the statmodel parameter values: The grid determines the following iteration sequence of the statmodel parameter values:
@ -122,7 +122,15 @@ public:
\f[\texttt{minVal} * \texttt{logStep} ^n < \texttt{maxVal}\f] \f[\texttt{minVal} * \texttt{logStep} ^n < \texttt{maxVal}\f]
The grid is logarithmic, so logStep must always be greater then 1. Default value is 1. The grid is logarithmic, so logStep must always be greater then 1. Default value is 1.
*/ */
double logStep; CV_PROP_RW double logStep;
/** @brief Creates a ParamGrid Ptr that can be given to the %SVM::trainAuto method
@param minVal minimum value of the parameter grid
@param maxVal maximum value of the parameter grid
@param logstep Logarithmic step for iterating the statmodel parameter
*/
CV_WRAP static Ptr<ParamGrid> create(double minVal=0., double maxVal=0., double logstep=1.);
}; };
/** @brief Class encapsulating training data. /** @brief Class encapsulating training data.
@ -691,6 +699,46 @@ public:
ParamGrid degreeGrid = getDefaultGrid(DEGREE), ParamGrid degreeGrid = getDefaultGrid(DEGREE),
bool balanced=false) = 0; bool balanced=false) = 0;
/** @brief Trains an %SVM with optimal parameters
@param samples training samples
@param layout See ml::SampleTypes.
@param responses vector of responses associated with the training samples.
@param kFold Cross-validation parameter. The training set is divided into kFold subsets. One
subset is used to test the model, the others form the train set. So, the %SVM algorithm is
@param Cgrid grid for C
@param gammaGrid grid for gamma
@param pGrid grid for p
@param nuGrid grid for nu
@param coeffGrid grid for coeff
@param degreeGrid grid for degree
@param balanced If true and the problem is 2-class classification then the method creates more
balanced cross-validation subsets that is proportions between classes in subsets are close
to such proportion in the whole train dataset.
The method trains the %SVM model automatically by choosing the optimal parameters C, gamma, p,
nu, coef0, degree. Parameters are considered optimal when the cross-validation
estimate of the test set error is minimal.
This function only makes use of SVM::getDefaultGrid for parameter optimization and thus only
offers rudimentary parameter options.
This function works for the classification (SVM::C_SVC or SVM::NU_SVC) as well as for the
regression (SVM::EPS_SVR or SVM::NU_SVR). If it is SVM::ONE_CLASS, no optimization is made and
the usual %SVM with parameters specified in params is executed.
*/
CV_WRAP bool trainAuto(InputArray samples,
int layout,
InputArray responses,
int kFold = 10,
Ptr<ParamGrid> Cgrid = SVM::getDefaultGridPtr(SVM::C),
Ptr<ParamGrid> gammaGrid = SVM::getDefaultGridPtr(SVM::GAMMA),
Ptr<ParamGrid> pGrid = SVM::getDefaultGridPtr(SVM::P),
Ptr<ParamGrid> nuGrid = SVM::getDefaultGridPtr(SVM::NU),
Ptr<ParamGrid> coeffGrid = SVM::getDefaultGridPtr(SVM::COEF),
Ptr<ParamGrid> degreeGrid = SVM::getDefaultGridPtr(SVM::DEGREE),
bool balanced=false);
/** @brief Retrieves all the support vectors /** @brief Retrieves all the support vectors
The method returns all the support vectors as a floating-point matrix, where support vectors are The method returns all the support vectors as a floating-point matrix, where support vectors are
@ -733,6 +781,16 @@ public:
*/ */
static ParamGrid getDefaultGrid( int param_id ); static ParamGrid getDefaultGrid( int param_id );
/** @brief Generates a grid for %SVM parameters.
@param param_id %SVM parameters IDs that must be one of the SVM::ParamTypes. The grid is
generated for the parameter with this ID.
The function generates a grid pointer for the specified parameter of the %SVM algorithm.
The grid may be passed to the function SVM::trainAuto.
*/
CV_WRAP static Ptr<ParamGrid> getDefaultGridPtr( int param_id );
/** Creates empty model. /** Creates empty model.
Use StatModel::train to train the model. Since %SVM has several parameters, you may want to Use StatModel::train to train the model. Since %SVM has several parameters, you may want to
find the best parameters for your problem, it can be done with SVM::trainAuto. */ find the best parameters for your problem, it can be done with SVM::trainAuto. */

@ -50,6 +50,10 @@ ParamGrid::ParamGrid(double _minVal, double _maxVal, double _logStep)
logStep = std::max(_logStep, 1.); logStep = std::max(_logStep, 1.);
} }
Ptr<ParamGrid> ParamGrid::create(double minval, double maxval, double logstep) {
return makePtr<ParamGrid>(minval, maxval, logstep);
}
bool StatModel::empty() const { return !isTrained(); } bool StatModel::empty() const { return !isTrained(); }
int StatModel::getVarCount() const { return 0; } int StatModel::getVarCount() const { return 0; }

@ -362,6 +362,12 @@ static void sortSamplesByClasses( const Mat& _samples, const Mat& _responses,
//////////////////////// SVM implementation ////////////////////////////// //////////////////////// SVM implementation //////////////////////////////
Ptr<ParamGrid> SVM::getDefaultGridPtr( int param_id)
{
ParamGrid grid = getDefaultGrid(param_id); // this is not a nice solution..
return makePtr<ParamGrid>(grid.minVal, grid.maxVal, grid.logStep);
}
ParamGrid SVM::getDefaultGrid( int param_id ) ParamGrid SVM::getDefaultGrid( int param_id )
{ {
ParamGrid grid; ParamGrid grid;
@ -1920,6 +1926,24 @@ public:
bool returnDFVal; bool returnDFVal;
}; };
bool trainAuto_(InputArray samples, int layout,
InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
{
Ptr<TrainData> data = TrainData::create(samples, layout, responses);
return this->trainAuto(
data, kfold,
*Cgrid.get(),
*gammaGrid.get(),
*pGrid.get(),
*nuGrid.get(),
*coeffGrid.get(),
*degreeGrid.get(),
balanced);
}
float predict( InputArray _samples, OutputArray _results, int flags ) const float predict( InputArray _samples, OutputArray _results, int flags ) const
{ {
float result = 0; float result = 0;
@ -2281,6 +2305,19 @@ Mat SVM::getUncompressedSupportVectors() const
return this_->getUncompressedSupportVectors_(); return this_->getUncompressedSupportVectors_();
} }
bool SVM::trainAuto(InputArray samples, int layout,
InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
{
SVMImpl* this_ = dynamic_cast<SVMImpl*>(this);
if (!this_) {
CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
}
return this_->trainAuto_(samples, layout, responses,
kfold, Cgrid, gammaGrid, pGrid, nuGrid, coeffGrid, degreeGrid, balanced);
}
} }
} }

Loading…
Cancel
Save