Fixed download_models.py, made vit as default alias

This commit is contained in:
Gursimar Singh 2024-11-15 16:10:10 +05:30
parent 3c79143af2
commit 51086d9bb8
4 changed files with 47 additions and 21 deletions

View File

@ -331,21 +331,32 @@ def parseYAMLFile(yaml_filepath, save_dir, model_name):
with open(yaml_filepath, 'r') as stream:
data_loaded = yaml.safe_load(stream)
for name, params in data_loaded.items():
if name != model_name:
if model_name != "" and name != model_name:
continue
for key in params.keys():
if key.endswith("load_info"):
prefix = key[:-len('load_info')]
load_info = params.get(prefix+"load_info", None)
if load_info:
fname = os.path.basename(params.get(prefix+"model"))
hash_sum = load_info.get(prefix+"sha1")
url = load_info.get(prefix+"url")
download_sha = load_info.get(prefix+"download_sha")
download_name = load_info.get(prefix+"download_name")
archive_member = load_info.get(prefix+"member")
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir,
download_name=download_name, download_sha=download_sha, archive_member=archive_member))
print(prefix)
if prefix == "config_":
fname = os.path.basename(params.get("config"))
hash_sum = load_info.get("sha1")
url = load_info.get("url")
download_sha = load_info.get("download_sha")
download_name = load_info.get("download_name")
archive_member = load_info.get("member")
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir,
download_name=download_name, download_sha=download_sha, archive_member=archive_member))
else:
fname = os.path.basename(params.get(prefix+"model"))
hash_sum = load_info.get(prefix+"sha1")
url = load_info.get(prefix+"url")
download_sha = load_info.get(prefix+"download_sha")
download_name = load_info.get(prefix+"download_name")
archive_member = load_info.get(prefix+"member")
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir,
download_name=download_name, download_sha=download_sha, archive_member=archive_member))
return models

View File

@ -380,6 +380,7 @@ reid:
sha1: "d4316b100db40f8840aa82626e1cf3f519a7f1ae"
model: "person_reid_youtu_2021nov.onnx"
yolo_load_info:
yolo_url: "https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov8n.onnx"
yolo_sha1: "68f864475d06e2ec4037181052739f268eeac38d"
yolo_model: "yolov8n.onnx"
mean: [0.485, 0.456, 0.406]

View File

@ -26,7 +26,7 @@ const string about = "Use this script for testing Object Tracking using OpenCV.
const string param_keys =
"{ help h | | Print help message }"
"{ @alias | | An alias name of model to extract preprocessing parameters from models.yml file. }"
"{ @alias | vit | An alias name of model to extract preprocessing parameters from models.yml file. }"
"{ zoo | ../dnn/models.yml | An optional path to file with preprocessing parameters }"
"{ input i | | Full path to input video folder, the specific camera index. (empty for camera 0) }"
"{ tracking_thrs | 0.3 | Tracking score threshold. If a bbox of score >= 0.3, it is considered as found }";
@ -205,7 +205,14 @@ int main(int argc, char** argv)
if (key == ' ')
{
selectRect = selectROI(windowName, image);
break;
if (selectRect.width > 0 && selectRect.height > 0)
{
break;
}
else
{
cout << "No valid selection made. Please select again." << endl;
}
}
else if (key == 27) // ESC key to exit
{
@ -253,7 +260,12 @@ int main(int argc, char** argv)
if (key == ' '){
putText(render_image, "Select the new target", Point(10, 2*fontSize), Scalar(0,0,0), fontFace, fontSize, fontWeight);
selectRect = selectROI(windowName, render_image);
tracker->init(image, selectRect);
if (selectRect.width > 0 && selectRect.height > 0){
tracker->init(image, selectRect);
}
else{
cout<<"New target is not selected, switching to previous target"<<endl;
}
}
else if (key == 'v'){
modelName = "vit";
@ -279,7 +291,7 @@ int main(int argc, char** argv)
rectangle(render_image, rect, Scalar(0, 255, 0), 2);
}
string timeLabel = format("Inference time: %.2f ms", tickMeter.getTimeMilli());
string timeLabel = format("FPS: %.2f", tickMeter.getFPS());
string scoreLabel = format("Score: %f", score);
string algoLabel = "Algorithm: " + modelName;
putText(render_image, timeLabel, Point(10, 2*fontSize), Scalar(0,0,0), fontFace, fontSize, fontWeight);

View File

@ -16,6 +16,8 @@ def help():
vit:
Download Model: python download_models.py vit
Example: python object_tracker.py vit
or
python object_tracker.py
dasiamrpn:
Download Model: python download_models.py dasiamrpn
Example: python object_tracker.py dasiamrpn
@ -113,7 +115,7 @@ def main(model_name, args):
if key == ord(' '):
bbox = cv.selectROI(windowName, image)
print('ROI: {}'.format(bbox))
if bbox:
if bbox != (0, 0, 0, 0):
break
if key == ord('q') or key == 27:
@ -154,7 +156,7 @@ def main(model_name, args):
cv.putText(render_image, "Select the new target", (10, int(55*fontSize)), cv.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
bbox = cv.selectROI(windowName, render_image)
print('ROI:', bbox)
if bbox:
if bbox != (0, 0, 0, 0):
tracker.init(frame, bbox)
elif key == ord('v'):
model_name = "vit"
@ -175,7 +177,7 @@ def main(model_name, args):
return
cv.rectangle(render_image, newbox, (200, 0, 0), thickness=2)
time_label = f"Inference time: {tick_meter.getTimeMilli():.2f} ms"
time_label = f"FPS: {tick_meter.getFPS():.2f}"
score_label = f"Tracking score: {score:.2f}"
algo_label = f"Algorithm: {model_name}"
cv.putText(render_image, time_label, (10, int(55*fontSize)), cv.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
@ -187,11 +189,11 @@ def main(model_name, args):
break
if __name__ == '__main__':
if len(sys.argv) < 2:
help()
exit(-1)
model_name = sys.argv[1]
help()
if len(sys.argv) < 2 or sys.argv[1].startswith("--"):
model_name = "vit"
else:
model_name = sys.argv[1]
args = load_parser(model_name)
main(model_name, args)