diff --git a/modules/video/include/opencv2/video/tracking.hpp b/modules/video/include/opencv2/video/tracking.hpp index eb5a6c7030..7f93a79a72 100644 --- a/modules/video/include/opencv2/video/tracking.hpp +++ b/modules/video/include/opencv2/video/tracking.hpp @@ -887,6 +887,43 @@ public: //bool update(InputArray image, CV_OUT Rect& boundingBox) CV_OVERRIDE; }; +/** @brief the VIT tracker is a super lightweight dnn-based general object tracking. + * + * VIT tracker is much faster and extremely lightweight due to special model structure, the model file is about 767KB. + * Model download link: https://github.com/opencv/opencv_zoo/tree/main/models/object_tracking_vittrack + * Author: PengyuLiu, 1872918507@qq.com + */ +class CV_EXPORTS_W TrackerVit : public Tracker +{ +protected: + TrackerVit(); // use ::create() +public: + virtual ~TrackerVit() CV_OVERRIDE; + + struct CV_EXPORTS_W_SIMPLE Params + { + CV_WRAP Params(); + CV_PROP_RW std::string net; + CV_PROP_RW int backend; + CV_PROP_RW int target; + CV_PROP_RW Scalar meanvalue; + CV_PROP_RW Scalar stdvalue; + }; + + /** @brief Constructor + @param parameters vit tracker parameters TrackerVit::Params + */ + static CV_WRAP + Ptr create(const TrackerVit::Params& parameters = TrackerVit::Params()); + + /** @brief Return tracking score + */ + CV_WRAP virtual float getTrackingScore() = 0; + + // void init(InputArray image, const Rect& boundingBox) CV_OVERRIDE; + // bool update(InputArray image, CV_OUT Rect& boundingBox) CV_OVERRIDE; +}; + //! @} video_track } // cv diff --git a/modules/video/src/tracking/tracker_vit.cpp b/modules/video/src/tracking/tracker_vit.cpp new file mode 100644 index 0000000000..7611e184ef --- /dev/null +++ b/modules/video/src/tracking/tracker_vit.cpp @@ -0,0 +1,219 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// Author, PengyuLiu, 1872918507@qq.com + +#include "../precomp.hpp" +#ifdef HAVE_OPENCV_DNN +#include "opencv2/dnn.hpp" +#endif + +namespace cv { + +TrackerVit::TrackerVit() +{ + // nothing +} + +TrackerVit::~TrackerVit() +{ + // nothing +} + +TrackerVit::Params::Params() +{ + net = "vitTracker.onnx"; + meanvalue = Scalar{0.485, 0.456, 0.406}; + stdvalue = Scalar{0.229, 0.224, 0.225}; +#ifdef HAVE_OPENCV_DNN + backend = dnn::DNN_BACKEND_DEFAULT; + target = dnn::DNN_TARGET_CPU; +#else + backend = -1; // invalid value + target = -1; // invalid value +#endif +} + +#ifdef HAVE_OPENCV_DNN + +class TrackerVitImpl : public TrackerVit +{ +public: + TrackerVitImpl(const TrackerVit::Params& parameters) + : params(parameters) + { + net = dnn::readNet(params.net); + CV_Assert(!net.empty()); + } + + void init(InputArray image, const Rect& boundingBox) CV_OVERRIDE; + bool update(InputArray image, Rect& boundingBox) CV_OVERRIDE; + float getTrackingScore() CV_OVERRIDE; + + Rect rect_last; + float tracking_score; + + TrackerVit::Params params; + + +protected: + void preprocess(const Mat& src, Mat& dst, Size size); + + const Size searchSize{256, 256}; + const Size templateSize{128, 128}; + + Mat hanningWindow; + + dnn::Net net; + Mat image; +}; + +static void crop_image(const Mat& src, Mat& dst, Rect box, int factor) +{ + int x = box.x, y = box.y, w = box.width, h = box.height; + int crop_sz = ceil(sqrt(w * h) * factor); + + int x1 = round(x + 0.5 * w - crop_sz * 0.5); + int x2 = x1 + crop_sz; + int y1 = round(y + 0.5 * h - crop_sz * 0.5); + int y2 = y1 + crop_sz; + + int x1_pad = std::max(0, -x1); + int y1_pad = std::max(0, -y1); + int x2_pad = std::max(x2 - src.size[1] + 1, 0); + int y2_pad = std::max(y2 - src.size[0] + 1, 0); + + Rect roi(x1 + x1_pad, y1 + y1_pad, x2 - x2_pad - x1 - x1_pad, y2 - y2_pad - y1 - y1_pad); + Mat im_crop = src(roi); + copyMakeBorder(im_crop, dst, y1_pad, y2_pad, x1_pad, x2_pad, BORDER_CONSTANT); +} + +void TrackerVitImpl::preprocess(const Mat& src, Mat& dst, Size size) +{ + Mat mean = Mat(size, CV_32FC3, params.meanvalue); + Mat std = Mat(size, CV_32FC3, params.stdvalue); + mean = dnn::blobFromImage(mean, 1.0, Size(), Scalar(), false); + std = dnn::blobFromImage(std, 1.0, Size(), Scalar(), false); + + Mat img; + resize(src, img, size); + + dst = dnn::blobFromImage(img, 1.0, Size(), Scalar(), false); + dst /= 255; + dst = (dst - mean) / std; +} + +static Mat hann1d(int sz, bool centered = true) { + Mat hanningWindow(sz, 1, CV_32FC1); + float* data = hanningWindow.ptr(0); + + if(centered) { + for(int i = 0; i < sz; i++) { + float val = 0.5 * (1 - std::cos((2 * M_PI / (sz + 1)) * (i + 1))); + data[i] = val; + } + } + else { + int half_sz = sz / 2; + for(int i = 0; i <= half_sz; i++) { + float val = 0.5 * (1 + std::cos((2 * M_PI / (sz + 2)) * i)); + data[i] = val; + data[sz - 1 - i] = val; + } + } + + return hanningWindow; +} + +static Mat hann2d(Size size, bool centered = true) { + int rows = size.height; + int cols = size.width; + + Mat hanningWindowRows = hann1d(rows, centered); + Mat hanningWindowCols = hann1d(cols, centered); + + Mat hanningWindow = hanningWindowRows * hanningWindowCols.t(); + + return hanningWindow; +} + +static Rect returnfromcrop(float x, float y, float w, float h, Rect res_Last) +{ + int cropwindowwh = 4 * sqrt(res_Last.width * res_Last.height); + int x0 = res_Last.x + 0.5 * res_Last.width - 0.5 * cropwindowwh; + int y0 = res_Last.y + 0.5 * res_Last.height - 0.5 * cropwindowwh; + Rect finalres; + finalres.x = x * cropwindowwh + x0; + finalres.y = y * cropwindowwh + y0; + finalres.width = w * cropwindowwh; + finalres.height = h * cropwindowwh; + return finalres; +} + +void TrackerVitImpl::init(InputArray image_, const Rect &boundingBox_) +{ + image = image_.getMat().clone(); + Mat crop; + crop_image(image, crop, boundingBox_, 2); + Mat blob; + preprocess(crop, blob, templateSize); + net.setInput(blob, "template"); + Size size(16, 16); + hanningWindow = hann2d(size, false); + rect_last = boundingBox_; +} + +bool TrackerVitImpl::update(InputArray image_, Rect &boundingBoxRes) +{ + image = image_.getMat().clone(); + Mat crop; + crop_image(image, crop, rect_last, 4); + Mat blob; + preprocess(crop, blob, searchSize); + net.setInput(blob, "search"); + std::vector outputName = {"output1", "output2", "output3"}; + std::vector outs; + net.forward(outs, outputName); + CV_Assert(outs.size() == 3); + + Mat conf_map = outs[0].reshape(0, {16, 16}); + Mat size_map = outs[1].reshape(0, {2, 16, 16}); + Mat offset_map = outs[2].reshape(0, {2, 16, 16}); + + multiply(conf_map, (1.0 - hanningWindow), conf_map); + + double maxVal; + Point maxLoc; + minMaxLoc(conf_map, nullptr, &maxVal, nullptr, &maxLoc); + tracking_score = maxVal; + + float cx = (maxLoc.x + offset_map.at(0, maxLoc.y, maxLoc.x)) / 16; + float cy = (maxLoc.y + offset_map.at(1, maxLoc.y, maxLoc.x)) / 16; + float w = size_map.at(0, maxLoc.y, maxLoc.x); + float h = size_map.at(1, maxLoc.y, maxLoc.x); + + Rect finalres = returnfromcrop(cx - w / 2, cy - h / 2, w, h, rect_last); + rect_last = finalres; + boundingBoxRes = finalres; + return true; +} + +float TrackerVitImpl::getTrackingScore() +{ + return tracking_score; +} + +Ptr TrackerVit::create(const TrackerVit::Params& parameters) +{ + return makePtr(parameters); +} + +#else // OPENCV_HAVE_DNN +Ptr TrackerVit::create(const TrackerVit::Params& parameters) +{ + CV_UNUSED(parameters); + CV_Error(Error::StsNotImplemented, "to use vittrack, the tracking module needs to be built with opencv_dnn !"); +} +#endif // OPENCV_HAVE_DNN +} diff --git a/modules/video/test/test_trackers.cpp b/modules/video/test/test_trackers.cpp index 6ede40896c..4970343099 100644 --- a/modules/video/test/test_trackers.cpp +++ b/modules/video/test/test_trackers.cpp @@ -160,4 +160,13 @@ TEST(NanoTrack, accuracy_NanoTrack_V2) checkTrackingAccuracy(tracker, 0.69); } +TEST(vittrack, accuracy_vittrack) +{ + std::string model = cvtest::findDataFile("dnn/onnx/models/vitTracker.onnx", false); + cv::TrackerVit::Params params; + params.net = model; + cv::Ptr tracker = TrackerVit::create(params); + checkTrackingAccuracy(tracker, 0.67); +} + }} // namespace opencv_test:: diff --git a/samples/dnn/vit_tracker.cpp b/samples/dnn/vit_tracker.cpp new file mode 100644 index 0000000000..02e5cea83f --- /dev/null +++ b/samples/dnn/vit_tracker.cpp @@ -0,0 +1,176 @@ +// VitTracker +// model: https://github.com/opencv/opencv_zoo/tree/main/models/object_tracking_vittrack + +#include +#include + +#include +#include +#include +#include + +using namespace cv; +using namespace cv::dnn; + +const char *keys = + "{ help h | | Print help message }" + "{ input i | | Full path to input video folder, the specific camera index. (empty for camera 0) }" + "{ net | vitTracker.onnx | Path to onnx model of vitTracker.onnx}" + "{ backend | 0 | Choose one of computation backends: " + "0: automatically (by default), " + "1: Halide language (http://halide-lang.org/), " + "2: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), " + "3: OpenCV implementation, " + "4: VKCOM, " + "5: CUDA }," + "{ target | 0 | Choose one of target computation devices: " + "0: CPU target (by default), " + "1: OpenCL, " + "2: OpenCL fp16 (half-float precision), " + "3: VPU, " + "4: Vulkan, " + "6: CUDA, " + "7: CUDA fp16 (half-float preprocess) }" +; + +static +int run(int argc, char** argv) +{ + // Parse command line arguments. + CommandLineParser parser(argc, argv, keys); + + if (parser.has("help")) + { + parser.printMessage(); + return 0; + } + + std::string inputName = parser.get("input"); + std::string net = parser.get("net"); + int backend = parser.get("backend"); + int target = parser.get("target"); + + Ptr tracker; + try + { + TrackerVit::Params params; + params.net = samples::findFile(net); + params.backend = backend; + params.target = target; + tracker = TrackerVit::create(params); + } + catch (const cv::Exception& ee) + { + std::cerr << "Exception: " << ee.what() << std::endl; + std::cout << "Can't load the network by using the following files:" << std::endl; + std::cout << "net : " << net << std::endl; + return 2; + } + + const std::string winName = "vitTracker"; + namedWindow(winName, WINDOW_AUTOSIZE); + + // Open a video file or an image file or a camera stream. + VideoCapture cap; + + if (inputName.empty() || (isdigit(inputName[0]) && inputName.size() == 1)) + { + int c = inputName.empty() ? 0 : inputName[0] - '0'; + std::cout << "Trying to open camera #" << c << " ..." << std::endl; + if (!cap.open(c)) + { + std::cout << "Capture from camera #" << c << " didn't work. Specify -i=