mirror of
https://github.com/opencv/opencv.git
synced 2025-07-24 22:16:27 +08:00
Merge pull request #26875 from asmorkalov:as/in_memory_models
Added trackers factory with pre-loaded dnn models #26875 Replaces https://github.com/opencv/opencv/pull/26295 Allows to substitute custom models or initialize tracker from in-memory model. ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [ ] The PR is proposed to the proper branch - [ ] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [ ] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
a8df0a06ac
commit
ae25c3194f
@ -425,6 +425,8 @@ class FuncInfo(GeneralInfo):
|
||||
arg_fix_map = func_fix_map.get(arg[1], {})
|
||||
arg[0] = arg_fix_map.get('ctype', arg[0]) #fixing arg type
|
||||
arg[3] = arg_fix_map.get('attrib', arg[3]) #fixing arg attrib
|
||||
if arg[0] == 'dnn_Net':
|
||||
arg[0] = 'Net'
|
||||
self.args.append(ArgInfo(arg))
|
||||
|
||||
def fullClassJAVA(self):
|
||||
@ -474,7 +476,7 @@ class JavaWrapperGenerator(object):
|
||||
jni_name = "(*("+classinfo.fullNameCPP()+"*)%(n)s_nativeObj)"
|
||||
type_dict.setdefault(name, {}).update(
|
||||
{ "j_type" : classinfo.jname,
|
||||
"jn_type" : "long", "jn_args" : (("__int64", ".nativeObj"),),
|
||||
"jn_type" : "long", "jn_args" : (("__int64", ".getNativeObjAddr()"),),
|
||||
"jni_name" : jni_name,
|
||||
"jni_type" : "jlong",
|
||||
"suffix" : "J",
|
||||
@ -483,7 +485,7 @@ class JavaWrapperGenerator(object):
|
||||
)
|
||||
type_dict.setdefault(name+'*', {}).update(
|
||||
{ "j_type" : classinfo.jname,
|
||||
"jn_type" : "long", "jn_args" : (("__int64", ".nativeObj"),),
|
||||
"jn_type" : "long", "jn_args" : (("__int64", ".getNativeObjAddr()"),),
|
||||
"jni_name" : "&("+jni_name+")",
|
||||
"jni_type" : "jlong",
|
||||
"suffix" : "J",
|
||||
|
@ -24,11 +24,12 @@ class NewOpenCVTests(unittest.TestCase):
|
||||
# path to local repository folder containing 'samples' folder
|
||||
repoPath = None
|
||||
extraTestDataPath = None
|
||||
extraDnnTestDataPath = None
|
||||
# github repository url
|
||||
repoUrl = 'https://raw.github.com/opencv/opencv/4.x'
|
||||
|
||||
def find_file(self, filename, searchPaths=[], required=True):
|
||||
searchPaths = searchPaths if searchPaths else [self.repoPath, self.extraTestDataPath]
|
||||
searchPaths = searchPaths if searchPaths else [self.repoPath, self.extraTestDataPath, self.extraDnnTestDataPath]
|
||||
for path in searchPaths:
|
||||
if path is not None:
|
||||
candidate = path + '/' + filename
|
||||
@ -83,10 +84,17 @@ class NewOpenCVTests(unittest.TestCase):
|
||||
print("Testing OpenCV", cv.__version__)
|
||||
print("Local repo path:", args.repo)
|
||||
NewOpenCVTests.repoPath = args.repo
|
||||
|
||||
try:
|
||||
NewOpenCVTests.extraTestDataPath = os.environ['OPENCV_TEST_DATA_PATH']
|
||||
except KeyError:
|
||||
print('Missing opencv extra repository. Some of tests may fail.')
|
||||
|
||||
try:
|
||||
NewOpenCVTests.extraDnnTestDataPath = os.environ['OPENCV_DNN_TEST_DATA_PATH']
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
random.seed(0)
|
||||
unit_argv = [sys.argv[0]] + other
|
||||
unittest.main(argv=unit_argv)
|
||||
|
@ -46,6 +46,9 @@
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
#ifdef HAVE_OPENCV_DNN
|
||||
# include "opencv2/dnn.hpp"
|
||||
#endif
|
||||
|
||||
namespace cv
|
||||
{
|
||||
@ -826,6 +829,13 @@ public:
|
||||
static CV_WRAP
|
||||
Ptr<TrackerGOTURN> create(const TrackerGOTURN::Params& parameters = TrackerGOTURN::Params());
|
||||
|
||||
#ifdef HAVE_OPENCV_DNN
|
||||
/** @brief Constructor
|
||||
@param model pre-loaded GOTURN model
|
||||
*/
|
||||
static CV_WRAP Ptr<TrackerGOTURN> create(const dnn::Net& model);
|
||||
#endif
|
||||
|
||||
//void init(InputArray image, const Rect& boundingBox) CV_OVERRIDE;
|
||||
//bool update(InputArray image, CV_OUT Rect& boundingBox) CV_OVERRIDE;
|
||||
};
|
||||
@ -853,6 +863,16 @@ public:
|
||||
static CV_WRAP
|
||||
Ptr<TrackerDaSiamRPN> create(const TrackerDaSiamRPN::Params& parameters = TrackerDaSiamRPN::Params());
|
||||
|
||||
#ifdef HAVE_OPENCV_DNN
|
||||
/** @brief Constructor
|
||||
* @param siam_rpn pre-loaded SiamRPN model
|
||||
* @param kernel_cls1 pre-loaded CLS model
|
||||
* @param kernel_r1 pre-loaded R1 model
|
||||
*/
|
||||
static CV_WRAP
|
||||
Ptr<TrackerDaSiamRPN> create(const dnn::Net& siam_rpn, const dnn::Net& kernel_cls1, const dnn::Net& kernel_r1);
|
||||
#endif
|
||||
|
||||
/** @brief Return tracking score
|
||||
*/
|
||||
CV_WRAP virtual float getTrackingScore() = 0;
|
||||
@ -891,6 +911,15 @@ public:
|
||||
static CV_WRAP
|
||||
Ptr<TrackerNano> create(const TrackerNano::Params& parameters = TrackerNano::Params());
|
||||
|
||||
#ifdef HAVE_OPENCV_DNN
|
||||
/** @brief Constructor
|
||||
* @param backbone pre-loaded backbone model
|
||||
* @param neckhead pre-loaded neckhead model
|
||||
*/
|
||||
static CV_WRAP
|
||||
Ptr<TrackerNano> create(const dnn::Net& backbone, const dnn::Net& neckhead);
|
||||
#endif
|
||||
|
||||
/** @brief Return tracking score
|
||||
*/
|
||||
CV_WRAP virtual float getTrackingScore() = 0;
|
||||
@ -929,6 +958,18 @@ public:
|
||||
static CV_WRAP
|
||||
Ptr<TrackerVit> create(const TrackerVit::Params& parameters = TrackerVit::Params());
|
||||
|
||||
#ifdef HAVE_OPENCV_DNN
|
||||
/** @brief Constructor
|
||||
* @param model pre-loaded DNN model
|
||||
* @param meanvalue mean value for image preprocessing
|
||||
* @param stdvalue std value for image preprocessing
|
||||
* @param tracking_score_threshold threshold for tracking score
|
||||
*/
|
||||
static CV_WRAP
|
||||
Ptr<TrackerVit> create(const dnn::Net& model, Scalar meanvalue = Scalar(0.485, 0.456, 0.406),
|
||||
Scalar stdvalue = Scalar(0.229, 0.224, 0.225), float tracking_score_threshold = 0.20f);
|
||||
#endif
|
||||
|
||||
/** @brief Return tracking score
|
||||
*/
|
||||
CV_WRAP virtual float getTrackingScore() = 0;
|
||||
|
@ -1,31 +1,98 @@
|
||||
package org.opencv.test.video;
|
||||
|
||||
import java.io.File;
|
||||
import org.opencv.core.Core;
|
||||
import org.opencv.core.CvType;
|
||||
import org.opencv.core.CvException;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.Rect;
|
||||
import org.opencv.dnn.Dnn;
|
||||
import org.opencv.dnn.Net;
|
||||
import org.opencv.test.OpenCVTestCase;
|
||||
|
||||
import org.opencv.video.Tracker;
|
||||
import org.opencv.video.TrackerGOTURN;
|
||||
import org.opencv.video.TrackerGOTURN_Params;
|
||||
import org.opencv.video.TrackerNano;
|
||||
import org.opencv.video.TrackerNano_Params;
|
||||
import org.opencv.video.TrackerVit;
|
||||
import org.opencv.video.TrackerVit_Params;
|
||||
import org.opencv.video.TrackerMIL;
|
||||
|
||||
public class TrackerCreateTest extends OpenCVTestCase {
|
||||
|
||||
private final static String ENV_OPENCV_DNN_TEST_DATA_PATH = "OPENCV_DNN_TEST_DATA_PATH";
|
||||
private final static String ENV_OPENCV_TEST_DATA_PATH = "OPENCV_TEST_DATA_PATH";
|
||||
private String testDataPath;
|
||||
private String modelsDataPath;
|
||||
|
||||
@Override
|
||||
protected void setUp() throws Exception {
|
||||
super.setUp();
|
||||
|
||||
// relys on https://developer.android.com/reference/java/lang/System
|
||||
isTestCaseEnabled = System.getProperties().getProperty("java.vm.name") != "Dalvik";
|
||||
if (!isTestCaseEnabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
testDataPath = System.getenv(ENV_OPENCV_TEST_DATA_PATH);
|
||||
if (testDataPath == null) {
|
||||
throw new Exception(ENV_OPENCV_TEST_DATA_PATH + " has to be defined!");
|
||||
}
|
||||
|
||||
modelsDataPath = System.getenv(ENV_OPENCV_DNN_TEST_DATA_PATH);
|
||||
if (modelsDataPath == null) {
|
||||
modelsDataPath = testDataPath;
|
||||
}
|
||||
|
||||
if (isTestCaseEnabled) {
|
||||
testDataPath = System.getenv(ENV_OPENCV_DNN_TEST_DATA_PATH);
|
||||
if (testDataPath == null)
|
||||
testDataPath = System.getenv(ENV_OPENCV_TEST_DATA_PATH);
|
||||
if (testDataPath == null)
|
||||
throw new Exception(ENV_OPENCV_TEST_DATA_PATH + " has to be defined!");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public void testCreateTrackerGOTURN() {
|
||||
Net net;
|
||||
try {
|
||||
Tracker tracker = TrackerGOTURN.create();
|
||||
assert(tracker != null);
|
||||
String protoFile = new File(testDataPath, "dnn/gsoc2016-goturn/goturn.prototxt").toString();
|
||||
String weightsFile = new File(modelsDataPath, "dnn/gsoc2016-goturn/goturn.caffemodel").toString();
|
||||
net = Dnn.readNetFromCaffe(protoFile, weightsFile);
|
||||
} catch (CvException e) {
|
||||
// expected, model files may be missing
|
||||
return;
|
||||
}
|
||||
Tracker tracker = TrackerGOTURN.create(net);
|
||||
assert(tracker != null);
|
||||
}
|
||||
|
||||
public void testCreateTrackerNano() {
|
||||
Net backbone;
|
||||
Net neckhead;
|
||||
try {
|
||||
String backboneFile = new File(modelsDataPath, "dnn/onnx/models/nanotrack_backbone_sim_v2.onnx").toString();
|
||||
String neckheadFile = new File(modelsDataPath, "dnn/onnx/models/nanotrack_head_sim_v2.onnx").toString();
|
||||
backbone = Dnn.readNet(backboneFile);
|
||||
neckhead = Dnn.readNet(neckheadFile);
|
||||
} catch (CvException e) {
|
||||
return;
|
||||
}
|
||||
Tracker tracker = TrackerNano.create(backbone, neckhead);
|
||||
assert(tracker != null);
|
||||
}
|
||||
|
||||
public void testCreateTrackerVit() {
|
||||
Net net;
|
||||
try {
|
||||
String backboneFile = new File(modelsDataPath, "dnn/onnx/models/vitTracker.onnx").toString();
|
||||
net = Dnn.readNet(backboneFile);
|
||||
} catch (CvException e) {
|
||||
return;
|
||||
}
|
||||
Tracker tracker = TrackerVit.create(net);
|
||||
assert(tracker != null);
|
||||
}
|
||||
|
||||
public void testCreateTrackerMIL() {
|
||||
@ -35,5 +102,4 @@ public class TrackerCreateTest extends OpenCVTestCase {
|
||||
Rect rect = new Rect(10, 10, 30, 30);
|
||||
tracker.init(mat, rect); // should not crash (https://github.com/opencv/opencv/issues/19915)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -7,12 +7,30 @@ from tests_common import NewOpenCVTests, unittest
|
||||
|
||||
class tracking_test(NewOpenCVTests):
|
||||
|
||||
def test_createTracker(self):
|
||||
t = cv.TrackerMIL_create()
|
||||
try:
|
||||
t = cv.TrackerGOTURN_create()
|
||||
except cv.error as e:
|
||||
pass # may fail due to missing DL model files
|
||||
def test_createMILTracker(self):
|
||||
t = cv.TrackerMIL.create()
|
||||
self.assertTrue(t is not None)
|
||||
|
||||
def test_createGoturnTracker(self):
|
||||
proto = self.find_file("dnn/gsoc2016-goturn/goturn.prototxt", required=False);
|
||||
weights = self.find_file("dnn/gsoc2016-goturn/goturn.caffemodel", required=False);
|
||||
net = cv.dnn.readNet(proto, weights)
|
||||
t = cv.TrackerGOTURN.create(net)
|
||||
self.assertTrue(t is not None)
|
||||
|
||||
def test_createNanoTracker(self):
|
||||
backbone_path = self.find_file("dnn/onnx/models/nanotrack_backbone_sim_v2.onnx", required=False);
|
||||
neckhead_path = self.find_file("dnn/onnx/models/nanotrack_head_sim_v2.onnx", required=False);
|
||||
backbone = cv.dnn.readNet(backbone_path)
|
||||
neckhead = cv.dnn.readNet(neckhead_path)
|
||||
t = cv.TrackerNano.create(backbone, neckhead)
|
||||
self.assertTrue(t is not None)
|
||||
|
||||
def test_createVitTracker(self):
|
||||
model_path = self.find_file("dnn/onnx/models/vitTracker.onnx", required=False);
|
||||
model = cv.dnn.readNet(model_path)
|
||||
t = cv.TrackerVit.create(model)
|
||||
self.assertTrue(t is not None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -57,10 +57,8 @@ Mat sizeCal(const Mat& w, const Mat& h)
|
||||
class TrackerDaSiamRPNImpl : public TrackerDaSiamRPN
|
||||
{
|
||||
public:
|
||||
TrackerDaSiamRPNImpl(const TrackerDaSiamRPN::Params& parameters)
|
||||
: params(parameters)
|
||||
TrackerDaSiamRPNImpl(const TrackerDaSiamRPN::Params& params)
|
||||
{
|
||||
|
||||
siamRPN = dnn::readNet(params.model);
|
||||
siamKernelCL1 = dnn::readNet(params.kernel_cls1);
|
||||
siamKernelR1 = dnn::readNet(params.kernel_r1);
|
||||
@ -77,12 +75,21 @@ public:
|
||||
siamKernelCL1.setPreferableTarget(params.target);
|
||||
}
|
||||
|
||||
TrackerDaSiamRPNImpl(const dnn::Net& siam_rpn, const dnn::Net& kernel_cls1, const dnn::Net& kernel_r1)
|
||||
{
|
||||
CV_Assert(!siam_rpn.empty());
|
||||
CV_Assert(!kernel_cls1.empty());
|
||||
CV_Assert(!kernel_r1.empty());
|
||||
|
||||
siamRPN = siam_rpn;
|
||||
siamKernelCL1 = kernel_cls1;
|
||||
siamKernelR1 = kernel_r1;
|
||||
}
|
||||
|
||||
void init(InputArray image, const Rect& boundingBox) CV_OVERRIDE;
|
||||
bool update(InputArray image, Rect& boundingBox) CV_OVERRIDE;
|
||||
float getTrackingScore() CV_OVERRIDE;
|
||||
|
||||
TrackerDaSiamRPN::Params params;
|
||||
|
||||
protected:
|
||||
dnn::Net siamRPN, siamKernelR1, siamKernelCL1;
|
||||
Rect boundingBox_;
|
||||
@ -425,16 +432,22 @@ Mat TrackerDaSiamRPNImpl::getSubwindow(Mat& img, const Rect2f& targetBox, float
|
||||
|
||||
return zCrop;
|
||||
}
|
||||
|
||||
Ptr<TrackerDaSiamRPN> TrackerDaSiamRPN::create(const TrackerDaSiamRPN::Params& parameters)
|
||||
{
|
||||
return makePtr<TrackerDaSiamRPNImpl>(parameters);
|
||||
}
|
||||
|
||||
Ptr<TrackerDaSiamRPN> TrackerDaSiamRPN::create(const dnn::Net& siam_rpn, const dnn::Net& kernel_cls1, const dnn::Net& kernel_r1)
|
||||
{
|
||||
return makePtr<TrackerDaSiamRPNImpl>(siam_rpn, kernel_cls1, kernel_r1);
|
||||
}
|
||||
|
||||
#else // OPENCV_HAVE_DNN
|
||||
Ptr<TrackerDaSiamRPN> TrackerDaSiamRPN::create(const TrackerDaSiamRPN::Params& parameters)
|
||||
{
|
||||
(void)(parameters);
|
||||
CV_Error(cv::Error::StsNotImplemented, "to use GOTURN, the tracking module needs to be built with opencv_dnn !");
|
||||
CV_Error(cv::Error::StsNotImplemented, "to use DaSiamRPN, the tracking module needs to be built with opencv_dnn !");
|
||||
}
|
||||
#endif // OPENCV_HAVE_DNN
|
||||
}
|
||||
|
@ -31,11 +31,16 @@ TrackerGOTURN::Params::Params()
|
||||
class TrackerGOTURNImpl : public TrackerGOTURN
|
||||
{
|
||||
public:
|
||||
TrackerGOTURNImpl(const dnn::Net& model)
|
||||
{
|
||||
CV_Assert(!model.empty());
|
||||
net = model;
|
||||
}
|
||||
|
||||
TrackerGOTURNImpl(const TrackerGOTURN::Params& parameters)
|
||||
: params(parameters)
|
||||
{
|
||||
// Load GOTURN architecture from *.prototxt and pretrained weights from *.caffemodel
|
||||
net = dnn::readNetFromCaffe(params.modelTxt, params.modelBin);
|
||||
net = dnn::readNetFromCaffe(parameters.modelTxt, parameters.modelBin);
|
||||
CV_Assert(!net.empty());
|
||||
}
|
||||
|
||||
@ -49,8 +54,6 @@ public:
|
||||
boundingBox_ = boundingBox & Rect(Point(0, 0), image_.size());
|
||||
}
|
||||
|
||||
TrackerGOTURN::Params params;
|
||||
|
||||
dnn::Net net;
|
||||
Rect boundingBox_;
|
||||
Mat image_;
|
||||
@ -129,6 +132,11 @@ Ptr<TrackerGOTURN> TrackerGOTURN::create(const TrackerGOTURN::Params& parameters
|
||||
return makePtr<TrackerGOTURNImpl>(parameters);
|
||||
}
|
||||
|
||||
Ptr<TrackerGOTURN> TrackerGOTURN::create(const dnn::Net& model)
|
||||
{
|
||||
return makePtr<TrackerGOTURNImpl>(model);
|
||||
}
|
||||
|
||||
#else // OPENCV_HAVE_DNN
|
||||
Ptr<TrackerGOTURN> TrackerGOTURN::create(const TrackerGOTURN::Params& parameters)
|
||||
{
|
||||
|
@ -87,18 +87,26 @@ class TrackerNanoImpl : public TrackerNano
|
||||
{
|
||||
public:
|
||||
TrackerNanoImpl(const TrackerNano::Params& parameters)
|
||||
: params(parameters)
|
||||
{
|
||||
backbone = dnn::readNet(params.backbone);
|
||||
neckhead = dnn::readNet(params.neckhead);
|
||||
backbone = dnn::readNet(parameters.backbone);
|
||||
neckhead = dnn::readNet(parameters.neckhead);
|
||||
|
||||
CV_Assert(!backbone.empty());
|
||||
CV_Assert(!neckhead.empty());
|
||||
|
||||
backbone.setPreferableBackend(params.backend);
|
||||
backbone.setPreferableTarget(params.target);
|
||||
neckhead.setPreferableBackend(params.backend);
|
||||
neckhead.setPreferableTarget(params.target);
|
||||
backbone.setPreferableBackend(parameters.backend);
|
||||
backbone.setPreferableTarget(parameters.target);
|
||||
neckhead.setPreferableBackend(parameters.backend);
|
||||
neckhead.setPreferableTarget(parameters.target);
|
||||
}
|
||||
|
||||
TrackerNanoImpl(const dnn::Net& _backbone, const dnn::Net& _neckhead)
|
||||
{
|
||||
CV_Assert(!_backbone.empty());
|
||||
CV_Assert(!_neckhead.empty());
|
||||
|
||||
backbone = _backbone;
|
||||
neckhead = _neckhead;
|
||||
}
|
||||
|
||||
void init(InputArray image, const Rect& boundingBox) CV_OVERRIDE;
|
||||
@ -110,8 +118,6 @@ public:
|
||||
std::vector<float> targetPos = {0, 0}; // center point of bounding box (x, y)
|
||||
float tracking_score;
|
||||
|
||||
TrackerNano::Params params;
|
||||
|
||||
struct trackerConfig
|
||||
{
|
||||
float windowInfluence = 0.455f;
|
||||
@ -349,6 +355,11 @@ Ptr<TrackerNano> TrackerNano::create(const TrackerNano::Params& parameters)
|
||||
return makePtr<TrackerNanoImpl>(parameters);
|
||||
}
|
||||
|
||||
Ptr<TrackerNano> TrackerNano::create(const dnn::Net& backbone, const dnn::Net& neckhead)
|
||||
{
|
||||
return makePtr<TrackerNanoImpl>(backbone, neckhead);
|
||||
}
|
||||
|
||||
#else // OPENCV_HAVE_DNN
|
||||
Ptr<TrackerNano> TrackerNano::create(const TrackerNano::Params& parameters)
|
||||
{
|
||||
|
@ -42,16 +42,26 @@ class TrackerVitImpl : public TrackerVit
|
||||
{
|
||||
public:
|
||||
TrackerVitImpl(const TrackerVit::Params& parameters)
|
||||
: params(parameters)
|
||||
{
|
||||
net = dnn::readNet(params.net);
|
||||
net = dnn::readNet(parameters.net);
|
||||
CV_Assert(!net.empty());
|
||||
|
||||
net.setPreferableBackend(params.backend);
|
||||
net.setPreferableTarget(params.target);
|
||||
net.setPreferableBackend(parameters.backend);
|
||||
net.setPreferableTarget(parameters.target);
|
||||
|
||||
i2bp.mean = params.meanvalue * 255.0;
|
||||
i2bp.scalefactor = (1.0 / params.stdvalue) * (1 / 255.0);
|
||||
i2bp.mean = parameters.meanvalue * 255.0;
|
||||
i2bp.scalefactor = (1.0 / parameters.stdvalue) * (1 / 255.0);
|
||||
tracking_score_threshold = parameters.tracking_score_threshold;
|
||||
}
|
||||
|
||||
TrackerVitImpl(const dnn::Net& model, Scalar meanvalue, Scalar stdvalue, float _tracking_score_threshold)
|
||||
{
|
||||
CV_Assert(!model.empty());
|
||||
|
||||
net = model;
|
||||
i2bp.mean = meanvalue * 255.0;
|
||||
i2bp.scalefactor = (1.0 / stdvalue) * (1 / 255.0);
|
||||
tracking_score_threshold = _tracking_score_threshold;
|
||||
}
|
||||
|
||||
void init(InputArray image, const Rect& boundingBox) CV_OVERRIDE;
|
||||
@ -61,7 +71,7 @@ public:
|
||||
Rect rect_last;
|
||||
float tracking_score;
|
||||
|
||||
TrackerVit::Params params;
|
||||
float tracking_score_threshold;
|
||||
dnn::Image2BlobParams i2bp;
|
||||
|
||||
|
||||
@ -189,7 +199,7 @@ bool TrackerVitImpl::update(InputArray image_, Rect &boundingBoxRes)
|
||||
minMaxLoc(conf_map, nullptr, &maxVal, nullptr, &maxLoc);
|
||||
tracking_score = static_cast<float>(maxVal);
|
||||
|
||||
if (tracking_score >= params.tracking_score_threshold) {
|
||||
if (tracking_score >= tracking_score_threshold) {
|
||||
float cx = (maxLoc.x + offset_map.at<float>(0, maxLoc.y, maxLoc.x)) / 16;
|
||||
float cy = (maxLoc.y + offset_map.at<float>(1, maxLoc.y, maxLoc.x)) / 16;
|
||||
float w = size_map.at<float>(0, maxLoc.y, maxLoc.x);
|
||||
@ -213,6 +223,11 @@ Ptr<TrackerVit> TrackerVit::create(const TrackerVit::Params& parameters)
|
||||
return makePtr<TrackerVitImpl>(parameters);
|
||||
}
|
||||
|
||||
Ptr<TrackerVit> TrackerVit::create(const dnn::Net& model, Scalar meanvalue, Scalar stdvalue, float tracking_score_threshold)
|
||||
{
|
||||
return makePtr<TrackerVitImpl>(model, meanvalue, stdvalue, tracking_score_threshold);
|
||||
}
|
||||
|
||||
#else // OPENCV_HAVE_DNN
|
||||
Ptr<TrackerVit> TrackerVit::create(const TrackerVit::Params& parameters)
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user