add scale factor to DB demo.

This commit is contained in:
zihaomu 2023-04-30 22:03:21 +08:00
parent e3e1f704a4
commit 8be93a6de7
4 changed files with 70 additions and 18 deletions

View File

@ -1422,7 +1422,7 @@ CV__DNN_INLINE_NS_BEGIN
/** @brief Set scalefactor value for frame.
* @param[in] scale Multiplier for frame values.
*/
CV_WRAP Model& setInputScale(double scale);
CV_WRAP Model& setInputScale(const Scalar& scale);
/** @brief Set flag crop for frame.
* @param[in] crop Flag which indicates whether image will be cropped after resize or not.

View File

@ -154,6 +154,21 @@ static inline std::string toString(const Mat& blob, const std::string& name = st
return ss.str();
}
// Scalefactor is a common parameter used for data scaling. In OpenCV, we often use Scalar to represent it.
// Because 0 is meaningless in scalefactor.
// If the scalefactor is (x, 0, 0, 0), we convert it to (x, x, x, x). The following func will do this hack.
static inline Scalar_<double> broadcastRealScalar(const Scalar_<double>& _scale)
{
Scalar_<double> scale = _scale;
if (scale[1] == 0 && scale[2] == 0 && scale[3] == 0)
{
CV_Assert(scale[0] != 0 && "Scalefactor of 0 is meaningless.");
scale = Scalar_<double>::all(scale[0]);
}
return scale;
}
CV__DNN_INLINE_NS_END

View File

@ -21,7 +21,7 @@ struct Model::Impl
Size size;
Scalar mean;
double scale = 1.0;
Scalar scale = Scalar::all(1.0);
bool swapRB = false;
bool crop = false;
Mat blob;
@ -60,7 +60,7 @@ public:
{
size = size_;
mean = mean_;
scale = scale_;
scale = Scalar::all(scale_);
crop = crop_;
swapRB = swapRB_;
}
@ -75,7 +75,7 @@ public:
mean = mean_;
}
/*virtual*/
void setInputScale(double scale_)
void setInputScale(const Scalar& scale_)
{
scale = scale_;
}
@ -97,7 +97,17 @@ public:
if (size.empty())
CV_Error(Error::StsBadSize, "Input size not specified");
blob = blobFromImage(frame, scale, size, mean, swapRB, crop);
Image2BlobParams param;
param.scalefactor = scale;
param.size = size;
param.mean = mean;
param.swapRB = swapRB;
if (crop)
{
param.paddingmode = DNN_PMODE_CROP_CENTER;
}
Mat blob = dnn::blobFromImageWithParams(frame, param); // [1, 10, 10, 4]
net.setInput(blob);
// Faster-RCNN or R-FCN
@ -162,9 +172,11 @@ Model& Model::setInputMean(const Scalar& mean)
return *this;
}
Model& Model::setInputScale(double scale)
Model& Model::setInputScale(const Scalar& scale_)
{
CV_DbgAssert(impl);
Scalar scale = broadcastRealScalar(scale_);
impl->setInputScale(scale);
return *this;
}
@ -1358,7 +1370,7 @@ struct TextDetectionModel_DB_Impl : public TextDetectionModel_Impl
{
CV_TRACE_FUNCTION();
std::vector< std::vector<Point2f> > results;
confidences.clear();
std::vector<Mat> outs;
processFrame(frame, outs);
CV_Assert(outs.size() == 1);
@ -1385,7 +1397,8 @@ struct TextDetectionModel_DB_Impl : public TextDetectionModel_Impl
std::vector<Point>& contour = contours[i];
// Calculate text contour score
if (contourScore(binary, contour) < polygonThreshold)
float score = contourScore(binary, contour);
if (score < polygonThreshold)
continue;
// Rescale
@ -1398,6 +1411,11 @@ struct TextDetectionModel_DB_Impl : public TextDetectionModel_Impl
// Unclip
RotatedRect box = minAreaRect(contourScaled);
float minLen = std::min(box.size.height/scaleWidth, box.size.width/scaleHeight);
// Filter very small boxes
if (minLen < 3)
continue;
// minArea() rect is not normalized, it may return rectangles with angle=-90 or height < width
const float angle_threshold = 60; // do not expect vertical text, TODO detection algo property
@ -1422,10 +1440,12 @@ struct TextDetectionModel_DB_Impl : public TextDetectionModel_Impl
approx.emplace_back(vertex[j]);
std::vector<Point2f> polygon;
unclip(approx, polygon, unclipRatio);
if (polygon.empty())
continue;
results.push_back(polygon);
confidences.push_back(score);
}
confidences = std::vector<float>(contours.size(), 1.0f);
return results;
}
@ -1458,7 +1478,10 @@ struct TextDetectionModel_DB_Impl : public TextDetectionModel_Impl
{
double area = contourArea(inPoly);
double length = arcLength(inPoly, true);
CV_Assert(length > FLT_EPSILON);
if(length == 0.)
return;
double distance = area * unclipRatio / length;
size_t numPoints = inPoly.size();

View File

@ -153,8 +153,8 @@ public:
const std::string& imgPath, const std::vector<std::vector<Point>>& gt,
float binThresh, float polyThresh,
uint maxCandidates, double unclipRatio,
const Size& size = {-1, -1}, Scalar mean = Scalar(),
double scale = 1.0, bool swapRB = false, bool crop = false)
const Size& size = {-1, -1}, Scalar mean = Scalar(), Scalar scale = Scalar::all(1.0),
double boxes_iou_diff = 0.05, bool swapRB = false, bool crop = false)
{
checkBackend();
@ -197,7 +197,7 @@ public:
imshow("result", result); // imwrite("result.png", result);
waitKey(0);
#endif
normAssertTextDetections(gt, contours, "", 0.05f);
normAssertTextDetections(gt, contours, "", boxes_iou_diff);
// 2. Check quadrangle-based API
// std::vector< std::vector<Point> > contours;
@ -209,7 +209,7 @@ public:
imshow("result_contours", result); // imwrite("result_contours.png", result);
waitKey(0);
#endif
normAssertTextDetections(gt, contours, "", 0.05f);
normAssertTextDetections(gt, contours, "", boxes_iou_diff);
}
void testTextDetectionModelByEAST(
@ -743,7 +743,8 @@ TEST_P(Test_Model, TextDetectionByDB)
applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
std::string imgPath = _tf("text_det_test1.png");
std::string weightPath = _tf("onnx/models/DB_TD500_resnet50.onnx", false);
std::string weightPathDB = _tf("onnx/models/DB_TD500_resnet50.onnx", false);
std::string weightPathPPDB = _tf("onnx/models/PP_OCRv3_DB_text_det.onnx", false);
// GroundTruth
std::vector<std::vector<Point>> gt = {
@ -752,15 +753,28 @@ TEST_P(Test_Model, TextDetectionByDB)
};
Size size{736, 736};
double scale = 1.0 / 255.0;
Scalar mean = Scalar(122.67891434, 116.66876762, 104.00698793);
Scalar scaleDB = Scalar::all(1.0 / 255.0);
Scalar meanDB = Scalar(122.67891434, 116.66876762, 104.00698793);
// new mean and stddev
Scalar meanPPDB = Scalar(123.675, 116.28, 103.53);
Scalar stddevPPDB = Scalar(0.229, 0.224, 0.225);
Scalar scalePPDB = scaleDB / stddevPPDB;
float binThresh = 0.3;
float polyThresh = 0.5;
uint maxCandidates = 200;
double unclipRatio = 2.0;
testTextDetectionModelByDB(weightPath, "", imgPath, gt, binThresh, polyThresh, maxCandidates, unclipRatio, size, mean, scale);
{
SCOPED_TRACE("Original DB");
testTextDetectionModelByDB(weightPathDB, "", imgPath, gt, binThresh, polyThresh, maxCandidates, unclipRatio, size, meanDB, scaleDB, 0.05f);
}
{
SCOPED_TRACE("PP-OCRDBv3");
testTextDetectionModelByDB(weightPathPPDB, "", imgPath, gt, binThresh, polyThresh, maxCandidates, unclipRatio, size, meanPPDB, scalePPDB, 0.21f);
}
}
TEST_P(Test_Model, TextDetectionByEAST)