mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #16335 from berak:fix_ml_python_digits_samples_3.4
This commit is contained in:
commit
4cc458eb10
@ -70,13 +70,8 @@ def deskew(img):
|
|||||||
img = cv.warpAffine(img, M, (SZ, SZ), flags=cv.WARP_INVERSE_MAP | cv.INTER_LINEAR)
|
img = cv.warpAffine(img, M, (SZ, SZ), flags=cv.WARP_INVERSE_MAP | cv.INTER_LINEAR)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
class StatModel(object):
|
|
||||||
def load(self, fn):
|
|
||||||
self.model.load(fn) # Known bug: https://github.com/opencv/opencv/issues/4969
|
|
||||||
def save(self, fn):
|
|
||||||
self.model.save(fn)
|
|
||||||
|
|
||||||
class KNearest(StatModel):
|
class KNearest(object):
|
||||||
def __init__(self, k = 3):
|
def __init__(self, k = 3):
|
||||||
self.k = k
|
self.k = k
|
||||||
self.model = cv.ml.KNearest_create()
|
self.model = cv.ml.KNearest_create()
|
||||||
@ -88,7 +83,13 @@ class KNearest(StatModel):
|
|||||||
_retval, results, _neigh_resp, _dists = self.model.findNearest(samples, self.k)
|
_retval, results, _neigh_resp, _dists = self.model.findNearest(samples, self.k)
|
||||||
return results.ravel()
|
return results.ravel()
|
||||||
|
|
||||||
class SVM(StatModel):
|
def load(self, fn):
|
||||||
|
self.model = cv.ml.KNearest_load(fn)
|
||||||
|
|
||||||
|
def save(self, fn):
|
||||||
|
self.model.save(fn)
|
||||||
|
|
||||||
|
class SVM(object):
|
||||||
def __init__(self, C = 1, gamma = 0.5):
|
def __init__(self, C = 1, gamma = 0.5):
|
||||||
self.model = cv.ml.SVM_create()
|
self.model = cv.ml.SVM_create()
|
||||||
self.model.setGamma(gamma)
|
self.model.setGamma(gamma)
|
||||||
@ -102,6 +103,11 @@ class SVM(StatModel):
|
|||||||
def predict(self, samples):
|
def predict(self, samples):
|
||||||
return self.model.predict(samples)[1].ravel()
|
return self.model.predict(samples)[1].ravel()
|
||||||
|
|
||||||
|
def load(self, fn):
|
||||||
|
self.model = cv.ml.SVM_load(fn)
|
||||||
|
|
||||||
|
def save(self, fn):
|
||||||
|
self.model.save(fn)
|
||||||
|
|
||||||
def evaluate_model(model, digits, samples, labels):
|
def evaluate_model(model, digits, samples, labels):
|
||||||
resp = model.predict(samples)
|
resp = model.predict(samples)
|
||||||
|
@ -1,4 +1,12 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
'''
|
||||||
|
Digit recognition from video.
|
||||||
|
|
||||||
|
Run digits.py before, to train and save the SVM.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
digits_video.py [{camera_id|video_file}]
|
||||||
|
'''
|
||||||
|
|
||||||
# Python 2/3 compatibility
|
# Python 2/3 compatibility
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -28,11 +36,7 @@ def main():
|
|||||||
print('"%s" not found, run digits.py first' % classifier_fn)
|
print('"%s" not found, run digits.py first' % classifier_fn)
|
||||||
return
|
return
|
||||||
|
|
||||||
if True:
|
model = cv.ml.SVM_load(classifier_fn)
|
||||||
model = cv.ml.SVM_load(classifier_fn)
|
|
||||||
else:
|
|
||||||
model = cv.ml.SVM_create()
|
|
||||||
model.load_(classifier_fn) #Known bug: https://github.com/opencv/opencv/issues/4969
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
_ret, frame = cap.read()
|
_ret, frame = cap.read()
|
||||||
|
Loading…
Reference in New Issue
Block a user