mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 06:03:15 +08:00
work on MLP sample in letter_recog.py (in progress...)
This commit is contained in:
parent
638f3d31cf
commit
622bd42224
@ -7,11 +7,27 @@ def load_base(fn):
|
||||
return samples, responses
|
||||
|
||||
class LetterStatModel(object):
|
||||
class_n = 26
|
||||
train_ratio = 0.5
|
||||
|
||||
def load(self, fn):
|
||||
self.model.load(fn)
|
||||
def save(self, fn):
|
||||
self.model.save(fn)
|
||||
|
||||
def unroll_samples(self, samples):
|
||||
sample_n, var_n = samples.shape
|
||||
new_samples = np.zeros((sample_n * self.class_n, var_n+1), np.float32)
|
||||
new_samples[:,:-1] = np.repeat(samples, self.class_n, axis=0)
|
||||
new_samples[:,-1] = np.tile(np.arange(self.class_n), sample_n)
|
||||
return new_samples
|
||||
|
||||
def unroll_responses(self, responses):
|
||||
sample_n = len(responses)
|
||||
new_responses = np.zeros(sample_n*self.class_n, np.int32)
|
||||
resp_idx = np.int32( responses + np.arange(sample_n)*self.class_n )
|
||||
new_responses[resp_idx] = 1
|
||||
return new_responses
|
||||
|
||||
class RTrees(LetterStatModel):
|
||||
def __init__(self):
|
||||
@ -43,7 +59,6 @@ class KNearest(LetterStatModel):
|
||||
class Boost(LetterStatModel):
|
||||
def __init__(self):
|
||||
self.model = cv2.Boost()
|
||||
self.class_n = 26
|
||||
|
||||
def train(self, samples, responses):
|
||||
sample_n, var_n = samples.shape
|
||||
@ -60,20 +75,6 @@ class Boost(LetterStatModel):
|
||||
pred = pred.reshape(-1, self.class_n).argmax(1)
|
||||
return pred
|
||||
|
||||
def unroll_samples(self, samples):
|
||||
sample_n, var_n = samples.shape
|
||||
new_samples = np.zeros((sample_n * self.class_n, var_n+1), np.float32)
|
||||
new_samples[:,:-1] = np.repeat(samples, self.class_n, axis=0)
|
||||
new_samples[:,-1] = np.tile(np.arange(self.class_n), sample_n)
|
||||
return new_samples
|
||||
|
||||
def unroll_responses(self, responses):
|
||||
sample_n = len(responses)
|
||||
new_responses = np.zeros(sample_n*self.class_n, np.int32)
|
||||
resp_idx = np.int32( responses + np.arange(sample_n)*self.class_n )
|
||||
new_responses[resp_idx] = 1
|
||||
return new_responses
|
||||
|
||||
|
||||
class SVM(LetterStatModel):
|
||||
train_ratio = 0.1
|
||||
@ -89,12 +90,36 @@ class SVM(LetterStatModel):
|
||||
def predict(self, samples):
|
||||
return np.float32( [self.model.predict(s) for s in samples] )
|
||||
|
||||
class MLP(LetterStatModel):
|
||||
def __init__(self):
|
||||
self.model = cv2.ANN_MLP()
|
||||
|
||||
def train(self, samples, responses):
|
||||
sample_n, var_n = samples.shape
|
||||
new_responses = self.unroll_responses(responses).reshape(-1, self.class_n)
|
||||
|
||||
layer_sizes = np.int32([var_n, 100, 100, self.class_n])
|
||||
self.model.create(layer_sizes)
|
||||
|
||||
# CvANN_MLP_TrainParams::BACKPROP,0.001
|
||||
params = dict( term_crit = (cv2.TERM_CRITERIA_COUNT, 300, 0.01),
|
||||
train_method = cv2.ANN_MLP_TRAIN_PARAMS_BACKPROP,
|
||||
bp_dw_scale = 0.001,
|
||||
bp_moment_scale = 0.0 )
|
||||
self.model.train(samples, np.float32(new_responses), None, params = params)
|
||||
|
||||
def predict(self, samples):
|
||||
pass
|
||||
#return np.float32( [self.model.predict(s) for s in samples] )
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import getopt
|
||||
import sys
|
||||
|
||||
models = [RTrees, KNearest, Boost, SVM] # MLP, NBayes
|
||||
models = [RTrees, KNearest, Boost, SVM, MLP] # NBayes
|
||||
models = dict( [(cls.__name__.lower(), cls) for cls in models] )
|
||||
|
||||
print 'USAGE: letter_recog.py [--model <model>] [--data <data fn>] [--load <model fn>] [--save <model fn>]'
|
||||
|
Loading…
Reference in New Issue
Block a user