mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #20524 from yichenj:dnn_text_recognition_enhance
This commit is contained in:
commit
05d733e707
@ -26,6 +26,11 @@ Before recognition, you should `setVocabulary` and `setDecodeType`.
|
||||
- `T` is the sequence length
|
||||
- `B` is the batch size (only support `B=1` in inference)
|
||||
- and `Dim` is the length of vocabulary +1('Blank' of CTC is at the index=0 of Dim).
|
||||
- "CTC-prefix-beam-search", the output of the text recognition model should be a probability matrix same with "CTC-greedy".
|
||||
- The algorithm is proposed at Hannun's [paper](https://arxiv.org/abs/1408.2873).
|
||||
- `setDecodeOptsCTCPrefixBeamSearch` could be used to control the beam size in search step.
|
||||
- To futher optimize for big vocabulary, a new option `vocPruneSize` is introduced to avoid iterate the whole vocbulary
|
||||
but only the number of `vocPruneSize` tokens with top probabilty.
|
||||
|
||||
@ref cv::dnn::TextRecognitionModel::recognize() is the main function for text recognition.
|
||||
- The input image should be a cropped text image or an image with `roiRects`
|
||||
|
@ -1373,7 +1373,9 @@ public:
|
||||
|
||||
/**
|
||||
* @brief Set the decoding method of translating the network output into string
|
||||
* @param[in] decodeType The decoding method of translating the network output into string: {'CTC-greedy': greedy decoding for the output of CTC-based methods}
|
||||
* @param[in] decodeType The decoding method of translating the network output into string, currently supported type:
|
||||
* - `"CTC-greedy"` greedy decoding for the output of CTC-based methods
|
||||
* - `"CTC-prefix-beam-search"` Prefix beam search decoding for the output of CTC-based methods
|
||||
*/
|
||||
CV_WRAP
|
||||
TextRecognitionModel& setDecodeType(const std::string& decodeType);
|
||||
@ -1385,6 +1387,15 @@ public:
|
||||
CV_WRAP
|
||||
const std::string& getDecodeType() const;
|
||||
|
||||
/**
|
||||
* @brief Set the decoding method options for `"CTC-prefix-beam-search"` decode usage
|
||||
* @param[in] beamSize Beam size for search
|
||||
* @param[in] vocPruneSize Parameter to optimize big vocabulary search,
|
||||
* only take top @p vocPruneSize tokens in each search step, @p vocPruneSize <= 0 stands for disable this prune.
|
||||
*/
|
||||
CV_WRAP
|
||||
TextRecognitionModel& setDecodeOptsCTCPrefixBeamSearch(int beamSize, int vocPruneSize = 0);
|
||||
|
||||
/**
|
||||
* @brief Set the vocabulary for recognition.
|
||||
* @param[in] vocabulary the associated vocabulary of the network.
|
||||
|
83
modules/dnn/src/math_utils.hpp
Normal file
83
modules/dnn/src/math_utils.hpp
Normal file
@ -0,0 +1,83 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
// Code is borrowed from https://github.com/kaldi-asr/kaldi/blob/master/src/base/kaldi-math.h
|
||||
|
||||
// base/kaldi-math.h
|
||||
|
||||
// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Yanmin Qian;
|
||||
// Jan Silovsky; Saarland University
|
||||
//
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef __OPENCV_DNN_MATH_UTILS_HPP__
|
||||
#define __OPENCV_DNN_MATH_UTILS_HPP__
|
||||
|
||||
#ifdef OS_QNX
|
||||
#include <math.h>
|
||||
#else
|
||||
#include <cmath>
|
||||
#endif
|
||||
|
||||
#include <limits>
|
||||
|
||||
#ifndef FLT_EPSILON
|
||||
#define FLT_EPSILON 1.19209290e-7f
|
||||
#endif
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
|
||||
const float kNegativeInfinity = -std::numeric_limits<float>::infinity();
|
||||
|
||||
const float kMinLogDiffFloat = std::log(FLT_EPSILON);
|
||||
|
||||
#if !defined(_MSC_VER) || (_MSC_VER >= 1700)
|
||||
inline float Log1p(float x) { return log1pf(x); }
|
||||
#else
|
||||
inline float Log1p(float x) {
|
||||
const float cutoff = 1.0e-07;
|
||||
if (x < cutoff)
|
||||
return x - 2 * x * x;
|
||||
else
|
||||
return Log(1.0 + x);
|
||||
}
|
||||
#endif
|
||||
|
||||
inline float Exp(float x) { return expf(x); }
|
||||
|
||||
inline float LogAdd(float x, float y) {
|
||||
float diff;
|
||||
if (x < y) {
|
||||
diff = x - y;
|
||||
x = y;
|
||||
} else {
|
||||
diff = y - x;
|
||||
}
|
||||
// diff is negative. x is now the larger one.
|
||||
|
||||
if (diff >= kMinLogDiffFloat) {
|
||||
float res;
|
||||
res = x + Log1p(Exp(diff));
|
||||
return res;
|
||||
} else {
|
||||
return x; // return the larger one.
|
||||
}
|
||||
}
|
||||
|
||||
}} // namespace
|
||||
|
||||
#endif // __OPENCV_DNN_MATH_UTILS_HPP__
|
@ -3,8 +3,10 @@
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "precomp.hpp"
|
||||
#include "math_utils.hpp"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <iterator>
|
||||
|
||||
#include <opencv2/imgproc.hpp>
|
||||
@ -552,6 +554,9 @@ struct TextRecognitionModel_Impl : public Model::Impl
|
||||
std::string decodeType;
|
||||
std::vector<std::string> vocabulary;
|
||||
|
||||
int beamSize = 10;
|
||||
int vocPruneSize = 0;
|
||||
|
||||
TextRecognitionModel_Impl()
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
@ -575,6 +580,13 @@ struct TextRecognitionModel_Impl : public Model::Impl
|
||||
decodeType = type;
|
||||
}
|
||||
|
||||
inline
|
||||
void setDecodeOptsCTCPrefixBeamSearch(int beam, int vocPrune)
|
||||
{
|
||||
beamSize = beam;
|
||||
vocPruneSize = vocPrune;
|
||||
}
|
||||
|
||||
virtual
|
||||
std::string decode(const Mat& prediction)
|
||||
{
|
||||
@ -586,8 +598,23 @@ struct TextRecognitionModel_Impl : public Model::Impl
|
||||
CV_Error(Error::StsBadArg, "TextRecognitionModel: vocabulary is not specified");
|
||||
|
||||
std::string decodeSeq;
|
||||
if (decodeType == "CTC-greedy")
|
||||
if (decodeType == "CTC-greedy") {
|
||||
decodeSeq = ctcGreedyDecode(prediction);
|
||||
} else if (decodeType == "CTC-prefix-beam-search") {
|
||||
decodeSeq = ctcPrefixBeamSearchDecode(prediction);
|
||||
} else if (decodeType.length() == 0) {
|
||||
CV_Error(Error::StsBadArg, "Please set decodeType");
|
||||
} else {
|
||||
CV_Error_(Error::StsBadArg, ("Unsupported decodeType: %s", decodeType.c_str()));
|
||||
}
|
||||
|
||||
return decodeSeq;
|
||||
}
|
||||
|
||||
virtual
|
||||
std::string ctcGreedyDecode(const Mat& prediction)
|
||||
{
|
||||
std::string decodeSeq;
|
||||
CV_CheckEQ(prediction.dims, 3, "");
|
||||
CV_CheckType(prediction.type(), CV_32FC1, "");
|
||||
const int vocLength = (int)(vocabulary.size());
|
||||
@ -624,12 +651,157 @@ struct TextRecognitionModel_Impl : public Model::Impl
|
||||
ctcFlag = true;
|
||||
}
|
||||
}
|
||||
} else if (decodeType.length() == 0) {
|
||||
CV_Error(Error::StsBadArg, "Please set decodeType");
|
||||
} else {
|
||||
CV_Error_(Error::StsBadArg, ("Unsupported decodeType: %s", decodeType.c_str()));
|
||||
return decodeSeq;
|
||||
}
|
||||
|
||||
struct PrefixScore
|
||||
{
|
||||
// blank ending score
|
||||
float pB;
|
||||
// none blank ending score
|
||||
float pNB;
|
||||
|
||||
PrefixScore() : pB(kNegativeInfinity), pNB(kNegativeInfinity)
|
||||
{
|
||||
|
||||
}
|
||||
PrefixScore(float pB, float pNB) : pB(pB), pNB(pNB)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
struct PrefixHash
|
||||
{
|
||||
size_t operator()(const std::vector<int>& prefix) const
|
||||
{
|
||||
// BKDR hash
|
||||
unsigned int seed = 131;
|
||||
size_t hash = 0;
|
||||
for (size_t i = 0; i < prefix.size(); i++)
|
||||
{
|
||||
hash = hash * seed + prefix[i];
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
};
|
||||
|
||||
static
|
||||
std::vector<std::pair<float, int>> TopK(
|
||||
const float* predictions, int length, int k)
|
||||
{
|
||||
std::vector<std::pair<float, int>> results;
|
||||
// No prune.
|
||||
if (k <= 0)
|
||||
{
|
||||
for (int i = 0; i < length; ++i)
|
||||
{
|
||||
results.emplace_back(predictions[i], i);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
for (int i = 0; i < k; ++i)
|
||||
{
|
||||
results.emplace_back(predictions[i], i);
|
||||
}
|
||||
std::make_heap(results.begin(), results.end(), std::greater<std::pair<float, int>>{});
|
||||
|
||||
for (int i = k; i < length; ++i)
|
||||
{
|
||||
if (predictions[i] > results.front().first)
|
||||
{
|
||||
std::pop_heap(results.begin(), results.end(), std::greater<std::pair<float, int>>{});
|
||||
results.pop_back();
|
||||
results.emplace_back(predictions[i], i);
|
||||
std::push_heap(results.begin(), results.end(), std::greater<std::pair<float, int>>{});
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
static inline
|
||||
bool PrefixScoreCompare(
|
||||
const std::pair<std::vector<int>, PrefixScore>& a,
|
||||
const std::pair<std::vector<int>, PrefixScore>& b)
|
||||
{
|
||||
float probA = LogAdd(a.second.pB, a.second.pNB);
|
||||
float probB = LogAdd(b.second.pB, b.second.pNB);
|
||||
return probA > probB;
|
||||
}
|
||||
|
||||
virtual
|
||||
std::string ctcPrefixBeamSearchDecode(const Mat& prediction) {
|
||||
// CTC prefix beam seach decode.
|
||||
// For more detail, refer to:
|
||||
// https://distill.pub/2017/ctc/#inference
|
||||
// https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0i
|
||||
using Beam = std::vector<std::pair<std::vector<int>, PrefixScore>>;
|
||||
using BeamInDict = std::unordered_map<std::vector<int>, PrefixScore, PrefixHash>;
|
||||
|
||||
CV_CheckType(prediction.type(), CV_32FC1, "");
|
||||
CV_CheckEQ(prediction.dims, 3, "");
|
||||
CV_CheckEQ(prediction.size[1], 1, "");
|
||||
CV_CheckEQ(prediction.size[2], (int)vocabulary.size() + 1, ""); // Length add 1 for ctc blank
|
||||
|
||||
std::string decodeSeq;
|
||||
Beam beam = {std::make_pair(std::vector<int>(), PrefixScore(0.0, kNegativeInfinity))};
|
||||
for (int i = 0; i < prediction.size[0]; i++)
|
||||
{
|
||||
// Loop over time
|
||||
BeamInDict nextBeam;
|
||||
const float* pred = prediction.ptr<float>(i);
|
||||
std::vector<std::pair<float, int>> topkPreds =
|
||||
TopK(pred, vocabulary.size() + 1, vocPruneSize);
|
||||
for (const auto& each : topkPreds)
|
||||
{
|
||||
// Loop over vocabulary
|
||||
float prob = each.first;
|
||||
int token = each.second;
|
||||
for (const auto& it : beam)
|
||||
{
|
||||
const std::vector<int>& prefix = it.first;
|
||||
const PrefixScore& prefixScore = it.second;
|
||||
if (token == 0) // 0 stands for ctc blank
|
||||
{
|
||||
PrefixScore& nextScore = nextBeam[prefix];
|
||||
nextScore.pB = LogAdd(nextScore.pB,
|
||||
LogAdd(prefixScore.pB + prob, prefixScore.pNB + prob));
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<int> nPrefix(prefix);
|
||||
nPrefix.push_back(token);
|
||||
PrefixScore& nextScore = nextBeam[nPrefix];
|
||||
if (prefix.size() > 0 && token == prefix.back())
|
||||
{
|
||||
nextScore.pNB = LogAdd(nextScore.pNB, prefixScore.pB + prob);
|
||||
PrefixScore& mScore = nextBeam[prefix];
|
||||
mScore.pNB = LogAdd(mScore.pNB, prefixScore.pNB + prob);
|
||||
}
|
||||
else
|
||||
{
|
||||
nextScore.pNB = LogAdd(nextScore.pNB,
|
||||
LogAdd(prefixScore.pB + prob, prefixScore.pNB + prob));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Beam prune
|
||||
Beam newBeam(nextBeam.begin(), nextBeam.end());
|
||||
int newBeamSize = std::min(static_cast<int>(newBeam.size()), beamSize);
|
||||
std::nth_element(newBeam.begin(), newBeam.begin() + newBeamSize,
|
||||
newBeam.end(), PrefixScoreCompare);
|
||||
newBeam.resize(newBeamSize);
|
||||
std::sort(newBeam.begin(), newBeam.end(), PrefixScoreCompare);
|
||||
beam = std::move(newBeam);
|
||||
}
|
||||
|
||||
CV_Assert(!beam.empty());
|
||||
for (int token : beam[0].first)
|
||||
{
|
||||
CV_Check(token, token > 0 && token <= vocabulary.size(), "");
|
||||
decodeSeq += vocabulary.at(token - 1);
|
||||
}
|
||||
return decodeSeq;
|
||||
}
|
||||
|
||||
@ -698,6 +870,12 @@ const std::string& TextRecognitionModel::getDecodeType() const
|
||||
return TextRecognitionModel_Impl::from(impl).decodeType;
|
||||
}
|
||||
|
||||
TextRecognitionModel& TextRecognitionModel::setDecodeOptsCTCPrefixBeamSearch(int beamSize, int vocPruneSize)
|
||||
{
|
||||
TextRecognitionModel_Impl::from(impl).setDecodeOptsCTCPrefixBeamSearch(beamSize, vocPruneSize);
|
||||
return *this;
|
||||
}
|
||||
|
||||
TextRecognitionModel& TextRecognitionModel::setVocabulary(const std::vector<std::string>& inputVoc)
|
||||
{
|
||||
TextRecognitionModel_Impl::from(impl).setVocabulary(inputVoc);
|
||||
|
@ -615,6 +615,25 @@ TEST_P(Test_Model, TextRecognition)
|
||||
testTextRecognitionModel(weightPath, "", imgPath, seq, decodeType, vocabulary, size, mean, scale);
|
||||
}
|
||||
|
||||
TEST_P(Test_Model, TextRecognitionWithCTCPrefixBeamSearch)
|
||||
{
|
||||
if (target == DNN_TARGET_OPENCL_FP16)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
|
||||
|
||||
std::string imgPath = _tf("text_rec_test.png");
|
||||
std::string weightPath = _tf("onnx/models/crnn.onnx", false);
|
||||
std::string seq = "welcome";
|
||||
|
||||
Size size{100, 32};
|
||||
double scale = 1.0 / 127.5;
|
||||
Scalar mean = Scalar(127.5);
|
||||
std::string decodeType = "CTC-prefix-beam-search";
|
||||
std::vector<std::string> vocabulary = {"0","1","2","3","4","5","6","7","8","9",
|
||||
"a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z"};
|
||||
|
||||
testTextRecognitionModel(weightPath, "", imgPath, seq, decodeType, vocabulary, size, mean, scale);
|
||||
}
|
||||
|
||||
TEST_P(Test_Model, TextDetectionByDB)
|
||||
{
|
||||
if (target == DNN_TARGET_OPENCL_FP16)
|
||||
|
Loading…
Reference in New Issue
Block a user