From 51086d9bb82dfff2f7bba51ce96c7720e3594d8e Mon Sep 17 00:00:00 2001 From: Gursimar Singh Date: Fri, 15 Nov 2024 16:10:10 +0530 Subject: [PATCH] Fixed download_models.py, made vit as default alias --- samples/dnn/download_models.py | 29 ++++++++++++++++++++--------- samples/dnn/models.yml | 1 + samples/dnn/object_tracker.cpp | 20 ++++++++++++++++---- samples/dnn/object_tracker.py | 18 ++++++++++-------- 4 files changed, 47 insertions(+), 21 deletions(-) diff --git a/samples/dnn/download_models.py b/samples/dnn/download_models.py index f4d69f2eef..57243f9e04 100644 --- a/samples/dnn/download_models.py +++ b/samples/dnn/download_models.py @@ -331,21 +331,32 @@ def parseYAMLFile(yaml_filepath, save_dir, model_name): with open(yaml_filepath, 'r') as stream: data_loaded = yaml.safe_load(stream) for name, params in data_loaded.items(): - if name != model_name: + if model_name != "" and name != model_name: continue for key in params.keys(): if key.endswith("load_info"): prefix = key[:-len('load_info')] load_info = params.get(prefix+"load_info", None) if load_info: - fname = os.path.basename(params.get(prefix+"model")) - hash_sum = load_info.get(prefix+"sha1") - url = load_info.get(prefix+"url") - download_sha = load_info.get(prefix+"download_sha") - download_name = load_info.get(prefix+"download_name") - archive_member = load_info.get(prefix+"member") - models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir, - download_name=download_name, download_sha=download_sha, archive_member=archive_member)) + print(prefix) + if prefix == "config_": + fname = os.path.basename(params.get("config")) + hash_sum = load_info.get("sha1") + url = load_info.get("url") + download_sha = load_info.get("download_sha") + download_name = load_info.get("download_name") + archive_member = load_info.get("member") + models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir, + download_name=download_name, download_sha=download_sha, archive_member=archive_member)) + else: + fname = os.path.basename(params.get(prefix+"model")) + hash_sum = load_info.get(prefix+"sha1") + url = load_info.get(prefix+"url") + download_sha = load_info.get(prefix+"download_sha") + download_name = load_info.get(prefix+"download_name") + archive_member = load_info.get(prefix+"member") + models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir, + download_name=download_name, download_sha=download_sha, archive_member=archive_member)) return models diff --git a/samples/dnn/models.yml b/samples/dnn/models.yml index 15dc3ed4c5..bac49349af 100644 --- a/samples/dnn/models.yml +++ b/samples/dnn/models.yml @@ -380,6 +380,7 @@ reid: sha1: "d4316b100db40f8840aa82626e1cf3f519a7f1ae" model: "person_reid_youtu_2021nov.onnx" yolo_load_info: + yolo_url: "https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov8n.onnx" yolo_sha1: "68f864475d06e2ec4037181052739f268eeac38d" yolo_model: "yolov8n.onnx" mean: [0.485, 0.456, 0.406] diff --git a/samples/dnn/object_tracker.cpp b/samples/dnn/object_tracker.cpp index 38430f7369..b2b28933ea 100644 --- a/samples/dnn/object_tracker.cpp +++ b/samples/dnn/object_tracker.cpp @@ -26,7 +26,7 @@ const string about = "Use this script for testing Object Tracking using OpenCV. const string param_keys = "{ help h | | Print help message }" - "{ @alias | | An alias name of model to extract preprocessing parameters from models.yml file. }" + "{ @alias | vit | An alias name of model to extract preprocessing parameters from models.yml file. }" "{ zoo | ../dnn/models.yml | An optional path to file with preprocessing parameters }" "{ input i | | Full path to input video folder, the specific camera index. (empty for camera 0) }" "{ tracking_thrs | 0.3 | Tracking score threshold. If a bbox of score >= 0.3, it is considered as found }"; @@ -205,7 +205,14 @@ int main(int argc, char** argv) if (key == ' ') { selectRect = selectROI(windowName, image); - break; + if (selectRect.width > 0 && selectRect.height > 0) + { + break; + } + else + { + cout << "No valid selection made. Please select again." << endl; + } } else if (key == 27) // ESC key to exit { @@ -253,7 +260,12 @@ int main(int argc, char** argv) if (key == ' '){ putText(render_image, "Select the new target", Point(10, 2*fontSize), Scalar(0,0,0), fontFace, fontSize, fontWeight); selectRect = selectROI(windowName, render_image); - tracker->init(image, selectRect); + if (selectRect.width > 0 && selectRect.height > 0){ + tracker->init(image, selectRect); + } + else{ + cout<<"New target is not selected, switching to previous target"<