|
|
|
@ -105,6 +105,7 @@ int str_to_ann_activation_function(String& str) |
|
|
|
|
void ann_check_data( Ptr<TrainData> _data ) |
|
|
|
|
{ |
|
|
|
|
CV_TRACE_FUNCTION(); |
|
|
|
|
CV_Assert(!_data.empty()); |
|
|
|
|
Mat values = _data->getSamples(); |
|
|
|
|
Mat var_idx = _data->getVarIdx(); |
|
|
|
|
int nvars = (int)var_idx.total(); |
|
|
|
@ -118,6 +119,7 @@ void ann_check_data( Ptr<TrainData> _data ) |
|
|
|
|
Mat ann_get_new_responses( Ptr<TrainData> _data, map<int, int>& cls_map ) |
|
|
|
|
{ |
|
|
|
|
CV_TRACE_FUNCTION(); |
|
|
|
|
CV_Assert(!_data.empty()); |
|
|
|
|
Mat train_sidx = _data->getTrainSampleIdx(); |
|
|
|
|
int* train_sidx_ptr = train_sidx.ptr<int>(); |
|
|
|
|
Mat responses = _data->getResponses(); |
|
|
|
@ -150,6 +152,8 @@ Mat ann_get_new_responses( Ptr<TrainData> _data, map<int, int>& cls_map ) |
|
|
|
|
float ann_calc_error( Ptr<StatModel> ann, Ptr<TrainData> _data, map<int, int>& cls_map, int type, vector<float> *resp_labels ) |
|
|
|
|
{ |
|
|
|
|
CV_TRACE_FUNCTION(); |
|
|
|
|
CV_Assert(!ann.empty()); |
|
|
|
|
CV_Assert(!_data.empty()); |
|
|
|
|
float err = 0; |
|
|
|
|
Mat samples = _data->getSamples(); |
|
|
|
|
Mat responses = _data->getResponses(); |
|
|
|
@ -264,13 +268,15 @@ TEST_P(ML_ANN_METHOD, Test) |
|
|
|
|
String dataname = folder + "waveform" + '_' + methodName; |
|
|
|
|
|
|
|
|
|
Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0); |
|
|
|
|
ASSERT_FALSE(tdata2.empty()) << "Could not find test data file : " << original_path; |
|
|
|
|
|
|
|
|
|
Mat samples = tdata2->getSamples()(Range(0, N), Range::all()); |
|
|
|
|
Mat responses(N, 3, CV_32FC1, Scalar(0)); |
|
|
|
|
for (int i = 0; i < N; i++) |
|
|
|
|
responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1; |
|
|
|
|
Ptr<TrainData> tdata = TrainData::create(samples, ml::ROW_SAMPLE, responses); |
|
|
|
|
ASSERT_FALSE(tdata.empty()); |
|
|
|
|
|
|
|
|
|
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path; |
|
|
|
|
RNG& rng = theRNG(); |
|
|
|
|
rng.state = 0; |
|
|
|
|
tdata->setTrainTestSplitRatio(0.8); |
|
|
|
|