mirror of
https://github.com/opencv/opencv.git
synced 2025-06-13 13:13:26 +08:00

Improved Tracker Samples #26202 Relates to #25006 This sample has been rewritten to track a selected target in a video or camera stream. It combines VIT tracker, Nano tracker and Dasiamrpn tracker into one tracker sample ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
201 lines
8.9 KiB
Python
201 lines
8.9 KiB
Python
#!/usr/bin/env python
|
|
import sys
|
|
import cv2 as cv
|
|
import argparse
|
|
from common import *
|
|
|
|
def help():
|
|
print(
|
|
'''
|
|
Use this script for testing Object Tracking using OpenCV.
|
|
Firstly, download required models using the download_models.py.
|
|
To run:
|
|
nanotrack:
|
|
Download Model: python download_models.py nanotrack
|
|
Example: python object_tracker.py nanotrack
|
|
vit:
|
|
Download Model: python download_models.py vit
|
|
Example: python object_tracker.py vit
|
|
or
|
|
python object_tracker.py
|
|
dasiamrpn:
|
|
Download Model: python download_models.py dasiamrpn
|
|
Example: python object_tracker.py dasiamrpn
|
|
To switch between models in runtime, make sure all the models are downloaded using download_models.py'''
|
|
)
|
|
|
|
def load_parser(model_name):
|
|
parser = argparse.ArgumentParser(add_help=False)
|
|
parser.add_argument('--zoo', default=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models.yml'),
|
|
help='An optional path to file with preprocessing parameters.')
|
|
parser.add_argument("--input", type=str, help="Path to video source")
|
|
args, _ = parser.parse_known_args()
|
|
|
|
add_preproc_args(args.zoo, parser, 'object_tracker', alias=model_name)
|
|
if model_name == "dasiamrpn":
|
|
add_preproc_args(args.zoo, parser, 'object_tracker', prefix="dasiamrpn_", alias="dasiamrpn")
|
|
add_preproc_args(args.zoo, parser, 'object_tracker', prefix="dasiamrpn_kernel_r1_", alias="dasiamrpn")
|
|
add_preproc_args(args.zoo, parser, 'object_tracker', prefix="dasiamrpn_kernel_cls_", alias="dasiamrpn")
|
|
elif model_name == "nanotrack":
|
|
add_preproc_args(args.zoo, parser, 'object_tracker', prefix="nanotrack_back_", alias="nanotrack")
|
|
add_preproc_args(args.zoo, parser, 'object_tracker', prefix="nanotrack_head_", alias="nanotrack")
|
|
elif model_name != "vit":
|
|
print("Pass the valid alias. Choices are { nanotrack, vit, dasiamrpn }")
|
|
exit(0)
|
|
parser = argparse.ArgumentParser(parents=[parser],
|
|
description='''
|
|
Firstly, download required models using `python download_models.py {modelName}`
|
|
Run using python object_tracker.py {modelName}.
|
|
''',
|
|
formatter_class=argparse.RawTextHelpFormatter)
|
|
return parser.parse_args()
|
|
|
|
def createTracker(model_name, args):
|
|
if model_name == 'dasiamrpn':
|
|
print("Using Dasiamrpn Tracker.")
|
|
params = cv.TrackerDaSiamRPN_Params()
|
|
params.model = findModel(args.dasiamrpn_model, args.dasiamrpn_sha1)
|
|
params.kernel_cls1 = findModel(args.dasiamrpn_kernel_cls_model, args.dasiamrpn_kernel_cls_sha1)
|
|
params.kernel_r1 = findModel(args.dasiamrpn_kernel_r1_model, args.dasiamrpn_kernel_r1_sha1)
|
|
tracker = cv.TrackerDaSiamRPN_create(params)
|
|
elif model_name == 'nanotrack':
|
|
print("Using Nano Tracker.")
|
|
params = cv.TrackerNano_Params()
|
|
params.backbone = findModel(args.nanotrack_back_model, args.nanotrack_back_sha1)
|
|
params.neckhead = findModel(args.nanotrack_head_model, args.nanotrack_head_sha1)
|
|
tracker = cv.TrackerNano_create(params)
|
|
elif model_name == 'vit':
|
|
print("Using Vit Tracker.")
|
|
params = cv.TrackerVit_Params()
|
|
params.net = findModel(args.model, args.sha1)
|
|
tracker = cv.TrackerVit_create(params)
|
|
else:
|
|
help()
|
|
exit(-1)
|
|
return tracker
|
|
|
|
def main(model_name, args):
|
|
tracker = createTracker(model_name, args)
|
|
videoPath = args.input
|
|
print('Using video: {}'.format(videoPath))
|
|
cap = cv.VideoCapture(cv.samples.findFile(args.input) if args.input else 0)
|
|
if not cap.isOpened():
|
|
print("Can't open video stream: {}".format(videoPath))
|
|
exit(-1)
|
|
|
|
stdSize = 0.6
|
|
stdWeight = 2
|
|
stdImgSize = 512
|
|
imgWidth = -1 # Initialization
|
|
fontSize = 1.5
|
|
fontThickness = 1
|
|
alpha = 0.5
|
|
windowName = "TRACKING"
|
|
cv.namedWindow(windowName, cv.WINDOW_NORMAL)
|
|
|
|
while True:
|
|
ret, image = cap.read()
|
|
if not ret:
|
|
print("Video completed!!")
|
|
return -1
|
|
if imgWidth == -1:
|
|
imgWidth = min(image.shape[:2])
|
|
fontSize = min(fontSize, (stdSize*imgWidth)/stdImgSize)
|
|
fontThickness = max(fontThickness,(stdWeight*imgWidth)//stdImgSize)
|
|
label = "Press space bar to pause video to draw bounding box."
|
|
labelSize, _ = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, fontSize, fontThickness)
|
|
org_img = image.copy()
|
|
cv.rectangle(image, (0, 0), (labelSize[0]+10, labelSize[1]+int(40*fontSize)), (255,255,255), cv.FILLED)
|
|
cv.addWeighted(image, alpha, org_img, 1 - alpha, 0, image)
|
|
cv.putText(image, label, (10, int(25*fontSize)), cv.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
|
|
cv.putText(image, "Press space bar after selecting.", (10, int(55*fontSize)), cv.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
|
|
cv.imshow(windowName, image)
|
|
|
|
key = cv.waitKey(30) & 0xFF
|
|
if key == ord(' '):
|
|
bbox = cv.selectROI(windowName, image)
|
|
print('ROI: {}'.format(bbox))
|
|
if bbox != (0, 0, 0, 0):
|
|
break
|
|
|
|
if key == ord('q') or key == 27:
|
|
return
|
|
try:
|
|
tracker.init(image, bbox)
|
|
except Exception as e:
|
|
print('Unable to initialize tracker with requested bounding box. Is there any object?')
|
|
print(e)
|
|
|
|
tick_meter = cv.TickMeter()
|
|
while cap.isOpened():
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
if imgWidth == -1:
|
|
imgWidth = min(frame.shape[:2])
|
|
fontSize = min(fontSize, (stdSize*imgWidth)/stdImgSize)
|
|
fontThickness = max(fontThickness,(stdWeight*imgWidth)//stdImgSize)
|
|
label="Press space bar to select new target"
|
|
labelSize, _ = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, fontSize, fontThickness)
|
|
tick_meter.reset()
|
|
tick_meter.start()
|
|
ok, newbox = tracker.update(frame)
|
|
tick_meter.stop()
|
|
score = tracker.getTrackingScore()
|
|
render_image = frame.copy()
|
|
key = cv.waitKey(30) & 0xFF
|
|
h, w = frame.shape[:2]
|
|
cv.rectangle(render_image, (0, 0), (labelSize[0]+10, labelSize[1]+int(100*fontSize)), (255,255,255), cv.FILLED)
|
|
cv.rectangle(render_image, (0, int(h-45*fontSize)), (w, h), (255,255,255), cv.FILLED)
|
|
cv.addWeighted(render_image, alpha, frame, 1 - alpha, 0, render_image)
|
|
cv.putText(render_image, label, (10, int(25*fontSize)), cv.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
|
|
cv.putText(render_image, "For switching between trackers: press 'v' for ViT, 'n' for Nanotrack, and 'd' for DaSiamRPN.", (10, h-10), cv.FONT_HERSHEY_SIMPLEX, 0.8*fontSize, (0, 0, 0), fontThickness)
|
|
|
|
if ok:
|
|
if key == ord(' '):
|
|
cv.putText(render_image, "Select the new target", (10, int(55*fontSize)), cv.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
|
|
bbox = cv.selectROI(windowName, render_image)
|
|
print('ROI:', bbox)
|
|
if bbox != (0, 0, 0, 0):
|
|
tracker.init(frame, bbox)
|
|
elif key == ord('v'):
|
|
model_name = "vit"
|
|
args = load_parser(model_name)
|
|
tracker = createTracker(model_name, args)
|
|
tracker.init(frame, newbox)
|
|
elif key == ord('n'):
|
|
model_name = "nanotrack"
|
|
args = load_parser(model_name)
|
|
tracker = createTracker(model_name, args)
|
|
tracker.init(frame, newbox)
|
|
elif key == ord('d'):
|
|
model_name = "dasiamrpn"
|
|
args = load_parser(model_name)
|
|
tracker = createTracker(model_name, args)
|
|
tracker.init(frame, newbox)
|
|
elif key == ord('q') or key == 27:
|
|
return
|
|
|
|
cv.rectangle(render_image, newbox, (200, 0, 0), thickness=2)
|
|
time_label = f"FPS: {tick_meter.getFPS():.2f}"
|
|
score_label = f"Tracking score: {score:.2f}"
|
|
algo_label = f"Algorithm: {model_name}"
|
|
cv.putText(render_image, time_label, (10, int(55*fontSize)), cv.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
|
|
cv.putText(render_image, score_label, (10, int(85*fontSize)), cv.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
|
|
cv.putText(render_image, algo_label, (10, int(115*fontSize)), cv.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
|
|
|
|
cv.imshow(windowName, render_image)
|
|
if key in [ord('q'), 27]:
|
|
break
|
|
|
|
if __name__ == '__main__':
|
|
help()
|
|
if len(sys.argv) < 2 or sys.argv[1].startswith("--"):
|
|
model_name = "vit"
|
|
else:
|
|
model_name = sys.argv[1]
|
|
args = load_parser(model_name)
|
|
|
|
main(model_name, args)
|
|
cv.destroyAllWindows()
|