mirror of
https://github.com/opencv/opencv.git
synced 2025-06-10 19:24:07 +08:00
Merge pull request #6190 from berak:fix_ml_python_tutorial
This commit is contained in:
commit
cb012010c6
@ -42,9 +42,9 @@ train_labels = np.repeat(k,250)[:,np.newaxis]
|
|||||||
test_labels = train_labels.copy()
|
test_labels = train_labels.copy()
|
||||||
|
|
||||||
# Initiate kNN, train the data, then test it with test data for k=1
|
# Initiate kNN, train the data, then test it with test data for k=1
|
||||||
knn = cv2.KNearest()
|
knn = cv2.ml.KNearest_create()
|
||||||
knn.train(train,train_labels)
|
knn.train(train, cv2.ml.ROW_SAMPLE, train_labels)
|
||||||
ret,result,neighbours,dist = knn.find_nearest(test,k=5)
|
ret,result,neighbours,dist = knn.findNearest(test,k=5)
|
||||||
|
|
||||||
# Now we check the accuracy of classification
|
# Now we check the accuracy of classification
|
||||||
# For that, compare the result with test_labels and check which are wrong
|
# For that, compare the result with test_labels and check which are wrong
|
||||||
@ -103,9 +103,9 @@ responses, trainData = np.hsplit(train,[1])
|
|||||||
labels, testData = np.hsplit(test,[1])
|
labels, testData = np.hsplit(test,[1])
|
||||||
|
|
||||||
# Initiate the kNN, classify, measure accuracy.
|
# Initiate the kNN, classify, measure accuracy.
|
||||||
knn = cv2.KNearest()
|
knn = cv2.ml.KNearest_create()
|
||||||
knn.train(trainData, responses)
|
knn.train(trainData, cv2.ml.ROW_SAMPLE, responses)
|
||||||
ret, result, neighbours, dist = knn.find_nearest(testData, k=5)
|
ret, result, neighbours, dist = knn.findNearest(testData, k=5)
|
||||||
|
|
||||||
correct = np.count_nonzero(result == labels)
|
correct = np.count_nonzero(result == labels)
|
||||||
accuracy = correct*100.0/10000
|
accuracy = correct*100.0/10000
|
||||||
|
@ -114,9 +114,9 @@ So let's see how it works. New comer is marked in green color.
|
|||||||
newcomer = np.random.randint(0,100,(1,2)).astype(np.float32)
|
newcomer = np.random.randint(0,100,(1,2)).astype(np.float32)
|
||||||
plt.scatter(newcomer[:,0],newcomer[:,1],80,'g','o')
|
plt.scatter(newcomer[:,0],newcomer[:,1],80,'g','o')
|
||||||
|
|
||||||
knn = cv2.KNearest()
|
knn = cv2.ml.KNearest_create()
|
||||||
knn.train(trainData,responses)
|
knn.train(trainData, cv2.ml.ROW_SAMPLE, responses)
|
||||||
ret, results, neighbours ,dist = knn.find_nearest(newcomer, 3)
|
ret, results, neighbours ,dist = knn.findNearest(newcomer, 3)
|
||||||
|
|
||||||
print "result: ", results,"\n"
|
print "result: ", results,"\n"
|
||||||
print "neighbours: ", neighbours,"\n"
|
print "neighbours: ", neighbours,"\n"
|
||||||
@ -140,7 +140,7 @@ obtained as arrays.
|
|||||||
@code{.py}
|
@code{.py}
|
||||||
# 10 new comers
|
# 10 new comers
|
||||||
newcomers = np.random.randint(0,100,(10,2)).astype(np.float32)
|
newcomers = np.random.randint(0,100,(10,2)).astype(np.float32)
|
||||||
ret, results,neighbours,dist = knn.find_nearest(newcomer, 3)
|
ret, results,neighbours,dist = knn.findNearest(newcomer, 3)
|
||||||
# The results also will contain 10 labels.
|
# The results also will contain 10 labels.
|
||||||
@endcode
|
@endcode
|
||||||
Additional Resources
|
Additional Resources
|
||||||
|
@ -64,9 +64,6 @@ import numpy as np
|
|||||||
SZ=20
|
SZ=20
|
||||||
bin_n = 16 # Number of bins
|
bin_n = 16 # Number of bins
|
||||||
|
|
||||||
svm_params = dict( kernel_type = cv2.SVM_LINEAR,
|
|
||||||
svm_type = cv2.SVM_C_SVC,
|
|
||||||
C=2.67, gamma=5.383 )
|
|
||||||
|
|
||||||
affine_flags = cv2.WARP_INVERSE_MAP|cv2.INTER_LINEAR
|
affine_flags = cv2.WARP_INVERSE_MAP|cv2.INTER_LINEAR
|
||||||
|
|
||||||
@ -105,8 +102,13 @@ hogdata = [map(hog,row) for row in deskewed]
|
|||||||
trainData = np.float32(hogdata).reshape(-1,64)
|
trainData = np.float32(hogdata).reshape(-1,64)
|
||||||
responses = np.float32(np.repeat(np.arange(10),250)[:,np.newaxis])
|
responses = np.float32(np.repeat(np.arange(10),250)[:,np.newaxis])
|
||||||
|
|
||||||
svm = cv2.SVM()
|
svm = cv2.ml.SVM_create()
|
||||||
svm.train(trainData,responses, params=svm_params)
|
svm.setKernel(cv2.ml.SVM_LINEAR)
|
||||||
|
svm.setType(cv2.ml.SVM_C_SVC)
|
||||||
|
svm.setC(2.67)
|
||||||
|
svm.setGamma(5.383)
|
||||||
|
|
||||||
|
svm.train(trainData, cv2.ml.ROW_SAMPLE, responses)
|
||||||
svm.save('svm_data.dat')
|
svm.save('svm_data.dat')
|
||||||
|
|
||||||
###### Now testing ########################
|
###### Now testing ########################
|
||||||
@ -114,7 +116,7 @@ svm.save('svm_data.dat')
|
|||||||
deskewed = [map(deskew,row) for row in test_cells]
|
deskewed = [map(deskew,row) for row in test_cells]
|
||||||
hogdata = [map(hog,row) for row in deskewed]
|
hogdata = [map(hog,row) for row in deskewed]
|
||||||
testData = np.float32(hogdata).reshape(-1,bin_n*4)
|
testData = np.float32(hogdata).reshape(-1,bin_n*4)
|
||||||
result = svm.predict_all(testData)
|
result = svm.predict(testData)
|
||||||
|
|
||||||
####### Check Accuracy ########################
|
####### Check Accuracy ########################
|
||||||
mask = result==responses
|
mask = result==responses
|
||||||
|
Loading…
Reference in New Issue
Block a user