|
|
|
@ -6,13 +6,6 @@ def load_base(fn): |
|
|
|
|
samples, responses = a[:,1:], a[:,0] |
|
|
|
|
return samples, responses |
|
|
|
|
|
|
|
|
|
# TODO move these to cv2 |
|
|
|
|
CV_ROW_SAMPLE = 1 |
|
|
|
|
CV_VAR_NUMERICAL = 0 |
|
|
|
|
CV_VAR_ORDERED = 0 |
|
|
|
|
CV_VAR_CATEGORICAL = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LetterStatModel(object): |
|
|
|
|
train_ratio = 0.5 |
|
|
|
|
def load(self, fn): |
|
|
|
@ -26,10 +19,10 @@ class RTrees(LetterStatModel): |
|
|
|
|
|
|
|
|
|
def train(self, samples, responses): |
|
|
|
|
sample_n, var_n = samples.shape |
|
|
|
|
var_types = np.array([CV_VAR_NUMERICAL] * var_n + [CV_VAR_CATEGORICAL], np.uint8) |
|
|
|
|
var_types = np.array([cv2.CV_VAR_NUMERICAL] * var_n + [cv2.CV_VAR_CATEGORICAL], np.uint8) |
|
|
|
|
#CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER)); |
|
|
|
|
params = dict(max_depth=10 ) |
|
|
|
|
self.model.train(samples, CV_ROW_SAMPLE, responses, varType = var_types, params = params) |
|
|
|
|
self.model.train(samples, cv2.CV_ROW_SAMPLE, responses, varType = var_types, params = params) |
|
|
|
|
|
|
|
|
|
def predict(self, samples): |
|
|
|
|
return np.float32( [self.model.predict(s) for s in samples] ) |
|
|
|
@ -56,10 +49,10 @@ class Boost(LetterStatModel): |
|
|
|
|
sample_n, var_n = samples.shape |
|
|
|
|
new_samples = self.unroll_samples(samples) |
|
|
|
|
new_responses = self.unroll_responses(responses) |
|
|
|
|
var_types = np.array([CV_VAR_NUMERICAL] * var_n + [CV_VAR_CATEGORICAL, CV_VAR_CATEGORICAL], np.uint8) |
|
|
|
|
var_types = np.array([cv2.CV_VAR_NUMERICAL] * var_n + [cv2.CV_VAR_CATEGORICAL, cv2.CV_VAR_CATEGORICAL], np.uint8) |
|
|
|
|
#CvBoostParams(CvBoost::REAL, 100, 0.95, 5, false, 0 ) |
|
|
|
|
params = dict(max_depth=5) #, use_surrogates=False) |
|
|
|
|
self.model.train(new_samples, CV_ROW_SAMPLE, new_responses, varType = var_types, params=params) |
|
|
|
|
self.model.train(new_samples, cv2.CV_ROW_SAMPLE, new_responses, varType = var_types, params=params) |
|
|
|
|
|
|
|
|
|
def predict(self, samples): |
|
|
|
|
new_samples = self.unroll_samples(samples) |
|
|
|
@ -105,7 +98,7 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
parser.add_argument('-model', default='rtrees', choices=models.keys()) |
|
|
|
|
parser.add_argument('-data', nargs=1, default='letter-recognition.data') |
|
|
|
|
parser.add_argument('-data', nargs=1, default='../cpp/letter-recognition.data') |
|
|
|
|
parser.add_argument('-load', nargs=1) |
|
|
|
|
parser.add_argument('-save', nargs=1) |
|
|
|
|
args = parser.parse_args() |
|
|
|
|