tesseract/unittest/lstm_test.h
Stefan Weil 563a1717d4 Simplify class LSTMTrainer
The function pointers and callbacks file_reader_, file_writer_,
checkpointer_reader_ and checkpoint_writer_ are always set to
the same values. Replacing them by direct function calls
simplifies the code and allows removing more code from tesscallback.h.

Signed-off-by: Stefan Weil <sw@weilnetz.de>
2019-06-23 08:51:44 +02:00

192 lines
7.9 KiB
C++

// (C) Copyright 2017, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef TESSERACT_UNITTEST_LSTM_TEST_H_
#define TESSERACT_UNITTEST_LSTM_TEST_H_
#include <memory>
#include <string>
#include <utility>
#include "include_gunit.h"
#include "absl/strings/str_cat.h"
#include "tprintf.h"
#include "helpers.h"
#include "functions.h"
#include "lang_model_helpers.h"
#include "log.h" // for LOG
#include "lstmtrainer.h"
#include "unicharset.h"
namespace tesseract {
#if DEBUG_DETAIL == 0
// Number of iterations to run all the trainers.
const int kTrainerIterations = 600;
// Number of iterations between accuracy checks.
const int kBatchIterations = 100;
#else
// Number of iterations to run all the trainers.
const int kTrainerIterations = 2;
// Number of iterations between accuracy checks.
const int kBatchIterations = 1;
#endif
// The fixture for testing LSTMTrainer.
class LSTMTrainerTest : public testing::Test {
protected:
void SetUp() {
std::locale::global(std::locale(""));
}
LSTMTrainerTest() {}
std::string TestDataNameToPath(const std::string& name) {
return file::JoinPath(TESTDATA_DIR,
"" + name);
}
std::string TessDataNameToPath(const std::string& name) {
return file::JoinPath(TESSDATA_DIR,
"" + name);
}
std::string TestingNameToPath(const std::string& name) {
return file::JoinPath(TESTING_DIR,
"" + name);
}
void SetupTrainerEng(const std::string& network_spec, const std::string& model_name,
bool recode, bool adam) {
SetupTrainer(network_spec, model_name, "eng/eng.unicharset",
"eng.Arial.exp0.lstmf", recode, adam, 5e-4, false, "eng");
}
void SetupTrainer(const std::string& network_spec, const std::string& model_name,
const std::string& unicharset_file, const std::string& lstmf_file,
bool recode, bool adam, double learning_rate,
bool layer_specific, const std::string& kLang) {
// constexpr char kLang[] = "eng"; // Exact value doesn't matter.
std::string unicharset_name = TestDataNameToPath(unicharset_file);
UNICHARSET unicharset;
ASSERT_TRUE(unicharset.load_from_file(unicharset_name.c_str(), false));
std::string script_dir = file::JoinPath(
LANGDATA_DIR, "");
GenericVector<STRING> words;
EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, "", FLAGS_test_tmpdir,
kLang, !recode, words, words, words, false,
nullptr, nullptr));
std::string model_path = file::JoinPath(FLAGS_test_tmpdir, model_name);
std::string checkpoint_path = model_path + "_checkpoint";
trainer_.reset(new LSTMTrainer(model_path.c_str(), checkpoint_path.c_str(),
0, 0));
trainer_->InitCharSet(file::JoinPath(FLAGS_test_tmpdir, kLang,
absl::StrCat(kLang, ".traineddata")));
int net_mode = adam ? NF_ADAM : 0;
// Adam needs a higher learning rate, due to not multiplying the effective
// rate by 1/(1-momentum).
if (adam) learning_rate *= 20.0;
if (layer_specific) net_mode |= NF_LAYER_SPECIFIC_LR;
EXPECT_TRUE(trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1,
learning_rate, 0.9, 0.999));
GenericVector<STRING> filenames;
filenames.push_back(STRING(TestDataNameToPath(lstmf_file).c_str()));
EXPECT_TRUE(trainer_->LoadAllTrainingData(filenames, CS_SEQUENTIAL, false));
LOG(INFO) << "Setup network:" << model_name << "\n" ;
}
// Trains for a given number of iterations and returns the char error rate.
double TrainIterations(int max_iterations) {
int iteration = trainer_->training_iteration();
int iteration_limit = iteration + max_iterations;
double best_error = 100.0;
do {
STRING log_str;
int target_iteration = iteration + kBatchIterations;
// Train a few.
double mean_error = 0.0;
while (iteration < target_iteration && iteration < iteration_limit) {
trainer_->TrainOnLine(trainer_.get(), false);
iteration = trainer_->training_iteration();
mean_error += trainer_->LastSingleError(ET_CHAR_ERROR);
}
trainer_->MaintainCheckpoints(nullptr, &log_str);
iteration = trainer_->training_iteration();
mean_error *= 100.0 / kBatchIterations;
LOG(INFO) << log_str.string();
LOG(INFO) << "Best error = " << best_error << "\n" ;
LOG(INFO) << "Mean error = " << mean_error << "\n" ;
if (mean_error < best_error) best_error = mean_error;
} while (iteration < iteration_limit);
LOG(INFO) << "Trainer error rate = " << best_error << "\n";
return best_error;
}
// Tests for a given number of iterations and returns the char error rate.
double TestIterations(int max_iterations) {
CHECK_GT(max_iterations, 0);
int iteration = trainer_->sample_iteration();
double mean_error = 0.0;
int error_count = 0;
while (error_count < max_iterations) {
const ImageData& trainingdata =
*trainer_->mutable_training_data()->GetPageBySerial(iteration);
NetworkIO fwd_outputs, targets;
if (trainer_->PrepareForBackward(&trainingdata, &fwd_outputs, &targets) !=
UNENCODABLE) {
mean_error += trainer_->NewSingleError(ET_CHAR_ERROR);
++error_count;
}
trainer_->SetIteration(++iteration);
}
mean_error *= 100.0 / max_iterations;
LOG(INFO) << "Tester error rate = " << mean_error << "\n" ;
return mean_error;
}
// Tests that the current trainer_ can be converted to int mode and still gets
// within 1% of the error rate. Returns the increase in error from float to
// int.
double TestIntMode(int test_iterations) {
GenericVector<char> trainer_data;
EXPECT_TRUE(trainer_->SaveTrainingDump(NO_BEST_TRAINER, trainer_.get(),
&trainer_data));
// Get the error on the next few iterations in float mode.
double float_err = TestIterations(test_iterations);
// Restore the dump, convert to int and test error on that.
EXPECT_TRUE(trainer_->ReadTrainingDump(trainer_data, trainer_.get()));
trainer_->ConvertToInt();
double int_err = TestIterations(test_iterations);
EXPECT_LT(int_err, float_err + 1.0);
return int_err - float_err;
}
// Sets up a trainer with the given language and given recode+ctc condition.
// It then verifies that the given str encodes and decodes back to the same
// string.
void TestEncodeDecode(const std::string& lang, const std::string& str, bool recode) {
std::string unicharset_name = lang + "/" + lang + ".unicharset";
std::string lstmf_name = lang + ".Arial_Unicode_MS.exp0.lstmf";
SetupTrainer("[1,1,0,32 Lbx100 O1c1]", "bidi-lstm", unicharset_name,
lstmf_name, recode, true, 5e-4, true, lang);
GenericVector<int> labels;
EXPECT_TRUE(trainer_->EncodeString(str.c_str(), &labels));
STRING decoded = trainer_->DecodeLabels(labels);
std::string decoded_str(&decoded[0], decoded.length());
EXPECT_EQ(str, decoded_str);
}
// Calls TestEncodeDeode with both recode on and off.
void TestEncodeDecodeBoth(const std::string& lang, const std::string& str) {
TestEncodeDecode(lang, str, false);
TestEncodeDecode(lang, str, true);
}
std::unique_ptr<LSTMTrainer> trainer_;
};
} // namespace tesseract.
#endif // THIRD_PARTY_TESSERACT_UNITTEST_LSTM_TEST_H_