mirror of
https://github.com/opencv/opencv.git
synced 2024-11-27 12:40:05 +08:00
Fixed download_models.py, made vit as default alias
This commit is contained in:
parent
3c79143af2
commit
51086d9bb8
@ -331,21 +331,32 @@ def parseYAMLFile(yaml_filepath, save_dir, model_name):
|
||||
with open(yaml_filepath, 'r') as stream:
|
||||
data_loaded = yaml.safe_load(stream)
|
||||
for name, params in data_loaded.items():
|
||||
if name != model_name:
|
||||
if model_name != "" and name != model_name:
|
||||
continue
|
||||
for key in params.keys():
|
||||
if key.endswith("load_info"):
|
||||
prefix = key[:-len('load_info')]
|
||||
load_info = params.get(prefix+"load_info", None)
|
||||
if load_info:
|
||||
fname = os.path.basename(params.get(prefix+"model"))
|
||||
hash_sum = load_info.get(prefix+"sha1")
|
||||
url = load_info.get(prefix+"url")
|
||||
download_sha = load_info.get(prefix+"download_sha")
|
||||
download_name = load_info.get(prefix+"download_name")
|
||||
archive_member = load_info.get(prefix+"member")
|
||||
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir,
|
||||
download_name=download_name, download_sha=download_sha, archive_member=archive_member))
|
||||
print(prefix)
|
||||
if prefix == "config_":
|
||||
fname = os.path.basename(params.get("config"))
|
||||
hash_sum = load_info.get("sha1")
|
||||
url = load_info.get("url")
|
||||
download_sha = load_info.get("download_sha")
|
||||
download_name = load_info.get("download_name")
|
||||
archive_member = load_info.get("member")
|
||||
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir,
|
||||
download_name=download_name, download_sha=download_sha, archive_member=archive_member))
|
||||
else:
|
||||
fname = os.path.basename(params.get(prefix+"model"))
|
||||
hash_sum = load_info.get(prefix+"sha1")
|
||||
url = load_info.get(prefix+"url")
|
||||
download_sha = load_info.get(prefix+"download_sha")
|
||||
download_name = load_info.get(prefix+"download_name")
|
||||
archive_member = load_info.get(prefix+"member")
|
||||
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir,
|
||||
download_name=download_name, download_sha=download_sha, archive_member=archive_member))
|
||||
|
||||
return models
|
||||
|
||||
|
@ -380,6 +380,7 @@ reid:
|
||||
sha1: "d4316b100db40f8840aa82626e1cf3f519a7f1ae"
|
||||
model: "person_reid_youtu_2021nov.onnx"
|
||||
yolo_load_info:
|
||||
yolo_url: "https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov8n.onnx"
|
||||
yolo_sha1: "68f864475d06e2ec4037181052739f268eeac38d"
|
||||
yolo_model: "yolov8n.onnx"
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
|
@ -26,7 +26,7 @@ const string about = "Use this script for testing Object Tracking using OpenCV.
|
||||
|
||||
const string param_keys =
|
||||
"{ help h | | Print help message }"
|
||||
"{ @alias | | An alias name of model to extract preprocessing parameters from models.yml file. }"
|
||||
"{ @alias | vit | An alias name of model to extract preprocessing parameters from models.yml file. }"
|
||||
"{ zoo | ../dnn/models.yml | An optional path to file with preprocessing parameters }"
|
||||
"{ input i | | Full path to input video folder, the specific camera index. (empty for camera 0) }"
|
||||
"{ tracking_thrs | 0.3 | Tracking score threshold. If a bbox of score >= 0.3, it is considered as found }";
|
||||
@ -205,7 +205,14 @@ int main(int argc, char** argv)
|
||||
if (key == ' ')
|
||||
{
|
||||
selectRect = selectROI(windowName, image);
|
||||
break;
|
||||
if (selectRect.width > 0 && selectRect.height > 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
cout << "No valid selection made. Please select again." << endl;
|
||||
}
|
||||
}
|
||||
else if (key == 27) // ESC key to exit
|
||||
{
|
||||
@ -253,7 +260,12 @@ int main(int argc, char** argv)
|
||||
if (key == ' '){
|
||||
putText(render_image, "Select the new target", Point(10, 2*fontSize), Scalar(0,0,0), fontFace, fontSize, fontWeight);
|
||||
selectRect = selectROI(windowName, render_image);
|
||||
tracker->init(image, selectRect);
|
||||
if (selectRect.width > 0 && selectRect.height > 0){
|
||||
tracker->init(image, selectRect);
|
||||
}
|
||||
else{
|
||||
cout<<"New target is not selected, switching to previous target"<<endl;
|
||||
}
|
||||
}
|
||||
else if (key == 'v'){
|
||||
modelName = "vit";
|
||||
@ -279,7 +291,7 @@ int main(int argc, char** argv)
|
||||
rectangle(render_image, rect, Scalar(0, 255, 0), 2);
|
||||
}
|
||||
|
||||
string timeLabel = format("Inference time: %.2f ms", tickMeter.getTimeMilli());
|
||||
string timeLabel = format("FPS: %.2f", tickMeter.getFPS());
|
||||
string scoreLabel = format("Score: %f", score);
|
||||
string algoLabel = "Algorithm: " + modelName;
|
||||
putText(render_image, timeLabel, Point(10, 2*fontSize), Scalar(0,0,0), fontFace, fontSize, fontWeight);
|
||||
|
@ -16,6 +16,8 @@ def help():
|
||||
vit:
|
||||
Download Model: python download_models.py vit
|
||||
Example: python object_tracker.py vit
|
||||
or
|
||||
python object_tracker.py
|
||||
dasiamrpn:
|
||||
Download Model: python download_models.py dasiamrpn
|
||||
Example: python object_tracker.py dasiamrpn
|
||||
@ -113,7 +115,7 @@ def main(model_name, args):
|
||||
if key == ord(' '):
|
||||
bbox = cv.selectROI(windowName, image)
|
||||
print('ROI: {}'.format(bbox))
|
||||
if bbox:
|
||||
if bbox != (0, 0, 0, 0):
|
||||
break
|
||||
|
||||
if key == ord('q') or key == 27:
|
||||
@ -154,7 +156,7 @@ def main(model_name, args):
|
||||
cv.putText(render_image, "Select the new target", (10, int(55*fontSize)), cv.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
|
||||
bbox = cv.selectROI(windowName, render_image)
|
||||
print('ROI:', bbox)
|
||||
if bbox:
|
||||
if bbox != (0, 0, 0, 0):
|
||||
tracker.init(frame, bbox)
|
||||
elif key == ord('v'):
|
||||
model_name = "vit"
|
||||
@ -175,7 +177,7 @@ def main(model_name, args):
|
||||
return
|
||||
|
||||
cv.rectangle(render_image, newbox, (200, 0, 0), thickness=2)
|
||||
time_label = f"Inference time: {tick_meter.getTimeMilli():.2f} ms"
|
||||
time_label = f"FPS: {tick_meter.getFPS():.2f}"
|
||||
score_label = f"Tracking score: {score:.2f}"
|
||||
algo_label = f"Algorithm: {model_name}"
|
||||
cv.putText(render_image, time_label, (10, int(55*fontSize)), cv.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
|
||||
@ -187,11 +189,11 @@ def main(model_name, args):
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) < 2:
|
||||
help()
|
||||
exit(-1)
|
||||
|
||||
model_name = sys.argv[1]
|
||||
help()
|
||||
if len(sys.argv) < 2 or sys.argv[1].startswith("--"):
|
||||
model_name = "vit"
|
||||
else:
|
||||
model_name = sys.argv[1]
|
||||
args = load_parser(model_name)
|
||||
|
||||
main(model_name, args)
|
||||
|
Loading…
Reference in New Issue
Block a user