2013-03-06 14:41:02 +08:00
|
|
|
#!/usr/bin/env python
|
2012-11-24 02:57:22 +08:00
|
|
|
|
2012-10-17 07:18:30 +08:00
|
|
|
'''
|
|
|
|
The sample demonstrates how to train Random Trees classifier
|
|
|
|
(or Boosting classifier, or MLP, or Knearest, or Support Vector Machines) using the provided dataset.
|
|
|
|
|
|
|
|
We use the sample database letter-recognition.data
|
|
|
|
from UCI Repository, here is the link:
|
|
|
|
|
|
|
|
Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).
|
|
|
|
UCI Repository of machine learning databases
|
|
|
|
[http://www.ics.uci.edu/~mlearn/MLRepository.html].
|
|
|
|
Irvine, CA: University of California, Department of Information and Computer Science.
|
|
|
|
|
|
|
|
The dataset consists of 20000 feature vectors along with the
|
|
|
|
responses - capital latin letters A..Z.
|
|
|
|
The first 10000 samples are used for training
|
|
|
|
and the remaining 10000 - to test the classifier.
|
|
|
|
======================================================
|
|
|
|
USAGE:
|
|
|
|
letter_recog.py [--model <model>]
|
|
|
|
[--data <data fn>]
|
|
|
|
[--load <model fn>] [--save <model fn>]
|
|
|
|
|
|
|
|
Models: RTrees, KNearest, Boost, SVM, MLP
|
|
|
|
'''
|
|
|
|
|
2015-12-13 09:43:58 +08:00
|
|
|
# Python 2/3 compatibility
|
|
|
|
from __future__ import print_function
|
|
|
|
|
2012-10-17 07:18:30 +08:00
|
|
|
import numpy as np
|
2017-12-11 17:55:03 +08:00
|
|
|
import cv2 as cv
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
def load_base(fn):
|
|
|
|
a = np.loadtxt(fn, np.float32, delimiter=',', converters={ 0 : lambda ch : ord(ch)-ord('A') })
|
|
|
|
samples, responses = a[:,1:], a[:,0]
|
|
|
|
return samples, responses
|
|
|
|
|
|
|
|
class LetterStatModel(object):
|
|
|
|
class_n = 26
|
|
|
|
train_ratio = 0.5
|
|
|
|
|
|
|
|
def load(self, fn):
|
2018-11-15 23:36:04 +08:00
|
|
|
self.model = self.model.load(fn)
|
2012-10-17 07:18:30 +08:00
|
|
|
def save(self, fn):
|
|
|
|
self.model.save(fn)
|
|
|
|
|
|
|
|
def unroll_samples(self, samples):
|
|
|
|
sample_n, var_n = samples.shape
|
|
|
|
new_samples = np.zeros((sample_n * self.class_n, var_n+1), np.float32)
|
|
|
|
new_samples[:,:-1] = np.repeat(samples, self.class_n, axis=0)
|
|
|
|
new_samples[:,-1] = np.tile(np.arange(self.class_n), sample_n)
|
|
|
|
return new_samples
|
|
|
|
|
|
|
|
def unroll_responses(self, responses):
|
|
|
|
sample_n = len(responses)
|
|
|
|
new_responses = np.zeros(sample_n*self.class_n, np.int32)
|
|
|
|
resp_idx = np.int32( responses + np.arange(sample_n)*self.class_n )
|
|
|
|
new_responses[resp_idx] = 1
|
|
|
|
return new_responses
|
|
|
|
|
|
|
|
class RTrees(LetterStatModel):
|
|
|
|
def __init__(self):
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model = cv.ml.RTrees_create()
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
def train(self, samples, responses):
|
2016-02-03 16:22:32 +08:00
|
|
|
self.model.setMaxDepth(20)
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model.train(samples, cv.ml.ROW_SAMPLE, responses.astype(int))
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
def predict(self, samples):
|
2017-08-25 00:45:14 +08:00
|
|
|
_ret, resp = self.model.predict(samples)
|
2016-02-03 16:22:32 +08:00
|
|
|
return resp.ravel()
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
|
|
|
|
class KNearest(LetterStatModel):
|
|
|
|
def __init__(self):
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model = cv.ml.KNearest_create()
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
def train(self, samples, responses):
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model.train(samples, cv.ml.ROW_SAMPLE, responses)
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
def predict(self, samples):
|
2017-08-25 00:45:14 +08:00
|
|
|
_retval, results, _neigh_resp, _dists = self.model.findNearest(samples, k = 10)
|
2012-10-17 07:18:30 +08:00
|
|
|
return results.ravel()
|
|
|
|
|
|
|
|
|
|
|
|
class Boost(LetterStatModel):
|
|
|
|
def __init__(self):
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model = cv.ml.Boost_create()
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
def train(self, samples, responses):
|
2017-08-25 00:45:14 +08:00
|
|
|
_sample_n, var_n = samples.shape
|
2012-10-17 07:18:30 +08:00
|
|
|
new_samples = self.unroll_samples(samples)
|
|
|
|
new_responses = self.unroll_responses(responses)
|
2017-12-11 17:55:03 +08:00
|
|
|
var_types = np.array([cv.ml.VAR_NUMERICAL] * var_n + [cv.ml.VAR_CATEGORICAL, cv.ml.VAR_CATEGORICAL], np.uint8)
|
2016-02-03 16:22:32 +08:00
|
|
|
|
2016-02-04 22:12:32 +08:00
|
|
|
self.model.setWeakCount(15)
|
|
|
|
self.model.setMaxDepth(10)
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model.train(cv.ml.TrainData_create(new_samples, cv.ml.ROW_SAMPLE, new_responses.astype(int), varType = var_types))
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
def predict(self, samples):
|
|
|
|
new_samples = self.unroll_samples(samples)
|
2017-08-25 00:45:14 +08:00
|
|
|
_ret, resp = self.model.predict(new_samples)
|
2016-02-03 16:22:32 +08:00
|
|
|
|
|
|
|
return resp.ravel().reshape(-1, self.class_n).argmax(1)
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
|
|
|
|
class SVM(LetterStatModel):
|
|
|
|
def __init__(self):
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model = cv.ml.SVM_create()
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
def train(self, samples, responses):
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model.setType(cv.ml.SVM_C_SVC)
|
2016-02-03 16:22:32 +08:00
|
|
|
self.model.setC(1)
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model.setKernel(cv.ml.SVM_RBF)
|
2016-02-04 22:12:32 +08:00
|
|
|
self.model.setGamma(.1)
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model.train(samples, cv.ml.ROW_SAMPLE, responses.astype(int))
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
def predict(self, samples):
|
2017-08-25 00:45:14 +08:00
|
|
|
_ret, resp = self.model.predict(samples)
|
2016-02-03 16:22:32 +08:00
|
|
|
return resp.ravel()
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
|
|
|
|
class MLP(LetterStatModel):
|
|
|
|
def __init__(self):
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model = cv.ml.ANN_MLP_create()
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
def train(self, samples, responses):
|
2017-08-25 00:45:14 +08:00
|
|
|
_sample_n, var_n = samples.shape
|
2012-10-17 07:18:30 +08:00
|
|
|
new_responses = self.unroll_responses(responses).reshape(-1, self.class_n)
|
|
|
|
layer_sizes = np.int32([var_n, 100, 100, self.class_n])
|
|
|
|
|
2016-02-03 16:22:32 +08:00
|
|
|
self.model.setLayerSizes(layer_sizes)
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model.setTrainMethod(cv.ml.ANN_MLP_BACKPROP)
|
2016-02-04 22:12:32 +08:00
|
|
|
self.model.setBackpropMomentumScale(0.0)
|
2016-02-03 16:22:32 +08:00
|
|
|
self.model.setBackpropWeightScale(0.001)
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model.setTermCriteria((cv.TERM_CRITERIA_COUNT, 20, 0.01))
|
|
|
|
self.model.setActivationFunction(cv.ml.ANN_MLP_SIGMOID_SYM, 2, 1)
|
2016-02-03 16:22:32 +08:00
|
|
|
|
2017-12-11 17:55:03 +08:00
|
|
|
self.model.train(samples, cv.ml.ROW_SAMPLE, np.float32(new_responses))
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
def predict(self, samples):
|
2017-08-25 00:45:14 +08:00
|
|
|
_ret, resp = self.model.predict(samples)
|
2012-10-17 07:18:30 +08:00
|
|
|
return resp.argmax(-1)
|
|
|
|
|
|
|
|
|
2016-02-03 16:22:32 +08:00
|
|
|
|
2012-10-17 07:18:30 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
import getopt
|
|
|
|
import sys
|
|
|
|
|
2015-12-13 09:43:58 +08:00
|
|
|
print(__doc__)
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
models = [RTrees, KNearest, Boost, SVM, MLP] # NBayes
|
|
|
|
models = dict( [(cls.__name__.lower(), cls) for cls in models] )
|
|
|
|
|
|
|
|
|
|
|
|
args, dummy = getopt.getopt(sys.argv[1:], '', ['model=', 'data=', 'load=', 'save='])
|
|
|
|
args = dict(args)
|
2016-02-03 16:22:32 +08:00
|
|
|
args.setdefault('--model', 'svm')
|
2018-11-14 23:56:21 +08:00
|
|
|
args.setdefault('--data', 'letter-recognition.data')
|
2012-10-17 07:18:30 +08:00
|
|
|
|
2018-11-14 23:56:21 +08:00
|
|
|
datafile = cv.samples.findFile(args['--data'])
|
|
|
|
|
|
|
|
print('loading data %s ...' % datafile)
|
|
|
|
samples, responses = load_base(datafile)
|
2012-10-17 07:18:30 +08:00
|
|
|
Model = models[args['--model']]
|
|
|
|
model = Model()
|
|
|
|
|
|
|
|
train_n = int(len(samples)*model.train_ratio)
|
|
|
|
if '--load' in args:
|
|
|
|
fn = args['--load']
|
2015-12-13 09:43:58 +08:00
|
|
|
print('loading model from %s ...' % fn)
|
2012-10-17 07:18:30 +08:00
|
|
|
model.load(fn)
|
|
|
|
else:
|
2015-12-13 09:43:58 +08:00
|
|
|
print('training %s ...' % Model.__name__)
|
2012-10-17 07:18:30 +08:00
|
|
|
model.train(samples[:train_n], responses[:train_n])
|
|
|
|
|
2015-12-13 09:43:58 +08:00
|
|
|
print('testing...')
|
2016-02-03 16:22:32 +08:00
|
|
|
train_rate = np.mean(model.predict(samples[:train_n]) == responses[:train_n].astype(int))
|
|
|
|
test_rate = np.mean(model.predict(samples[train_n:]) == responses[train_n:].astype(int))
|
2012-10-17 07:18:30 +08:00
|
|
|
|
2015-12-13 09:43:58 +08:00
|
|
|
print('train rate: %f test rate: %f' % (train_rate*100, test_rate*100))
|
2012-10-17 07:18:30 +08:00
|
|
|
|
|
|
|
if '--save' in args:
|
|
|
|
fn = args['--save']
|
2015-12-13 09:43:58 +08:00
|
|
|
print('saving model to %s ...' % fn)
|
2012-10-17 07:18:30 +08:00
|
|
|
model.save(fn)
|
2017-12-11 17:55:03 +08:00
|
|
|
cv.destroyAllWindows()
|