mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2024-11-24 02:59:07 +08:00
Part 2 of separating out the unicharset from the LSTM model, fixing command line for training
This commit is contained in:
parent
61adbdfa4b
commit
2633fef0b6
@ -339,16 +339,15 @@ bool SquishedDawg::read_squished_dawg(TFile *file) {
|
||||
return true;
|
||||
}
|
||||
|
||||
NODE_MAP SquishedDawg::build_node_map(inT32 *num_nodes) const {
|
||||
std::unique_ptr<EDGE_REF[]> SquishedDawg::build_node_map(
|
||||
inT32 *num_nodes) const {
|
||||
EDGE_REF edge;
|
||||
NODE_MAP node_map;
|
||||
std::unique_ptr<EDGE_REF[]> node_map(new EDGE_REF[num_edges_]);
|
||||
inT32 node_counter;
|
||||
inT32 num_edges;
|
||||
|
||||
node_map = (NODE_MAP) malloc(sizeof(EDGE_REF) * num_edges_);
|
||||
|
||||
for (edge = 0; edge < num_edges_; edge++) // init all slots
|
||||
node_map [edge] = -1;
|
||||
node_map[edge] = -1;
|
||||
|
||||
node_counter = num_forward_edges(0);
|
||||
|
||||
@ -366,25 +365,25 @@ NODE_MAP SquishedDawg::build_node_map(inT32 *num_nodes) const {
|
||||
edge--;
|
||||
}
|
||||
}
|
||||
return (node_map);
|
||||
return node_map;
|
||||
}
|
||||
|
||||
void SquishedDawg::write_squished_dawg(FILE *file) {
|
||||
bool SquishedDawg::write_squished_dawg(TFile *file) {
|
||||
EDGE_REF edge;
|
||||
inT32 num_edges;
|
||||
inT32 node_count = 0;
|
||||
NODE_MAP node_map;
|
||||
EDGE_REF old_index;
|
||||
EDGE_RECORD temp_record;
|
||||
|
||||
if (debug_level_) tprintf("write_squished_dawg\n");
|
||||
|
||||
node_map = build_node_map(&node_count);
|
||||
std::unique_ptr<EDGE_REF[]> node_map(build_node_map(&node_count));
|
||||
|
||||
// Write the magic number to help detecting a change in endianness.
|
||||
inT16 magic = kDawgMagicNumber;
|
||||
fwrite(&magic, sizeof(inT16), 1, file);
|
||||
fwrite(&unicharset_size_, sizeof(inT32), 1, file);
|
||||
if (file->FWrite(&magic, sizeof(magic), 1) != 1) return false;
|
||||
if (file->FWrite(&unicharset_size_, sizeof(unicharset_size_), 1) != 1)
|
||||
return false;
|
||||
|
||||
// Count the number of edges in this Dawg.
|
||||
num_edges = 0;
|
||||
@ -392,7 +391,8 @@ void SquishedDawg::write_squished_dawg(FILE *file) {
|
||||
if (forward_edge(edge))
|
||||
num_edges++;
|
||||
|
||||
fwrite(&num_edges, sizeof(inT32), 1, file); // write edge count to file
|
||||
// Write edge count to file.
|
||||
if (file->FWrite(&num_edges, sizeof(num_edges), 1) != 1) return false;
|
||||
|
||||
if (debug_level_) {
|
||||
tprintf("%d nodes in DAWG\n", node_count);
|
||||
@ -405,7 +405,8 @@ void SquishedDawg::write_squished_dawg(FILE *file) {
|
||||
old_index = next_node_from_edge_rec(edges_[edge]);
|
||||
set_next_node(edge, node_map[old_index]);
|
||||
temp_record = edges_[edge];
|
||||
fwrite(&(temp_record), sizeof(EDGE_RECORD), 1, file);
|
||||
if (file->FWrite(&temp_record, sizeof(temp_record), 1) != 1)
|
||||
return false;
|
||||
set_next_node(edge, old_index);
|
||||
} while (!last_edge(edge++));
|
||||
|
||||
@ -416,7 +417,7 @@ void SquishedDawg::write_squished_dawg(FILE *file) {
|
||||
edge--;
|
||||
}
|
||||
}
|
||||
free(node_map);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace tesseract
|
||||
|
26
dict/dawg.h
26
dict/dawg.h
@ -31,9 +31,10 @@
|
||||
I n c l u d e s
|
||||
----------------------------------------------------------------------*/
|
||||
|
||||
#include <memory>
|
||||
#include "elst.h"
|
||||
#include "ratngs.h"
|
||||
#include "params.h"
|
||||
#include "ratngs.h"
|
||||
#include "tesscallback.h"
|
||||
|
||||
#ifndef __GNUC__
|
||||
@ -483,18 +484,22 @@ class SquishedDawg : public Dawg {
|
||||
void print_node(NODE_REF node, int max_num_edges) const;
|
||||
|
||||
/// Writes the squished/reduced Dawg to a file.
|
||||
void write_squished_dawg(FILE *file);
|
||||
bool write_squished_dawg(TFile *file);
|
||||
|
||||
/// Opens the file with the given filename and writes the
|
||||
/// squished/reduced Dawg to the file.
|
||||
void write_squished_dawg(const char *filename) {
|
||||
FILE *file = fopen(filename, "wb");
|
||||
if (file == NULL) {
|
||||
tprintf("Error opening %s\n", filename);
|
||||
exit(1);
|
||||
bool write_squished_dawg(const char *filename) {
|
||||
TFile file;
|
||||
file.OpenWrite(nullptr);
|
||||
if (!this->write_squished_dawg(&file)) {
|
||||
tprintf("Error serializing %s\n", filename);
|
||||
return false;
|
||||
}
|
||||
this->write_squished_dawg(file);
|
||||
fclose(file);
|
||||
if (!file.CloseWrite(filename, nullptr)) {
|
||||
tprintf("Error writing file %s\n", filename);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
@ -549,8 +554,7 @@ class SquishedDawg : public Dawg {
|
||||
tprintf("__________________________\n");
|
||||
}
|
||||
/// Constructs a mapping from the memory node indices to disk node indices.
|
||||
NODE_MAP build_node_map(inT32 *num_nodes) const;
|
||||
|
||||
std::unique_ptr<EDGE_REF[]> build_node_map(inT32 *num_nodes) const;
|
||||
|
||||
// Member variables.
|
||||
EDGE_ARRAY edges_;
|
||||
|
@ -290,40 +290,27 @@ bool Trie::read_and_add_word_list(const char *filename,
|
||||
const UNICHARSET &unicharset,
|
||||
Trie::RTLReversePolicy reverse_policy) {
|
||||
GenericVector<STRING> word_list;
|
||||
if (!read_word_list(filename, unicharset, reverse_policy, &word_list))
|
||||
return false;
|
||||
if (!read_word_list(filename, &word_list)) return false;
|
||||
word_list.sort(sort_strings_by_dec_length);
|
||||
return add_word_list(word_list, unicharset);
|
||||
return add_word_list(word_list, unicharset, reverse_policy);
|
||||
}
|
||||
|
||||
bool Trie::read_word_list(const char *filename,
|
||||
const UNICHARSET &unicharset,
|
||||
Trie::RTLReversePolicy reverse_policy,
|
||||
GenericVector<STRING>* words) {
|
||||
FILE *word_file;
|
||||
char string[CHARS_PER_LINE];
|
||||
char line_str[CHARS_PER_LINE];
|
||||
int word_count = 0;
|
||||
|
||||
word_file = fopen(filename, "rb");
|
||||
if (word_file == NULL) return false;
|
||||
|
||||
while (fgets(string, CHARS_PER_LINE, word_file) != NULL) {
|
||||
chomp_string(string); // remove newline
|
||||
WERD_CHOICE word(string, unicharset);
|
||||
if ((reverse_policy == RRP_REVERSE_IF_HAS_RTL &&
|
||||
word.has_rtl_unichar_id()) ||
|
||||
reverse_policy == RRP_FORCE_REVERSE) {
|
||||
word.reverse_and_mirror_unichar_ids();
|
||||
}
|
||||
while (fgets(line_str, sizeof(line_str), word_file) != NULL) {
|
||||
chomp_string(line_str); // remove newline
|
||||
STRING word_str(line_str);
|
||||
++word_count;
|
||||
if (debug_level_ && word_count % 10000 == 0)
|
||||
tprintf("Read %d words so far\n", word_count);
|
||||
if (word.length() != 0 && !word.contains_unichar_id(INVALID_UNICHAR_ID)) {
|
||||
words->push_back(word.unichar_string());
|
||||
} else if (debug_level_) {
|
||||
tprintf("Skipping invalid word %s\n", string);
|
||||
if (debug_level_ >= 3) word.print();
|
||||
}
|
||||
words->push_back(word_str);
|
||||
}
|
||||
if (debug_level_)
|
||||
tprintf("Read %d words total.\n", word_count);
|
||||
@ -331,10 +318,18 @@ bool Trie::read_word_list(const char *filename,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Trie::add_word_list(const GenericVector<STRING>& words,
|
||||
const UNICHARSET &unicharset) {
|
||||
bool Trie::add_word_list(const GenericVector<STRING> &words,
|
||||
const UNICHARSET &unicharset,
|
||||
Trie::RTLReversePolicy reverse_policy) {
|
||||
for (int i = 0; i < words.size(); ++i) {
|
||||
WERD_CHOICE word(words[i].string(), unicharset);
|
||||
if (word.length() == 0 || word.contains_unichar_id(INVALID_UNICHAR_ID))
|
||||
continue;
|
||||
if ((reverse_policy == RRP_REVERSE_IF_HAS_RTL &&
|
||||
word.has_rtl_unichar_id()) ||
|
||||
reverse_policy == RRP_FORCE_REVERSE) {
|
||||
word.reverse_and_mirror_unichar_ids();
|
||||
}
|
||||
if (!word_in_dawg(word)) {
|
||||
add_word_to_dawg(word);
|
||||
if (!word_in_dawg(word)) {
|
||||
|
12
dict/trie.h
12
dict/trie.h
@ -177,18 +177,16 @@ class Trie : public Dawg {
|
||||
const UNICHARSET &unicharset,
|
||||
Trie::RTLReversePolicy reverse);
|
||||
|
||||
// Reads a list of words from the given file, applying the reverse_policy,
|
||||
// according to information in the unicharset.
|
||||
// Reads a list of words from the given file.
|
||||
// Returns false on error.
|
||||
bool read_word_list(const char *filename,
|
||||
const UNICHARSET &unicharset,
|
||||
Trie::RTLReversePolicy reverse_policy,
|
||||
GenericVector<STRING>* words);
|
||||
// Adds a list of words previously read using read_word_list to the trie
|
||||
// using the given unicharset to convert to unichar-ids.
|
||||
// using the given unicharset and reverse_policy to convert to unichar-ids.
|
||||
// Returns false on error.
|
||||
bool add_word_list(const GenericVector<STRING>& words,
|
||||
const UNICHARSET &unicharset);
|
||||
bool add_word_list(const GenericVector<STRING> &words,
|
||||
const UNICHARSET &unicharset,
|
||||
Trie::RTLReversePolicy reverse_policy);
|
||||
|
||||
// Inserts the list of patterns from the given file into the Trie.
|
||||
// The pattern list file should contain one pattern per line in UTF-8 format.
|
||||
|
@ -130,22 +130,6 @@ bool LSTMTrainer::TryLoadingCheckpoint(const char* filename) {
|
||||
return checkpoint_reader_->Run(data, this);
|
||||
}
|
||||
|
||||
// Initializes the character set encode/decode mechanism.
|
||||
// train_flags control training behavior according to the TrainingFlags
|
||||
// enum, including character set encoding.
|
||||
// script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided,
|
||||
// fully initializes the unicharset from the universal unicharsets.
|
||||
// Note: Call before InitNetwork!
|
||||
void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset,
|
||||
const STRING& script_dir, int train_flags) {
|
||||
EmptyConstructor();
|
||||
training_flags_ = train_flags;
|
||||
ccutil_.unicharset.CopyFrom(unicharset);
|
||||
null_char_ = GetUnicharset().has_special_codes() ? UNICHAR_BROKEN
|
||||
: GetUnicharset().size();
|
||||
SetUnicharsetProperties(script_dir);
|
||||
}
|
||||
|
||||
// Initializes the trainer with a network_spec in the network description
|
||||
// net_flags control network behavior according to the NetworkFlags enum.
|
||||
// There isn't really much difference between them - only where the effects
|
||||
@ -278,9 +262,10 @@ void LSTMTrainer::DebugNetwork() {
|
||||
// Loads a set of lstmf files that were created using the lstm.train config to
|
||||
// tesseract into memory ready for training. Returns false if nothing was
|
||||
// loaded.
|
||||
bool LSTMTrainer::LoadAllTrainingData(const GenericVector<STRING>& filenames) {
|
||||
bool LSTMTrainer::LoadAllTrainingData(const GenericVector<STRING>& filenames,
|
||||
CachingStrategy cache_strategy) {
|
||||
training_data_.Clear();
|
||||
return training_data_.LoadDocuments(filenames, CacheStrategy(), file_reader_);
|
||||
return training_data_.LoadDocuments(filenames, cache_strategy, file_reader_);
|
||||
}
|
||||
|
||||
// Keeps track of best and locally worst char error_rate and launches tests
|
||||
@ -908,6 +893,15 @@ bool LSTMTrainer::ReadLocalTrainingDump(const TessdataManager* mgr,
|
||||
return DeSerialize(mgr, &fp);
|
||||
}
|
||||
|
||||
// Writes the full recognition traineddata to the given filename.
|
||||
bool LSTMTrainer::SaveTraineddata(const STRING& filename) {
|
||||
GenericVector<char> recognizer_data;
|
||||
SaveRecognitionDump(&recognizer_data);
|
||||
mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
|
||||
recognizer_data.size());
|
||||
return mgr_.SaveFile(filename, file_writer_);
|
||||
}
|
||||
|
||||
// Writes the recognizer to memory, so that it can be used for testing later.
|
||||
void LSTMTrainer::SaveRecognitionDump(GenericVector<char>* data) const {
|
||||
TFile fp;
|
||||
@ -964,52 +958,6 @@ void LSTMTrainer::EmptyConstructor() {
|
||||
InitIterations();
|
||||
}
|
||||
|
||||
// Sets the unicharset properties using the given script_dir as a source of
|
||||
// script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets
|
||||
// up the recoder_ to simplify the unicharset.
|
||||
void LSTMTrainer::SetUnicharsetProperties(const STRING& script_dir) {
|
||||
tprintf("Setting unichar properties\n");
|
||||
for (int s = 0; s < GetUnicharset().get_script_table_size(); ++s) {
|
||||
if (strcmp("NULL", GetUnicharset().get_script_from_script_id(s)) == 0)
|
||||
continue;
|
||||
// Load the unicharset for the script if available.
|
||||
STRING filename = script_dir + "/" +
|
||||
GetUnicharset().get_script_from_script_id(s) +
|
||||
".unicharset";
|
||||
UNICHARSET script_set;
|
||||
GenericVector<char> data;
|
||||
if ((*file_reader_)(filename, &data) &&
|
||||
script_set.load_from_inmemory_file(&data[0], data.size())) {
|
||||
tprintf("Setting properties for script %s\n",
|
||||
GetUnicharset().get_script_from_script_id(s));
|
||||
ccutil_.unicharset.SetPropertiesFromOther(script_set);
|
||||
}
|
||||
}
|
||||
if (IsRecoding()) {
|
||||
STRING filename = script_dir + "/radical-stroke.txt";
|
||||
GenericVector<char> data;
|
||||
if ((*file_reader_)(filename, &data)) {
|
||||
data += '\0';
|
||||
STRING stroke_table = &data[0];
|
||||
if (recoder_.ComputeEncoding(GetUnicharset(), null_char_,
|
||||
&stroke_table)) {
|
||||
RecodedCharID code;
|
||||
recoder_.EncodeUnichar(null_char_, &code);
|
||||
null_char_ = code(0);
|
||||
// Space should encode as itself.
|
||||
recoder_.EncodeUnichar(UNICHAR_SPACE, &code);
|
||||
ASSERT_HOST(code(0) == UNICHAR_SPACE);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
tprintf("Failed to load radical-stroke info from: %s\n",
|
||||
filename.string());
|
||||
}
|
||||
}
|
||||
training_flags_ |= TF_COMPRESS_UNICHARSET;
|
||||
recoder_.SetupPassThrough(GetUnicharset());
|
||||
}
|
||||
|
||||
// Outputs the string and periodically displays the given network inputs
|
||||
// as an image in the given window, and the corresponding labels at the
|
||||
// corresponding x_starts.
|
||||
|
@ -101,14 +101,6 @@ class LSTMTrainer : public LSTMRecognizer {
|
||||
// false in case of failure.
|
||||
bool TryLoadingCheckpoint(const char* filename);
|
||||
|
||||
// Initializes the character set encode/decode mechanism.
|
||||
// train_flags control training behavior according to the TrainingFlags
|
||||
// enum, including character set encoding.
|
||||
// script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided,
|
||||
// fully initializes the unicharset from the universal unicharsets.
|
||||
// Note: Call before InitNetwork!
|
||||
void InitCharSet(const UNICHARSET& unicharset, const STRING& script_dir,
|
||||
int train_flags);
|
||||
// Initializes the character set encode/decode mechanism directly from a
|
||||
// previously setup traineddata containing dawgs, UNICHARSET and
|
||||
// UnicharCompress. Note: Call before InitNetwork!
|
||||
@ -186,7 +178,8 @@ class LSTMTrainer : public LSTMRecognizer {
|
||||
// Loads a set of lstmf files that were created using the lstm.train config to
|
||||
// tesseract into memory ready for training. Returns false if nothing was
|
||||
// loaded.
|
||||
bool LoadAllTrainingData(const GenericVector<STRING>& filenames);
|
||||
bool LoadAllTrainingData(const GenericVector<STRING>& filenames,
|
||||
CachingStrategy cache_strategy);
|
||||
|
||||
// Keeps track of best and locally worst error rate, using internally computed
|
||||
// values. See MaintainCheckpointsSpecific for more detail.
|
||||
@ -315,12 +308,12 @@ class LSTMTrainer : public LSTMRecognizer {
|
||||
// Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
|
||||
void SetupCheckpointInfo();
|
||||
|
||||
// Writes the full recognition traineddata to the given filename.
|
||||
bool SaveTraineddata(const STRING& filename);
|
||||
|
||||
// Writes the recognizer to memory, so that it can be used for testing later.
|
||||
void SaveRecognitionDump(GenericVector<char>* data) const;
|
||||
|
||||
// Writes current best model to a file, unless it has already been written.
|
||||
bool SaveBestModel(FileWriter writer) const;
|
||||
|
||||
// Returns a suitable filename for a training dump, based on the model_base_,
|
||||
// the iteration and the error rates.
|
||||
STRING DumpFilename() const;
|
||||
@ -336,11 +329,6 @@ class LSTMTrainer : public LSTMRecognizer {
|
||||
// Factored sub-constructor sets up reasonable default values.
|
||||
void EmptyConstructor();
|
||||
|
||||
// Sets the unicharset properties using the given script_dir as a source of
|
||||
// script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets
|
||||
// up the recoder_ to simplify the unicharset.
|
||||
void SetUnicharsetProperties(const STRING& script_dir);
|
||||
|
||||
// Outputs the string and periodically displays the given network inputs
|
||||
// as an image in the given window, and the corresponding labels at the
|
||||
// corresponding x_starts.
|
||||
|
@ -19,8 +19,8 @@ endif
|
||||
|
||||
noinst_HEADERS = \
|
||||
boxchar.h commandlineflags.h commontraining.h degradeimage.h \
|
||||
fileio.h icuerrorcode.h ligature_table.h lstmtester.h normstrngs.h \
|
||||
mergenf.h pango_font_info.h stringrenderer.h \
|
||||
fileio.h icuerrorcode.h lang_model_helpers.h ligature_table.h \
|
||||
lstmtester.h mergenf.h normstrngs.h pango_font_info.h stringrenderer.h \
|
||||
tessopt.h tlog.h unicharset_training_utils.h util.h \
|
||||
validate_grapheme.h validate_indic.h validate_khmer.h \
|
||||
validate_myanmar.h validator.h
|
||||
@ -33,15 +33,15 @@ libtesseract_training_la_LIBADD = \
|
||||
|
||||
libtesseract_training_la_SOURCES = \
|
||||
boxchar.cpp commandlineflags.cpp commontraining.cpp degradeimage.cpp \
|
||||
fileio.cpp ligature_table.cpp lstmtester.cpp normstrngs.cpp pango_font_info.cpp \
|
||||
stringrenderer.cpp tlog.cpp unicharset_training_utils.cpp \
|
||||
fileio.cpp lang_model_helpers.cpp ligature_table.cpp lstmtester.cpp \
|
||||
normstrngs.cpp pango_font_info.cpp stringrenderer.cpp tlog.cpp unicharset_training_utils.cpp \
|
||||
validate_grapheme.cpp validate_indic.cpp validate_khmer.cpp \
|
||||
validate_myanmar.cpp validator.cpp
|
||||
|
||||
libtesseract_tessopt_la_SOURCES = \
|
||||
tessopt.cpp
|
||||
|
||||
bin_PROGRAMS = ambiguous_words classifier_tester cntraining combine_tessdata \
|
||||
bin_PROGRAMS = ambiguous_words classifier_tester cntraining combine_lang_model combine_tessdata \
|
||||
dawg2wordlist lstmeval lstmtraining mftraining set_unicharset_properties shapeclustering \
|
||||
text2image unicharset_extractor wordlist2dawg
|
||||
|
||||
@ -94,11 +94,26 @@ classifier_tester_LDADD += \
|
||||
../api/libtesseract.la
|
||||
endif
|
||||
|
||||
combine_lang_model_SOURCES = combine_lang_model.cpp
|
||||
#combine_lang_model_LDFLAGS = -static
|
||||
combine_lang_model_LDADD = \
|
||||
libtesseract_training.la \
|
||||
libtesseract_tessopt.la \
|
||||
$(ICU_I18N_LIBS) $(ICU_UC_LIBS)
|
||||
if USING_MULTIPLELIBS
|
||||
combine_lang_model_LDADD += \
|
||||
../ccutil/libtesseract_ccutil.la
|
||||
else
|
||||
combine_lang_model_LDADD += \
|
||||
../api/libtesseract.la
|
||||
endif
|
||||
|
||||
combine_tessdata_SOURCES = combine_tessdata.cpp
|
||||
#combine_tessdata_LDFLAGS = -static
|
||||
if USING_MULTIPLELIBS
|
||||
combine_tessdata_LDADD = \
|
||||
../ccutil/libtesseract_ccutil.la
|
||||
../ccutil/libtesseract_ccutil.la \
|
||||
../lstm/libtesseract_lstm.la
|
||||
else
|
||||
combine_tessdata_LDADD = \
|
||||
../api/libtesseract.la
|
||||
|
87
training/combine_lang_model.cpp
Normal file
87
training/combine_lang_model.cpp
Normal file
@ -0,0 +1,87 @@
|
||||
// Copyright 2017 Google Inc. All Rights Reserved.
|
||||
// Author: rays@google.com (Ray Smith)
|
||||
// Purpose: Program to generate a traineddata file that can be used to train an
|
||||
// LSTM-based neural network model from a unicharset and an optional
|
||||
// set of wordlists. Eliminates the need to run
|
||||
// set_unicharset_properties, wordlist2dawg, some non-existent binary
|
||||
// to generate the recoder, and finally combine_tessdata.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "commandlineflags.h"
|
||||
#include "lang_model_helpers.h"
|
||||
#include "tprintf.h"
|
||||
#include "unicharset_training_utils.h"
|
||||
|
||||
STRING_PARAM_FLAG(input_unicharset, "",
|
||||
"Unicharset to complete and use in encoding");
|
||||
STRING_PARAM_FLAG(script_dir, "",
|
||||
"Directory name for input script unicharsets");
|
||||
STRING_PARAM_FLAG(words, "",
|
||||
"File listing words to use for the system dictionary");
|
||||
STRING_PARAM_FLAG(puncs, "", "File listing punctuation patterns");
|
||||
STRING_PARAM_FLAG(numbers, "", "File listing number patterns");
|
||||
STRING_PARAM_FLAG(output_dir, "", "Root directory for output files");
|
||||
STRING_PARAM_FLAG(version_str, "", "Version string to add to traineddata file");
|
||||
STRING_PARAM_FLAG(lang, "", "Name of language being processed");
|
||||
BOOL_PARAM_FLAG(lang_is_rtl, false,
|
||||
"True if lang being processed is written right-to-left");
|
||||
BOOL_PARAM_FLAG(pass_through_recoder, false,
|
||||
"If true, the recoder is a simple pass-through of the"
|
||||
" unicharset. Otherwise, potentially a compression of it");
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tesseract::ParseCommandLineFlags(argv[0], &argc, &argv, true);
|
||||
|
||||
// Check validity of input flags.
|
||||
if (FLAGS_input_unicharset.empty() || FLAGS_script_dir.empty() ||
|
||||
FLAGS_output_dir.empty() || FLAGS_lang.empty()) {
|
||||
tprintf("Usage: %s --input_unicharset filename --script_dir dirname\n",
|
||||
argv[0]);
|
||||
tprintf(" --output_dir rootdir --lang lang [--lang_is_rtl]\n");
|
||||
tprintf(" [--words file --puncs file --numbers file]\n");
|
||||
tprintf("Sets properties on the input unicharset file, and writes:\n");
|
||||
tprintf("rootdir/lang/lang.charset_size=ddd.txt\n");
|
||||
tprintf("rootdir/lang/lang.traineddata\n");
|
||||
tprintf("rootdir/lang/lang.unicharset\n");
|
||||
tprintf("If the 3 word lists are provided, the dawgs are also added to");
|
||||
tprintf(" the traineddata file.\n");
|
||||
tprintf("The output unicharset and charset_size files are just for human");
|
||||
tprintf(" readability.\n");
|
||||
exit(1);
|
||||
}
|
||||
GenericVector<STRING> words, puncs, numbers;
|
||||
// If these reads fail, we get a warning message and an empty list of words.
|
||||
tesseract::ReadFile(FLAGS_words.c_str(), nullptr).split('\n', &words);
|
||||
tesseract::ReadFile(FLAGS_puncs.c_str(), nullptr).split('\n', &puncs);
|
||||
tesseract::ReadFile(FLAGS_numbers.c_str(), nullptr).split('\n', &numbers);
|
||||
// Load the input unicharset
|
||||
UNICHARSET unicharset;
|
||||
if (!unicharset.load_from_file(FLAGS_input_unicharset.c_str(), false)) {
|
||||
tprintf("Failed to load unicharset from %s\n",
|
||||
FLAGS_input_unicharset.c_str());
|
||||
return 1;
|
||||
}
|
||||
tprintf("Loaded unicharset of size %d from file %s\n", unicharset.size(),
|
||||
FLAGS_input_unicharset.c_str());
|
||||
|
||||
// Set unichar properties
|
||||
tprintf("Setting unichar properties\n");
|
||||
tesseract::SetupBasicProperties(/*report_errors*/ true,
|
||||
/*decompose (NFD)*/ false, &unicharset);
|
||||
tprintf("Setting script properties\n");
|
||||
tesseract::SetScriptProperties(FLAGS_script_dir.c_str(), &unicharset);
|
||||
// Combine everything into a traineddata file.
|
||||
return tesseract::CombineLangModel(
|
||||
unicharset, FLAGS_script_dir.c_str(), FLAGS_version_str.c_str(),
|
||||
FLAGS_output_dir.c_str(), FLAGS_lang.c_str(), FLAGS_pass_through_recoder,
|
||||
words, puncs, numbers, FLAGS_lang_is_rtl, /*reader*/ nullptr,
|
||||
/*writer*/ nullptr);
|
||||
}
|
231
training/lang_model_helpers.cpp
Normal file
231
training/lang_model_helpers.cpp
Normal file
@ -0,0 +1,231 @@
|
||||
// Copyright 2017 Google Inc. All Rights Reserved.
|
||||
// Author: rays@google.com (Ray Smith)
|
||||
// Purpose: Collection of convenience functions to simplify creation of the
|
||||
// unicharset, recoder, and dawgs for an LSTM model.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "lang_model_helpers.h"
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <cstdlib>
|
||||
#include "dawg.h"
|
||||
#include "fileio.h"
|
||||
#include "tessdatamanager.h"
|
||||
#include "trie.h"
|
||||
#include "unicharcompress.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Helper makes a filename (<output_dir>/<lang>/<lang><suffix>) and writes data
|
||||
// to the file, using writer if not null, otherwise, a default writer.
|
||||
// Default writer will overwrite any existing file, but a supplied writer
|
||||
// can do its own thing. If lang is empty, returns true but does nothing.
|
||||
// NOTE that suffix should contain any required . for the filename.
|
||||
bool WriteFile(const string& output_dir, const string& lang,
|
||||
const string& suffix, const GenericVector<char>& data,
|
||||
FileWriter writer) {
|
||||
if (lang.empty()) return true;
|
||||
string dirname = output_dir + "/" + lang;
|
||||
// Attempt to make the directory, but ignore errors, as it may not be a
|
||||
// standard filesystem, and the writer will complain if not successful.
|
||||
mkdir(dirname.c_str(), S_IRWXU | S_IRWXG);
|
||||
string filename = dirname + "/" + lang + suffix;
|
||||
if (writer == nullptr)
|
||||
return SaveDataToFile(data, filename.c_str());
|
||||
else
|
||||
return (*writer)(data, filename.c_str());
|
||||
}
|
||||
|
||||
// Helper reads a file with optional reader and returns a STRING.
|
||||
// On failure emits a warning message and returns and empty STRING.
|
||||
STRING ReadFile(const string& filename, FileReader reader) {
|
||||
if (filename.empty()) return STRING();
|
||||
GenericVector<char> data;
|
||||
bool read_result;
|
||||
if (reader == nullptr)
|
||||
read_result = LoadDataFromFile(filename.c_str(), &data);
|
||||
else
|
||||
read_result = (*reader)(filename.c_str(), &data);
|
||||
if (read_result) return STRING(&data[0], data.size());
|
||||
tprintf("Failed to read data from: %s\n", filename.c_str());
|
||||
return STRING();
|
||||
}
|
||||
|
||||
// Helper writes the unicharset to file and to the traineddata.
|
||||
bool WriteUnicharset(const UNICHARSET& unicharset, const string& output_dir,
|
||||
const string& lang, FileWriter writer,
|
||||
TessdataManager* traineddata) {
|
||||
GenericVector<char> unicharset_data;
|
||||
TFile fp;
|
||||
fp.OpenWrite(&unicharset_data);
|
||||
if (!unicharset.save_to_file(&fp)) return false;
|
||||
traineddata->OverwriteEntry(TESSDATA_LSTM_UNICHARSET, &unicharset_data[0],
|
||||
unicharset_data.size());
|
||||
return WriteFile(output_dir, lang, ".unicharset", unicharset_data, writer);
|
||||
}
|
||||
|
||||
// Helper creates the recoder and writes it to the traineddata, and a human-
|
||||
// readable form to file.
|
||||
bool WriteRecoder(const UNICHARSET& unicharset, bool pass_through,
|
||||
const string& output_dir, const string& lang,
|
||||
FileWriter writer, STRING* radical_table_data,
|
||||
TessdataManager* traineddata) {
|
||||
UnicharCompress recoder;
|
||||
// Where the unicharset is carefully setup already to contain a good
|
||||
// compact encoding, use a pass-through recoder that does nothing.
|
||||
// For scripts that have a large number of unicodes (Han, Hangul) we want
|
||||
// to use the recoder to compress the symbol space by re-encoding each
|
||||
// unicode as multiple codes from a smaller 'alphabet' that are related to the
|
||||
// shapes in the character. Hangul Jamo is a perfect example of this.
|
||||
// See the Hangul Syllables section, sub-section "Equivalence" in:
|
||||
// http://www.unicode.org/versions/Unicode10.0.0/ch18.pdf
|
||||
if (pass_through) {
|
||||
recoder.SetupPassThrough(unicharset);
|
||||
} else {
|
||||
int null_char =
|
||||
unicharset.has_special_codes() ? UNICHAR_BROKEN : unicharset.size();
|
||||
tprintf("Null char=%d\n", null_char);
|
||||
if (!recoder.ComputeEncoding(unicharset, null_char, radical_table_data)) {
|
||||
tprintf("Creation of encoded unicharset failed!!\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
TFile fp;
|
||||
GenericVector<char> recoder_data;
|
||||
fp.OpenWrite(&recoder_data);
|
||||
if (!recoder.Serialize(&fp)) return false;
|
||||
traineddata->OverwriteEntry(TESSDATA_LSTM_RECODER, &recoder_data[0],
|
||||
recoder_data.size());
|
||||
STRING encoding = recoder.GetEncodingAsString(unicharset);
|
||||
recoder_data.init_to_size(encoding.length(), 0);
|
||||
memcpy(&recoder_data[0], &encoding[0], encoding.length());
|
||||
STRING suffix;
|
||||
suffix.add_str_int(".charset_size=", recoder.code_range());
|
||||
suffix += ".txt";
|
||||
return WriteFile(output_dir, lang, suffix.string(), recoder_data, writer);
|
||||
}
|
||||
|
||||
// Helper builds a dawg from the given words, using the unicharset as coding,
|
||||
// and reverse_policy for LTR/RTL, and overwrites file_type in the traineddata.
|
||||
static bool WriteDawg(const GenericVector<STRING>& words,
|
||||
const UNICHARSET& unicharset,
|
||||
Trie::RTLReversePolicy reverse_policy,
|
||||
TessdataType file_type, TessdataManager* traineddata) {
|
||||
// The first 3 arguments are not used in this case.
|
||||
Trie trie(DAWG_TYPE_WORD, "", SYSTEM_DAWG_PERM, unicharset.size(), 0);
|
||||
trie.add_word_list(words, unicharset, reverse_policy);
|
||||
tprintf("Reducing Trie to SquishedDawg\n");
|
||||
std::unique_ptr<SquishedDawg> dawg(trie.trie_to_dawg());
|
||||
if (dawg == nullptr || dawg->NumEdges() == 0) return false;
|
||||
TFile fp;
|
||||
GenericVector<char> dawg_data;
|
||||
fp.OpenWrite(&dawg_data);
|
||||
if (!dawg->write_squished_dawg(&fp)) return false;
|
||||
traineddata->OverwriteEntry(file_type, &dawg_data[0], dawg_data.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
// Builds and writes the dawgs, given a set of words, punctuation
|
||||
// patterns, number patterns, to the traineddata. Encoding uses the given
|
||||
// unicharset, and the punc dawgs is reversed if lang_is_rtl.
|
||||
static bool WriteDawgs(const GenericVector<STRING>& words,
|
||||
const GenericVector<STRING>& puncs,
|
||||
const GenericVector<STRING>& numbers, bool lang_is_rtl,
|
||||
const UNICHARSET& unicharset,
|
||||
TessdataManager* traineddata) {
|
||||
if (puncs.empty()) {
|
||||
tprintf("Must have non-empty puncs list to use language models!!\n");
|
||||
return false;
|
||||
}
|
||||
// For each of the dawg types, make the dawg, and write to traineddata.
|
||||
// Dawgs are reversed as follows:
|
||||
// Words: According to the word content.
|
||||
// Puncs: According to lang_is_rtl.
|
||||
// Numbers: Never.
|
||||
// System dawg (main wordlist).
|
||||
if (!words.empty() &&
|
||||
!WriteDawg(words, unicharset, Trie::RRP_REVERSE_IF_HAS_RTL,
|
||||
TESSDATA_LSTM_SYSTEM_DAWG, traineddata)) {
|
||||
return false;
|
||||
}
|
||||
// punc/punc-dawg.
|
||||
Trie::RTLReversePolicy reverse_policy =
|
||||
lang_is_rtl ? Trie::RRP_FORCE_REVERSE : Trie::RRP_DO_NO_REVERSE;
|
||||
if (!WriteDawg(puncs, unicharset, reverse_policy, TESSDATA_LSTM_PUNC_DAWG,
|
||||
traineddata)) {
|
||||
return false;
|
||||
}
|
||||
// numbers/number-dawg.
|
||||
if (!numbers.empty() &&
|
||||
!WriteDawg(numbers, unicharset, Trie::RRP_DO_NO_REVERSE,
|
||||
TESSDATA_LSTM_NUMBER_DAWG, traineddata)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// The main function for combine_lang_model.cpp.
|
||||
// Returns EXIT_SUCCESS or EXIT_FAILURE for error.
|
||||
int CombineLangModel(const UNICHARSET& unicharset, const string& script_dir,
|
||||
const string& version_str, const string& output_dir,
|
||||
const string& lang, bool pass_through_recoder,
|
||||
const GenericVector<STRING>& words,
|
||||
const GenericVector<STRING>& puncs,
|
||||
const GenericVector<STRING>& numbers, bool lang_is_rtl,
|
||||
FileReader reader, FileWriter writer) {
|
||||
// Build the traineddata file.
|
||||
TessdataManager traineddata;
|
||||
if (!version_str.empty()) {
|
||||
traineddata.SetVersionString(traineddata.VersionString() + ":" +
|
||||
version_str);
|
||||
}
|
||||
// Unicharset and recoder.
|
||||
if (!WriteUnicharset(unicharset, output_dir, lang, writer, &traineddata)) {
|
||||
tprintf("Error writing unicharset!!\n");
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
// If there is a config file, read it and add to traineddata.
|
||||
string config_filename = script_dir + "/" + lang + "/" + lang + ".config";
|
||||
STRING config_file = ReadFile(config_filename, reader);
|
||||
if (config_file.length() > 0) {
|
||||
traineddata.OverwriteEntry(TESSDATA_LANG_CONFIG, &config_file[0],
|
||||
config_file.length());
|
||||
}
|
||||
string radical_filename = script_dir + "/radical-stroke.txt";
|
||||
STRING radical_data = ReadFile(radical_filename, reader);
|
||||
if (radical_data.length() == 0) {
|
||||
tprintf("Error reading radical code table %s\n", radical_filename.c_str());
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
if (!WriteRecoder(unicharset, pass_through_recoder, output_dir, lang, writer,
|
||||
&radical_data, &traineddata)) {
|
||||
tprintf("Error writing recoder!!\n");
|
||||
}
|
||||
if (!words.empty() || !puncs.empty() || !numbers.empty()) {
|
||||
if (!WriteDawgs(words, puncs, numbers, lang_is_rtl, unicharset,
|
||||
&traineddata)) {
|
||||
tprintf("Error during conversion of wordlists to DAWGs!!\n");
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
// Traineddata file.
|
||||
GenericVector<char> traineddata_data;
|
||||
traineddata.Serialize(&traineddata_data);
|
||||
if (!WriteFile(output_dir, lang, ".traineddata", traineddata_data, writer)) {
|
||||
tprintf("Error writing output traineddata file!!\n");
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tesseract
|
84
training/lang_model_helpers.h
Normal file
84
training/lang_model_helpers.h
Normal file
@ -0,0 +1,84 @@
|
||||
// Copyright 2017 Google Inc. All Rights Reserved.
|
||||
// Author: rays@google.com (Ray Smith)
|
||||
// Purpose: Collection of convenience functions to simplify creation of the
|
||||
// unicharset, recoder, and dawgs for an LSTM model.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#ifndef TESSERACT_TRAINING_LANG_MODEL_HELPERS_H_
|
||||
#define TESSERACT_TRAINING_LANG_MODEL_HELPERS_H_
|
||||
|
||||
#include <string>
|
||||
#include "genericvector.h"
|
||||
#include "serialis.h"
|
||||
#include "strngs.h"
|
||||
#include "tessdatamanager.h"
|
||||
#include "unicharset.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Helper makes a filename (<output_dir>/<lang>/<lang><suffix>) and writes data
|
||||
// to the file, using writer if not null, otherwise, a default writer.
|
||||
// Default writer will overwrite any existing file, but a supplied writer
|
||||
// can do its own thing. If lang is empty, returns true but does nothing.
|
||||
// NOTE that suffix should contain any required . for the filename.
|
||||
bool WriteFile(const string& output_dir, const string& lang,
|
||||
const string& suffix, const GenericVector<char>& data,
|
||||
FileWriter writer);
|
||||
// Helper reads a file with optional reader and returns a STRING.
|
||||
// On failure emits a warning message and returns and empty STRING.
|
||||
STRING ReadFile(const string& filename, FileReader reader);
|
||||
|
||||
// Helper writes the unicharset to file and to the traineddata.
|
||||
bool WriteUnicharset(const UNICHARSET& unicharset, const string& output_dir,
|
||||
const string& lang, FileWriter writer,
|
||||
TessdataManager* traineddata);
|
||||
// Helper creates the recoder from the unicharset and writes it to the
|
||||
// traineddata, with a human-readable form to file at:
|
||||
// <output_dir>/<lang>/<lang>.charset_size=<num> for some num being the size
|
||||
// of the re-encoded character set. The charset_size file is written using
|
||||
// writer if not null, or using a default file writer otherwise, overwriting
|
||||
// any existing content.
|
||||
// If pass_through is true, then the recoder will be a no-op, passing the
|
||||
// unicharset codes through unchanged. Otherwise, the recoder will "compress"
|
||||
// the unicharset by encoding Hangul in Jamos, decomposing multi-unicode
|
||||
// symbols into sequences of unicodes, and encoding Han using the data in the
|
||||
// radical_table_data, which must be the content of the file:
|
||||
// langdata/radical-stroke.txt.
|
||||
bool WriteRecoder(const UNICHARSET& unicharset, bool pass_through,
|
||||
const string& output_dir, const string& lang,
|
||||
FileWriter writer, STRING* radical_table_data,
|
||||
TessdataManager* traineddata);
|
||||
|
||||
// The main function for combine_lang_model.cpp.
|
||||
// Returns EXIT_SUCCESS or EXIT_FAILURE for error.
|
||||
// unicharset: can be a hand-created file with incomplete fields. Its basic
|
||||
// and script properties will be set before it is used.
|
||||
// script_dir: should point to the langdata (github repo) directory.
|
||||
// version_str: arbitrary version label.
|
||||
// Output files will be written to <output_dir>/<lang>/<lang>.*
|
||||
// If pass_through_recoder is true, the unicharset will be used unchanged as
|
||||
// labels in the classifier, otherwise, the unicharset will be "compressed" to
|
||||
// make the recognition task simpler and faster.
|
||||
// The words/puncs/numbers lists may be all empty. If any are non-empty then
|
||||
// puncs must be non-empty.
|
||||
// lang_is_rtl indicates that the language is generally written from right
|
||||
// to left (eg Arabic/Hebrew).
|
||||
int CombineLangModel(const UNICHARSET& unicharset, const string& script_dir,
|
||||
const string& version_str, const string& output_dir,
|
||||
const string& lang, bool pass_through_recoder,
|
||||
const GenericVector<STRING>& words,
|
||||
const GenericVector<STRING>& puncs,
|
||||
const GenericVector<STRING>& numbers, bool lang_is_rtl,
|
||||
FileReader reader, FileWriter writer);
|
||||
|
||||
} // namespace tesseract
|
||||
|
||||
#endif // TESSERACT_TRAINING_LANG_MODEL_HELPERS_H_
|
@ -32,6 +32,8 @@ STRING_PARAM_FLAG(traineddata, "",
|
||||
STRING_PARAM_FLAG(eval_listfile, "",
|
||||
"File listing sample files in lstmf training format.");
|
||||
INT_PARAM_FLAG(max_image_MB, 2000, "Max memory to use for images.");
|
||||
INT_PARAM_FLAG(verbosity, 1,
|
||||
"Amount of diagnosting information to output (0-2).");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
ParseArguments(&argc, &argv);
|
||||
@ -45,6 +47,10 @@ int main(int argc, char **argv) {
|
||||
}
|
||||
tesseract::TessdataManager mgr;
|
||||
if (!mgr.Init(FLAGS_model.c_str())) {
|
||||
if (FLAGS_traineddata.empty()) {
|
||||
tprintf("Must supply --traineddata to eval a training checkpoint!\n");
|
||||
return 1;
|
||||
}
|
||||
tprintf("%s is not a recognition model, trying training checkpoint...\n",
|
||||
FLAGS_model.c_str());
|
||||
if (!mgr.Init(FLAGS_traineddata.c_str())) {
|
||||
@ -67,7 +73,9 @@ int main(int argc, char **argv) {
|
||||
return 1;
|
||||
}
|
||||
double errs = 0.0;
|
||||
STRING result = tester.RunEvalSync(0, &errs, mgr, 0);
|
||||
STRING result =
|
||||
tester.RunEvalSync(0, &errs, mgr,
|
||||
/*training_stage (irrelevant)*/ 0, FLAGS_verbosity);
|
||||
tprintf("%s\n", result.string());
|
||||
return 0;
|
||||
} /* main */
|
||||
|
@ -81,7 +81,7 @@ STRING LSTMTester::RunEvalAsync(int iteration, const double* training_errors,
|
||||
// describing the results.
|
||||
STRING LSTMTester::RunEvalSync(int iteration, const double* training_errors,
|
||||
const TessdataManager& model_mgr,
|
||||
int training_stage) {
|
||||
int training_stage, int verbosity) {
|
||||
LSTMTrainer trainer;
|
||||
trainer.InitCharSet(model_mgr);
|
||||
TFile fp;
|
||||
@ -97,11 +97,20 @@ STRING LSTMTester::RunEvalSync(int iteration, const double* training_errors,
|
||||
const ImageData* trainingdata = test_data_.GetPageBySerial(eval_iteration);
|
||||
trainer.SetIteration(++eval_iteration);
|
||||
NetworkIO fwd_outputs, targets;
|
||||
if (trainer.PrepareForBackward(trainingdata, &fwd_outputs, &targets) !=
|
||||
UNENCODABLE) {
|
||||
Trainability result =
|
||||
trainer.PrepareForBackward(trainingdata, &fwd_outputs, &targets);
|
||||
if (result != UNENCODABLE) {
|
||||
char_error += trainer.NewSingleError(tesseract::ET_CHAR_ERROR);
|
||||
word_error += trainer.NewSingleError(tesseract::ET_WORD_RECERR);
|
||||
++error_count;
|
||||
if (verbosity > 1 || (verbosity > 0 && result != PERFECT)) {
|
||||
tprintf("Truth:%s\n", trainingdata->transcription().string());
|
||||
GenericVector<int> ocr_labels;
|
||||
GenericVector<int> xcoords;
|
||||
trainer.LabelsFromOutputs(fwd_outputs, &ocr_labels, &xcoords);
|
||||
STRING ocr_text = trainer.DecodeLabels(ocr_labels);
|
||||
tprintf("OCR :%s\n", ocr_text.string());
|
||||
}
|
||||
}
|
||||
}
|
||||
char_error *= 100.0 / total_pages_;
|
||||
@ -125,7 +134,8 @@ void* LSTMTester::ThreadFunc(void* lstmtester_void) {
|
||||
LSTMTester* lstmtester = static_cast<LSTMTester*>(lstmtester_void);
|
||||
lstmtester->test_result_ = lstmtester->RunEvalSync(
|
||||
lstmtester->test_iteration_, lstmtester->test_training_errors_,
|
||||
lstmtester->test_model_mgr_, lstmtester->test_training_stage_);
|
||||
lstmtester->test_model_mgr_, lstmtester->test_training_stage_,
|
||||
/*verbosity*/ 0);
|
||||
lstmtester->UnlockRunning();
|
||||
return lstmtester_void;
|
||||
}
|
||||
|
@ -55,9 +55,11 @@ class LSTMTester {
|
||||
STRING RunEvalAsync(int iteration, const double* training_errors,
|
||||
const TessdataManager& model_mgr, int training_stage);
|
||||
// Runs an evaluation synchronously on the stored eval data and returns a
|
||||
// string describing the results. Args as RunEvalAsync.
|
||||
// string describing the results. Args as RunEvalAsync, except verbosity,
|
||||
// which outputs errors, if 1, or all results if 2.
|
||||
STRING RunEvalSync(int iteration, const double* training_errors,
|
||||
const TessdataManager& model_mgr, int training_stage);
|
||||
const TessdataManager& model_mgr, int training_stage,
|
||||
int verbosity);
|
||||
|
||||
private:
|
||||
// Static helper thread function for RunEvalAsync, with a specific signature
|
||||
|
@ -29,9 +29,8 @@
|
||||
|
||||
INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment.");
|
||||
STRING_PARAM_FLAG(net_spec, "", "Network specification");
|
||||
INT_PARAM_FLAG(train_mode, 80, "Controls gross training behavior.");
|
||||
INT_PARAM_FLAG(net_mode, 192, "Controls network behavior.");
|
||||
INT_PARAM_FLAG(perfect_sample_delay, 4,
|
||||
INT_PARAM_FLAG(perfect_sample_delay, 0,
|
||||
"How many imperfect samples between perfect ones.");
|
||||
DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent.");
|
||||
DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights.");
|
||||
@ -40,21 +39,23 @@ DOUBLE_PARAM_FLAG(momentum, 0.9, "Decay factor for repeating deltas.");
|
||||
INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images.");
|
||||
STRING_PARAM_FLAG(continue_from, "", "Existing model to extend");
|
||||
STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models");
|
||||
STRING_PARAM_FLAG(script_dir, "",
|
||||
"Required to set unicharset properties or"
|
||||
" use unicharset compression.");
|
||||
STRING_PARAM_FLAG(train_listfile, "",
|
||||
"File listing training files in lstmf training format.");
|
||||
STRING_PARAM_FLAG(eval_listfile, "",
|
||||
"File listing eval files in lstmf training format.");
|
||||
BOOL_PARAM_FLAG(stop_training, false,
|
||||
"Just convert the training model to a runtime model.");
|
||||
BOOL_PARAM_FLAG(convert_to_int, false,
|
||||
"Convert the recognition model to an integer model.");
|
||||
BOOL_PARAM_FLAG(sequential_training, false,
|
||||
"Use the training files sequentially instead of round-robin.");
|
||||
INT_PARAM_FLAG(append_index, -1, "Index in continue_from Network at which to"
|
||||
" attach the new network defined by net_spec");
|
||||
BOOL_PARAM_FLAG(debug_network, false,
|
||||
"Get info on distribution of weight values");
|
||||
INT_PARAM_FLAG(max_iterations, 0, "If set, exit after this many iterations");
|
||||
DECLARE_STRING_PARAM_FLAG(U);
|
||||
STRING_PARAM_FLAG(traineddata, "",
|
||||
"Combined Dawgs/Unicharset/Recoder for language model");
|
||||
|
||||
// Number of training images to train between calls to MaintainCheckpoints.
|
||||
const int kNumPagesPerBatch = 100;
|
||||
@ -85,6 +86,7 @@ int main(int argc, char **argv) {
|
||||
nullptr, nullptr, nullptr, nullptr, FLAGS_model_output.c_str(),
|
||||
checkpoint_file.c_str(), FLAGS_debug_interval,
|
||||
static_cast<inT64>(FLAGS_max_image_MB) * 1048576);
|
||||
trainer.InitCharSet(FLAGS_traineddata.c_str());
|
||||
|
||||
// Reading something from an existing model doesn't require many flags,
|
||||
// so do it now and exit.
|
||||
@ -97,12 +99,8 @@ int main(int argc, char **argv) {
|
||||
if (FLAGS_debug_network) {
|
||||
trainer.DebugNetwork();
|
||||
} else {
|
||||
if (FLAGS_train_mode & tesseract::TF_INT_MODE)
|
||||
trainer.ConvertToInt();
|
||||
GenericVector<char> recognizer_data;
|
||||
trainer.SaveRecognitionDump(&recognizer_data);
|
||||
if (!tesseract::SaveDataToFile(recognizer_data,
|
||||
FLAGS_model_output.c_str())) {
|
||||
if (FLAGS_convert_to_int) trainer.ConvertToInt();
|
||||
if (!trainer.SaveTraineddata(FLAGS_model_output.c_str())) {
|
||||
tprintf("Failed to write recognition model : %s\n",
|
||||
FLAGS_model_output.c_str());
|
||||
}
|
||||
@ -123,7 +121,6 @@ int main(int argc, char **argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
UNICHARSET unicharset;
|
||||
// Checkpoints always take priority if they are available.
|
||||
if (trainer.TryLoadingCheckpoint(checkpoint_file.string()) ||
|
||||
trainer.TryLoadingCheckpoint(checkpoint_bak.string())) {
|
||||
@ -140,14 +137,6 @@ int main(int argc, char **argv) {
|
||||
trainer.InitIterations();
|
||||
}
|
||||
if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
|
||||
// We need a unicharset to start from scratch or append.
|
||||
string unicharset_str;
|
||||
// Character coding to be used by the classifier.
|
||||
if (!unicharset.load_from_file(FLAGS_U.c_str())) {
|
||||
tprintf("Error: must provide a -U unicharset!\n");
|
||||
return 1;
|
||||
}
|
||||
tesseract::SetupBasicProperties(true, &unicharset);
|
||||
if (FLAGS_append_index >= 0) {
|
||||
tprintf("Appending a new network to an old one!!");
|
||||
if (FLAGS_continue_from.empty()) {
|
||||
@ -156,8 +145,6 @@ int main(int argc, char **argv) {
|
||||
}
|
||||
}
|
||||
// We are initializing from scratch.
|
||||
trainer.InitCharSet(unicharset, FLAGS_script_dir.c_str(),
|
||||
FLAGS_train_mode);
|
||||
if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index,
|
||||
FLAGS_net_mode, FLAGS_weight_range,
|
||||
FLAGS_learning_rate, FLAGS_momentum)) {
|
||||
@ -168,7 +155,9 @@ int main(int argc, char **argv) {
|
||||
trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
|
||||
}
|
||||
}
|
||||
if (!trainer.LoadAllTrainingData(filenames)) {
|
||||
if (!trainer.LoadAllTrainingData(
|
||||
filenames, FLAGS_sequential_training ? tesseract::CS_SEQUENTIAL
|
||||
: tesseract::CS_ROUND_ROBIN)) {
|
||||
tprintf("Load of images failed!!\n");
|
||||
return 1;
|
||||
}
|
||||
|
@ -60,11 +60,11 @@ initialize_fontconfig
|
||||
|
||||
phase_I_generate_image 8
|
||||
phase_UP_generate_unicharset
|
||||
phase_D_generate_dawg
|
||||
if ((LINEDATA)); then
|
||||
phase_E_extract_features "lstm.train" 8 "lstmf"
|
||||
make__lstmdata
|
||||
else
|
||||
phase_D_generate_dawg
|
||||
phase_E_extract_features "box.train" 8 "tr"
|
||||
phase_C_cluster_prototypes "${TRAINING_DIR}/${LANG_CODE}.normproto"
|
||||
if [[ "${ENABLE_SHAPE_CLUSTERING}" == "y" ]]; then
|
||||
|
@ -44,11 +44,19 @@ err_exit() {
|
||||
run_command() {
|
||||
local cmd=$(which $1)
|
||||
if [[ -z ${cmd} ]]; then
|
||||
err_exit "$1 not found"
|
||||
for d in api training; do
|
||||
cmd=$(which $d/$1)
|
||||
if [[ ! -z ${cmd} ]]; then
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [[ -z ${cmd} ]]; then
|
||||
err_exit "$1 not found"
|
||||
fi
|
||||
fi
|
||||
shift
|
||||
tlog "[$(date)] ${cmd} $@"
|
||||
${cmd} "$@" 2>&1 1>&2 | tee -a ${LOG_FILE}
|
||||
"${cmd}" "$@" 2>&1 1>&2 | tee -a ${LOG_FILE}
|
||||
# check completion status
|
||||
if [[ $? -gt 0 ]]; then
|
||||
err_exit "Program $(basename ${cmd}) failed. Abort."
|
||||
@ -204,7 +212,7 @@ generate_font_image() {
|
||||
common_args+=" --fonts_dir=${FONTS_DIR} --strip_unrenderable_words"
|
||||
common_args+=" --leading=${LEADING}"
|
||||
common_args+=" --char_spacing=${CHAR_SPACING} --exposure=${EXPOSURE}"
|
||||
common_args+=" --outputbase=${outbase}"
|
||||
common_args+=" --outputbase=${outbase} --max_pages=3"
|
||||
|
||||
# add --writing_mode=vertical-upright to common_args if the font is
|
||||
# specified to be rendered vertically.
|
||||
@ -490,36 +498,43 @@ phase_B_generate_ambiguities() {
|
||||
|
||||
make__lstmdata() {
|
||||
tlog "\n=== Constructing LSTM training data ==="
|
||||
local lang_prefix=${LANGDATA_ROOT}/${LANG_CODE}/${LANG_CODE}
|
||||
if [[ ! -d ${OUTPUT_DIR} ]]; then
|
||||
local lang_prefix="${LANGDATA_ROOT}/${LANG_CODE}/${LANG_CODE}"
|
||||
if [[ ! -d "${OUTPUT_DIR}" ]]; then
|
||||
tlog "Creating new directory ${OUTPUT_DIR}"
|
||||
mkdir -p ${OUTPUT_DIR}
|
||||
mkdir -p "${OUTPUT_DIR}"
|
||||
fi
|
||||
local lang_is_rtl=""
|
||||
# TODO(rays) set using script lang lists.
|
||||
case "${LANG_CODE}" in
|
||||
ara | div| fas | pus | snd | syr | uig | urd | kur_ara | heb | yid )
|
||||
lang_is_rtl="--lang_is_rtl" ;;
|
||||
* ) ;;
|
||||
esac
|
||||
local pass_through=""
|
||||
# TODO(rays) set using script lang lists.
|
||||
case "${LANG_CODE}" in
|
||||
asm | ben | bih | hin | mar | nep | guj | kan | mal | tam | tel | pan | \
|
||||
dzo | sin | san | bod | ori | khm | mya | tha | lao | heb | yid | ara | \
|
||||
fas | pus | snd | urd | div | syr | uig | kur_ara )
|
||||
pass_through="--pass_through_recoder" ;;
|
||||
* ) ;;
|
||||
esac
|
||||
|
||||
# Copy available files for this language from the langdata dir.
|
||||
if [[ -r ${lang_prefix}.config ]]; then
|
||||
tlog "Copying ${lang_prefix}.config to ${OUTPUT_DIR}"
|
||||
cp ${lang_prefix}.config ${OUTPUT_DIR}
|
||||
chmod u+w ${OUTPUT_DIR}/${LANG_CODE}.config
|
||||
fi
|
||||
if [[ -r "${TRAINING_DIR}/${LANG_CODE}.unicharset" ]]; then
|
||||
tlog "Moving ${TRAINING_DIR}/${LANG_CODE}.unicharset to ${OUTPUT_DIR}"
|
||||
mv "${TRAINING_DIR}/${LANG_CODE}.unicharset" "${OUTPUT_DIR}"
|
||||
fi
|
||||
for ext in number-dawg punc-dawg word-dawg; do
|
||||
local src="${TRAINING_DIR}/${LANG_CODE}.${ext}"
|
||||
if [[ -r "${src}" ]]; then
|
||||
dest="${OUTPUT_DIR}/${LANG_CODE}.lstm-${ext}"
|
||||
tlog "Moving ${src} to ${dest}"
|
||||
mv "${src}" "${dest}"
|
||||
fi
|
||||
done
|
||||
# Build the starter traineddata from the inputs.
|
||||
run_command combine_lang_model \
|
||||
--input_unicharset "${TRAINING_DIR}/${LANG_CODE}.unicharset" \
|
||||
--script_dir "${LANGDATA_ROOT}" \
|
||||
--words "${lang_prefix}.wordlist" \
|
||||
--numbers "${lang_prefix}.numbers" \
|
||||
--puncs "${lang_prefix}.punc" \
|
||||
--output_dir "${OUTPUT_DIR}" --lang "${LANG_CODE}" \
|
||||
"${pass_through}" "${lang_is_rtl}"
|
||||
for f in "${TRAINING_DIR}/${LANG_CODE}".*.lstmf; do
|
||||
tlog "Moving ${f} to ${OUTPUT_DIR}"
|
||||
mv "${f}" "${OUTPUT_DIR}"
|
||||
done
|
||||
local lstm_list="${OUTPUT_DIR}/${LANG_CODE}.training_files.txt"
|
||||
ls -1 "${OUTPUT_DIR}"/*.lstmf > "${lstm_list}"
|
||||
ls -1 "${OUTPUT_DIR}/${LANG_CODE}".*.lstmf > "${lstm_list}"
|
||||
}
|
||||
|
||||
make__traineddata() {
|
||||
|
@ -79,6 +79,9 @@ INT_PARAM_FLAG(xsize, 3600, "Width of output image");
|
||||
// Max height of output image (in pixels).
|
||||
INT_PARAM_FLAG(ysize, 4800, "Height of output image");
|
||||
|
||||
// Max number of pages to produce.
|
||||
INT_PARAM_FLAG(max_pages, 0, "Maximum number of pages to output (0=unlimited)");
|
||||
|
||||
// Margin around text (in pixels).
|
||||
INT_PARAM_FLAG(margin, 100, "Margin round edges of image");
|
||||
|
||||
@ -579,7 +582,10 @@ int Main() {
|
||||
for (int pass = 0; pass < num_pass; ++pass) {
|
||||
int page_num = 0;
|
||||
string font_used;
|
||||
for (size_t offset = 0; offset < strlen(to_render_utf8); ++im, ++page_num) {
|
||||
for (size_t offset = 0;
|
||||
offset < strlen(to_render_utf8) &&
|
||||
(FLAGS_max_pages == 0 || page_num < FLAGS_max_pages);
|
||||
++im, ++page_num) {
|
||||
tlog(1, "Starting page %d\n", im);
|
||||
Pix* pix = nullptr;
|
||||
if (FLAGS_find_fonts) {
|
||||
|
@ -139,6 +139,42 @@ void SetupBasicProperties(bool report_errors, bool decompose,
|
||||
unicharset->post_load_setup();
|
||||
}
|
||||
|
||||
// Helper sets the properties from universal script unicharsets, if found.
|
||||
void SetScriptProperties(const string& script_dir, UNICHARSET* unicharset) {
|
||||
for (int s = 0; s < unicharset->get_script_table_size(); ++s) {
|
||||
// Load the unicharset for the script if available.
|
||||
string filename = script_dir + "/" +
|
||||
unicharset->get_script_from_script_id(s) + ".unicharset";
|
||||
UNICHARSET script_set;
|
||||
if (script_set.load_from_file(filename.c_str())) {
|
||||
unicharset->SetPropertiesFromOther(script_set);
|
||||
} else if (s != unicharset->common_sid() && s != unicharset->null_sid()) {
|
||||
tprintf("Failed to load script unicharset from:%s\n", filename.c_str());
|
||||
}
|
||||
}
|
||||
for (int c = SPECIAL_UNICHAR_CODES_COUNT; c < unicharset->size(); ++c) {
|
||||
if (unicharset->PropertiesIncomplete(c)) {
|
||||
tprintf("Warning: properties incomplete for index %d = %s\n", c,
|
||||
unicharset->id_to_unichar(c));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper gets the combined x-heights string.
|
||||
string GetXheightString(const string& script_dir,
|
||||
const UNICHARSET& unicharset) {
|
||||
string xheights_str;
|
||||
for (int s = 0; s < unicharset.get_script_table_size(); ++s) {
|
||||
// Load the xheights for the script if available.
|
||||
string filename = script_dir + "/" +
|
||||
unicharset.get_script_from_script_id(s) + ".xheights";
|
||||
string script_heights;
|
||||
if (File::ReadFileToString(filename, &script_heights))
|
||||
xheights_str += script_heights;
|
||||
}
|
||||
return xheights_str;
|
||||
}
|
||||
|
||||
// Helper to set the properties for an input unicharset file, writes to the
|
||||
// output file. If an appropriate script unicharset can be found in the
|
||||
// script_dir directory, then the tops and bottoms are expanded using the
|
||||
@ -158,29 +194,11 @@ void SetPropertiesForInputFile(const string& script_dir,
|
||||
// Set unichar properties
|
||||
tprintf("Setting unichar properties\n");
|
||||
SetupBasicProperties(true, false, &unicharset);
|
||||
string xheights_str;
|
||||
for (int s = 0; s < unicharset.get_script_table_size(); ++s) {
|
||||
// Load the unicharset for the script if available.
|
||||
string filename = script_dir + "/" +
|
||||
unicharset.get_script_from_script_id(s) + ".unicharset";
|
||||
UNICHARSET script_set;
|
||||
if (script_set.load_from_file(filename.c_str())) {
|
||||
unicharset.SetPropertiesFromOther(script_set);
|
||||
}
|
||||
// Load the xheights for the script if available.
|
||||
filename = script_dir + "/" + unicharset.get_script_from_script_id(s) +
|
||||
".xheights";
|
||||
string script_heights;
|
||||
if (File::ReadFileToString(filename, &script_heights))
|
||||
xheights_str += script_heights;
|
||||
}
|
||||
if (!output_xheights_file.empty())
|
||||
tprintf("Setting script properties\n");
|
||||
SetScriptProperties(script_dir, &unicharset);
|
||||
if (!output_xheights_file.empty()) {
|
||||
string xheights_str = GetXheightString(script_dir, unicharset);
|
||||
File::WriteStringToFileOrDie(xheights_str, output_xheights_file);
|
||||
for (int c = SPECIAL_UNICHAR_CODES_COUNT; c < unicharset.size(); ++c) {
|
||||
if (unicharset.PropertiesIncomplete(c)) {
|
||||
tprintf("Warning: properties incomplete for index %d = %s\n",
|
||||
c, unicharset.id_to_unichar(c));
|
||||
}
|
||||
}
|
||||
|
||||
// Write the output unicharset
|
||||
|
@ -38,6 +38,10 @@ void SetupBasicProperties(bool report_errors, bool decompose,
|
||||
inline void SetupBasicProperties(bool report_errors, UNICHARSET* unicharset) {
|
||||
SetupBasicProperties(report_errors, false, unicharset);
|
||||
}
|
||||
// Helper sets the properties from universal script unicharsets, if found.
|
||||
void SetScriptProperties(const string& script_dir, UNICHARSET* unicharset);
|
||||
// Helper gets the combined x-heights string.
|
||||
string GetXheightString(const string& script_dir, const UNICHARSET& unicharset);
|
||||
|
||||
// Helper to set the properties for an input unicharset file, writes to the
|
||||
// output file. If an appropriate script unicharset can be found in the
|
||||
|
Loading…
Reference in New Issue
Block a user