|
|
|
@ -74,30 +74,35 @@ class StatModel(object): |
|
|
|
|
class KNearest(StatModel): |
|
|
|
|
def __init__(self, k = 3): |
|
|
|
|
self.k = k |
|
|
|
|
self.model = cv2.KNearest() |
|
|
|
|
self.model = cv2.ml.KNearest_create() |
|
|
|
|
|
|
|
|
|
def train(self, samples, responses): |
|
|
|
|
self.model = cv2.KNearest() |
|
|
|
|
self.model.train(samples, responses) |
|
|
|
|
self.model = cv2.ml.KNearest_create() |
|
|
|
|
self.model.train(samples, cv2.ml.ROW_SAMPLE, responses) |
|
|
|
|
|
|
|
|
|
def predict(self, samples): |
|
|
|
|
retval, results, neigh_resp, dists = self.model.find_nearest(samples, self.k) |
|
|
|
|
retval, results, neigh_resp, dists = self.model.findNearest(samples, self.k) |
|
|
|
|
return results.ravel() |
|
|
|
|
|
|
|
|
|
class SVM(StatModel): |
|
|
|
|
def __init__(self, C = 1, gamma = 0.5): |
|
|
|
|
self.params = dict( kernel_type = cv2.SVM_RBF, |
|
|
|
|
svm_type = cv2.SVM_C_SVC, |
|
|
|
|
self.params = dict( kernel_type = cv2.ml.SVM_RBF, |
|
|
|
|
svm_type = cv2.ml.SVM_C_SVC, |
|
|
|
|
C = C, |
|
|
|
|
gamma = gamma ) |
|
|
|
|
self.model = cv2.SVM() |
|
|
|
|
self.model = cv2.ml.SVM_create() |
|
|
|
|
|
|
|
|
|
def train(self, samples, responses): |
|
|
|
|
self.model = cv2.SVM() |
|
|
|
|
self.model.train(samples, responses, params = self.params) |
|
|
|
|
self.model = cv2.ml.SVM_create() |
|
|
|
|
""" original code """ |
|
|
|
|
#self.model.train(samples, responses, params = self.params) |
|
|
|
|
""" but it's either this """ |
|
|
|
|
self.model.train(samples, cv2.ml.ROW_SAMPLE, responses) |
|
|
|
|
""" or this """ |
|
|
|
|
#self.model.train(samples, params = self.params) |
|
|
|
|
|
|
|
|
|
def predict(self, samples): |
|
|
|
|
return self.model.predict_all(samples).ravel() |
|
|
|
|
return self.model.predict(samples)[1][0].ravel() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_model(model, digits, samples, labels): |
|
|
|
|