mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 20:42:53 +08:00
dnn(pytest/test_input_3d): reload model between switching targets
This commit is contained in:
parent
c63aa7f085
commit
646924fce8
@ -323,20 +323,22 @@ class dnn_test(NewOpenCVTests):
|
|||||||
raise unittest.SkipTest("Missing DNN test files (dnn/onnx/data/{input/output}_hidden_lstm.npy). "
|
raise unittest.SkipTest("Missing DNN test files (dnn/onnx/data/{input/output}_hidden_lstm.npy). "
|
||||||
"Verify OPENCV_DNN_TEST_DATA_PATH configuration parameter.")
|
"Verify OPENCV_DNN_TEST_DATA_PATH configuration parameter.")
|
||||||
|
|
||||||
net = cv.dnn.readNet(model)
|
|
||||||
input = np.load(input_file)
|
input = np.load(input_file)
|
||||||
# we have to expand the shape of input tensor because Python bindings cut 3D tensors to 2D
|
# we have to expand the shape of input tensor because Python bindings cut 3D tensors to 2D
|
||||||
# it should be fixed in future. see : https://github.com/opencv/opencv/issues/19091
|
# it should be fixed in future. see : https://github.com/opencv/opencv/issues/19091
|
||||||
# please remove `expand_dims` after that
|
# please remove `expand_dims` after that
|
||||||
input = np.expand_dims(input, axis=3)
|
input = np.expand_dims(input, axis=3)
|
||||||
gold_output = np.load(output_file)
|
gold_output = np.load(output_file)
|
||||||
net.setInput(input)
|
|
||||||
|
|
||||||
for backend, target in self.dnnBackendsAndTargets:
|
for backend, target in self.dnnBackendsAndTargets:
|
||||||
printParams(backend, target)
|
printParams(backend, target)
|
||||||
|
|
||||||
|
net = cv.dnn.readNet(model)
|
||||||
|
|
||||||
net.setPreferableBackend(backend)
|
net.setPreferableBackend(backend)
|
||||||
net.setPreferableTarget(target)
|
net.setPreferableTarget(target)
|
||||||
|
|
||||||
|
net.setInput(input)
|
||||||
real_output = net.forward()
|
real_output = net.forward()
|
||||||
|
|
||||||
normAssert(self, real_output, gold_output, "", getDefaultThreshold(target))
|
normAssert(self, real_output, gold_output, "", getDefaultThreshold(target))
|
||||||
|
Loading…
Reference in New Issue
Block a user