mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2024-11-27 20:59:36 +08:00
[training] More unique ptrs.
This commit is contained in:
parent
4415209fd6
commit
6e94564152
@ -110,8 +110,7 @@ int main(int argc, char **argv) {
|
||||
tesseract::CheckSharedLibraryVersion();
|
||||
ParseArguments(&argc, &argv);
|
||||
STRING file_prefix;
|
||||
auto trainer =
|
||||
tesseract::LoadTrainingData(argc, argv, false, nullptr, &file_prefix);
|
||||
auto [trainer,_] = tesseract::LoadTrainingData(argc, argv, false, false, &file_prefix);
|
||||
tesseract::TessBaseAPI* api;
|
||||
// Decode the classifier string.
|
||||
tesseract::ShapeClassifier* shape_classifier = InitializeClassifier(
|
||||
|
@ -149,16 +149,14 @@ void ParseArguments(int* argc, char ***argv) {
|
||||
|
||||
namespace tesseract {
|
||||
// Helper loads shape table from the given file.
|
||||
ShapeTable* LoadShapeTable(const STRING& file_prefix) {
|
||||
ShapeTable* shape_table = nullptr;
|
||||
std::unique_ptr<ShapeTable> LoadShapeTable(const STRING& file_prefix) {
|
||||
std::unique_ptr<ShapeTable> shape_table;
|
||||
STRING shape_table_file = file_prefix;
|
||||
shape_table_file += kShapeTableFileSuffix;
|
||||
TFile shape_fp;
|
||||
if (shape_fp.Open(shape_table_file.c_str(), nullptr)) {
|
||||
shape_table = new ShapeTable;
|
||||
shape_table = std::make_unique<ShapeTable>();
|
||||
if (!shape_table->DeSerialize(&shape_fp)) {
|
||||
delete shape_table;
|
||||
shape_table = nullptr;
|
||||
tprintf("Error: Failed to read shape table %s\n",
|
||||
shape_table_file.c_str());
|
||||
} else {
|
||||
@ -206,9 +204,10 @@ void WriteShapeTable(const STRING& file_prefix, const ShapeTable& shape_table) {
|
||||
* If shape_table is not nullptr, but failed to load, make a fake flat one,
|
||||
* as shape clustering was not run.
|
||||
*/
|
||||
std::unique_ptr<MasterTrainer> LoadTrainingData(int argc, const char* const * argv,
|
||||
std::pair<std::unique_ptr<MasterTrainer>, std::unique_ptr<ShapeTable>>
|
||||
LoadTrainingData(int argc, const char* const * argv,
|
||||
bool replication,
|
||||
ShapeTable** shape_table,
|
||||
bool shape_analysis,
|
||||
STRING* file_prefix) {
|
||||
InitFeatureDefs(&feature_defs);
|
||||
InitIntegerFX();
|
||||
@ -221,12 +220,9 @@ std::unique_ptr<MasterTrainer> LoadTrainingData(int argc, const char* const * ar
|
||||
// a shape_table written by a previous shape clustering, then
|
||||
// shape_analysis will be true, meaning that the MasterTrainer will replace
|
||||
// some members of the unicharset with their fragments.
|
||||
bool shape_analysis = false;
|
||||
if (shape_table != nullptr) {
|
||||
*shape_table = LoadShapeTable(*file_prefix);
|
||||
if (*shape_table != nullptr) shape_analysis = true;
|
||||
} else {
|
||||
shape_analysis = true;
|
||||
std::unique_ptr<ShapeTable> shape_table;
|
||||
if (shape_analysis) {
|
||||
shape_table = LoadShapeTable(*file_prefix);
|
||||
}
|
||||
auto trainer = std::make_unique<MasterTrainer>(NM_CHAR_ANISOTROPIC,
|
||||
shape_analysis,
|
||||
@ -289,18 +285,18 @@ std::unique_ptr<MasterTrainer> LoadTrainingData(int argc, const char* const * ar
|
||||
fprintf(stderr, "Failed to save unicharset to file %s\n", FLAGS_O.c_str());
|
||||
return {};
|
||||
}
|
||||
if (shape_table != nullptr) {
|
||||
if (shape_analysis) {
|
||||
// If we previously failed to load a shapetable, then shape clustering
|
||||
// wasn't run so make a flat one now.
|
||||
if (*shape_table == nullptr) {
|
||||
*shape_table = new ShapeTable;
|
||||
trainer->SetupFlatShapeTable(*shape_table);
|
||||
if (!shape_table) {
|
||||
shape_table = std::make_unique<ShapeTable>();
|
||||
trainer->SetupFlatShapeTable(shape_table.get());
|
||||
tprintf("Flat shape table summary: %s\n",
|
||||
(*shape_table)->SummaryStr().c_str());
|
||||
shape_table->SummaryStr().c_str());
|
||||
}
|
||||
(*shape_table)->set_unicharset(trainer->unicharset());
|
||||
shape_table->set_unicharset(trainer->unicharset());
|
||||
}
|
||||
return trainer;
|
||||
return { std::move(trainer), std::move(shape_table) };
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
||||
|
@ -100,7 +100,7 @@ using MERGE_CLASS = MERGE_CLASS_NODE*;
|
||||
namespace tesseract {
|
||||
|
||||
// Helper loads shape table from the given file.
|
||||
ShapeTable* LoadShapeTable(const STRING& file_prefix);
|
||||
std::unique_ptr<ShapeTable> LoadShapeTable(const STRING& file_prefix);
|
||||
// Helper to write the shape_table.
|
||||
TESS_COMMON_TRAINING_API
|
||||
void WriteShapeTable(const STRING& file_prefix, const ShapeTable& shape_table);
|
||||
@ -119,9 +119,10 @@ void WriteShapeTable(const STRING& file_prefix, const ShapeTable& shape_table);
|
||||
// If shape_table is not nullptr, but failed to load, make a fake flat one,
|
||||
// as shape clustering was not run.
|
||||
TESS_COMMON_TRAINING_API
|
||||
std::unique_ptr<MasterTrainer> LoadTrainingData(int argc, const char* const * argv,
|
||||
std::pair<std::unique_ptr<MasterTrainer>, std::unique_ptr<ShapeTable>>
|
||||
LoadTrainingData(int argc, const char* const * argv,
|
||||
bool replication,
|
||||
ShapeTable** shape_table,
|
||||
bool shape_analysis,
|
||||
STRING* file_prefix);
|
||||
|
||||
} // namespace tesseract.
|
||||
|
@ -202,12 +202,11 @@ int main (int argc, char **argv) {
|
||||
|
||||
ParseArguments(&argc, &argv);
|
||||
|
||||
ShapeTable* shape_table = nullptr;
|
||||
STRING file_prefix;
|
||||
// Load the training data.
|
||||
auto trainer = tesseract::LoadTrainingData(argc, argv,
|
||||
auto [trainer,shape_table] = tesseract::LoadTrainingData(argc, argv,
|
||||
false,
|
||||
&shape_table,
|
||||
true,
|
||||
&file_prefix);
|
||||
if (trainer == nullptr) return 1; // Failed.
|
||||
|
||||
@ -216,7 +215,7 @@ int main (int argc, char **argv) {
|
||||
// with the same list of unichars becomes a different class and the configs
|
||||
// represent the different combinations of fonts.
|
||||
IndexMapBiDi config_map;
|
||||
SetupConfigMap(shape_table, &config_map);
|
||||
SetupConfigMap(shape_table.get(), &config_map);
|
||||
|
||||
WriteShapeTable(file_prefix, *shape_table);
|
||||
// If the shape_table is flat, then either we didn't run shape clustering, or
|
||||
@ -270,7 +269,6 @@ int main (int argc, char **argv) {
|
||||
}
|
||||
delete [] float_classes;
|
||||
FreeLabeledClassList(mf_classes);
|
||||
delete shape_table;
|
||||
printf("Done!\n");
|
||||
if (!FLAGS_test_ch.empty()) {
|
||||
// If we are displaying debug window(s), wait for the user to look at them.
|
||||
|
@ -49,9 +49,7 @@ int main(int argc, char **argv) {
|
||||
ParseArguments(&argc, &argv);
|
||||
|
||||
STRING file_prefix;
|
||||
auto trainer =
|
||||
tesseract::LoadTrainingData(argc, argv, false, nullptr, &file_prefix);
|
||||
|
||||
auto [trainer,_] = tesseract::LoadTrainingData(argc, argv, false, false, &file_prefix);
|
||||
if (!trainer)
|
||||
return 1;
|
||||
|
||||
|
@ -126,7 +126,7 @@ class MockClassifier : public ShapeClassifier {
|
||||
return results->size();
|
||||
}
|
||||
// Provides access to the ShapeTable that this classifier works with.
|
||||
virtual const ShapeTable* GetShapeTable() const { return shape_table_; }
|
||||
const ShapeTable* GetShapeTable() const override { return shape_table_; }
|
||||
|
||||
private:
|
||||
// Borrowed pointer to the ShapeTable.
|
||||
@ -159,15 +159,6 @@ class MasterTrainerTest : public testing::Test {
|
||||
return file::JoinPath(FLAGS_test_tmpdir, name);
|
||||
}
|
||||
|
||||
MasterTrainerTest() {
|
||||
shape_table_ = nullptr;
|
||||
master_trainer_ = nullptr;
|
||||
}
|
||||
~MasterTrainerTest() {
|
||||
delete master_trainer_;
|
||||
delete shape_table_;
|
||||
}
|
||||
|
||||
// Initializes the master_trainer_ and shape_table_.
|
||||
// if load_from_tmp, then reloads a master trainer that was saved by a
|
||||
// previous call in which it was false.
|
||||
@ -180,11 +171,9 @@ class MasterTrainerTest : public testing::Test {
|
||||
const char* argv[] = {tr_file_name.c_str()};
|
||||
int argc = 1;
|
||||
STRING file_prefix;
|
||||
delete master_trainer_;
|
||||
delete shape_table_;
|
||||
shape_table_ = nullptr;
|
||||
master_trainer_ =
|
||||
LoadTrainingData(argc, argv, false, &shape_table_, &file_prefix);
|
||||
auto [m,s] = LoadTrainingData(argc, argv, false, true, &file_prefix);
|
||||
master_trainer_ = std::move(m);
|
||||
shape_table_ = std::move(s);
|
||||
EXPECT_TRUE(master_trainer_ != nullptr);
|
||||
EXPECT_TRUE(shape_table_ != nullptr);
|
||||
}
|
||||
@ -237,8 +226,8 @@ class MasterTrainerTest : public testing::Test {
|
||||
}
|
||||
|
||||
// Objects declared here can be used by all tests in the test case for Foo.
|
||||
ShapeTable* shape_table_;
|
||||
MasterTrainer* master_trainer_;
|
||||
std::unique_ptr<ShapeTable> shape_table_;
|
||||
std::unique_ptr<MasterTrainer> master_trainer_;
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -268,12 +257,11 @@ TEST_F(MasterTrainerTest, ErrorCounterTest) {
|
||||
// count junk.
|
||||
if (shape_table_->FindShape(0, -1) < 0) shape_table_->AddShape(0, 0);
|
||||
// Make a mock classifier.
|
||||
tesseract::ShapeClassifier* shape_classifier =
|
||||
new tesseract::MockClassifier(shape_table_);
|
||||
auto shape_classifier = std::make_unique<MockClassifier>(shape_table_.get());
|
||||
// Get the accuracy report.
|
||||
STRING accuracy_report;
|
||||
master_trainer_->TestClassifierOnSamples(tesseract::CT_UNICHAR_TOP1_ERR, 0,
|
||||
false, shape_classifier,
|
||||
false, shape_classifier.get(),
|
||||
&accuracy_report);
|
||||
LOG(INFO) << accuracy_report.c_str();
|
||||
std::string result_string = accuracy_report.c_str();
|
||||
@ -298,7 +286,5 @@ TEST_F(MasterTrainerTest, ErrorCounterTest) {
|
||||
result_values[tesseract::CT_OK_MULTI_UNICHAR]);
|
||||
EXPECT_EQ(num_samples - kNumNonReject, result_values[tesseract::CT_REJECT]);
|
||||
EXPECT_EQ(kNumAnswers, result_values[tesseract::CT_NUM_RESULTS]);
|
||||
|
||||
delete shape_classifier;
|
||||
#endif
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user