lstmtrainer: Modernize code

Signed-off-by: Stefan Weil <sw@weilnetz.de>
This commit is contained in:
Stefan Weil 2021-01-21 23:17:14 +01:00
parent 0cdaab5ac9
commit e3fd938bca

View File

@ -699,11 +699,11 @@ bool LSTMTrainer::EncodeString(const STRING& str, const UNICHARSET& unicharset,
if (unicharset.encode_string(cleaned.c_str(), true, &internal_labels, nullptr,
&err_index)) {
bool success = true;
for (int i = 0; i < internal_labels.size(); ++i) {
for (auto internal_label : internal_labels) {
if (recoder != nullptr) {
// Re-encode labels via recoder.
RecodedCharID code;
int len = recoder->EncodeUnichar(internal_labels[i], &code);
int len = recoder->EncodeUnichar(internal_label, &code);
if (len > 0) {
for (int j = 0; j < len; ++j) {
labels->push_back(code(j));
@ -715,7 +715,7 @@ bool LSTMTrainer::EncodeString(const STRING& str, const UNICHARSET& unicharset,
break;
}
} else {
labels->push_back(internal_labels[i]);
labels->push_back(internal_label);
if (!simple_text) labels->push_back(null_char);
}
}
@ -791,9 +791,10 @@ Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata,
// Apart from space and null, increment the label. This is changes the
// script-id to the same script-id but upside-down.
// The labels need to be reversed in order, as the first is now the last.
for (int c = 0; c < truth_labels.size(); ++c) {
if (truth_labels[c] != UNICHAR_SPACE && truth_labels[c] != null_char_)
++truth_labels[c];
for (auto truth_label : truth_labels) {
if (truth_label != UNICHAR_SPACE && truth_label != null_char_) {
++truth_label;
}
}
std::reverse(truth_labels.begin(), truth_labels.end());
}
@ -1088,10 +1089,12 @@ bool LSTMTrainer::ComputeTextTargets(const NetworkIO& outputs,
DecodeLabels(truth_labels).c_str(), targets->Width());
return false;
}
for (int i = 0; i < truth_labels.size() && i < targets->Width(); ++i) {
targets->SetActivations(i, truth_labels[i], 1.0);
size_t i = 0;
for (auto truth_label : truth_labels) {
targets->SetActivations(i, truth_label, 1.0);
++i;
}
for (int i = truth_labels.size(); i < targets->Width(); ++i) {
for (i = truth_labels.size(); i < targets->Width(); ++i) {
targets->SetActivations(i, null_char_, 1.0);
}
return true;
@ -1173,20 +1176,20 @@ double LSTMTrainer::ComputeCharError(const std::vector<int>& truth_str,
std::vector<int> label_counts;
label_counts.resize(NumOutputs(), 0);
int truth_size = 0;
for (int i = 0; i < truth_str.size(); ++i) {
if (truth_str[i] != null_char_) {
++label_counts[truth_str[i]];
for (auto ch : truth_str) {
if (ch != null_char_) {
++label_counts[ch];
++truth_size;
}
}
for (int i = 0; i < ocr_str.size(); ++i) {
if (ocr_str[i] != null_char_) {
--label_counts[ocr_str[i]];
for (auto ch : ocr_str) {
if (ch != null_char_) {
--label_counts[ch];
}
}
int char_errors = 0;
for (int i = 0; i < label_counts.size(); ++i) {
char_errors += abs(label_counts[i]);
for (auto label_count : label_counts) {
char_errors += abs(label_count);
}
if (truth_size == 0) {
return (char_errors == 0) ? 0.0 : 1.0;
@ -1203,19 +1206,19 @@ double LSTMTrainer::ComputeWordError(STRING* truth_str, STRING* ocr_str) {
if (truth_words.empty()) return 0.0;
ocr_str->split(' ', &ocr_words);
StrMap word_counts;
for (int i = 0; i < truth_words.size(); ++i) {
std::string truth_word(truth_words[i].c_str());
auto it = word_counts.find(truth_word);
for (auto truth_word : truth_words) {
std::string truth_word_string(truth_word.c_str());
auto it = word_counts.find(truth_word_string);
if (it == word_counts.end())
word_counts.insert(std::make_pair(truth_word, 1));
word_counts.insert(std::make_pair(truth_word_string, 1));
else
++it->second;
}
for (int i = 0; i < ocr_words.size(); ++i) {
std::string ocr_word(ocr_words[i].c_str());
auto it = word_counts.find(ocr_word);
for (auto ocr_word : ocr_words) {
std::string ocr_word_string(ocr_word.c_str());
auto it = word_counts.find(ocr_word_string);
if (it == word_counts.end())
word_counts.insert(std::make_pair(ocr_word, -1));
word_counts.insert(std::make_pair(ocr_word_string, -1));
else
--it->second;
}