Open Source Computer Vision Library
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
140 lines
4.3 KiB
140 lines
4.3 KiB
12 years ago
#!/usr/bin/env python
13 years ago
13 years ago
Digit recognition adjustment.
Grid search is used to find the best parameters for SVM and KNearest classifiers.
SVM adjustment follows the guidelines given in
10 years ago
| [--model {svm|knearest}]
13 years ago
--model {svm|knearest} - select the classifier (SVM is the default)
9 years ago
# Python 2/3 compatibility
from __future__ import print_function
import sys
PY3 = sys.version_info[0] == 3
if PY3:
xrange = range
13 years ago
import numpy as np
import cv2
from multiprocessing.pool import ThreadPool
from digits import *
def cross_validate(model_class, params, samples, labels, kfold = 3, pool = None):
n = len(samples)
folds = np.array_split(np.arange(n), kfold)
def f(i):
model = model_class(**params)
test_idx = folds[i]
train_idx = list(folds)
train_idx = np.hstack(train_idx)
train_samples, train_labels = samples[train_idx], labels[train_idx]
test_samples, test_labels = samples[test_idx], labels[test_idx]
model.train(train_samples, train_labels)
resp = model.predict(test_samples)
score = (resp != test_labels).mean()
9 years ago
print(".", end='')
13 years ago
return score
if pool is None:
9 years ago
scores = list(map(f, xrange(kfold)))
13 years ago
scores =, xrange(kfold))
return np.mean(scores)
class App(object):
10 years ago
def __init__(self):
self._samples, self._labels = self.preprocess()
13 years ago
def preprocess(self):
digits, labels = load_digits(DIGITS_FN)
shuffle = np.random.permutation(len(digits))
digits, labels = digits[shuffle], labels[shuffle]
9 years ago
digits2 = list(map(deskew, digits))
13 years ago
samples = preprocess_hog(digits2)
return samples, labels
def get_dataset(self):
10 years ago
return self._samples, self._labels
13 years ago
def run_jobs(self, f, jobs):
10 years ago
pool = ThreadPool(processes=cv2.getNumberOfCPUs())
ires = pool.imap_unordered(f, jobs)
13 years ago
return ires
def adjust_SVM(self):
Cs = np.logspace(0, 10, 15, base=2)
gammas = np.logspace(-7, 4, 15, base=2)
scores = np.zeros((len(Cs), len(gammas)))
scores[:] = np.nan
9 years ago
print('adjusting SVM (may take a long time) ...')
13 years ago
def f(job):
i, j = job
samples, labels = self.get_dataset()
params = dict(C = Cs[i], gamma=gammas[j])
score = cross_validate(SVM, params, samples, labels)
return i, j, score
ires = self.run_jobs(f, np.ndindex(*scores.shape))
for count, (i, j, score) in enumerate(ires):
scores[i, j] = score
9 years ago
print('%d / %d (best error: %.2f %%, last: %.2f %%)' %
(count+1, scores.size, np.nanmin(scores)*100, score*100))
13 years ago
9 years ago
print('writing score table to "svm_scores.npz"')
13 years ago
np.savez('svm_scores.npz', scores=scores, Cs=Cs, gammas=gammas)
i, j = np.unravel_index(scores.argmin(), scores.shape)
best_params = dict(C = Cs[i], gamma=gammas[j])
9 years ago
print('best params:', best_params)
print('best error: %.2f %%' % (scores.min()*100))
13 years ago
return best_params
def adjust_KNearest(self):
9 years ago
print('adjusting KNearest ...')
13 years ago
def f(k):
samples, labels = self.get_dataset()
err = cross_validate(KNearest, dict(k=k), samples, labels)
return k, err
best_err, best_k = np.inf, -1
for k, err in self.run_jobs(f, xrange(1, 9)):
if err < best_err:
best_err, best_k = err, k
9 years ago
print('k = %d, error: %.2f %%' % (k, err*100))
13 years ago
best_params = dict(k=best_k)
9 years ago
print('best params:', best_params, 'err: %.2f' % (best_err*100))
13 years ago
return best_params
if __name__ == '__main__':
import getopt
import sys
9 years ago
13 years ago
10 years ago
args, _ = getopt.getopt(sys.argv[1:], '', ['model='])
13 years ago
args = dict(args)
args.setdefault('--model', 'svm')
args.setdefault('--env', '')
if args['--model'] not in ['svm', 'knearest']:
9 years ago
print('unknown model "%s"' % args['--model'])
13 years ago
t = clock()
10 years ago
app = App()
13 years ago
if args['--model'] == 'knearest':
9 years ago
print('work time: %f s' % (clock() - t))