From e3fd938bca0e9bdf80590a775fe1d9ee4dff6bb7 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Thu, 21 Jan 2021 23:17:14 +0100 Subject: [PATCH] lstmtrainer: Modernize code Signed-off-by: Stefan Weil --- src/training/unicharset/lstmtrainer.cpp | 53 +++++++++++++------------ 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/src/training/unicharset/lstmtrainer.cpp b/src/training/unicharset/lstmtrainer.cpp index 23801358e..a78093196 100644 --- a/src/training/unicharset/lstmtrainer.cpp +++ b/src/training/unicharset/lstmtrainer.cpp @@ -699,11 +699,11 @@ bool LSTMTrainer::EncodeString(const STRING& str, const UNICHARSET& unicharset, if (unicharset.encode_string(cleaned.c_str(), true, &internal_labels, nullptr, &err_index)) { bool success = true; - for (int i = 0; i < internal_labels.size(); ++i) { + for (auto internal_label : internal_labels) { if (recoder != nullptr) { // Re-encode labels via recoder. RecodedCharID code; - int len = recoder->EncodeUnichar(internal_labels[i], &code); + int len = recoder->EncodeUnichar(internal_label, &code); if (len > 0) { for (int j = 0; j < len; ++j) { labels->push_back(code(j)); @@ -715,7 +715,7 @@ bool LSTMTrainer::EncodeString(const STRING& str, const UNICHARSET& unicharset, break; } } else { - labels->push_back(internal_labels[i]); + labels->push_back(internal_label); if (!simple_text) labels->push_back(null_char); } } @@ -791,9 +791,10 @@ Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata, // Apart from space and null, increment the label. This is changes the // script-id to the same script-id but upside-down. // The labels need to be reversed in order, as the first is now the last. - for (int c = 0; c < truth_labels.size(); ++c) { - if (truth_labels[c] != UNICHAR_SPACE && truth_labels[c] != null_char_) - ++truth_labels[c]; + for (auto truth_label : truth_labels) { + if (truth_label != UNICHAR_SPACE && truth_label != null_char_) { + ++truth_label; + } } std::reverse(truth_labels.begin(), truth_labels.end()); } @@ -1088,10 +1089,12 @@ bool LSTMTrainer::ComputeTextTargets(const NetworkIO& outputs, DecodeLabels(truth_labels).c_str(), targets->Width()); return false; } - for (int i = 0; i < truth_labels.size() && i < targets->Width(); ++i) { - targets->SetActivations(i, truth_labels[i], 1.0); + size_t i = 0; + for (auto truth_label : truth_labels) { + targets->SetActivations(i, truth_label, 1.0); + ++i; } - for (int i = truth_labels.size(); i < targets->Width(); ++i) { + for (i = truth_labels.size(); i < targets->Width(); ++i) { targets->SetActivations(i, null_char_, 1.0); } return true; @@ -1173,20 +1176,20 @@ double LSTMTrainer::ComputeCharError(const std::vector& truth_str, std::vector label_counts; label_counts.resize(NumOutputs(), 0); int truth_size = 0; - for (int i = 0; i < truth_str.size(); ++i) { - if (truth_str[i] != null_char_) { - ++label_counts[truth_str[i]]; + for (auto ch : truth_str) { + if (ch != null_char_) { + ++label_counts[ch]; ++truth_size; } } - for (int i = 0; i < ocr_str.size(); ++i) { - if (ocr_str[i] != null_char_) { - --label_counts[ocr_str[i]]; + for (auto ch : ocr_str) { + if (ch != null_char_) { + --label_counts[ch]; } } int char_errors = 0; - for (int i = 0; i < label_counts.size(); ++i) { - char_errors += abs(label_counts[i]); + for (auto label_count : label_counts) { + char_errors += abs(label_count); } if (truth_size == 0) { return (char_errors == 0) ? 0.0 : 1.0; @@ -1203,19 +1206,19 @@ double LSTMTrainer::ComputeWordError(STRING* truth_str, STRING* ocr_str) { if (truth_words.empty()) return 0.0; ocr_str->split(' ', &ocr_words); StrMap word_counts; - for (int i = 0; i < truth_words.size(); ++i) { - std::string truth_word(truth_words[i].c_str()); - auto it = word_counts.find(truth_word); + for (auto truth_word : truth_words) { + std::string truth_word_string(truth_word.c_str()); + auto it = word_counts.find(truth_word_string); if (it == word_counts.end()) - word_counts.insert(std::make_pair(truth_word, 1)); + word_counts.insert(std::make_pair(truth_word_string, 1)); else ++it->second; } - for (int i = 0; i < ocr_words.size(); ++i) { - std::string ocr_word(ocr_words[i].c_str()); - auto it = word_counts.find(ocr_word); + for (auto ocr_word : ocr_words) { + std::string ocr_word_string(ocr_word.c_str()); + auto it = word_counts.find(ocr_word_string); if (it == word_counts.end()) - word_counts.insert(std::make_pair(ocr_word, -1)); + word_counts.insert(std::make_pair(ocr_word_string, -1)); else --it->second; }