|
|
|
@ -91,31 +91,34 @@ class SVM(LetterStatModel): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
import argparse |
|
|
|
|
import getopt |
|
|
|
|
import sys |
|
|
|
|
|
|
|
|
|
models = [RTrees, KNearest, Boost, SVM] # MLP, NBayes |
|
|
|
|
models = dict( [(cls.__name__.lower(), cls) for cls in models] ) |
|
|
|
|
|
|
|
|
|
print 'USAGE: letter_recog.py [--model <model>] [--data <data fn>] [--load <model fn>] [--save <model fn>]' |
|
|
|
|
print 'Models: ', ', '.join(models) |
|
|
|
|
print |
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
parser.add_argument('-model', default='rtrees', choices=models.keys()) |
|
|
|
|
parser.add_argument('-data', nargs=1, default='../cpp/letter-recognition.data') |
|
|
|
|
parser.add_argument('-load', nargs=1) |
|
|
|
|
parser.add_argument('-save', nargs=1) |
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
print 'loading data %s ...' % args.data |
|
|
|
|
samples, responses = load_base(args.data) |
|
|
|
|
Model = models[args.model] |
|
|
|
|
args, dummy = getopt.getopt(sys.argv[1:], '', ['model=', 'data=', 'load=', 'save=']) |
|
|
|
|
args = dict(args) |
|
|
|
|
args.setdefault('--model', 'rtrees') |
|
|
|
|
args.setdefault('--data', '../cpp/letter-recognition.data') |
|
|
|
|
|
|
|
|
|
print 'loading data %s ...' % args['--data'] |
|
|
|
|
samples, responses = load_base(args['--data']) |
|
|
|
|
Model = models[args['--model']] |
|
|
|
|
model = Model() |
|
|
|
|
|
|
|
|
|
train_n = int(len(samples)*model.train_ratio) |
|
|
|
|
if args.load is None: |
|
|
|
|
print 'training %s ...' % Model.__name__ |
|
|
|
|
model.train(samples[:train_n], responses[:train_n]) |
|
|
|
|
else: |
|
|
|
|
fn = args.load[0] |
|
|
|
|
if '--load' in args: |
|
|
|
|
fn = args['--load'] |
|
|
|
|
print 'loading model from %s ...' % fn |
|
|
|
|
model.load(fn) |
|
|
|
|
else: |
|
|
|
|
print 'training %s ...' % Model.__name__ |
|
|
|
|
model.train(samples[:train_n], responses[:train_n]) |
|
|
|
|
|
|
|
|
|
print 'testing...' |
|
|
|
|
train_rate = np.mean(model.predict(samples[:train_n]) == responses[:train_n]) |
|
|
|
@ -123,7 +126,7 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
print 'train rate: %f test rate: %f' % (train_rate*100, test_rate*100) |
|
|
|
|
|
|
|
|
|
if args.save is not None: |
|
|
|
|
fn = args.save[0] |
|
|
|
|
if '--save' in args: |
|
|
|
|
fn = args['--save'] |
|
|
|
|
print 'saving model to %s ...' % fn |
|
|
|
|
model.save(fn) |
|
|
|
|