mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2024-11-27 20:59:36 +08:00
More std::vector.
This commit is contained in:
parent
154ea6bab8
commit
9710bc0465
@ -221,7 +221,7 @@ void Tesseract::ambigs_classify_and_output(const char* label,
|
||||
ASSERT_HOST(best_choice != nullptr);
|
||||
|
||||
// Compute the number of unichars in the label.
|
||||
GenericVector<UNICHAR_ID> encoding;
|
||||
std::vector<UNICHAR_ID> encoding;
|
||||
if (!unicharset.encode_string(label, true, &encoding, nullptr, nullptr)) {
|
||||
tprintf("Not outputting illegal unichar %s\n", label);
|
||||
return;
|
||||
|
@ -78,8 +78,8 @@ void BlamerBundle::SetWordTruth(const UNICHARSET& unicharset,
|
||||
truth_word_.InsertBox(0, word_box);
|
||||
truth_has_char_boxes_ = false;
|
||||
// Encode the string as UNICHAR_IDs.
|
||||
GenericVector<UNICHAR_ID> encoding;
|
||||
GenericVector<char> lengths;
|
||||
std::vector<UNICHAR_ID> encoding;
|
||||
std::vector<char> lengths;
|
||||
unicharset.encode_string(truth_str, false, &encoding, &lengths, nullptr);
|
||||
int total_length = 0;
|
||||
for (int i = 0; i < encoding.size(); total_length += lengths[i++]) {
|
||||
|
@ -217,8 +217,8 @@ const char *ScriptPosToString(enum ScriptPos script_pos) {
|
||||
WERD_CHOICE::WERD_CHOICE(const char *src_string,
|
||||
const UNICHARSET &unicharset)
|
||||
: unicharset_(&unicharset){
|
||||
GenericVector<UNICHAR_ID> encoding;
|
||||
GenericVector<char> lengths;
|
||||
std::vector<UNICHAR_ID> encoding;
|
||||
std::vector<char> lengths;
|
||||
std::string cleaned = unicharset.CleanupString(src_string);
|
||||
if (unicharset.encode_string(cleaned.c_str(), true, &encoding, &lengths,
|
||||
nullptr)) {
|
||||
|
@ -130,7 +130,7 @@ void UnicharAmbigs::LoadUnicharAmbigs(const UNICHARSET& encoder_set,
|
||||
}
|
||||
// Update ambigs_for_adaption_.
|
||||
if (use_ambigs_for_adaption) {
|
||||
GenericVector<UNICHAR_ID> encoding;
|
||||
std::vector<UNICHAR_ID> encoding;
|
||||
// Silently ignore invalid strings, as before, so it is safe to use a
|
||||
// universal ambigs file.
|
||||
if (unicharset->encode_string(replacement_string, true, &encoding,
|
||||
@ -235,7 +235,7 @@ bool UnicharAmbigs::ParseAmbiguityLine(
|
||||
return false;
|
||||
}
|
||||
// Encode wrong-string.
|
||||
GenericVector<UNICHAR_ID> unichars;
|
||||
std::vector<UNICHAR_ID> unichars;
|
||||
if (!unicharset.encode_string(fields[0].c_str(), true, &unichars, nullptr,
|
||||
nullptr)) {
|
||||
return false;
|
||||
|
@ -98,6 +98,34 @@ class TESS_API TFile {
|
||||
return FWrite(data, sizeof(T), count) == count;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool Serialize(const std::vector<T>& data) {
|
||||
auto size_used_ = data.size();
|
||||
if (FWrite(&size_used_, sizeof(size_used_), 1) != 1) {
|
||||
return false;
|
||||
}
|
||||
if (FWrite(data.data(), sizeof(T), size_used_) != size_used_) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool DeSerialize(std::vector<T>& data) {
|
||||
uint32_t reserved;
|
||||
if (FReadEndian(&reserved, sizeof(reserved), 1) != 1) {
|
||||
return false;
|
||||
}
|
||||
// Arbitrarily limit the number of elements to protect against bad data.
|
||||
const uint32_t limit = 50000000;
|
||||
//assert(reserved <= limit);
|
||||
if (reserved > limit) {
|
||||
return false;
|
||||
}
|
||||
data.reserve(reserved);
|
||||
return FReadEndian(data.data(), sizeof(T), reserved) == reserved;
|
||||
}
|
||||
|
||||
// Skip data.
|
||||
bool Skip(size_t count);
|
||||
|
||||
|
@ -212,8 +212,8 @@ UNICHAR_ID UNICHARSET::unichar_to_id(const char* const unichar_repr,
|
||||
// WARNING: this function now encodes the whole string for precision.
|
||||
// Use encode_string in preference to repeatedly calling step.
|
||||
int UNICHARSET::step(const char* str) const {
|
||||
GenericVector<UNICHAR_ID> encoding;
|
||||
GenericVector<char> lengths;
|
||||
std::vector<UNICHAR_ID> encoding;
|
||||
std::vector<char> lengths;
|
||||
encode_string(str, true, &encoding, &lengths, nullptr);
|
||||
if (encoding.empty() || encoding[0] == INVALID_UNICHAR_ID) return 0;
|
||||
return lengths[0];
|
||||
@ -224,7 +224,7 @@ int UNICHARSET::step(const char* str) const {
|
||||
// into the second (return) argument.
|
||||
bool UNICHARSET::encodable_string(const char *str,
|
||||
int *first_bad_position) const {
|
||||
GenericVector<UNICHAR_ID> encoding;
|
||||
std::vector<UNICHAR_ID> encoding;
|
||||
return encode_string(str, true, &encoding, nullptr, first_bad_position);
|
||||
}
|
||||
|
||||
@ -238,13 +238,13 @@ bool UNICHARSET::encodable_string(const char *str,
|
||||
// that do not belong in the unicharset, or encoding may fail.
|
||||
// Use CleanupString to perform the cleaning.
|
||||
bool UNICHARSET::encode_string(const char* str, bool give_up_on_failure,
|
||||
GenericVector<UNICHAR_ID>* encoding,
|
||||
GenericVector<char>* lengths,
|
||||
std::vector<UNICHAR_ID>* encoding,
|
||||
std::vector<char>* lengths,
|
||||
int* encoded_length) const {
|
||||
GenericVector<UNICHAR_ID> working_encoding;
|
||||
GenericVector<char> working_lengths;
|
||||
GenericVector<char> best_lengths;
|
||||
encoding->truncate(0); // Just in case str is empty.
|
||||
std::vector<UNICHAR_ID> working_encoding;
|
||||
std::vector<char> working_lengths;
|
||||
std::vector<char> best_lengths;
|
||||
encoding->resize(0); // Just in case str is empty.
|
||||
int str_length = strlen(str);
|
||||
int str_pos = 0;
|
||||
bool perfect = true;
|
||||
@ -352,13 +352,13 @@ STRING UNICHARSET::debug_str(UNICHAR_ID id) const {
|
||||
// Sets the normed_ids vector from the normed string. normed_ids is not
|
||||
// stored in the file, and needs to be set when the UNICHARSET is loaded.
|
||||
void UNICHARSET::set_normed_ids(UNICHAR_ID unichar_id) {
|
||||
unichars[unichar_id].properties.normed_ids.truncate(0);
|
||||
unichars[unichar_id].properties.normed_ids.resize(0);
|
||||
if (unichar_id == UNICHAR_SPACE && id_to_unichar(unichar_id)[0] == ' ') {
|
||||
unichars[unichar_id].properties.normed_ids.push_back(UNICHAR_SPACE);
|
||||
} else if (!encode_string(unichars[unichar_id].properties.normed.c_str(),
|
||||
true, &unichars[unichar_id].properties.normed_ids,
|
||||
nullptr, nullptr)) {
|
||||
unichars[unichar_id].properties.normed_ids.truncate(0);
|
||||
unichars[unichar_id].properties.normed_ids.resize(0);
|
||||
unichars[unichar_id].properties.normed_ids.push_back(unichar_id);
|
||||
}
|
||||
}
|
||||
@ -481,11 +481,11 @@ bool UNICHARSET::SizesDistinct(UNICHAR_ID id1, UNICHAR_ID id2) const {
|
||||
// the overall process of encoding a partially failed string more efficient.
|
||||
// See unicharset.h for definition of the args.
|
||||
void UNICHARSET::encode_string(const char* str, int str_index, int str_length,
|
||||
GenericVector<UNICHAR_ID>* encoding,
|
||||
GenericVector<char>* lengths,
|
||||
std::vector<UNICHAR_ID>* encoding,
|
||||
std::vector<char>* lengths,
|
||||
int* best_total_length,
|
||||
GenericVector<UNICHAR_ID>* best_encoding,
|
||||
GenericVector<char>* best_lengths) const {
|
||||
std::vector<UNICHAR_ID>* best_encoding,
|
||||
std::vector<char>* best_lengths) const {
|
||||
if (str_index > *best_total_length) {
|
||||
// This is the best result so far.
|
||||
*best_total_length = str_index;
|
||||
@ -509,8 +509,8 @@ void UNICHARSET::encode_string(const char* str, int str_index, int str_length,
|
||||
if (*best_total_length == str_length)
|
||||
return; // Tail recursion success!
|
||||
// Failed with that length, truncate back and try again.
|
||||
encoding->truncate(encoding_index);
|
||||
lengths->truncate(encoding_index);
|
||||
encoding->resize(encoding_index);
|
||||
lengths->resize(encoding_index);
|
||||
}
|
||||
int step = UNICHAR::utf8_step(str + str_index + length);
|
||||
if (step == 0) step = 1;
|
||||
@ -528,7 +528,7 @@ bool UNICHARSET::GetStrProperties(const char* utf8_str,
|
||||
props->Init();
|
||||
props->SetRangesEmpty();
|
||||
int total_unicodes = 0;
|
||||
GenericVector<UNICHAR_ID> encoding;
|
||||
std::vector<UNICHAR_ID> encoding;
|
||||
if (!encode_string(utf8_str, true, &encoding, nullptr, nullptr))
|
||||
return false; // Some part was invalid.
|
||||
for (int i = 0; i < encoding.size(); ++i) {
|
||||
@ -611,7 +611,7 @@ void UNICHARSET::unichar_insert(const char* const unichar_repr,
|
||||
old_style_included_ ? unichar_repr : CleanupString(unichar_repr);
|
||||
if (!cleaned.empty() && !ids.contains(cleaned.data(), cleaned.size())) {
|
||||
const char* str = cleaned.c_str();
|
||||
GenericVector<int> encoding;
|
||||
std::vector<int> encoding;
|
||||
if (!old_style_included_ &&
|
||||
encode_string(str, true, &encoding, nullptr, nullptr))
|
||||
return;
|
||||
@ -950,7 +950,7 @@ void UNICHARSET::set_black_and_whitelist(const char* blacklist,
|
||||
unichars[ch].properties.enabled = def_enabled;
|
||||
if (!def_enabled) {
|
||||
// Enable the whitelist.
|
||||
GenericVector<UNICHAR_ID> encoding;
|
||||
std::vector<UNICHAR_ID> encoding;
|
||||
encode_string(whitelist, false, &encoding, nullptr, nullptr);
|
||||
for (int i = 0; i < encoding.size(); ++i) {
|
||||
if (encoding[i] != INVALID_UNICHAR_ID)
|
||||
@ -959,7 +959,7 @@ void UNICHARSET::set_black_and_whitelist(const char* blacklist,
|
||||
}
|
||||
if (blacklist != nullptr && blacklist[0] != '\0') {
|
||||
// Disable the blacklist.
|
||||
GenericVector<UNICHAR_ID> encoding;
|
||||
std::vector<UNICHAR_ID> encoding;
|
||||
encode_string(blacklist, false, &encoding, nullptr, nullptr);
|
||||
for (int i = 0; i < encoding.size(); ++i) {
|
||||
if (encoding[i] != INVALID_UNICHAR_ID)
|
||||
@ -968,7 +968,7 @@ void UNICHARSET::set_black_and_whitelist(const char* blacklist,
|
||||
}
|
||||
if (unblacklist != nullptr && unblacklist[0] != '\0') {
|
||||
// Re-enable the unblacklist.
|
||||
GenericVector<UNICHAR_ID> encoding;
|
||||
std::vector<UNICHAR_ID> encoding;
|
||||
encode_string(unblacklist, false, &encoding, nullptr, nullptr);
|
||||
for (int i = 0; i < encoding.size(); ++i) {
|
||||
if (encoding[i] != INVALID_UNICHAR_ID)
|
||||
|
@ -227,8 +227,8 @@ class TESS_API UNICHARSET {
|
||||
// that do not belong in the unicharset, or encoding may fail.
|
||||
// Use CleanupString to perform the cleaning.
|
||||
bool encode_string(const char* str, bool give_up_on_failure,
|
||||
GenericVector<UNICHAR_ID>* encoding,
|
||||
GenericVector<char>* lengths,
|
||||
std::vector<UNICHAR_ID>* encoding,
|
||||
std::vector<char>* lengths,
|
||||
int* encoded_length) const;
|
||||
|
||||
// Return the unichar representation corresponding to the given UNICHAR_ID
|
||||
@ -467,7 +467,7 @@ class TESS_API UNICHARSET {
|
||||
// Record normalized version of unichar with the given unichar_id.
|
||||
void set_normed(UNICHAR_ID unichar_id, const char* normed) {
|
||||
unichars[unichar_id].properties.normed = normed;
|
||||
unichars[unichar_id].properties.normed_ids.truncate(0);
|
||||
unichars[unichar_id].properties.normed_ids.resize(0);
|
||||
}
|
||||
// Sets the normed_ids vector from the normed string. normed_ids is not
|
||||
// stored in the file, and needs to be set when the UNICHARSET is loaded.
|
||||
@ -818,7 +818,7 @@ class TESS_API UNICHARSET {
|
||||
// Returns a vector of UNICHAR_IDs that represent the ids of the normalized
|
||||
// version of the given id. There may be more than one UNICHAR_ID in the
|
||||
// vector if unichar_id represents a ligature.
|
||||
const GenericVector<UNICHAR_ID>& normed_ids(UNICHAR_ID unichar_id) const {
|
||||
const std::vector<UNICHAR_ID>& normed_ids(UNICHAR_ID unichar_id) const {
|
||||
return unichars[unichar_id].properties.normed_ids;
|
||||
}
|
||||
|
||||
@ -946,7 +946,7 @@ class TESS_API UNICHARSET {
|
||||
// A string of unichar_ids that represent the corresponding normed string.
|
||||
// For awkward characters like em-dash, this gives hyphen.
|
||||
// For ligatures, this gives the string of normal unichars.
|
||||
GenericVector<UNICHAR_ID> normed_ids;
|
||||
std::vector<UNICHAR_ID> normed_ids;
|
||||
STRING normed; // normalized version of this unichar
|
||||
// Contains meta information about the fragment if a unichar represents
|
||||
// a fragment of a character, otherwise should be set to nullptr.
|
||||
@ -972,11 +972,11 @@ class TESS_API UNICHARSET {
|
||||
// best_encoding contains the encoding that used the longest part of str.
|
||||
// best_lengths (may be null) contains the lengths of best_encoding.
|
||||
void encode_string(const char* str, int str_index, int str_length,
|
||||
GenericVector<UNICHAR_ID>* encoding,
|
||||
GenericVector<char>* lengths,
|
||||
std::vector<UNICHAR_ID>* encoding,
|
||||
std::vector<char>* lengths,
|
||||
int* best_total_length,
|
||||
GenericVector<UNICHAR_ID>* best_encoding,
|
||||
GenericVector<char>* best_lengths) const;
|
||||
std::vector<UNICHAR_ID>* best_encoding,
|
||||
std::vector<char>* best_lengths) const;
|
||||
|
||||
// Gets the properties for a grapheme string, combining properties for
|
||||
// multiple characters in a meaningful way where possible.
|
||||
|
@ -824,24 +824,24 @@ bool Dict::valid_bigram(const WERD_CHOICE& word1,
|
||||
if (w2start >= w2end) return word2.length() < 3;
|
||||
|
||||
const UNICHARSET& uchset = getUnicharset();
|
||||
GenericVector<UNICHAR_ID> bigram_string;
|
||||
std::vector<UNICHAR_ID> bigram_string;
|
||||
bigram_string.reserve(w1end + w2end + 1);
|
||||
for (int i = w1start; i < w1end; i++) {
|
||||
const GenericVector<UNICHAR_ID>& normed_ids =
|
||||
const auto &normed_ids =
|
||||
getUnicharset().normed_ids(word1.unichar_id(i));
|
||||
if (normed_ids.size() == 1 && uchset.get_isdigit(normed_ids[0]))
|
||||
bigram_string.push_back(question_unichar_id_);
|
||||
else
|
||||
bigram_string += normed_ids;
|
||||
bigram_string.insert(bigram_string.end(), normed_ids.begin(), normed_ids.end());
|
||||
}
|
||||
bigram_string.push_back(UNICHAR_SPACE);
|
||||
for (int i = w2start; i < w2end; i++) {
|
||||
const GenericVector<UNICHAR_ID>& normed_ids =
|
||||
const auto &normed_ids =
|
||||
getUnicharset().normed_ids(word2.unichar_id(i));
|
||||
if (normed_ids.size() == 1 && uchset.get_isdigit(normed_ids[0]))
|
||||
bigram_string.push_back(question_unichar_id_);
|
||||
else
|
||||
bigram_string += normed_ids;
|
||||
bigram_string.insert(bigram_string.end(), normed_ids.begin(), normed_ids.end());
|
||||
}
|
||||
WERD_CHOICE normalized_word(&uchset, bigram_string.size());
|
||||
for (int i = 0; i < bigram_string.size(); ++i) {
|
||||
|
@ -116,7 +116,7 @@ class TESS_API Dict {
|
||||
inline bool compound_marker(UNICHAR_ID unichar_id) {
|
||||
const UNICHARSET& unicharset = getUnicharset();
|
||||
ASSERT_HOST(unicharset.contains_unichar_id(unichar_id));
|
||||
const GenericVector<UNICHAR_ID>& normed_ids =
|
||||
const auto &normed_ids =
|
||||
unicharset.normed_ids(unichar_id);
|
||||
return normed_ids.size() == 1 &&
|
||||
(normed_ids[0] == hyphen_unichar_id_ ||
|
||||
@ -127,7 +127,7 @@ class TESS_API Dict {
|
||||
inline bool is_apostrophe(UNICHAR_ID unichar_id) {
|
||||
const UNICHARSET& unicharset = getUnicharset();
|
||||
ASSERT_HOST(unicharset.contains_unichar_id(unichar_id));
|
||||
const GenericVector<UNICHAR_ID>& normed_ids =
|
||||
const auto &normed_ids =
|
||||
unicharset.normed_ids(unichar_id);
|
||||
return normed_ids.size() == 1 && normed_ids[0] == apostrophe_unichar_id_;
|
||||
}
|
||||
@ -157,7 +157,7 @@ class TESS_API Dict {
|
||||
if (!last_word_on_line_ || first_pos)
|
||||
return false;
|
||||
ASSERT_HOST(unicharset->contains_unichar_id(unichar_id));
|
||||
const GenericVector<UNICHAR_ID>& normed_ids =
|
||||
const auto &normed_ids =
|
||||
unicharset->normed_ids(unichar_id);
|
||||
return normed_ids.size() == 1 && normed_ids[0] == hyphen_unichar_id_;
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ void Dict::go_deeper_dawg_fxn(
|
||||
}
|
||||
int num_unigrams = 0;
|
||||
word->remove_last_unichar_id();
|
||||
GenericVector<UNICHAR_ID> encoding;
|
||||
std::vector<UNICHAR_ID> encoding;
|
||||
const char *ngram_str = getUnicharset().id_to_unichar(orig_uch_id);
|
||||
// Since the string came out of the unicharset, failure is impossible.
|
||||
ASSERT_HOST(getUnicharset().encode_string(ngram_str, true, &encoding, nullptr,
|
||||
|
@ -319,7 +319,7 @@ bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
|
||||
}
|
||||
pixDestroy(&pix);
|
||||
if (debug) {
|
||||
GenericVector<int> labels, coords;
|
||||
std::vector<int> labels, coords;
|
||||
LabelsFromOutputs(*outputs, &labels, &coords);
|
||||
#ifndef GRAPHICS_DISABLED
|
||||
DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_);
|
||||
@ -331,7 +331,7 @@ bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
|
||||
|
||||
// Converts an array of labels to utf-8, whether or not the labels are
|
||||
// augmented with character boundaries.
|
||||
STRING LSTMRecognizer::DecodeLabels(const GenericVector<int>& labels) {
|
||||
STRING LSTMRecognizer::DecodeLabels(const std::vector<int>& labels) {
|
||||
STRING result;
|
||||
int end = 1;
|
||||
for (int start = 0; start < labels.size(); start = end) {
|
||||
@ -349,8 +349,8 @@ STRING LSTMRecognizer::DecodeLabels(const GenericVector<int>& labels) {
|
||||
// Displays the forward results in a window with the characters and
|
||||
// boundaries as determined by the labels and label_coords.
|
||||
void LSTMRecognizer::DisplayForward(const NetworkIO& inputs,
|
||||
const GenericVector<int>& labels,
|
||||
const GenericVector<int>& label_coords,
|
||||
const std::vector<int>& labels,
|
||||
const std::vector<int>& label_coords,
|
||||
const char* window_name,
|
||||
ScrollView** window) {
|
||||
Pix* input_pix = inputs.ToPix();
|
||||
@ -362,8 +362,8 @@ void LSTMRecognizer::DisplayForward(const NetworkIO& inputs,
|
||||
|
||||
// Displays the labels and cuts at the corresponding xcoords.
|
||||
// Size of labels should match xcoords.
|
||||
void LSTMRecognizer::DisplayLSTMOutput(const GenericVector<int>& labels,
|
||||
const GenericVector<int>& xcoords,
|
||||
void LSTMRecognizer::DisplayLSTMOutput(const std::vector<int>& labels,
|
||||
const std::vector<int>& xcoords,
|
||||
int height, ScrollView* window) {
|
||||
int x_scale = network_->XScaleFactor();
|
||||
window->TextAttributes("Arial", height / 4, false, false, false);
|
||||
@ -390,8 +390,8 @@ void LSTMRecognizer::DisplayLSTMOutput(const GenericVector<int>& labels,
|
||||
// Prints debug output detailing the activation path that is implied by the
|
||||
// label_coords.
|
||||
void LSTMRecognizer::DebugActivationPath(const NetworkIO& outputs,
|
||||
const GenericVector<int>& labels,
|
||||
const GenericVector<int>& xcoords) {
|
||||
const std::vector<int>& labels,
|
||||
const std::vector<int>& xcoords) {
|
||||
if (xcoords[0] > 0)
|
||||
DebugActivationRange(outputs, "<null>", null_char_, 0, xcoords[0]);
|
||||
int end = 1;
|
||||
@ -460,8 +460,8 @@ static bool NullIsBest(const NetworkIO& output, float null_thr,
|
||||
// final xcoord for the end of the output.
|
||||
// The conversion method is determined by internal state.
|
||||
void LSTMRecognizer::LabelsFromOutputs(const NetworkIO& outputs,
|
||||
GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords) {
|
||||
std::vector<int>* labels,
|
||||
std::vector<int>* xcoords) {
|
||||
if (SimpleTextOutput()) {
|
||||
LabelsViaSimpleText(outputs, labels, xcoords);
|
||||
} else {
|
||||
@ -472,8 +472,8 @@ void LSTMRecognizer::LabelsFromOutputs(const NetworkIO& outputs,
|
||||
// As LabelsViaCTC except that this function constructs the best path that
|
||||
// contains only legal sequences of subcodes for CJK.
|
||||
void LSTMRecognizer::LabelsViaReEncode(const NetworkIO& output,
|
||||
GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords) {
|
||||
std::vector<int>* labels,
|
||||
std::vector<int>* xcoords) {
|
||||
if (search_ == nullptr) {
|
||||
search_ =
|
||||
new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
|
||||
@ -486,10 +486,10 @@ void LSTMRecognizer::LabelsViaReEncode(const NetworkIO& output,
|
||||
// the simple character model (each position is a char, and the null_char_ is
|
||||
// mainly intended for tail padding.)
|
||||
void LSTMRecognizer::LabelsViaSimpleText(const NetworkIO& output,
|
||||
GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords) {
|
||||
labels->truncate(0);
|
||||
xcoords->truncate(0);
|
||||
std::vector<int>* labels,
|
||||
std::vector<int>* xcoords) {
|
||||
labels->resize(0);
|
||||
xcoords->resize(0);
|
||||
const int width = output.Width();
|
||||
for (int t = 0; t < width; ++t) {
|
||||
float score = 0.0f;
|
||||
@ -504,7 +504,7 @@ void LSTMRecognizer::LabelsViaSimpleText(const NetworkIO& output,
|
||||
|
||||
// Returns a string corresponding to the label starting at start. Sets *end
|
||||
// to the next start and if non-null, *decoded to the unichar id.
|
||||
const char* LSTMRecognizer::DecodeLabel(const GenericVector<int>& labels,
|
||||
const char* LSTMRecognizer::DecodeLabel(const std::vector<int>& labels,
|
||||
int start, int* end, int* decoded) {
|
||||
*end = start + 1;
|
||||
if (IsRecoding()) {
|
||||
|
@ -76,10 +76,10 @@ class TESS_API LSTMRecognizer {
|
||||
bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
|
||||
// Returns a vector of layer ids that can be passed to other layer functions
|
||||
// to access a specific layer.
|
||||
GenericVector<STRING> EnumerateLayers() const {
|
||||
std::vector<STRING> EnumerateLayers() const {
|
||||
ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
|
||||
auto* series = static_cast<Series*>(network_);
|
||||
GenericVector<STRING> layers;
|
||||
std::vector<STRING> layers;
|
||||
series->EnumerateLayers(nullptr, &layers);
|
||||
return layers;
|
||||
}
|
||||
@ -106,7 +106,7 @@ class TESS_API LSTMRecognizer {
|
||||
ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
|
||||
learning_rate_ *= factor;
|
||||
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
|
||||
GenericVector<STRING> layers = EnumerateLayers();
|
||||
std::vector<STRING> layers = EnumerateLayers();
|
||||
for (int i = 0; i < layers.size(); ++i) {
|
||||
ScaleLayerLearningRate(layers[i], factor);
|
||||
}
|
||||
@ -198,19 +198,19 @@ class TESS_API LSTMRecognizer {
|
||||
|
||||
// Converts an array of labels to utf-8, whether or not the labels are
|
||||
// augmented with character boundaries.
|
||||
STRING DecodeLabels(const GenericVector<int>& labels);
|
||||
STRING DecodeLabels(const std::vector<int>& labels);
|
||||
|
||||
// Displays the forward results in a window with the characters and
|
||||
// boundaries as determined by the labels and label_coords.
|
||||
void DisplayForward(const NetworkIO& inputs, const GenericVector<int>& labels,
|
||||
const GenericVector<int>& label_coords,
|
||||
void DisplayForward(const NetworkIO& inputs, const std::vector<int>& labels,
|
||||
const std::vector<int>& label_coords,
|
||||
const char* window_name, ScrollView** window);
|
||||
// Converts the network output to a sequence of labels. Outputs labels, scores
|
||||
// and start xcoords of each char, and each null_char_, with an additional
|
||||
// final xcoord for the end of the output.
|
||||
// The conversion method is determined by internal state.
|
||||
void LabelsFromOutputs(const NetworkIO& outputs, GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords);
|
||||
void LabelsFromOutputs(const NetworkIO& outputs, std::vector<int>* labels,
|
||||
std::vector<int>* xcoords);
|
||||
|
||||
protected:
|
||||
// Sets the random seed from the sample_iteration_;
|
||||
@ -222,15 +222,15 @@ class TESS_API LSTMRecognizer {
|
||||
|
||||
// Displays the labels and cuts at the corresponding xcoords.
|
||||
// Size of labels should match xcoords.
|
||||
void DisplayLSTMOutput(const GenericVector<int>& labels,
|
||||
const GenericVector<int>& xcoords, int height,
|
||||
void DisplayLSTMOutput(const std::vector<int>& labels,
|
||||
const std::vector<int>& xcoords, int height,
|
||||
ScrollView* window);
|
||||
|
||||
// Prints debug output detailing the activation path that is implied by the
|
||||
// xcoords.
|
||||
void DebugActivationPath(const NetworkIO& outputs,
|
||||
const GenericVector<int>& labels,
|
||||
const GenericVector<int>& xcoords);
|
||||
const std::vector<int>& labels,
|
||||
const std::vector<int>& xcoords);
|
||||
|
||||
// Prints debug output detailing activations and 2nd choice over a range
|
||||
// of positions.
|
||||
@ -239,17 +239,17 @@ class TESS_API LSTMRecognizer {
|
||||
|
||||
// As LabelsViaCTC except that this function constructs the best path that
|
||||
// contains only legal sequences of subcodes for recoder_.
|
||||
void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords);
|
||||
void LabelsViaReEncode(const NetworkIO& output, std::vector<int>* labels,
|
||||
std::vector<int>* xcoords);
|
||||
// Converts the network output to a sequence of labels, with scores, using
|
||||
// the simple character model (each position is a char, and the null_char_ is
|
||||
// mainly intended for tail padding.)
|
||||
void LabelsViaSimpleText(const NetworkIO& output, GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords);
|
||||
void LabelsViaSimpleText(const NetworkIO& output, std::vector<int>* labels,
|
||||
std::vector<int>* xcoords);
|
||||
|
||||
// Returns a string corresponding to the label starting at start. Sets *end
|
||||
// to the next start and if non-null, *decoded to the unichar id.
|
||||
const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
|
||||
const char* DecodeLabel(const std::vector<int>& labels, int start, int* end,
|
||||
int* decoded);
|
||||
|
||||
// Returns a string corresponding to a given single label id, falling back to
|
||||
|
@ -137,7 +137,7 @@ void Plumbing::DebugWeights() {
|
||||
|
||||
// Returns a set of strings representing the layer-ids of all layers below.
|
||||
void Plumbing::EnumerateLayers(const STRING* prefix,
|
||||
GenericVector<STRING>* layers) const {
|
||||
std::vector<STRING>* layers) const {
|
||||
for (int i = 0; i < stack_.size(); ++i) {
|
||||
STRING layer_name;
|
||||
if (prefix) layer_name = *prefix;
|
||||
|
@ -98,7 +98,7 @@ class Plumbing : public Network {
|
||||
// Returns a set of strings representing the layer-ids of all layers below.
|
||||
TESS_API
|
||||
void EnumerateLayers(const STRING* prefix,
|
||||
GenericVector<STRING>* layers) const;
|
||||
std::vector<STRING>* layers) const;
|
||||
// Returns a pointer to the network layer corresponding to the given id.
|
||||
TESS_API
|
||||
Network* GetLayer(const char* id) const;
|
||||
|
@ -191,9 +191,9 @@ void RecodeBeamSearch::calculateCharBoundaries(std::vector<int>* starts,
|
||||
|
||||
// Returns the best path as labels/scores/xcoords similar to simple CTC.
|
||||
void RecodeBeamSearch::ExtractBestPathAsLabels(
|
||||
GenericVector<int>* labels, GenericVector<int>* xcoords) const {
|
||||
labels->truncate(0);
|
||||
xcoords->truncate(0);
|
||||
std::vector<int>* labels, std::vector<int>* xcoords) const {
|
||||
labels->resize(0);
|
||||
xcoords->resize(0);
|
||||
GenericVector<const RecodeNode*> best_nodes;
|
||||
ExtractBestPaths(&best_nodes, nullptr);
|
||||
// Now just run CTC on the best nodes.
|
||||
@ -214,9 +214,9 @@ void RecodeBeamSearch::ExtractBestPathAsLabels(
|
||||
// Returns the best path as unichar-ids/certs/ratings/xcoords skipping
|
||||
// duplicates, nulls and intermediate parts.
|
||||
void RecodeBeamSearch::ExtractBestPathAsUnicharIds(
|
||||
bool debug, const UNICHARSET* unicharset, GenericVector<int>* unichar_ids,
|
||||
GenericVector<float>* certs, GenericVector<float>* ratings,
|
||||
GenericVector<int>* xcoords) const {
|
||||
bool debug, const UNICHARSET* unicharset, std::vector<int>* unichar_ids,
|
||||
std::vector<float>* certs, std::vector<float>* ratings,
|
||||
std::vector<int>* xcoords) const {
|
||||
GenericVector<const RecodeNode*> best_nodes;
|
||||
ExtractBestPaths(&best_nodes, nullptr);
|
||||
ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords);
|
||||
@ -234,10 +234,10 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box,
|
||||
PointerVector<WERD_RES>* words,
|
||||
int lstm_choice_mode) {
|
||||
words->truncate(0);
|
||||
GenericVector<int> unichar_ids;
|
||||
GenericVector<float> certs;
|
||||
GenericVector<float> ratings;
|
||||
GenericVector<int> xcoords;
|
||||
std::vector<int> unichar_ids;
|
||||
std::vector<float> certs;
|
||||
std::vector<float> ratings;
|
||||
std::vector<int> xcoords;
|
||||
GenericVector<const RecodeNode*> best_nodes;
|
||||
GenericVector<const RecodeNode*> second_nodes;
|
||||
character_boundaries_.clear();
|
||||
@ -406,10 +406,10 @@ void RecodeBeamSearch::extractSymbolChoices(const UNICHARSET* unicharset) {
|
||||
}
|
||||
character_boundaries_[0] = 0;
|
||||
for (int j = 1; j < character_boundaries_.size(); ++j) {
|
||||
GenericVector<int> unichar_ids;
|
||||
GenericVector<float> certs;
|
||||
GenericVector<float> ratings;
|
||||
GenericVector<int> xcoords;
|
||||
std::vector<int> unichar_ids;
|
||||
std::vector<float> certs;
|
||||
std::vector<float> ratings;
|
||||
std::vector<int> xcoords;
|
||||
int backpath = character_boundaries_[j] - character_boundaries_[j - 1];
|
||||
heaps = currentBeam->get(character_boundaries_[j] - 1)->beams_->heap();
|
||||
GenericVector<const RecodeNode*> best_nodes;
|
||||
@ -544,13 +544,13 @@ void RecodeBeamSearch::DebugBeamPos(const UNICHARSET& unicharset,
|
||||
/* static */
|
||||
void RecodeBeamSearch::ExtractPathAsUnicharIds(
|
||||
const GenericVector<const RecodeNode*>& best_nodes,
|
||||
GenericVector<int>* unichar_ids, GenericVector<float>* certs,
|
||||
GenericVector<float>* ratings, GenericVector<int>* xcoords,
|
||||
std::vector<int>* unichar_ids, std::vector<float>* certs,
|
||||
std::vector<float>* ratings, std::vector<int>* xcoords,
|
||||
std::vector<int>* character_boundaries) {
|
||||
unichar_ids->truncate(0);
|
||||
certs->truncate(0);
|
||||
ratings->truncate(0);
|
||||
xcoords->truncate(0);
|
||||
unichar_ids->resize(0);
|
||||
certs->resize(0);
|
||||
ratings->resize(0);
|
||||
xcoords->resize(0);
|
||||
std::vector<int> starts;
|
||||
std::vector<int> ends;
|
||||
// Backtrack extracting only valid, non-duplicate unichar-ids.
|
||||
@ -609,7 +609,7 @@ WERD_RES* RecodeBeamSearch::InitializeWord(bool leading_space,
|
||||
const TBOX& line_box, int word_start,
|
||||
int word_end, float space_certainty,
|
||||
const UNICHARSET* unicharset,
|
||||
const GenericVector<int>& xcoords,
|
||||
const std::vector<int>& xcoords,
|
||||
float scale_factor) {
|
||||
// Make a fake blob for each non-zero label.
|
||||
C_BLOB_LIST blobs;
|
||||
@ -641,7 +641,7 @@ WERD_RES* RecodeBeamSearch::InitializeWord(bool leading_space,
|
||||
// is one of the top_n.
|
||||
void RecodeBeamSearch::ComputeTopN(const float* outputs, int num_outputs,
|
||||
int top_n) {
|
||||
top_n_flags_.init_to_size(num_outputs, TN_ALSO_RAN);
|
||||
top_n_flags_.resize(num_outputs, TN_ALSO_RAN);
|
||||
top_code_ = -1;
|
||||
second_code_ = -1;
|
||||
top_heap_.clear();
|
||||
@ -671,7 +671,7 @@ void RecodeBeamSearch::ComputeTopN(const float* outputs, int num_outputs,
|
||||
void RecodeBeamSearch::ComputeSecTopN(std::unordered_set<int>* exList,
|
||||
const float* outputs, int num_outputs,
|
||||
int top_n) {
|
||||
top_n_flags_.init_to_size(num_outputs, TN_ALSO_RAN);
|
||||
top_n_flags_.resize(num_outputs, TN_ALSO_RAN);
|
||||
top_code_ = -1;
|
||||
second_code_ = -1;
|
||||
top_heap_.clear();
|
||||
@ -1295,9 +1295,9 @@ void RecodeBeamSearch::DebugPath(
|
||||
// Helper prints debug information on the given unichar path.
|
||||
void RecodeBeamSearch::DebugUnicharPath(
|
||||
const UNICHARSET* unicharset, const GenericVector<const RecodeNode*>& path,
|
||||
const GenericVector<int>& unichar_ids, const GenericVector<float>& certs,
|
||||
const GenericVector<float>& ratings,
|
||||
const GenericVector<int>& xcoords) const {
|
||||
const std::vector<int>& unichar_ids, const std::vector<float>& certs,
|
||||
const std::vector<float>& ratings,
|
||||
const std::vector<int>& xcoords) const {
|
||||
int num_ids = unichar_ids.size();
|
||||
double total_rating = 0.0;
|
||||
for (int c = 0; c < num_ids; ++c) {
|
||||
|
@ -198,15 +198,15 @@ class TESS_API RecodeBeamSearch {
|
||||
int lstm_choice_mode = 0);
|
||||
|
||||
// Returns the best path as labels/scores/xcoords similar to simple CTC.
|
||||
void ExtractBestPathAsLabels(GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords) const;
|
||||
void ExtractBestPathAsLabels(std::vector<int>* labels,
|
||||
std::vector<int>* xcoords) const;
|
||||
// Returns the best path as unichar-ids/certs/ratings/xcoords skipping
|
||||
// duplicates, nulls and intermediate parts.
|
||||
void ExtractBestPathAsUnicharIds(bool debug, const UNICHARSET* unicharset,
|
||||
GenericVector<int>* unichar_ids,
|
||||
GenericVector<float>* certs,
|
||||
GenericVector<float>* ratings,
|
||||
GenericVector<int>* xcoords) const;
|
||||
std::vector<int>* unichar_ids,
|
||||
std::vector<float>* certs,
|
||||
std::vector<float>* ratings,
|
||||
std::vector<int>* xcoords) const;
|
||||
|
||||
// Returns the best path as a set of WERD_RES.
|
||||
void ExtractBestPathAsWords(const TBOX& line_box, float scale_factor,
|
||||
@ -310,8 +310,8 @@ class TESS_API RecodeBeamSearch {
|
||||
// duplicates, nulls and intermediate parts.
|
||||
static void ExtractPathAsUnicharIds(
|
||||
const GenericVector<const RecodeNode*>& best_nodes,
|
||||
GenericVector<int>* unichar_ids, GenericVector<float>* certs,
|
||||
GenericVector<float>* ratings, GenericVector<int>* xcoords,
|
||||
std::vector<int>* unichar_ids, std::vector<float>* certs,
|
||||
std::vector<float>* ratings, std::vector<int>* xcoords,
|
||||
std::vector<int>* character_boundaries = nullptr);
|
||||
|
||||
// Sets up a word with the ratings matrix and fake blobs with boxes in the
|
||||
@ -319,7 +319,7 @@ class TESS_API RecodeBeamSearch {
|
||||
WERD_RES* InitializeWord(bool leading_space, const TBOX& line_box,
|
||||
int word_start, int word_end, float space_certainty,
|
||||
const UNICHARSET* unicharset,
|
||||
const GenericVector<int>& xcoords,
|
||||
const std::vector<int>& xcoords,
|
||||
float scale_factor);
|
||||
|
||||
// Fills top_n_flags_ with bools that are true iff the corresponding output
|
||||
@ -415,10 +415,10 @@ class TESS_API RecodeBeamSearch {
|
||||
// Helper prints debug information on the given unichar path.
|
||||
void DebugUnicharPath(const UNICHARSET* unicharset,
|
||||
const GenericVector<const RecodeNode*>& path,
|
||||
const GenericVector<int>& unichar_ids,
|
||||
const GenericVector<float>& certs,
|
||||
const GenericVector<float>& ratings,
|
||||
const GenericVector<int>& xcoords) const;
|
||||
const std::vector<int>& unichar_ids,
|
||||
const std::vector<float>& certs,
|
||||
const std::vector<float>& ratings,
|
||||
const std::vector<int>& xcoords) const;
|
||||
|
||||
static const int kBeamWidths[RecodedCharID::kMaxCodeLen + 1];
|
||||
|
||||
@ -432,7 +432,7 @@ class TESS_API RecodeBeamSearch {
|
||||
int beam_size_;
|
||||
// A flag to indicate which outputs are the top-n choices. Current timestep
|
||||
// only.
|
||||
GenericVector<TopNState> top_n_flags_;
|
||||
std::vector<TopNState> top_n_flags_;
|
||||
// A record of the highest and second scoring codes.
|
||||
int top_code_;
|
||||
int second_code_;
|
||||
|
@ -15,19 +15,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
#include "ctc.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cfloat> // for FLT_MAX
|
||||
#include <memory>
|
||||
#include "ctc.h"
|
||||
|
||||
#include "genericvector.h"
|
||||
#include "matrix.h"
|
||||
#include "networkio.h"
|
||||
|
||||
#include "network.h"
|
||||
#include "scrollview.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cfloat> // for FLT_MAX
|
||||
#include <memory>
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Magic constants that keep CTC stable.
|
||||
@ -51,7 +51,7 @@ const double CTC::kMinTotalFinalProb_ = 1e-6;
|
||||
// On return targets is filled with the computed targets.
|
||||
// Returns false if there is insufficient time for the labels.
|
||||
/* static */
|
||||
bool CTC::ComputeCTCTargets(const GenericVector<int>& labels, int null_char,
|
||||
bool CTC::ComputeCTCTargets(const std::vector<int>& labels, int null_char,
|
||||
const GENERIC_2D_ARRAY<float>& outputs,
|
||||
NetworkIO* targets) {
|
||||
std::unique_ptr<CTC> ctc(new CTC(labels, null_char, outputs));
|
||||
@ -80,7 +80,7 @@ bool CTC::ComputeCTCTargets(const GenericVector<int>& labels, int null_char,
|
||||
return true;
|
||||
}
|
||||
|
||||
CTC::CTC(const GenericVector<int>& labels, int null_char,
|
||||
CTC::CTC(const std::vector<int>& labels, int null_char,
|
||||
const GENERIC_2D_ARRAY<float>& outputs)
|
||||
: labels_(labels), outputs_(outputs), null_char_(null_char) {
|
||||
num_timesteps_ = outputs.dim1();
|
||||
@ -91,8 +91,8 @@ CTC::CTC(const GenericVector<int>& labels, int null_char,
|
||||
// Computes vectors of min and max label index for each timestep, based on
|
||||
// whether skippability of nulls makes it possible to complete a valid path.
|
||||
bool CTC::ComputeLabelLimits() {
|
||||
min_labels_.init_to_size(num_timesteps_, 0);
|
||||
max_labels_.init_to_size(num_timesteps_, 0);
|
||||
min_labels_.resize(num_timesteps_, 0);
|
||||
max_labels_.resize(num_timesteps_, 0);
|
||||
int min_u = num_labels_ - 1;
|
||||
if (labels_[min_u] == null_char_) --min_u;
|
||||
for (int t = num_timesteps_ - 1; t >= 0; --t) {
|
||||
@ -125,8 +125,8 @@ bool CTC::ComputeLabelLimits() {
|
||||
void CTC::ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const {
|
||||
// Initialize all targets to zero.
|
||||
targets->Resize(num_timesteps_, num_classes_, 0.0f);
|
||||
GenericVector<float> half_widths;
|
||||
GenericVector<int> means;
|
||||
std::vector<float> half_widths;
|
||||
std::vector<int> means;
|
||||
ComputeWidthsAndMeans(&half_widths, &means);
|
||||
for (int l = 0; l < num_labels_; ++l) {
|
||||
int label = labels_[l];
|
||||
@ -166,8 +166,8 @@ void CTC::ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const {
|
||||
|
||||
// Computes mean positions and half widths of the simple targets by spreading
|
||||
// the labels evenly over the available timesteps.
|
||||
void CTC::ComputeWidthsAndMeans(GenericVector<float>* half_widths,
|
||||
GenericVector<int>* means) const {
|
||||
void CTC::ComputeWidthsAndMeans(std::vector<float>* half_widths,
|
||||
std::vector<int>* means) const {
|
||||
// Count the number of labels of each type, in regexp terms, counts plus
|
||||
// (non-null or necessary null, which must occur at least once) and star
|
||||
// (optional null).
|
||||
|
@ -50,7 +50,7 @@ class TESS_COMMON_TRAINING_API CTC {
|
||||
// normalized with NormalizeProbs.
|
||||
// On return targets is filled with the computed targets.
|
||||
// Returns false if there is insufficient time for the labels.
|
||||
static bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
|
||||
static bool ComputeCTCTargets(const std::vector<int>& truth_labels,
|
||||
int null_char,
|
||||
const GENERIC_2D_ARRAY<float>& outputs,
|
||||
NetworkIO* targets);
|
||||
@ -58,7 +58,7 @@ class TESS_COMMON_TRAINING_API CTC {
|
||||
private:
|
||||
// Constructor is private as the instance only holds information specific to
|
||||
// the current labels, outputs etc, and is built by the static function.
|
||||
CTC(const GenericVector<int>& labels, int null_char,
|
||||
CTC(const std::vector<int>& labels, int null_char,
|
||||
const GENERIC_2D_ARRAY<float>& outputs);
|
||||
|
||||
// Computes vectors of min and max label index for each timestep, based on
|
||||
@ -69,8 +69,8 @@ class TESS_COMMON_TRAINING_API CTC {
|
||||
void ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const;
|
||||
// Computes mean positions and half widths of the simple targets by spreading
|
||||
// the labels even over the available timesteps.
|
||||
void ComputeWidthsAndMeans(GenericVector<float>* half_widths,
|
||||
GenericVector<int>* means) const;
|
||||
void ComputeWidthsAndMeans(std::vector<float>* half_widths,
|
||||
std::vector<int>* means) const;
|
||||
// Calculates and returns a suitable fraction of the simple targets to add
|
||||
// to the network outputs.
|
||||
float CalculateBiasFraction();
|
||||
@ -110,7 +110,7 @@ class TESS_COMMON_TRAINING_API CTC {
|
||||
static const double kMinTotalFinalProb_;
|
||||
|
||||
// The truth label indices that are to be matched to outputs_.
|
||||
const GenericVector<int>& labels_;
|
||||
const std::vector<int>& labels_;
|
||||
// The network outputs.
|
||||
GENERIC_2D_ARRAY<float> outputs_;
|
||||
// The null or "blank" label.
|
||||
@ -122,8 +122,8 @@ class TESS_COMMON_TRAINING_API CTC {
|
||||
// Number of labels in labels_.
|
||||
int num_labels_;
|
||||
// Min and max valid label indices for each timestep.
|
||||
GenericVector<int> min_labels_;
|
||||
GenericVector<int> max_labels_;
|
||||
std::vector<int> min_labels_;
|
||||
std::vector<int> max_labels_;
|
||||
};
|
||||
|
||||
} // namespace tesseract
|
||||
|
@ -107,8 +107,8 @@ STRING LSTMTester::RunEvalSync(int iteration, const double* training_errors,
|
||||
++error_count;
|
||||
if (verbosity > 1 || (verbosity > 0 && result != PERFECT)) {
|
||||
tprintf("Truth:%s\n", trainingdata->transcription().c_str());
|
||||
GenericVector<int> ocr_labels;
|
||||
GenericVector<int> xcoords;
|
||||
std::vector<int> ocr_labels;
|
||||
std::vector<int> xcoords;
|
||||
trainer.LabelsFromOutputs(fwd_outputs, &ocr_labels, &xcoords);
|
||||
STRING ocr_text = trainer.DecodeLabels(ocr_labels);
|
||||
tprintf("OCR :%s\n", ocr_text.c_str());
|
||||
|
@ -201,7 +201,7 @@ void LSTMTrainer::InitIterations() {
|
||||
for (int i = 0; i < ET_COUNT; ++i) {
|
||||
best_error_rates_[i] = 100.0;
|
||||
worst_error_rates_[i] = 0.0;
|
||||
error_buffers_[i].init_to_size(kRollingBufferSize_, 0.0);
|
||||
error_buffers_[i].resize(kRollingBufferSize_, 0.0);
|
||||
error_rates_[i] = 100.0;
|
||||
}
|
||||
error_rate_of_last_saved_best_ = kMinStartedErrorRate;
|
||||
@ -222,7 +222,7 @@ Trainability LSTMTrainer::GridSearchDictParams(
|
||||
return result;
|
||||
|
||||
// Encode/decode the truth to get the normalization.
|
||||
GenericVector<int> truth_labels, ocr_labels, xcoords;
|
||||
std::vector<int> truth_labels, ocr_labels, xcoords;
|
||||
ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels));
|
||||
// NO-dict error.
|
||||
RecodeBeamSearch base_search(recoder_, null_char_, SimpleTextOutput(), nullptr);
|
||||
@ -406,7 +406,7 @@ bool LSTMTrainer::Serialize(SerializeAmount serialize_amount,
|
||||
if (!fp->Serialize(&perfect_delay_)) return false;
|
||||
if (!fp->Serialize(&last_perfect_training_iteration_)) return false;
|
||||
for (const auto & error_buffer : error_buffers_) {
|
||||
if (!error_buffer.Serialize(fp)) return false;
|
||||
if (!fp->Serialize(error_buffer)) return false;
|
||||
}
|
||||
if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) return false;
|
||||
if (!fp->Serialize(&training_stage_)) return false;
|
||||
@ -428,8 +428,8 @@ bool LSTMTrainer::Serialize(SerializeAmount serialize_amount,
|
||||
if (sub_trainer_ != nullptr && !SaveTrainingDump(LIGHT, sub_trainer_, &sub_data))
|
||||
return false;
|
||||
if (!fp->Serialize(sub_data)) return false;
|
||||
if (!best_error_history_.Serialize(fp)) return false;
|
||||
if (!best_error_iterations_.Serialize(fp)) return false;
|
||||
if (!fp->Serialize(best_error_history_)) return false;
|
||||
if (!fp->Serialize(best_error_iterations_)) return false;
|
||||
return fp->Serialize(&improvement_steps_);
|
||||
}
|
||||
|
||||
@ -450,7 +450,7 @@ bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
|
||||
if (!fp->DeSerialize(&perfect_delay_)) return false;
|
||||
if (!fp->DeSerialize(&last_perfect_training_iteration_)) return false;
|
||||
for (auto & error_buffer : error_buffers_) {
|
||||
if (!error_buffer.DeSerialize(fp)) return false;
|
||||
if (!fp->DeSerialize(error_buffer)) return false;
|
||||
}
|
||||
if (!fp->DeSerialize(&error_rates_[0], countof(error_rates_))) return false;
|
||||
if (!fp->DeSerialize(&training_stage_)) return false;
|
||||
@ -476,8 +476,8 @@ bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
|
||||
sub_trainer_ = new LSTMTrainer();
|
||||
if (!ReadTrainingDump(sub_data, sub_trainer_)) return false;
|
||||
}
|
||||
if (!best_error_history_.DeSerialize(fp)) return false;
|
||||
if (!best_error_iterations_.DeSerialize(fp)) return false;
|
||||
if (!fp->DeSerialize(best_error_history_)) return false;
|
||||
if (!fp->DeSerialize(best_error_iterations_)) return false;
|
||||
return fp->DeSerialize(&improvement_steps_);
|
||||
}
|
||||
|
||||
@ -583,15 +583,15 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
|
||||
LR_SAME, // Learning rate will stay the same.
|
||||
LR_COUNT // Size of arrays.
|
||||
};
|
||||
GenericVector<STRING> layers = EnumerateLayers();
|
||||
std::vector<STRING> layers = EnumerateLayers();
|
||||
int num_layers = layers.size();
|
||||
GenericVector<int> num_weights;
|
||||
num_weights.init_to_size(num_layers, 0);
|
||||
GenericVector<double> bad_sums[LR_COUNT];
|
||||
GenericVector<double> ok_sums[LR_COUNT];
|
||||
std::vector<int> num_weights;
|
||||
num_weights.resize(num_layers, 0);
|
||||
std::vector<double> bad_sums[LR_COUNT];
|
||||
std::vector<double> ok_sums[LR_COUNT];
|
||||
for (int i = 0; i < LR_COUNT; ++i) {
|
||||
bad_sums[i].init_to_size(num_layers, 0.0);
|
||||
ok_sums[i].init_to_size(num_layers, 0.0);
|
||||
bad_sums[i].resize(num_layers, 0.0);
|
||||
ok_sums[i].resize(num_layers, 0.0);
|
||||
}
|
||||
double momentum_factor = 1.0 / (1.0 - momentum_);
|
||||
std::vector<char> orig_trainer;
|
||||
@ -687,14 +687,14 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
|
||||
/* static */
|
||||
bool LSTMTrainer::EncodeString(const STRING& str, const UNICHARSET& unicharset,
|
||||
const UnicharCompress* recoder, bool simple_text,
|
||||
int null_char, GenericVector<int>* labels) {
|
||||
int null_char, std::vector<int>* labels) {
|
||||
if (str.c_str() == nullptr || str.length() <= 0) {
|
||||
tprintf("Empty truth string!\n");
|
||||
return false;
|
||||
}
|
||||
int err_index;
|
||||
GenericVector<int> internal_labels;
|
||||
labels->truncate(0);
|
||||
std::vector<int> internal_labels;
|
||||
labels->resize(0);
|
||||
if (!simple_text) labels->push_back(null_char);
|
||||
std::string cleaned = unicharset.CleanupString(str.c_str());
|
||||
if (unicharset.encode_string(cleaned.c_str(), true, &internal_labels, nullptr,
|
||||
@ -775,7 +775,7 @@ Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata,
|
||||
// Ensure repeatability of random elements even across checkpoints.
|
||||
bool debug = debug_interval_ > 0 &&
|
||||
training_iteration() % debug_interval_ == 0;
|
||||
GenericVector<int> truth_labels;
|
||||
std::vector<int> truth_labels;
|
||||
if (!EncodeString(trainingdata->transcription(), &truth_labels)) {
|
||||
tprintf("Can't encode transcription: '%s' in language '%s'\n",
|
||||
trainingdata->transcription().c_str(),
|
||||
@ -796,7 +796,7 @@ Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata,
|
||||
if (truth_labels[c] != UNICHAR_SPACE && truth_labels[c] != null_char_)
|
||||
++truth_labels[c];
|
||||
}
|
||||
truth_labels.reverse();
|
||||
std::reverse(truth_labels.begin(), truth_labels.end());
|
||||
}
|
||||
}
|
||||
int w = 0;
|
||||
@ -832,8 +832,8 @@ Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata,
|
||||
tprintf("Logistic outputs not implemented yet!\n");
|
||||
return UNENCODABLE;
|
||||
}
|
||||
GenericVector<int> ocr_labels;
|
||||
GenericVector<int> xcoords;
|
||||
std::vector<int> ocr_labels;
|
||||
std::vector<int> xcoords;
|
||||
LabelsFromOutputs(*fwd_outputs, &ocr_labels, &xcoords);
|
||||
// CTC does not produce correct target labels to begin with.
|
||||
if (loss_type != LT_CTC) {
|
||||
@ -1003,7 +1003,7 @@ void LSTMTrainer::EmptyConstructor() {
|
||||
bool LSTMTrainer::DebugLSTMTraining(const NetworkIO& inputs,
|
||||
const ImageData& trainingdata,
|
||||
const NetworkIO& fwd_outputs,
|
||||
const GenericVector<int>& truth_labels,
|
||||
const std::vector<int>& truth_labels,
|
||||
const NetworkIO& outputs) {
|
||||
const STRING& truth_text = DecodeLabels(truth_labels);
|
||||
if (truth_text.c_str() == nullptr || truth_text.length() <= 0) {
|
||||
@ -1012,8 +1012,8 @@ bool LSTMTrainer::DebugLSTMTraining(const NetworkIO& inputs,
|
||||
}
|
||||
if (debug_interval_ != 0) {
|
||||
// Get class labels, xcoords and string.
|
||||
GenericVector<int> labels;
|
||||
GenericVector<int> xcoords;
|
||||
std::vector<int> labels;
|
||||
std::vector<int> xcoords;
|
||||
LabelsFromOutputs(outputs, &labels, &xcoords);
|
||||
STRING text = DecodeLabels(labels);
|
||||
tprintf("Iteration %d: GROUND TRUTH : %s\n",
|
||||
@ -1079,7 +1079,7 @@ void LSTMTrainer::DisplayTargets(const NetworkIO& targets,
|
||||
// Builds a no-compromises target where the first positions should be the
|
||||
// truth labels and the rest is padded with the null_char_.
|
||||
bool LSTMTrainer::ComputeTextTargets(const NetworkIO& outputs,
|
||||
const GenericVector<int>& truth_labels,
|
||||
const std::vector<int>& truth_labels,
|
||||
NetworkIO* targets) {
|
||||
if (truth_labels.size() > targets->Width()) {
|
||||
tprintf("Error: transcription %s too long to fit into target of width %d\n",
|
||||
@ -1098,7 +1098,7 @@ bool LSTMTrainer::ComputeTextTargets(const NetworkIO& outputs,
|
||||
// Builds a target using standard CTC. truth_labels should be pre-padded with
|
||||
// nulls wherever desired. They don't have to be between all labels.
|
||||
// outputs is input-output, as it gets clipped to minimum probability.
|
||||
bool LSTMTrainer::ComputeCTCTargets(const GenericVector<int>& truth_labels,
|
||||
bool LSTMTrainer::ComputeCTCTargets(const std::vector<int>& truth_labels,
|
||||
NetworkIO* outputs, NetworkIO* targets) {
|
||||
// Bottom-clip outputs to a minimum probability.
|
||||
CTC::NormalizeProbs(outputs);
|
||||
@ -1166,10 +1166,10 @@ double LSTMTrainer::ComputeWinnerError(const NetworkIO& deltas) {
|
||||
}
|
||||
|
||||
// Computes a very simple bag of chars char error rate.
|
||||
double LSTMTrainer::ComputeCharError(const GenericVector<int>& truth_str,
|
||||
const GenericVector<int>& ocr_str) {
|
||||
GenericVector<int> label_counts;
|
||||
label_counts.init_to_size(NumOutputs(), 0);
|
||||
double LSTMTrainer::ComputeCharError(const std::vector<int>& truth_str,
|
||||
const std::vector<int>& ocr_str) {
|
||||
std::vector<int> label_counts;
|
||||
label_counts.resize(NumOutputs(), 0);
|
||||
int truth_size = 0;
|
||||
for (int i = 0; i < truth_str.size(); ++i) {
|
||||
if (truth_str[i] != null_char_) {
|
||||
@ -1231,7 +1231,7 @@ void LSTMTrainer::UpdateErrorBuffer(double new_error, ErrorTypes type) {
|
||||
int index = training_iteration_ % kRollingBufferSize_;
|
||||
error_buffers_[type][index] = new_error;
|
||||
// Compute the mean error.
|
||||
int mean_count = std::min(training_iteration_ + 1, error_buffers_[type].size());
|
||||
int mean_count = std::min<int>(training_iteration_ + 1, error_buffers_[type].size());
|
||||
double buffer_sum = 0.0;
|
||||
for (int i = 0; i < mean_count; ++i) buffer_sum += error_buffers_[type][i];
|
||||
double mean = buffer_sum / mean_count;
|
||||
|
@ -183,8 +183,8 @@ class TESS_UNICHARSET_TRAINING_API LSTMTrainer : public LSTMRecognizer {
|
||||
// returns a log message to indicate progress. Returns false if nothing
|
||||
// interesting happened.
|
||||
bool MaintainCheckpointsSpecific(int iteration,
|
||||
const GenericVector<char>* train_model,
|
||||
const GenericVector<char>* rec_model,
|
||||
const std::vector<char>* train_model,
|
||||
const std::vector<char>* rec_model,
|
||||
TestCallback tester, STRING* log_msg);
|
||||
// Builds a string containing a progress message with current error rates.
|
||||
void PrepareLogMsg(STRING* log_msg) const;
|
||||
@ -232,14 +232,14 @@ class TESS_UNICHARSET_TRAINING_API LSTMTrainer : public LSTMRecognizer {
|
||||
|
||||
// Converts the string to integer class labels, with appropriate null_char_s
|
||||
// in between if not in SimpleTextOutput mode. Returns false on failure.
|
||||
bool EncodeString(const STRING& str, GenericVector<int>* labels) const {
|
||||
bool EncodeString(const STRING& str, std::vector<int>* labels) const {
|
||||
return EncodeString(str, GetUnicharset(), IsRecoding() ? &recoder_ : nullptr,
|
||||
SimpleTextOutput(), null_char_, labels);
|
||||
}
|
||||
// Static version operates on supplied unicharset, encoder, simple_text.
|
||||
static bool EncodeString(const STRING& str, const UNICHARSET& unicharset,
|
||||
const UnicharCompress* recoder, bool simple_text,
|
||||
int null_char, GenericVector<int>* labels);
|
||||
int null_char, std::vector<int>* labels);
|
||||
|
||||
// Performs forward-backward on the given trainingdata.
|
||||
// Returns the sample that was used or nullptr if the next sample was deemed
|
||||
@ -327,7 +327,7 @@ class TESS_UNICHARSET_TRAINING_API LSTMTrainer : public LSTMRecognizer {
|
||||
bool DebugLSTMTraining(const NetworkIO& inputs,
|
||||
const ImageData& trainingdata,
|
||||
const NetworkIO& fwd_outputs,
|
||||
const GenericVector<int>& truth_labels,
|
||||
const std::vector<int>& truth_labels,
|
||||
const NetworkIO& outputs);
|
||||
// Displays the network targets as line a line graph.
|
||||
void DisplayTargets(const NetworkIO& targets, const char* window_name,
|
||||
@ -336,13 +336,13 @@ class TESS_UNICHARSET_TRAINING_API LSTMTrainer : public LSTMRecognizer {
|
||||
// Builds a no-compromises target where the first positions should be the
|
||||
// truth labels and the rest is padded with the null_char_.
|
||||
bool ComputeTextTargets(const NetworkIO& outputs,
|
||||
const GenericVector<int>& truth_labels,
|
||||
const std::vector<int>& truth_labels,
|
||||
NetworkIO* targets);
|
||||
|
||||
// Builds a target using standard CTC. truth_labels should be pre-padded with
|
||||
// nulls wherever desired. They don't have to be between all labels.
|
||||
// outputs is input-output, as it gets clipped to minimum probability.
|
||||
bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
|
||||
bool ComputeCTCTargets(const std::vector<int>& truth_labels,
|
||||
NetworkIO* outputs, NetworkIO* targets);
|
||||
|
||||
// Computes network errors, and stores the results in the rolling buffers,
|
||||
@ -362,8 +362,8 @@ class TESS_UNICHARSET_TRAINING_API LSTMTrainer : public LSTMRecognizer {
|
||||
double ComputeWinnerError(const NetworkIO& deltas);
|
||||
|
||||
// Computes a very simple bag of chars char error rate.
|
||||
double ComputeCharError(const GenericVector<int>& truth_str,
|
||||
const GenericVector<int>& ocr_str);
|
||||
double ComputeCharError(const std::vector<int>& truth_str,
|
||||
const std::vector<int>& ocr_str);
|
||||
// Computes a very simple bag of words word recall error rate.
|
||||
// NOTE that this is destructive on both input strings.
|
||||
double ComputeWordError(STRING* truth_str, STRING* ocr_str);
|
||||
@ -436,8 +436,8 @@ class TESS_UNICHARSET_TRAINING_API LSTMTrainer : public LSTMRecognizer {
|
||||
int training_stage_;
|
||||
// History of best error rate against iteration. Used for computing the
|
||||
// number of steps to each 2% improvement.
|
||||
GenericVector<double> best_error_history_;
|
||||
GenericVector<int> best_error_iterations_;
|
||||
std::vector<double> best_error_history_;
|
||||
std::vector<int> best_error_iterations_;
|
||||
// Number of iterations since the best_error_rate_ was 2% more than it is now.
|
||||
int32_t improvement_steps_;
|
||||
// Number of iterations that yielded a non-zero delta error and thus provided
|
||||
@ -458,7 +458,7 @@ class TESS_UNICHARSET_TRAINING_API LSTMTrainer : public LSTMRecognizer {
|
||||
// Rolling buffers storing recent training errors are indexed by
|
||||
// training_iteration % kRollingBufferSize_.
|
||||
static const int kRollingBufferSize_ = 1000;
|
||||
GenericVector<double> error_buffers_[ET_COUNT];
|
||||
std::vector<double> error_buffers_[ET_COUNT];
|
||||
// Rounded mean percent trailing training errors in the buffers.
|
||||
double error_rates_[ET_COUNT]; // RMS training error.
|
||||
// Traineddata file with optional dawgs + UNICHARSET and recoder.
|
||||
|
@ -843,7 +843,7 @@ LanguageModelDawgInfo *LanguageModel::GenerateDawgInfo(
|
||||
// Call LetterIsOkay().
|
||||
// Use the normalized IDs so that all shapes of ' can be allowed in words
|
||||
// like don't.
|
||||
const GenericVector<UNICHAR_ID>& normed_ids =
|
||||
const auto &normed_ids =
|
||||
dict_->getUnicharset().normed_ids(b.unichar_id());
|
||||
DawgPositionVector tmp_active_dawgs;
|
||||
for (int i = 0; i < normed_ids.size(); ++i) {
|
||||
|
@ -66,7 +66,7 @@ TEST(LangModelTest, AddACharacter) {
|
||||
file::JoinPath(output_dir, lang1, absl::StrCat(lang1, ".traineddata"));
|
||||
LSTMTrainer trainer1;
|
||||
trainer1.InitCharSet(traineddata1);
|
||||
GenericVector<int> labels1;
|
||||
std::vector<int> labels1;
|
||||
EXPECT_TRUE(trainer1.EncodeString(kTestString, &labels1));
|
||||
STRING test1_decoded = trainer1.DecodeLabels(labels1);
|
||||
std::string test1_str(&test1_decoded[0], test1_decoded.length());
|
||||
@ -89,13 +89,13 @@ TEST(LangModelTest, AddACharacter) {
|
||||
file::JoinPath(output_dir, lang2, absl::StrCat(lang2, ".traineddata"));
|
||||
LSTMTrainer trainer2;
|
||||
trainer2.InitCharSet(traineddata2);
|
||||
GenericVector<int> labels2;
|
||||
std::vector<int> labels2;
|
||||
EXPECT_TRUE(trainer2.EncodeString(kTestString, &labels2));
|
||||
STRING test2_decoded = trainer2.DecodeLabels(labels2);
|
||||
std::string test2_str(&test2_decoded[0], test2_decoded.length());
|
||||
LOG(INFO) << "Labels2=" << test2_str << "\n";
|
||||
// encode kTestStringRupees.
|
||||
GenericVector<int> labels3;
|
||||
std::vector<int> labels3;
|
||||
EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels3));
|
||||
STRING test3_decoded = trainer2.DecodeLabels(labels3);
|
||||
std::string test3_str(&test3_decoded[0], test3_decoded.length());
|
||||
@ -158,7 +158,7 @@ TEST(LangModelTest, AddACharacterHindi) {
|
||||
file::JoinPath(output_dir, lang1, absl::StrCat(lang1, ".traineddata"));
|
||||
LSTMTrainer trainer1;
|
||||
trainer1.InitCharSet(traineddata1);
|
||||
GenericVector<int> labels1;
|
||||
std::vector<int> labels1;
|
||||
EXPECT_TRUE(trainer1.EncodeString(kTestString, &labels1));
|
||||
STRING test1_decoded = trainer1.DecodeLabels(labels1);
|
||||
std::string test1_str(&test1_decoded[0], test1_decoded.length());
|
||||
@ -181,13 +181,13 @@ TEST(LangModelTest, AddACharacterHindi) {
|
||||
file::JoinPath(output_dir, lang2, absl::StrCat(lang2, ".traineddata"));
|
||||
LSTMTrainer trainer2;
|
||||
trainer2.InitCharSet(traineddata2);
|
||||
GenericVector<int> labels2;
|
||||
std::vector<int> labels2;
|
||||
EXPECT_TRUE(trainer2.EncodeString(kTestString, &labels2));
|
||||
STRING test2_decoded = trainer2.DecodeLabels(labels2);
|
||||
std::string test2_str(&test2_decoded[0], test2_decoded.length());
|
||||
LOG(INFO) << "Labels2=" << test2_str << "\n";
|
||||
// encode kTestStringRupees.
|
||||
GenericVector<int> labels3;
|
||||
std::vector<int> labels3;
|
||||
EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels3));
|
||||
STRING test3_decoded = trainer2.DecodeLabels(labels3);
|
||||
std::string test3_str(&test3_decoded[0], test3_decoded.length());
|
||||
|
@ -208,7 +208,7 @@ TEST_F(LSTMTrainerTest, TestLayerAccess) {
|
||||
128 * (4 * (128 + 32 + 1)),
|
||||
112 * (2 * 128 + 1)};
|
||||
|
||||
GenericVector<STRING> layers = trainer_->EnumerateLayers();
|
||||
auto layers = trainer_->EnumerateLayers();
|
||||
EXPECT_EQ(kNumLayers, layers.size());
|
||||
for (int i = 0; i < kNumLayers && i < layers.size(); ++i) {
|
||||
EXPECT_STREQ(kLayerIds[i], layers[i].c_str());
|
||||
|
@ -172,7 +172,7 @@ class LSTMTrainerTest : public testing::Test {
|
||||
std::string lstmf_name = lang + ".Arial_Unicode_MS.exp0.lstmf";
|
||||
SetupTrainer("[1,1,0,32 Lbx100 O1c1]", "bidi-lstm", unicharset_name,
|
||||
lstmf_name, recode, true, 5e-4, true, lang);
|
||||
GenericVector<int> labels;
|
||||
std::vector<int> labels;
|
||||
EXPECT_TRUE(trainer_->EncodeString(str.c_str(), &labels));
|
||||
STRING decoded = trainer_->DecodeLabels(labels);
|
||||
std::string decoded_str(&decoded[0], decoded.length());
|
||||
|
@ -36,7 +36,7 @@ TEST_F(LSTMTrainerTest, MapCoder) {
|
||||
deu_trainer.InitCharSet(TestDataNameToPath("deu/deu.traineddata"));
|
||||
// A string that uses characters common to French and German.
|
||||
std::string kTestStr = "The quick brown 'fox' jumps over: the lazy dog!";
|
||||
GenericVector<int> deu_labels;
|
||||
std::vector<int> deu_labels;
|
||||
EXPECT_TRUE(deu_trainer.EncodeString(kTestStr.c_str(), &deu_labels));
|
||||
// The french trainer cannot decode them correctly.
|
||||
STRING badly_decoded = fra_trainer.DecodeLabels(deu_labels);
|
||||
@ -44,12 +44,12 @@ TEST_F(LSTMTrainerTest, MapCoder) {
|
||||
LOG(INFO) << "bad_str fra=" << bad_str << "\n";
|
||||
EXPECT_NE(kTestStr, bad_str);
|
||||
// Encode the string as fra.
|
||||
GenericVector<int> fra_labels;
|
||||
std::vector<int> fra_labels;
|
||||
EXPECT_TRUE(fra_trainer.EncodeString(kTestStr.c_str(), &fra_labels));
|
||||
// Use the mapper to compute what the labels are as deu.
|
||||
std::vector<int> mapping = fra_trainer.MapRecoder(deu_trainer.GetUnicharset(),
|
||||
deu_trainer.GetRecoder());
|
||||
GenericVector<int> mapped_fra_labels(fra_labels.size(), -1);
|
||||
std::vector<int> mapped_fra_labels(fra_labels.size(), -1);
|
||||
for (int i = 0; i < fra_labels.size(); ++i) {
|
||||
mapped_fra_labels[i] = mapping[fra_labels[i]];
|
||||
EXPECT_NE(-1, mapped_fra_labels[i]) << "i=" << i << ", ch=" << kTestStr[i];
|
||||
|
@ -129,7 +129,7 @@ class RecodeBeamTest : public ::testing::Test {
|
||||
beam_search.Decode(output, 3.5, -0.125, -25.0, nullptr);
|
||||
// Uncomment and/or change nullptr above to &ccutil_.unicharset to debug:
|
||||
// beam_search.DebugBeams(ccutil_.unicharset);
|
||||
GenericVector<int> labels, xcoords;
|
||||
std::vector<int> labels, xcoords;
|
||||
beam_search.ExtractBestPathAsLabels(&labels, &xcoords);
|
||||
LOG(INFO) << "Labels size = " << labels.size() << " coords "
|
||||
<< xcoords.size() << "\n";
|
||||
@ -159,8 +159,8 @@ class RecodeBeamTest : public ::testing::Test {
|
||||
EXPECT_EQ(truth_utf8, decoded);
|
||||
|
||||
// Check that ExtractBestPathAsUnicharIds does the same thing.
|
||||
GenericVector<int> unichar_ids;
|
||||
GenericVector<float> certainties, ratings;
|
||||
std::vector<int> unichar_ids;
|
||||
std::vector<float> certainties, ratings;
|
||||
beam_search.ExtractBestPathAsUnicharIds(false, &ccutil_.unicharset,
|
||||
&unichar_ids, &certainties,
|
||||
&ratings, &xcoords);
|
||||
@ -253,7 +253,7 @@ class RecodeBeamTest : public ::testing::Test {
|
||||
int EncodeUTF8(const char* utf8_str, float score, int start_t, TRand* random,
|
||||
GENERIC_2D_ARRAY<float>* outputs) {
|
||||
int t = start_t;
|
||||
GenericVector<int> unichar_ids;
|
||||
std::vector<int> unichar_ids;
|
||||
EXPECT_TRUE(ccutil_.unicharset.encode_string(utf8_str, true, &unichar_ids,
|
||||
nullptr, nullptr));
|
||||
if (unichar_ids.empty() || utf8_str[0] == '\0') {
|
||||
|
@ -49,7 +49,7 @@ TEST(UnicharsetTest, Basics) {
|
||||
EXPECT_EQ(u.unichar_to_id("\ufb01"), INVALID_UNICHAR_ID);
|
||||
// The fi pair has no valid id.
|
||||
EXPECT_EQ(u.unichar_to_id("fi"), INVALID_UNICHAR_ID);
|
||||
GenericVector<int> labels;
|
||||
std::vector<int> labels;
|
||||
EXPECT_TRUE(u.encode_string("affine", true, &labels, nullptr, nullptr));
|
||||
std::vector<int> v(&labels[0], &labels[0] + labels.size());
|
||||
EXPECT_THAT(v, ElementsAreArray({3, 4, 4, 5, 7, 6}));
|
||||
@ -93,13 +93,13 @@ TEST(UnicharsetTest, Multibyte) {
|
||||
EXPECT_EQ(u.unichar_to_id("fi"), 6);
|
||||
// The fi ligature is findable.
|
||||
EXPECT_EQ(u.unichar_to_id("\ufb01"), 6);
|
||||
GenericVector<int> labels;
|
||||
std::vector<int> labels;
|
||||
EXPECT_TRUE(u.encode_string("\u0627\u062c\u062c\u062f\u0635\u062b", true,
|
||||
&labels, nullptr, nullptr));
|
||||
std::vector<int> v(&labels[0], &labels[0] + labels.size());
|
||||
EXPECT_THAT(v, ElementsAreArray({3, 4, 4, 5, 8, 7}));
|
||||
// With the fi ligature the fi is picked out.
|
||||
GenericVector<char> lengths;
|
||||
std::vector<char> lengths;
|
||||
int encoded_length;
|
||||
std::string src_str = "\u0627\u062c\ufb01\u0635\u062b";
|
||||
// src_str has to be pre-cleaned for lengths to be correct.
|
||||
|
Loading…
Reference in New Issue
Block a user