From 16f50dbe50ce58d3780399bbbd9ee8b8537f8144 Mon Sep 17 00:00:00 2001 From: "P. Druzhkov" Date: Mon, 29 Nov 2010 21:58:52 +0000 Subject: [PATCH] bug with negative class labels is fixed --- modules/ml/src/gbt.cpp | 44 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/modules/ml/src/gbt.cpp b/modules/ml/src/gbt.cpp index f0e00f2b8a..2b8959fe87 100644 --- a/modules/ml/src/gbt.cpp +++ b/modules/ml/src/gbt.cpp @@ -215,6 +215,7 @@ CvGBTrees::train( const CvMat* _train_data, int _tflag, cvCopy( _responses, orig_response); orig_response->step = CV_ELEM_SIZE(_responses->type); + /* if (!is_regression) { int max_label = -1; @@ -231,6 +232,38 @@ CvGBTrees::train( const CvMat* _train_data, int _tflag, if (class_labels->data.i[i]) class_labels->data.i[i] = ++class_count; } + */ + if (!is_regression) + { + class_count = 0; + unsigned char * mask = new unsigned char[get_len(orig_response)]; + for (int i=0; idata.fl[j]) == int(orig_response->data.fl[i])) + mask[j] = 1; + } + delete[] mask; + + class_labels = cvCreateMat(1, class_count, CV_32S); + class_labels->data.i[0] = int(orig_response->data.fl[0]); + int j = 1; + for (int i=1; idata.fl[i]) - class_labels->data.i[k]) && (kdata.i[k] = int(orig_response->data.fl[i]); + j++; + } + } + } data->is_classifier = false; @@ -443,8 +476,16 @@ void CvGBTrees::find_gradient(const int k) exp_sfi += res; } int orig_label = int(resp_data[idx]); + /* grad_data[idx] = (float)(!(k-class_labels->data.i[orig_label]+1)) - (float)(exp_fk / exp_sfi); + */ + int ensemble_label = 0; + while (class_labels->data.i[ensemble_label] - orig_label) + ensemble_label++; + + grad_data[idx] = (float)(!(k-ensemble_label)) - + (float)(exp_fk / exp_sfi); } }; break; @@ -772,10 +813,13 @@ float CvGBTrees::predict( const CvMat* _sample, const CvMat* _missing, delete[] sum; + /* int orig_class_label = -1; for (int i=0; idata.i[i] == class_label+1) orig_class_label = i; + */ + int orig_class_label = class_labels->data.i[class_label]; return float(orig_class_label); }