[training] More unique ptrs.

This commit is contained in:
Egor Pugin 2021-01-05 17:03:26 +03:00
parent 4415209fd6
commit 6e94564152
6 changed files with 33 additions and 55 deletions

View File

@ -110,8 +110,7 @@ int main(int argc, char **argv) {
tesseract::CheckSharedLibraryVersion();
ParseArguments(&argc, &argv);
STRING file_prefix;
auto trainer =
tesseract::LoadTrainingData(argc, argv, false, nullptr, &file_prefix);
auto [trainer,_] = tesseract::LoadTrainingData(argc, argv, false, false, &file_prefix);
tesseract::TessBaseAPI* api;
// Decode the classifier string.
tesseract::ShapeClassifier* shape_classifier = InitializeClassifier(

View File

@ -149,16 +149,14 @@ void ParseArguments(int* argc, char ***argv) {
namespace tesseract {
// Helper loads shape table from the given file.
ShapeTable* LoadShapeTable(const STRING& file_prefix) {
ShapeTable* shape_table = nullptr;
std::unique_ptr<ShapeTable> LoadShapeTable(const STRING& file_prefix) {
std::unique_ptr<ShapeTable> shape_table;
STRING shape_table_file = file_prefix;
shape_table_file += kShapeTableFileSuffix;
TFile shape_fp;
if (shape_fp.Open(shape_table_file.c_str(), nullptr)) {
shape_table = new ShapeTable;
shape_table = std::make_unique<ShapeTable>();
if (!shape_table->DeSerialize(&shape_fp)) {
delete shape_table;
shape_table = nullptr;
tprintf("Error: Failed to read shape table %s\n",
shape_table_file.c_str());
} else {
@ -206,9 +204,10 @@ void WriteShapeTable(const STRING& file_prefix, const ShapeTable& shape_table) {
* If shape_table is not nullptr, but failed to load, make a fake flat one,
* as shape clustering was not run.
*/
std::unique_ptr<MasterTrainer> LoadTrainingData(int argc, const char* const * argv,
std::pair<std::unique_ptr<MasterTrainer>, std::unique_ptr<ShapeTable>>
LoadTrainingData(int argc, const char* const * argv,
bool replication,
ShapeTable** shape_table,
bool shape_analysis,
STRING* file_prefix) {
InitFeatureDefs(&feature_defs);
InitIntegerFX();
@ -221,12 +220,9 @@ std::unique_ptr<MasterTrainer> LoadTrainingData(int argc, const char* const * ar
// a shape_table written by a previous shape clustering, then
// shape_analysis will be true, meaning that the MasterTrainer will replace
// some members of the unicharset with their fragments.
bool shape_analysis = false;
if (shape_table != nullptr) {
*shape_table = LoadShapeTable(*file_prefix);
if (*shape_table != nullptr) shape_analysis = true;
} else {
shape_analysis = true;
std::unique_ptr<ShapeTable> shape_table;
if (shape_analysis) {
shape_table = LoadShapeTable(*file_prefix);
}
auto trainer = std::make_unique<MasterTrainer>(NM_CHAR_ANISOTROPIC,
shape_analysis,
@ -289,18 +285,18 @@ std::unique_ptr<MasterTrainer> LoadTrainingData(int argc, const char* const * ar
fprintf(stderr, "Failed to save unicharset to file %s\n", FLAGS_O.c_str());
return {};
}
if (shape_table != nullptr) {
if (shape_analysis) {
// If we previously failed to load a shapetable, then shape clustering
// wasn't run so make a flat one now.
if (*shape_table == nullptr) {
*shape_table = new ShapeTable;
trainer->SetupFlatShapeTable(*shape_table);
if (!shape_table) {
shape_table = std::make_unique<ShapeTable>();
trainer->SetupFlatShapeTable(shape_table.get());
tprintf("Flat shape table summary: %s\n",
(*shape_table)->SummaryStr().c_str());
shape_table->SummaryStr().c_str());
}
(*shape_table)->set_unicharset(trainer->unicharset());
shape_table->set_unicharset(trainer->unicharset());
}
return trainer;
return { std::move(trainer), std::move(shape_table) };
}
} // namespace tesseract.

View File

@ -100,7 +100,7 @@ using MERGE_CLASS = MERGE_CLASS_NODE*;
namespace tesseract {
// Helper loads shape table from the given file.
ShapeTable* LoadShapeTable(const STRING& file_prefix);
std::unique_ptr<ShapeTable> LoadShapeTable(const STRING& file_prefix);
// Helper to write the shape_table.
TESS_COMMON_TRAINING_API
void WriteShapeTable(const STRING& file_prefix, const ShapeTable& shape_table);
@ -119,9 +119,10 @@ void WriteShapeTable(const STRING& file_prefix, const ShapeTable& shape_table);
// If shape_table is not nullptr, but failed to load, make a fake flat one,
// as shape clustering was not run.
TESS_COMMON_TRAINING_API
std::unique_ptr<MasterTrainer> LoadTrainingData(int argc, const char* const * argv,
std::pair<std::unique_ptr<MasterTrainer>, std::unique_ptr<ShapeTable>>
LoadTrainingData(int argc, const char* const * argv,
bool replication,
ShapeTable** shape_table,
bool shape_analysis,
STRING* file_prefix);
} // namespace tesseract.

View File

@ -202,12 +202,11 @@ int main (int argc, char **argv) {
ParseArguments(&argc, &argv);
ShapeTable* shape_table = nullptr;
STRING file_prefix;
// Load the training data.
auto trainer = tesseract::LoadTrainingData(argc, argv,
auto [trainer,shape_table] = tesseract::LoadTrainingData(argc, argv,
false,
&shape_table,
true,
&file_prefix);
if (trainer == nullptr) return 1; // Failed.
@ -216,7 +215,7 @@ int main (int argc, char **argv) {
// with the same list of unichars becomes a different class and the configs
// represent the different combinations of fonts.
IndexMapBiDi config_map;
SetupConfigMap(shape_table, &config_map);
SetupConfigMap(shape_table.get(), &config_map);
WriteShapeTable(file_prefix, *shape_table);
// If the shape_table is flat, then either we didn't run shape clustering, or
@ -270,7 +269,6 @@ int main (int argc, char **argv) {
}
delete [] float_classes;
FreeLabeledClassList(mf_classes);
delete shape_table;
printf("Done!\n");
if (!FLAGS_test_ch.empty()) {
// If we are displaying debug window(s), wait for the user to look at them.

View File

@ -49,9 +49,7 @@ int main(int argc, char **argv) {
ParseArguments(&argc, &argv);
STRING file_prefix;
auto trainer =
tesseract::LoadTrainingData(argc, argv, false, nullptr, &file_prefix);
auto [trainer,_] = tesseract::LoadTrainingData(argc, argv, false, false, &file_prefix);
if (!trainer)
return 1;

View File

@ -126,7 +126,7 @@ class MockClassifier : public ShapeClassifier {
return results->size();
}
// Provides access to the ShapeTable that this classifier works with.
virtual const ShapeTable* GetShapeTable() const { return shape_table_; }
const ShapeTable* GetShapeTable() const override { return shape_table_; }
private:
// Borrowed pointer to the ShapeTable.
@ -159,15 +159,6 @@ class MasterTrainerTest : public testing::Test {
return file::JoinPath(FLAGS_test_tmpdir, name);
}
MasterTrainerTest() {
shape_table_ = nullptr;
master_trainer_ = nullptr;
}
~MasterTrainerTest() {
delete master_trainer_;
delete shape_table_;
}
// Initializes the master_trainer_ and shape_table_.
// if load_from_tmp, then reloads a master trainer that was saved by a
// previous call in which it was false.
@ -180,11 +171,9 @@ class MasterTrainerTest : public testing::Test {
const char* argv[] = {tr_file_name.c_str()};
int argc = 1;
STRING file_prefix;
delete master_trainer_;
delete shape_table_;
shape_table_ = nullptr;
master_trainer_ =
LoadTrainingData(argc, argv, false, &shape_table_, &file_prefix);
auto [m,s] = LoadTrainingData(argc, argv, false, true, &file_prefix);
master_trainer_ = std::move(m);
shape_table_ = std::move(s);
EXPECT_TRUE(master_trainer_ != nullptr);
EXPECT_TRUE(shape_table_ != nullptr);
}
@ -237,8 +226,8 @@ class MasterTrainerTest : public testing::Test {
}
// Objects declared here can be used by all tests in the test case for Foo.
ShapeTable* shape_table_;
MasterTrainer* master_trainer_;
std::unique_ptr<ShapeTable> shape_table_;
std::unique_ptr<MasterTrainer> master_trainer_;
#endif
};
@ -268,12 +257,11 @@ TEST_F(MasterTrainerTest, ErrorCounterTest) {
// count junk.
if (shape_table_->FindShape(0, -1) < 0) shape_table_->AddShape(0, 0);
// Make a mock classifier.
tesseract::ShapeClassifier* shape_classifier =
new tesseract::MockClassifier(shape_table_);
auto shape_classifier = std::make_unique<MockClassifier>(shape_table_.get());
// Get the accuracy report.
STRING accuracy_report;
master_trainer_->TestClassifierOnSamples(tesseract::CT_UNICHAR_TOP1_ERR, 0,
false, shape_classifier,
false, shape_classifier.get(),
&accuracy_report);
LOG(INFO) << accuracy_report.c_str();
std::string result_string = accuracy_report.c_str();
@ -298,7 +286,5 @@ TEST_F(MasterTrainerTest, ErrorCounterTest) {
result_values[tesseract::CT_OK_MULTI_UNICHAR]);
EXPECT_EQ(num_samples - kNumNonReject, result_values[tesseract::CT_REJECT]);
EXPECT_EQ(kNumAnswers, result_values[tesseract::CT_NUM_RESULTS]);
delete shape_classifier;
#endif
}