2018-10-12 20:00:14 +08:00
|
|
|
// (C) Copyright 2017, Google Inc.
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
2018-08-24 21:07:48 +08:00
|
|
|
|
|
|
|
// Although this is a trivial-looking test, it exercises a lot of code:
|
|
|
|
// SampleIterator has to correctly iterate over the correct characters, or
|
|
|
|
// it will fail.
|
|
|
|
// The canonical and cloud features computed by TrainingSampleSet need to
|
|
|
|
// be correct, along with the distance caches, organizing samples by font
|
|
|
|
// and class, indexing of features, distance calculations.
|
|
|
|
// IntFeatureDist has to work, or the canonical samples won't work.
|
|
|
|
// Mastertrainer has ability to read tr files and set itself up tested.
|
|
|
|
// Finally the serialize/deserialize test ensures that MasterTrainer,
|
|
|
|
// TrainingSampleSet, TrainingSample can all serialize/deserialize correctly
|
|
|
|
// enough to reproduce the same results.
|
|
|
|
|
2018-10-12 20:00:14 +08:00
|
|
|
#include "include_gunit.h"
|
|
|
|
|
2021-03-13 05:06:34 +08:00
|
|
|
#include "commontraining.h"
|
2018-10-12 20:00:14 +08:00
|
|
|
#include "errorcounter.h"
|
2021-03-13 05:06:34 +08:00
|
|
|
#include "log.h" // for LOG
|
2018-10-12 20:00:14 +08:00
|
|
|
#include "mastertrainer.h"
|
|
|
|
#include "shapeclassifier.h"
|
|
|
|
#include "shapetable.h"
|
|
|
|
#include "trainingsample.h"
|
2021-03-13 05:06:34 +08:00
|
|
|
#include "unicharset.h"
|
2018-08-24 21:07:48 +08:00
|
|
|
|
2021-03-13 05:06:34 +08:00
|
|
|
#include "absl/strings/numbers.h" // for safe_strto32
|
|
|
|
#include "absl/strings/str_split.h" // for absl::StrSplit
|
2021-01-05 21:46:24 +08:00
|
|
|
|
|
|
|
#include <string>
|
|
|
|
#include <utility>
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
using namespace tesseract;
|
|
|
|
|
2018-08-24 21:07:48 +08:00
|
|
|
// Specs of the MockClassifier.
|
2018-10-12 20:00:14 +08:00
|
|
|
static const int kNumTopNErrs = 10;
|
|
|
|
static const int kNumTop2Errs = kNumTopNErrs + 20;
|
|
|
|
static const int kNumTop1Errs = kNumTop2Errs + 30;
|
|
|
|
static const int kNumTopTopErrs = kNumTop1Errs + 25;
|
|
|
|
static const int kNumNonReject = 1000;
|
|
|
|
static const int kNumCorrect = kNumNonReject - kNumTop1Errs;
|
2018-08-24 21:07:48 +08:00
|
|
|
// The total number of answers is given by the number of non-rejects plus
|
|
|
|
// all the multiple answers.
|
2018-10-12 20:00:14 +08:00
|
|
|
static const int kNumAnswers = kNumNonReject + 2 * (kNumTop2Errs - kNumTopNErrs) +
|
2021-03-13 05:06:34 +08:00
|
|
|
(kNumTop1Errs - kNumTop2Errs) + (kNumTopTopErrs - kNumTop1Errs);
|
2018-08-24 21:07:48 +08:00
|
|
|
|
2019-05-17 00:07:32 +08:00
|
|
|
#ifndef DISABLED_LEGACY_ENGINE
|
2021-03-13 05:06:34 +08:00
|
|
|
static bool safe_strto32(const std::string &str, int *pResult) {
|
2018-10-12 20:00:14 +08:00
|
|
|
long n = strtol(str.c_str(), nullptr, 0);
|
|
|
|
*pResult = n;
|
|
|
|
return true;
|
|
|
|
}
|
2019-05-17 00:07:32 +08:00
|
|
|
#endif
|
2018-10-12 20:00:14 +08:00
|
|
|
|
2018-08-24 21:07:48 +08:00
|
|
|
// Mock ShapeClassifier that cheats by looking at the correct answer, and
|
|
|
|
// creates a specific pattern of errors that can be tested.
|
|
|
|
class MockClassifier : public ShapeClassifier {
|
2021-03-13 05:06:34 +08:00
|
|
|
public:
|
|
|
|
explicit MockClassifier(ShapeTable *shape_table)
|
2018-09-29 15:19:13 +08:00
|
|
|
: shape_table_(shape_table), num_done_(0), done_bad_font_(false) {
|
2018-08-24 21:07:48 +08:00
|
|
|
// Add a false font answer to the shape table. We pick a random unichar_id,
|
|
|
|
// add a new shape for it with a false font. Font must actually exist in
|
|
|
|
// the font table, but not match anything in the first 1000 samples.
|
|
|
|
false_unichar_id_ = 67;
|
|
|
|
false_shape_ = shape_table_->AddShape(false_unichar_id_, 25);
|
|
|
|
}
|
|
|
|
virtual ~MockClassifier() {}
|
|
|
|
|
|
|
|
// Classifies the given [training] sample, writing to results.
|
|
|
|
// If debug is non-zero, then various degrees of classifier dependent debug
|
|
|
|
// information is provided.
|
|
|
|
// If keep_this (a shape index) is >= 0, then the results should always
|
|
|
|
// contain keep_this, and (if possible) anything of intermediate confidence.
|
|
|
|
// The return value is the number of classes saved in results.
|
2021-03-13 05:06:34 +08:00
|
|
|
int ClassifySample(const TrainingSample &sample, Pix *page_pix, int debug, UNICHAR_ID keep_this,
|
|
|
|
std::vector<ShapeRating> *results) override {
|
2018-08-24 21:07:48 +08:00
|
|
|
results->clear();
|
|
|
|
// Everything except the first kNumNonReject is a reject.
|
2021-03-13 05:06:34 +08:00
|
|
|
if (++num_done_ > kNumNonReject)
|
|
|
|
return 0;
|
2018-08-24 21:07:48 +08:00
|
|
|
|
|
|
|
int class_id = sample.class_id();
|
|
|
|
int font_id = sample.font_id();
|
|
|
|
int shape_id = shape_table_->FindShape(class_id, font_id);
|
|
|
|
// Get ids of some wrong answers.
|
|
|
|
int wrong_id1 = shape_id > 10 ? shape_id - 1 : shape_id + 1;
|
|
|
|
int wrong_id2 = shape_id > 10 ? shape_id - 2 : shape_id + 2;
|
|
|
|
if (num_done_ <= kNumTopNErrs) {
|
|
|
|
// The first kNumTopNErrs are top-n errors.
|
|
|
|
results->push_back(ShapeRating(wrong_id1, 1.0f));
|
|
|
|
} else if (num_done_ <= kNumTop2Errs) {
|
|
|
|
// The next kNumTop2Errs - kNumTopNErrs are top-2 errors.
|
|
|
|
results->push_back(ShapeRating(wrong_id1, 1.0f));
|
|
|
|
results->push_back(ShapeRating(wrong_id2, 0.875f));
|
|
|
|
results->push_back(ShapeRating(shape_id, 0.75f));
|
|
|
|
} else if (num_done_ <= kNumTop1Errs) {
|
|
|
|
// The next kNumTop1Errs - kNumTop2Errs are top-1 errors.
|
|
|
|
results->push_back(ShapeRating(wrong_id1, 1.0f));
|
|
|
|
results->push_back(ShapeRating(shape_id, 0.8f));
|
|
|
|
} else if (num_done_ <= kNumTopTopErrs) {
|
|
|
|
// The next kNumTopTopErrs - kNumTop1Errs are cases where the actual top
|
|
|
|
// is not correct, but do not count as a top-1 error because the rating
|
|
|
|
// is close enough to the top answer.
|
|
|
|
results->push_back(ShapeRating(wrong_id1, 1.0f));
|
|
|
|
results->push_back(ShapeRating(shape_id, 0.99f));
|
|
|
|
} else if (!done_bad_font_ && class_id == false_unichar_id_) {
|
|
|
|
// There is a single character with a bad font.
|
|
|
|
results->push_back(ShapeRating(false_shape_, 1.0f));
|
|
|
|
done_bad_font_ = true;
|
|
|
|
} else {
|
|
|
|
// Everything else is correct.
|
|
|
|
results->push_back(ShapeRating(shape_id, 1.0f));
|
|
|
|
}
|
|
|
|
return results->size();
|
|
|
|
}
|
|
|
|
// Provides access to the ShapeTable that this classifier works with.
|
2021-03-13 05:06:34 +08:00
|
|
|
const ShapeTable *GetShapeTable() const override {
|
|
|
|
return shape_table_;
|
|
|
|
}
|
2018-08-24 21:07:48 +08:00
|
|
|
|
2021-03-13 05:06:34 +08:00
|
|
|
private:
|
2018-08-24 21:07:48 +08:00
|
|
|
// Borrowed pointer to the ShapeTable.
|
2021-03-13 05:06:34 +08:00
|
|
|
ShapeTable *shape_table_;
|
2018-08-24 21:07:48 +08:00
|
|
|
// Unichar_id of a random character that occurs after the first 60 samples.
|
|
|
|
int false_unichar_id_;
|
|
|
|
// Shape index of prepared false answer for false_unichar_id.
|
|
|
|
int false_shape_;
|
|
|
|
// The number of classifications we have processed.
|
|
|
|
int num_done_;
|
|
|
|
// True after the false font has been emitted.
|
|
|
|
bool done_bad_font_;
|
|
|
|
};
|
|
|
|
|
|
|
|
const double kMin1lDistance = 0.25;
|
|
|
|
|
|
|
|
// The fixture for testing Tesseract.
|
|
|
|
class MasterTrainerTest : public testing::Test {
|
2019-05-17 00:07:32 +08:00
|
|
|
#ifndef DISABLED_LEGACY_ENGINE
|
2021-03-13 05:06:34 +08:00
|
|
|
protected:
|
2019-05-17 00:07:32 +08:00
|
|
|
void SetUp() {
|
|
|
|
std::locale::global(std::locale(""));
|
2020-12-31 01:17:58 +08:00
|
|
|
file::MakeTmpdir();
|
2019-05-17 00:07:32 +08:00
|
|
|
}
|
|
|
|
|
2021-03-13 05:06:34 +08:00
|
|
|
std::string TestDataNameToPath(const std::string &name) {
|
2018-10-12 20:00:14 +08:00
|
|
|
return file::JoinPath(TESTING_DIR, name);
|
2018-08-24 21:07:48 +08:00
|
|
|
}
|
2021-03-13 05:06:34 +08:00
|
|
|
std::string TmpNameToPath(const std::string &name) {
|
2018-08-24 21:07:48 +08:00
|
|
|
return file::JoinPath(FLAGS_test_tmpdir, name);
|
|
|
|
}
|
|
|
|
|
2021-01-23 21:54:38 +08:00
|
|
|
MasterTrainerTest() {
|
|
|
|
shape_table_ = nullptr;
|
|
|
|
master_trainer_ = nullptr;
|
|
|
|
}
|
|
|
|
~MasterTrainerTest() {
|
|
|
|
delete shape_table_;
|
|
|
|
}
|
|
|
|
|
2018-08-24 21:07:48 +08:00
|
|
|
// Initializes the master_trainer_ and shape_table_.
|
|
|
|
// if load_from_tmp, then reloads a master trainer that was saved by a
|
|
|
|
// previous call in which it was false.
|
|
|
|
void LoadMasterTrainer() {
|
2018-10-12 20:00:14 +08:00
|
|
|
FLAGS_output_trainer = TmpNameToPath("tmp_trainer").c_str();
|
|
|
|
FLAGS_F = file::JoinPath(LANGDATA_DIR, "font_properties").c_str();
|
|
|
|
FLAGS_X = TestDataNameToPath("eng.xheights").c_str();
|
2019-01-21 23:41:05 +08:00
|
|
|
FLAGS_U = TestDataNameToPath("eng.unicharset").c_str();
|
2018-10-12 20:00:14 +08:00
|
|
|
std::string tr_file_name(TestDataNameToPath("eng.Arial.exp0.tr"));
|
2021-03-13 05:06:34 +08:00
|
|
|
const char *argv[] = {tr_file_name.c_str()};
|
2018-08-24 21:07:48 +08:00
|
|
|
int argc = 1;
|
|
|
|
STRING file_prefix;
|
2021-01-23 21:54:38 +08:00
|
|
|
delete shape_table_;
|
|
|
|
shape_table_ = nullptr;
|
2021-03-13 05:06:34 +08:00
|
|
|
master_trainer_ = LoadTrainingData(argc, argv, false, &shape_table_, &file_prefix);
|
2018-09-29 15:27:12 +08:00
|
|
|
EXPECT_TRUE(master_trainer_ != nullptr);
|
|
|
|
EXPECT_TRUE(shape_table_ != nullptr);
|
2018-08-24 21:07:48 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// EXPECTs that the distance between I and l in Arial is 0 and that the
|
|
|
|
// distance to 1 is significantly not 0.
|
|
|
|
void VerifyIl1() {
|
|
|
|
// Find the font id for Arial.
|
|
|
|
int font_id = master_trainer_->GetFontInfoId("Arial");
|
|
|
|
EXPECT_GE(font_id, 0);
|
|
|
|
// Track down the characters we are interested in.
|
|
|
|
int unichar_I = master_trainer_->unicharset().unichar_to_id("I");
|
|
|
|
EXPECT_GT(unichar_I, 0);
|
|
|
|
int unichar_l = master_trainer_->unicharset().unichar_to_id("l");
|
|
|
|
EXPECT_GT(unichar_l, 0);
|
|
|
|
int unichar_1 = master_trainer_->unicharset().unichar_to_id("1");
|
|
|
|
EXPECT_GT(unichar_1, 0);
|
|
|
|
// Now get the shape ids.
|
|
|
|
int shape_I = shape_table_->FindShape(unichar_I, font_id);
|
|
|
|
EXPECT_GE(shape_I, 0);
|
|
|
|
int shape_l = shape_table_->FindShape(unichar_l, font_id);
|
|
|
|
EXPECT_GE(shape_l, 0);
|
|
|
|
int shape_1 = shape_table_->FindShape(unichar_1, font_id);
|
|
|
|
EXPECT_GE(shape_1, 0);
|
|
|
|
|
2021-03-13 05:06:34 +08:00
|
|
|
float dist_I_l = master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_l);
|
2018-08-24 21:07:48 +08:00
|
|
|
// No tolerance here. We expect that I and l should match exactly.
|
|
|
|
EXPECT_EQ(0.0f, dist_I_l);
|
2021-03-13 05:06:34 +08:00
|
|
|
float dist_l_I = master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_I);
|
2018-08-24 21:07:48 +08:00
|
|
|
// BOTH ways.
|
|
|
|
EXPECT_EQ(0.0f, dist_l_I);
|
|
|
|
|
|
|
|
// l/1 on the other hand should be distinct.
|
2021-03-13 05:06:34 +08:00
|
|
|
float dist_l_1 = master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_1);
|
2018-08-24 21:07:48 +08:00
|
|
|
EXPECT_GT(dist_l_1, kMin1lDistance);
|
2021-03-13 05:06:34 +08:00
|
|
|
float dist_1_l = master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_l);
|
2018-08-24 21:07:48 +08:00
|
|
|
EXPECT_GT(dist_1_l, kMin1lDistance);
|
|
|
|
|
|
|
|
// So should I/1.
|
2021-03-13 05:06:34 +08:00
|
|
|
float dist_I_1 = master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_1);
|
2018-08-24 21:07:48 +08:00
|
|
|
EXPECT_GT(dist_I_1, kMin1lDistance);
|
2021-03-13 05:06:34 +08:00
|
|
|
float dist_1_I = master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_I);
|
2018-08-24 21:07:48 +08:00
|
|
|
EXPECT_GT(dist_1_I, kMin1lDistance);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Objects declared here can be used by all tests in the test case for Foo.
|
2021-03-13 05:06:34 +08:00
|
|
|
ShapeTable *shape_table_;
|
2021-01-05 22:03:26 +08:00
|
|
|
std::unique_ptr<MasterTrainer> master_trainer_;
|
2019-05-17 00:07:32 +08:00
|
|
|
#endif
|
2018-08-24 21:07:48 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
// Tests that the MasterTrainer correctly loads its data and reaches the correct
|
|
|
|
// conclusion over the distance between Arial I l and 1.
|
|
|
|
TEST_F(MasterTrainerTest, Il1Test) {
|
2019-05-17 00:07:32 +08:00
|
|
|
#ifdef DISABLED_LEGACY_ENGINE
|
|
|
|
// Skip test because LoadTrainingData is missing.
|
|
|
|
GTEST_SKIP();
|
|
|
|
#else
|
2018-08-24 21:07:48 +08:00
|
|
|
// Initialize the master_trainer_ and load the Arial tr file.
|
|
|
|
LoadMasterTrainer();
|
|
|
|
VerifyIl1();
|
2019-05-17 00:07:32 +08:00
|
|
|
#endif
|
2018-08-24 21:07:48 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Tests the ErrorCounter using a MockClassifier to check that it counts
|
|
|
|
// error categories correctly.
|
|
|
|
TEST_F(MasterTrainerTest, ErrorCounterTest) {
|
2019-05-17 00:07:32 +08:00
|
|
|
#ifdef DISABLED_LEGACY_ENGINE
|
|
|
|
// Skip test because LoadTrainingData is missing.
|
|
|
|
GTEST_SKIP();
|
|
|
|
#else
|
2018-08-24 21:07:48 +08:00
|
|
|
// Initialize the master_trainer_ from the saved tmp file.
|
|
|
|
LoadMasterTrainer();
|
|
|
|
// Add the space character to the shape_table_ if not already present to
|
|
|
|
// count junk.
|
2021-03-13 05:06:34 +08:00
|
|
|
if (shape_table_->FindShape(0, -1) < 0)
|
|
|
|
shape_table_->AddShape(0, 0);
|
2018-08-24 21:07:48 +08:00
|
|
|
// Make a mock classifier.
|
2021-01-23 21:54:38 +08:00
|
|
|
auto shape_classifier = std::make_unique<MockClassifier>(shape_table_);
|
2018-08-24 21:07:48 +08:00
|
|
|
// Get the accuracy report.
|
|
|
|
STRING accuracy_report;
|
2021-03-13 05:06:34 +08:00
|
|
|
master_trainer_->TestClassifierOnSamples(tesseract::CT_UNICHAR_TOP1_ERR, 0, false,
|
|
|
|
shape_classifier.get(), &accuracy_report);
|
2019-09-25 16:07:51 +08:00
|
|
|
LOG(INFO) << accuracy_report.c_str();
|
|
|
|
std::string result_string = accuracy_report.c_str();
|
2021-03-13 05:06:34 +08:00
|
|
|
std::vector<std::string> results = absl::StrSplit(result_string, '\t', absl::SkipEmpty());
|
2018-08-24 21:07:48 +08:00
|
|
|
EXPECT_EQ(tesseract::CT_SIZE + 1, results.size());
|
|
|
|
int result_values[tesseract::CT_SIZE];
|
|
|
|
for (int i = 0; i < tesseract::CT_SIZE; ++i) {
|
|
|
|
EXPECT_TRUE(safe_strto32(results[i + 1], &result_values[i]));
|
|
|
|
}
|
|
|
|
// These tests are more-or-less immune to additions to the number of
|
|
|
|
// categories or changes in the training data.
|
|
|
|
int num_samples = master_trainer_->GetSamples()->num_raw_samples();
|
|
|
|
EXPECT_EQ(kNumCorrect, result_values[tesseract::CT_UNICHAR_TOP_OK]);
|
|
|
|
EXPECT_EQ(1, result_values[tesseract::CT_FONT_ATTR_ERR]);
|
|
|
|
EXPECT_EQ(kNumTopTopErrs, result_values[tesseract::CT_UNICHAR_TOPTOP_ERR]);
|
|
|
|
EXPECT_EQ(kNumTop1Errs, result_values[tesseract::CT_UNICHAR_TOP1_ERR]);
|
|
|
|
EXPECT_EQ(kNumTop2Errs, result_values[tesseract::CT_UNICHAR_TOP2_ERR]);
|
|
|
|
EXPECT_EQ(kNumTopNErrs, result_values[tesseract::CT_UNICHAR_TOPN_ERR]);
|
|
|
|
// Each of the TOPTOP errs also counts as a multi-unichar.
|
2021-03-13 05:06:34 +08:00
|
|
|
EXPECT_EQ(kNumTopTopErrs - kNumTop1Errs, result_values[tesseract::CT_OK_MULTI_UNICHAR]);
|
2018-08-24 21:07:48 +08:00
|
|
|
EXPECT_EQ(num_samples - kNumNonReject, result_values[tesseract::CT_REJECT]);
|
|
|
|
EXPECT_EQ(kNumAnswers, result_values[tesseract::CT_NUM_RESULTS]);
|
2019-05-17 00:07:32 +08:00
|
|
|
#endif
|
2018-08-24 21:07:48 +08:00
|
|
|
}
|