samples(python): update tracker.py

This commit is contained in:
Alexander Alekhin 2021-12-29 00:44:56 +00:00
parent 3b0ed61826
commit 7f075b0b15

View File

@ -1,5 +1,4 @@
#!/usr/bin/env python #!/usr/bin/env python
''' '''
Tracker demo Tracker demo
@ -36,43 +35,49 @@ class App(object):
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
self.trackerAlgorithm = args.tracker_algo
self.tracker = self.createTracker()
def initializeTracker(self, image, trackerAlgorithm): def createTracker(self):
if self.trackerAlgorithm == 'mil':
tracker = cv.TrackerMIL_create()
elif self.trackerAlgorithm == 'goturn':
params = cv.TrackerGOTURN_Params()
params.modelTxt = self.args.goturn
params.modelBin = self.args.goturn_model
tracker = cv.TrackerGOTURN_create(params)
elif self.trackerAlgorithm == 'dasiamrpn':
params = cv.TrackerDaSiamRPN_Params()
params.model = self.args.dasiamrpn_net
params.kernel_cls1 = self.args.dasiamrpn_kernel_cls1
params.kernel_r1 = self.args.dasiamrpn_kernel_r1
tracker = cv.TrackerDaSiamRPN_create(params)
else:
sys.exit("Tracker {} is not recognized. Please use one of three available: mil, goturn, dasiamrpn.".format(self.trackerAlgorithm))
return tracker
def initializeTracker(self, image):
while True: while True:
if trackerAlgorithm == 'mil':
tracker = cv.TrackerMIL_create()
elif trackerAlgorithm == 'goturn':
params = cv.TrackerGOTURN_Params()
params.modelTxt = self.args.goturn
params.modelBin = self.args.goturn_model
tracker = cv.TrackerGOTURN_create(params)
elif trackerAlgorithm == 'dasiamrpn':
params = cv.TrackerDaSiamRPN_Params()
params.model = self.args.dasiamrpn_net
params.kernel_cls1 = self.args.dasiamrpn_kernel_cls1
params.kernel_r1 = self.args.dasiamrpn_kernel_r1
tracker = cv.TrackerDaSiamRPN_create(params)
else:
sys.exit("Tracker {} is not recognized. Please use one of three available: mil, goturn, dasiamrpn.".format(trackerAlgorithm))
print('==> Select object ROI for tracker ...') print('==> Select object ROI for tracker ...')
bbox = cv.selectROI('tracking', image) bbox = cv.selectROI('tracking', image)
print('ROI: {}'.format(bbox)) print('ROI: {}'.format(bbox))
if bbox[2] <= 0 or bbox[3] <= 0:
sys.exit("ROI selection cancelled. Exiting...")
try: try:
tracker.init(image, bbox) self.tracker.init(image, bbox)
except Exception as e: except Exception as e:
print('Unable to initialize tracker with requested bounding box. Is there any object?') print('Unable to initialize tracker with requested bounding box. Is there any object?')
print(e) print(e)
print('Try again ...') print('Try again ...')
continue continue
return tracker return
def run(self): def run(self):
videoPath = self.args.input videoPath = self.args.input
trackerAlgorithm = self.args.tracker_algo print('Using video: {}'.format(videoPath))
camera = create_capture(videoPath, presets['cube']) camera = create_capture(cv.samples.findFileOrKeep(videoPath), presets['cube'])
if not camera.isOpened(): if not camera.isOpened():
sys.exit("Can't open video stream: {}".format(videoPath)) sys.exit("Can't open video stream: {}".format(videoPath))
@ -82,7 +87,7 @@ class App(object):
assert image is not None assert image is not None
cv.namedWindow('tracking') cv.namedWindow('tracking')
tracker = self.initializeTracker(image, trackerAlgorithm) self.initializeTracker(image)
print("==> Tracking is started. Press 'SPACE' to re-initialize tracker or 'ESC' for exit...") print("==> Tracking is started. Press 'SPACE' to re-initialize tracker or 'ESC' for exit...")
@ -92,7 +97,7 @@ class App(object):
print("Can't read frame") print("Can't read frame")
break break
ok, newbox = tracker.update(image) ok, newbox = self.tracker.update(image)
#print(ok, newbox) #print(ok, newbox)
if ok: if ok:
@ -101,7 +106,7 @@ class App(object):
cv.imshow("tracking", image) cv.imshow("tracking", image)
k = cv.waitKey(1) k = cv.waitKey(1)
if k == 32: # SPACE if k == 32: # SPACE
tracker = self.initializeTracker(image) self.initializeTracker(image)
if k == 27: # ESC if k == 27: # ESC
break break
@ -112,22 +117,13 @@ if __name__ == '__main__':
print(__doc__) print(__doc__)
parser = argparse.ArgumentParser(description="Run tracker") parser = argparse.ArgumentParser(description="Run tracker")
parser.add_argument("--input", type=str, default="vtest.avi", help="Path to video source") parser.add_argument("--input", type=str, default="vtest.avi", help="Path to video source")
parser.add_argument("--tracker_algo", type=str, default="mil", help="One of three available tracking algorithms: mil, goturn, dasiamrpn") parser.add_argument("--tracker_algo", type=str, default="mil", help="One of available tracking algorithms: mil, goturn, dasiamrpn")
parser.add_argument("--goturn", type=str, default="goturn.prototxt", help="Path to GOTURN architecture") parser.add_argument("--goturn", type=str, default="goturn.prototxt", help="Path to GOTURN architecture")
parser.add_argument("--goturn_model", type=str, default="goturn.caffemodel", help="Path to GOTERN model") parser.add_argument("--goturn_model", type=str, default="goturn.caffemodel", help="Path to GOTERN model")
parser.add_argument("--dasiamrpn_net", type=str, default="dasiamrpn_model.onnx", help="Path to onnx model of DaSiamRPN net") parser.add_argument("--dasiamrpn_net", type=str, default="dasiamrpn_model.onnx", help="Path to onnx model of DaSiamRPN net")
parser.add_argument("--dasiamrpn_kernel_r1", type=str, default="dasiamrpn_kernel_r1.onnx", help="Path to onnx model of DaSiamRPN kernel_r1") parser.add_argument("--dasiamrpn_kernel_r1", type=str, default="dasiamrpn_kernel_r1.onnx", help="Path to onnx model of DaSiamRPN kernel_r1")
parser.add_argument("--dasiamrpn_kernel_cls1", type=str, default="dasiamrpn_kernel_cls1.onnx", help="Path to onnx model of DaSiamRPN kernel_cls1") parser.add_argument("--dasiamrpn_kernel_cls1", type=str, default="dasiamrpn_kernel_cls1.onnx", help="Path to onnx model of DaSiamRPN kernel_cls1")
parser.add_argument("--dasiamrpn_backend", type=int, default=0, help="Choose one of computation backends:\
0: automatically (by default),\
1: Halide language (http://halide-lang.org/),\
2: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit),\
3: OpenCV implementation")
parser.add_argument("--dasiamrpn_target", type=int, default=0, help="Choose one of target computation devices:\
0: CPU target (by default),\
1: OpenCL,\
2: OpenCL fp16 (half-float precision),\
3: VPU")
args = parser.parse_args() args = parser.parse_args()
App(args).run() App(args).run()
cv.destroyAllWindows() cv.destroyAllWindows()