From 2633fef0b6ac9b616eae3d457bf796076eb8f43c Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Wed, 2 Aug 2017 13:29:23 -0700 Subject: [PATCH] Part 2 of separating out the unicharset from the LSTM model, fixing command line for training --- dict/dawg.cpp | 29 ++-- dict/dawg.h | 26 +-- dict/trie.cpp | 39 ++--- dict/trie.h | 12 +- lstm/lstmtrainer.cpp | 76 ++------ lstm/lstmtrainer.h | 22 +-- training/Makefile.am | 27 ++- training/combine_lang_model.cpp | 87 ++++++++++ training/lang_model_helpers.cpp | 231 +++++++++++++++++++++++++ training/lang_model_helpers.h | 84 +++++++++ training/lstmeval.cpp | 10 +- training/lstmtester.cpp | 18 +- training/lstmtester.h | 6 +- training/lstmtraining.cpp | 37 ++-- training/tesstrain.sh | 2 +- training/tesstrain_utils.sh | 65 ++++--- training/text2image.cpp | 8 +- training/unicharset_training_utils.cpp | 62 ++++--- training/unicharset_training_utils.h | 4 + 19 files changed, 624 insertions(+), 221 deletions(-) create mode 100644 training/combine_lang_model.cpp create mode 100644 training/lang_model_helpers.cpp create mode 100644 training/lang_model_helpers.h diff --git a/dict/dawg.cpp b/dict/dawg.cpp index 45e84b0c..85d2471d 100644 --- a/dict/dawg.cpp +++ b/dict/dawg.cpp @@ -339,16 +339,15 @@ bool SquishedDawg::read_squished_dawg(TFile *file) { return true; } -NODE_MAP SquishedDawg::build_node_map(inT32 *num_nodes) const { +std::unique_ptr SquishedDawg::build_node_map( + inT32 *num_nodes) const { EDGE_REF edge; - NODE_MAP node_map; + std::unique_ptr node_map(new EDGE_REF[num_edges_]); inT32 node_counter; inT32 num_edges; - node_map = (NODE_MAP) malloc(sizeof(EDGE_REF) * num_edges_); - for (edge = 0; edge < num_edges_; edge++) // init all slots - node_map [edge] = -1; + node_map[edge] = -1; node_counter = num_forward_edges(0); @@ -366,25 +365,25 @@ NODE_MAP SquishedDawg::build_node_map(inT32 *num_nodes) const { edge--; } } - return (node_map); + return node_map; } -void SquishedDawg::write_squished_dawg(FILE *file) { +bool SquishedDawg::write_squished_dawg(TFile *file) { EDGE_REF edge; inT32 num_edges; inT32 node_count = 0; - NODE_MAP node_map; EDGE_REF old_index; EDGE_RECORD temp_record; if (debug_level_) tprintf("write_squished_dawg\n"); - node_map = build_node_map(&node_count); + std::unique_ptr node_map(build_node_map(&node_count)); // Write the magic number to help detecting a change in endianness. inT16 magic = kDawgMagicNumber; - fwrite(&magic, sizeof(inT16), 1, file); - fwrite(&unicharset_size_, sizeof(inT32), 1, file); + if (file->FWrite(&magic, sizeof(magic), 1) != 1) return false; + if (file->FWrite(&unicharset_size_, sizeof(unicharset_size_), 1) != 1) + return false; // Count the number of edges in this Dawg. num_edges = 0; @@ -392,7 +391,8 @@ void SquishedDawg::write_squished_dawg(FILE *file) { if (forward_edge(edge)) num_edges++; - fwrite(&num_edges, sizeof(inT32), 1, file); // write edge count to file + // Write edge count to file. + if (file->FWrite(&num_edges, sizeof(num_edges), 1) != 1) return false; if (debug_level_) { tprintf("%d nodes in DAWG\n", node_count); @@ -405,7 +405,8 @@ void SquishedDawg::write_squished_dawg(FILE *file) { old_index = next_node_from_edge_rec(edges_[edge]); set_next_node(edge, node_map[old_index]); temp_record = edges_[edge]; - fwrite(&(temp_record), sizeof(EDGE_RECORD), 1, file); + if (file->FWrite(&temp_record, sizeof(temp_record), 1) != 1) + return false; set_next_node(edge, old_index); } while (!last_edge(edge++)); @@ -416,7 +417,7 @@ void SquishedDawg::write_squished_dawg(FILE *file) { edge--; } } - free(node_map); + return true; } } // namespace tesseract diff --git a/dict/dawg.h b/dict/dawg.h index c36e7ba4..c3ee7d24 100644 --- a/dict/dawg.h +++ b/dict/dawg.h @@ -31,9 +31,10 @@ I n c l u d e s ----------------------------------------------------------------------*/ +#include #include "elst.h" -#include "ratngs.h" #include "params.h" +#include "ratngs.h" #include "tesscallback.h" #ifndef __GNUC__ @@ -483,18 +484,22 @@ class SquishedDawg : public Dawg { void print_node(NODE_REF node, int max_num_edges) const; /// Writes the squished/reduced Dawg to a file. - void write_squished_dawg(FILE *file); + bool write_squished_dawg(TFile *file); /// Opens the file with the given filename and writes the /// squished/reduced Dawg to the file. - void write_squished_dawg(const char *filename) { - FILE *file = fopen(filename, "wb"); - if (file == NULL) { - tprintf("Error opening %s\n", filename); - exit(1); + bool write_squished_dawg(const char *filename) { + TFile file; + file.OpenWrite(nullptr); + if (!this->write_squished_dawg(&file)) { + tprintf("Error serializing %s\n", filename); + return false; } - this->write_squished_dawg(file); - fclose(file); + if (!file.CloseWrite(filename, nullptr)) { + tprintf("Error writing file %s\n", filename); + return false; + } + return true; } private: @@ -549,8 +554,7 @@ class SquishedDawg : public Dawg { tprintf("__________________________\n"); } /// Constructs a mapping from the memory node indices to disk node indices. - NODE_MAP build_node_map(inT32 *num_nodes) const; - + std::unique_ptr build_node_map(inT32 *num_nodes) const; // Member variables. EDGE_ARRAY edges_; diff --git a/dict/trie.cpp b/dict/trie.cpp index a4406664..3bee2ea8 100644 --- a/dict/trie.cpp +++ b/dict/trie.cpp @@ -290,40 +290,27 @@ bool Trie::read_and_add_word_list(const char *filename, const UNICHARSET &unicharset, Trie::RTLReversePolicy reverse_policy) { GenericVector word_list; - if (!read_word_list(filename, unicharset, reverse_policy, &word_list)) - return false; + if (!read_word_list(filename, &word_list)) return false; word_list.sort(sort_strings_by_dec_length); - return add_word_list(word_list, unicharset); + return add_word_list(word_list, unicharset, reverse_policy); } bool Trie::read_word_list(const char *filename, - const UNICHARSET &unicharset, - Trie::RTLReversePolicy reverse_policy, GenericVector* words) { FILE *word_file; - char string[CHARS_PER_LINE]; + char line_str[CHARS_PER_LINE]; int word_count = 0; word_file = fopen(filename, "rb"); if (word_file == NULL) return false; - while (fgets(string, CHARS_PER_LINE, word_file) != NULL) { - chomp_string(string); // remove newline - WERD_CHOICE word(string, unicharset); - if ((reverse_policy == RRP_REVERSE_IF_HAS_RTL && - word.has_rtl_unichar_id()) || - reverse_policy == RRP_FORCE_REVERSE) { - word.reverse_and_mirror_unichar_ids(); - } + while (fgets(line_str, sizeof(line_str), word_file) != NULL) { + chomp_string(line_str); // remove newline + STRING word_str(line_str); ++word_count; if (debug_level_ && word_count % 10000 == 0) tprintf("Read %d words so far\n", word_count); - if (word.length() != 0 && !word.contains_unichar_id(INVALID_UNICHAR_ID)) { - words->push_back(word.unichar_string()); - } else if (debug_level_) { - tprintf("Skipping invalid word %s\n", string); - if (debug_level_ >= 3) word.print(); - } + words->push_back(word_str); } if (debug_level_) tprintf("Read %d words total.\n", word_count); @@ -331,10 +318,18 @@ bool Trie::read_word_list(const char *filename, return true; } -bool Trie::add_word_list(const GenericVector& words, - const UNICHARSET &unicharset) { +bool Trie::add_word_list(const GenericVector &words, + const UNICHARSET &unicharset, + Trie::RTLReversePolicy reverse_policy) { for (int i = 0; i < words.size(); ++i) { WERD_CHOICE word(words[i].string(), unicharset); + if (word.length() == 0 || word.contains_unichar_id(INVALID_UNICHAR_ID)) + continue; + if ((reverse_policy == RRP_REVERSE_IF_HAS_RTL && + word.has_rtl_unichar_id()) || + reverse_policy == RRP_FORCE_REVERSE) { + word.reverse_and_mirror_unichar_ids(); + } if (!word_in_dawg(word)) { add_word_to_dawg(word); if (!word_in_dawg(word)) { diff --git a/dict/trie.h b/dict/trie.h index 8428ebba..554cf4ed 100644 --- a/dict/trie.h +++ b/dict/trie.h @@ -177,18 +177,16 @@ class Trie : public Dawg { const UNICHARSET &unicharset, Trie::RTLReversePolicy reverse); - // Reads a list of words from the given file, applying the reverse_policy, - // according to information in the unicharset. + // Reads a list of words from the given file. // Returns false on error. bool read_word_list(const char *filename, - const UNICHARSET &unicharset, - Trie::RTLReversePolicy reverse_policy, GenericVector* words); // Adds a list of words previously read using read_word_list to the trie - // using the given unicharset to convert to unichar-ids. + // using the given unicharset and reverse_policy to convert to unichar-ids. // Returns false on error. - bool add_word_list(const GenericVector& words, - const UNICHARSET &unicharset); + bool add_word_list(const GenericVector &words, + const UNICHARSET &unicharset, + Trie::RTLReversePolicy reverse_policy); // Inserts the list of patterns from the given file into the Trie. // The pattern list file should contain one pattern per line in UTF-8 format. diff --git a/lstm/lstmtrainer.cpp b/lstm/lstmtrainer.cpp index f13b278a..78affe14 100644 --- a/lstm/lstmtrainer.cpp +++ b/lstm/lstmtrainer.cpp @@ -130,22 +130,6 @@ bool LSTMTrainer::TryLoadingCheckpoint(const char* filename) { return checkpoint_reader_->Run(data, this); } -// Initializes the character set encode/decode mechanism. -// train_flags control training behavior according to the TrainingFlags -// enum, including character set encoding. -// script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided, -// fully initializes the unicharset from the universal unicharsets. -// Note: Call before InitNetwork! -void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset, - const STRING& script_dir, int train_flags) { - EmptyConstructor(); - training_flags_ = train_flags; - ccutil_.unicharset.CopyFrom(unicharset); - null_char_ = GetUnicharset().has_special_codes() ? UNICHAR_BROKEN - : GetUnicharset().size(); - SetUnicharsetProperties(script_dir); -} - // Initializes the trainer with a network_spec in the network description // net_flags control network behavior according to the NetworkFlags enum. // There isn't really much difference between them - only where the effects @@ -278,9 +262,10 @@ void LSTMTrainer::DebugNetwork() { // Loads a set of lstmf files that were created using the lstm.train config to // tesseract into memory ready for training. Returns false if nothing was // loaded. -bool LSTMTrainer::LoadAllTrainingData(const GenericVector& filenames) { +bool LSTMTrainer::LoadAllTrainingData(const GenericVector& filenames, + CachingStrategy cache_strategy) { training_data_.Clear(); - return training_data_.LoadDocuments(filenames, CacheStrategy(), file_reader_); + return training_data_.LoadDocuments(filenames, cache_strategy, file_reader_); } // Keeps track of best and locally worst char error_rate and launches tests @@ -908,6 +893,15 @@ bool LSTMTrainer::ReadLocalTrainingDump(const TessdataManager* mgr, return DeSerialize(mgr, &fp); } +// Writes the full recognition traineddata to the given filename. +bool LSTMTrainer::SaveTraineddata(const STRING& filename) { + GenericVector recognizer_data; + SaveRecognitionDump(&recognizer_data); + mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0], + recognizer_data.size()); + return mgr_.SaveFile(filename, file_writer_); +} + // Writes the recognizer to memory, so that it can be used for testing later. void LSTMTrainer::SaveRecognitionDump(GenericVector* data) const { TFile fp; @@ -964,52 +958,6 @@ void LSTMTrainer::EmptyConstructor() { InitIterations(); } -// Sets the unicharset properties using the given script_dir as a source of -// script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets -// up the recoder_ to simplify the unicharset. -void LSTMTrainer::SetUnicharsetProperties(const STRING& script_dir) { - tprintf("Setting unichar properties\n"); - for (int s = 0; s < GetUnicharset().get_script_table_size(); ++s) { - if (strcmp("NULL", GetUnicharset().get_script_from_script_id(s)) == 0) - continue; - // Load the unicharset for the script if available. - STRING filename = script_dir + "/" + - GetUnicharset().get_script_from_script_id(s) + - ".unicharset"; - UNICHARSET script_set; - GenericVector data; - if ((*file_reader_)(filename, &data) && - script_set.load_from_inmemory_file(&data[0], data.size())) { - tprintf("Setting properties for script %s\n", - GetUnicharset().get_script_from_script_id(s)); - ccutil_.unicharset.SetPropertiesFromOther(script_set); - } - } - if (IsRecoding()) { - STRING filename = script_dir + "/radical-stroke.txt"; - GenericVector data; - if ((*file_reader_)(filename, &data)) { - data += '\0'; - STRING stroke_table = &data[0]; - if (recoder_.ComputeEncoding(GetUnicharset(), null_char_, - &stroke_table)) { - RecodedCharID code; - recoder_.EncodeUnichar(null_char_, &code); - null_char_ = code(0); - // Space should encode as itself. - recoder_.EncodeUnichar(UNICHAR_SPACE, &code); - ASSERT_HOST(code(0) == UNICHAR_SPACE); - return; - } - } else { - tprintf("Failed to load radical-stroke info from: %s\n", - filename.string()); - } - } - training_flags_ |= TF_COMPRESS_UNICHARSET; - recoder_.SetupPassThrough(GetUnicharset()); -} - // Outputs the string and periodically displays the given network inputs // as an image in the given window, and the corresponding labels at the // corresponding x_starts. diff --git a/lstm/lstmtrainer.h b/lstm/lstmtrainer.h index 65df18af..377e0015 100644 --- a/lstm/lstmtrainer.h +++ b/lstm/lstmtrainer.h @@ -101,14 +101,6 @@ class LSTMTrainer : public LSTMRecognizer { // false in case of failure. bool TryLoadingCheckpoint(const char* filename); - // Initializes the character set encode/decode mechanism. - // train_flags control training behavior according to the TrainingFlags - // enum, including character set encoding. - // script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided, - // fully initializes the unicharset from the universal unicharsets. - // Note: Call before InitNetwork! - void InitCharSet(const UNICHARSET& unicharset, const STRING& script_dir, - int train_flags); // Initializes the character set encode/decode mechanism directly from a // previously setup traineddata containing dawgs, UNICHARSET and // UnicharCompress. Note: Call before InitNetwork! @@ -186,7 +178,8 @@ class LSTMTrainer : public LSTMRecognizer { // Loads a set of lstmf files that were created using the lstm.train config to // tesseract into memory ready for training. Returns false if nothing was // loaded. - bool LoadAllTrainingData(const GenericVector& filenames); + bool LoadAllTrainingData(const GenericVector& filenames, + CachingStrategy cache_strategy); // Keeps track of best and locally worst error rate, using internally computed // values. See MaintainCheckpointsSpecific for more detail. @@ -315,12 +308,12 @@ class LSTMTrainer : public LSTMRecognizer { // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump. void SetupCheckpointInfo(); + // Writes the full recognition traineddata to the given filename. + bool SaveTraineddata(const STRING& filename); + // Writes the recognizer to memory, so that it can be used for testing later. void SaveRecognitionDump(GenericVector* data) const; - // Writes current best model to a file, unless it has already been written. - bool SaveBestModel(FileWriter writer) const; - // Returns a suitable filename for a training dump, based on the model_base_, // the iteration and the error rates. STRING DumpFilename() const; @@ -336,11 +329,6 @@ class LSTMTrainer : public LSTMRecognizer { // Factored sub-constructor sets up reasonable default values. void EmptyConstructor(); - // Sets the unicharset properties using the given script_dir as a source of - // script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets - // up the recoder_ to simplify the unicharset. - void SetUnicharsetProperties(const STRING& script_dir); - // Outputs the string and periodically displays the given network inputs // as an image in the given window, and the corresponding labels at the // corresponding x_starts. diff --git a/training/Makefile.am b/training/Makefile.am index 8d06d945..bc9c5e4b 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -19,8 +19,8 @@ endif noinst_HEADERS = \ boxchar.h commandlineflags.h commontraining.h degradeimage.h \ - fileio.h icuerrorcode.h ligature_table.h lstmtester.h normstrngs.h \ - mergenf.h pango_font_info.h stringrenderer.h \ + fileio.h icuerrorcode.h lang_model_helpers.h ligature_table.h \ + lstmtester.h mergenf.h normstrngs.h pango_font_info.h stringrenderer.h \ tessopt.h tlog.h unicharset_training_utils.h util.h \ validate_grapheme.h validate_indic.h validate_khmer.h \ validate_myanmar.h validator.h @@ -33,15 +33,15 @@ libtesseract_training_la_LIBADD = \ libtesseract_training_la_SOURCES = \ boxchar.cpp commandlineflags.cpp commontraining.cpp degradeimage.cpp \ - fileio.cpp ligature_table.cpp lstmtester.cpp normstrngs.cpp pango_font_info.cpp \ - stringrenderer.cpp tlog.cpp unicharset_training_utils.cpp \ + fileio.cpp lang_model_helpers.cpp ligature_table.cpp lstmtester.cpp \ + normstrngs.cpp pango_font_info.cpp stringrenderer.cpp tlog.cpp unicharset_training_utils.cpp \ validate_grapheme.cpp validate_indic.cpp validate_khmer.cpp \ validate_myanmar.cpp validator.cpp libtesseract_tessopt_la_SOURCES = \ tessopt.cpp -bin_PROGRAMS = ambiguous_words classifier_tester cntraining combine_tessdata \ +bin_PROGRAMS = ambiguous_words classifier_tester cntraining combine_lang_model combine_tessdata \ dawg2wordlist lstmeval lstmtraining mftraining set_unicharset_properties shapeclustering \ text2image unicharset_extractor wordlist2dawg @@ -94,11 +94,26 @@ classifier_tester_LDADD += \ ../api/libtesseract.la endif +combine_lang_model_SOURCES = combine_lang_model.cpp +#combine_lang_model_LDFLAGS = -static +combine_lang_model_LDADD = \ + libtesseract_training.la \ + libtesseract_tessopt.la \ + $(ICU_I18N_LIBS) $(ICU_UC_LIBS) +if USING_MULTIPLELIBS +combine_lang_model_LDADD += \ + ../ccutil/libtesseract_ccutil.la +else +combine_lang_model_LDADD += \ + ../api/libtesseract.la +endif + combine_tessdata_SOURCES = combine_tessdata.cpp #combine_tessdata_LDFLAGS = -static if USING_MULTIPLELIBS combine_tessdata_LDADD = \ - ../ccutil/libtesseract_ccutil.la + ../ccutil/libtesseract_ccutil.la \ + ../lstm/libtesseract_lstm.la else combine_tessdata_LDADD = \ ../api/libtesseract.la diff --git a/training/combine_lang_model.cpp b/training/combine_lang_model.cpp new file mode 100644 index 00000000..4ce15d2c --- /dev/null +++ b/training/combine_lang_model.cpp @@ -0,0 +1,87 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +// Purpose: Program to generate a traineddata file that can be used to train an +// LSTM-based neural network model from a unicharset and an optional +// set of wordlists. Eliminates the need to run +// set_unicharset_properties, wordlist2dawg, some non-existent binary +// to generate the recoder, and finally combine_tessdata. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "commandlineflags.h" +#include "lang_model_helpers.h" +#include "tprintf.h" +#include "unicharset_training_utils.h" + +STRING_PARAM_FLAG(input_unicharset, "", + "Unicharset to complete and use in encoding"); +STRING_PARAM_FLAG(script_dir, "", + "Directory name for input script unicharsets"); +STRING_PARAM_FLAG(words, "", + "File listing words to use for the system dictionary"); +STRING_PARAM_FLAG(puncs, "", "File listing punctuation patterns"); +STRING_PARAM_FLAG(numbers, "", "File listing number patterns"); +STRING_PARAM_FLAG(output_dir, "", "Root directory for output files"); +STRING_PARAM_FLAG(version_str, "", "Version string to add to traineddata file"); +STRING_PARAM_FLAG(lang, "", "Name of language being processed"); +BOOL_PARAM_FLAG(lang_is_rtl, false, + "True if lang being processed is written right-to-left"); +BOOL_PARAM_FLAG(pass_through_recoder, false, + "If true, the recoder is a simple pass-through of the" + " unicharset. Otherwise, potentially a compression of it"); + +int main(int argc, char** argv) { + tesseract::ParseCommandLineFlags(argv[0], &argc, &argv, true); + + // Check validity of input flags. + if (FLAGS_input_unicharset.empty() || FLAGS_script_dir.empty() || + FLAGS_output_dir.empty() || FLAGS_lang.empty()) { + tprintf("Usage: %s --input_unicharset filename --script_dir dirname\n", + argv[0]); + tprintf(" --output_dir rootdir --lang lang [--lang_is_rtl]\n"); + tprintf(" [--words file --puncs file --numbers file]\n"); + tprintf("Sets properties on the input unicharset file, and writes:\n"); + tprintf("rootdir/lang/lang.charset_size=ddd.txt\n"); + tprintf("rootdir/lang/lang.traineddata\n"); + tprintf("rootdir/lang/lang.unicharset\n"); + tprintf("If the 3 word lists are provided, the dawgs are also added to"); + tprintf(" the traineddata file.\n"); + tprintf("The output unicharset and charset_size files are just for human"); + tprintf(" readability.\n"); + exit(1); + } + GenericVector words, puncs, numbers; + // If these reads fail, we get a warning message and an empty list of words. + tesseract::ReadFile(FLAGS_words.c_str(), nullptr).split('\n', &words); + tesseract::ReadFile(FLAGS_puncs.c_str(), nullptr).split('\n', &puncs); + tesseract::ReadFile(FLAGS_numbers.c_str(), nullptr).split('\n', &numbers); + // Load the input unicharset + UNICHARSET unicharset; + if (!unicharset.load_from_file(FLAGS_input_unicharset.c_str(), false)) { + tprintf("Failed to load unicharset from %s\n", + FLAGS_input_unicharset.c_str()); + return 1; + } + tprintf("Loaded unicharset of size %d from file %s\n", unicharset.size(), + FLAGS_input_unicharset.c_str()); + + // Set unichar properties + tprintf("Setting unichar properties\n"); + tesseract::SetupBasicProperties(/*report_errors*/ true, + /*decompose (NFD)*/ false, &unicharset); + tprintf("Setting script properties\n"); + tesseract::SetScriptProperties(FLAGS_script_dir.c_str(), &unicharset); + // Combine everything into a traineddata file. + return tesseract::CombineLangModel( + unicharset, FLAGS_script_dir.c_str(), FLAGS_version_str.c_str(), + FLAGS_output_dir.c_str(), FLAGS_lang.c_str(), FLAGS_pass_through_recoder, + words, puncs, numbers, FLAGS_lang_is_rtl, /*reader*/ nullptr, + /*writer*/ nullptr); +} diff --git a/training/lang_model_helpers.cpp b/training/lang_model_helpers.cpp new file mode 100644 index 00000000..bdbf1e62 --- /dev/null +++ b/training/lang_model_helpers.cpp @@ -0,0 +1,231 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +// Purpose: Collection of convenience functions to simplify creation of the +// unicharset, recoder, and dawgs for an LSTM model. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "lang_model_helpers.h" + +#include +#include +#include +#include "dawg.h" +#include "fileio.h" +#include "tessdatamanager.h" +#include "trie.h" +#include "unicharcompress.h" + +namespace tesseract { + +// Helper makes a filename (//) and writes data +// to the file, using writer if not null, otherwise, a default writer. +// Default writer will overwrite any existing file, but a supplied writer +// can do its own thing. If lang is empty, returns true but does nothing. +// NOTE that suffix should contain any required . for the filename. +bool WriteFile(const string& output_dir, const string& lang, + const string& suffix, const GenericVector& data, + FileWriter writer) { + if (lang.empty()) return true; + string dirname = output_dir + "/" + lang; + // Attempt to make the directory, but ignore errors, as it may not be a + // standard filesystem, and the writer will complain if not successful. + mkdir(dirname.c_str(), S_IRWXU | S_IRWXG); + string filename = dirname + "/" + lang + suffix; + if (writer == nullptr) + return SaveDataToFile(data, filename.c_str()); + else + return (*writer)(data, filename.c_str()); +} + +// Helper reads a file with optional reader and returns a STRING. +// On failure emits a warning message and returns and empty STRING. +STRING ReadFile(const string& filename, FileReader reader) { + if (filename.empty()) return STRING(); + GenericVector data; + bool read_result; + if (reader == nullptr) + read_result = LoadDataFromFile(filename.c_str(), &data); + else + read_result = (*reader)(filename.c_str(), &data); + if (read_result) return STRING(&data[0], data.size()); + tprintf("Failed to read data from: %s\n", filename.c_str()); + return STRING(); +} + +// Helper writes the unicharset to file and to the traineddata. +bool WriteUnicharset(const UNICHARSET& unicharset, const string& output_dir, + const string& lang, FileWriter writer, + TessdataManager* traineddata) { + GenericVector unicharset_data; + TFile fp; + fp.OpenWrite(&unicharset_data); + if (!unicharset.save_to_file(&fp)) return false; + traineddata->OverwriteEntry(TESSDATA_LSTM_UNICHARSET, &unicharset_data[0], + unicharset_data.size()); + return WriteFile(output_dir, lang, ".unicharset", unicharset_data, writer); +} + +// Helper creates the recoder and writes it to the traineddata, and a human- +// readable form to file. +bool WriteRecoder(const UNICHARSET& unicharset, bool pass_through, + const string& output_dir, const string& lang, + FileWriter writer, STRING* radical_table_data, + TessdataManager* traineddata) { + UnicharCompress recoder; + // Where the unicharset is carefully setup already to contain a good + // compact encoding, use a pass-through recoder that does nothing. + // For scripts that have a large number of unicodes (Han, Hangul) we want + // to use the recoder to compress the symbol space by re-encoding each + // unicode as multiple codes from a smaller 'alphabet' that are related to the + // shapes in the character. Hangul Jamo is a perfect example of this. + // See the Hangul Syllables section, sub-section "Equivalence" in: + // http://www.unicode.org/versions/Unicode10.0.0/ch18.pdf + if (pass_through) { + recoder.SetupPassThrough(unicharset); + } else { + int null_char = + unicharset.has_special_codes() ? UNICHAR_BROKEN : unicharset.size(); + tprintf("Null char=%d\n", null_char); + if (!recoder.ComputeEncoding(unicharset, null_char, radical_table_data)) { + tprintf("Creation of encoded unicharset failed!!\n"); + return false; + } + } + TFile fp; + GenericVector recoder_data; + fp.OpenWrite(&recoder_data); + if (!recoder.Serialize(&fp)) return false; + traineddata->OverwriteEntry(TESSDATA_LSTM_RECODER, &recoder_data[0], + recoder_data.size()); + STRING encoding = recoder.GetEncodingAsString(unicharset); + recoder_data.init_to_size(encoding.length(), 0); + memcpy(&recoder_data[0], &encoding[0], encoding.length()); + STRING suffix; + suffix.add_str_int(".charset_size=", recoder.code_range()); + suffix += ".txt"; + return WriteFile(output_dir, lang, suffix.string(), recoder_data, writer); +} + +// Helper builds a dawg from the given words, using the unicharset as coding, +// and reverse_policy for LTR/RTL, and overwrites file_type in the traineddata. +static bool WriteDawg(const GenericVector& words, + const UNICHARSET& unicharset, + Trie::RTLReversePolicy reverse_policy, + TessdataType file_type, TessdataManager* traineddata) { + // The first 3 arguments are not used in this case. + Trie trie(DAWG_TYPE_WORD, "", SYSTEM_DAWG_PERM, unicharset.size(), 0); + trie.add_word_list(words, unicharset, reverse_policy); + tprintf("Reducing Trie to SquishedDawg\n"); + std::unique_ptr dawg(trie.trie_to_dawg()); + if (dawg == nullptr || dawg->NumEdges() == 0) return false; + TFile fp; + GenericVector dawg_data; + fp.OpenWrite(&dawg_data); + if (!dawg->write_squished_dawg(&fp)) return false; + traineddata->OverwriteEntry(file_type, &dawg_data[0], dawg_data.size()); + return true; +} + +// Builds and writes the dawgs, given a set of words, punctuation +// patterns, number patterns, to the traineddata. Encoding uses the given +// unicharset, and the punc dawgs is reversed if lang_is_rtl. +static bool WriteDawgs(const GenericVector& words, + const GenericVector& puncs, + const GenericVector& numbers, bool lang_is_rtl, + const UNICHARSET& unicharset, + TessdataManager* traineddata) { + if (puncs.empty()) { + tprintf("Must have non-empty puncs list to use language models!!\n"); + return false; + } + // For each of the dawg types, make the dawg, and write to traineddata. + // Dawgs are reversed as follows: + // Words: According to the word content. + // Puncs: According to lang_is_rtl. + // Numbers: Never. + // System dawg (main wordlist). + if (!words.empty() && + !WriteDawg(words, unicharset, Trie::RRP_REVERSE_IF_HAS_RTL, + TESSDATA_LSTM_SYSTEM_DAWG, traineddata)) { + return false; + } + // punc/punc-dawg. + Trie::RTLReversePolicy reverse_policy = + lang_is_rtl ? Trie::RRP_FORCE_REVERSE : Trie::RRP_DO_NO_REVERSE; + if (!WriteDawg(puncs, unicharset, reverse_policy, TESSDATA_LSTM_PUNC_DAWG, + traineddata)) { + return false; + } + // numbers/number-dawg. + if (!numbers.empty() && + !WriteDawg(numbers, unicharset, Trie::RRP_DO_NO_REVERSE, + TESSDATA_LSTM_NUMBER_DAWG, traineddata)) { + return false; + } + return true; +} + +// The main function for combine_lang_model.cpp. +// Returns EXIT_SUCCESS or EXIT_FAILURE for error. +int CombineLangModel(const UNICHARSET& unicharset, const string& script_dir, + const string& version_str, const string& output_dir, + const string& lang, bool pass_through_recoder, + const GenericVector& words, + const GenericVector& puncs, + const GenericVector& numbers, bool lang_is_rtl, + FileReader reader, FileWriter writer) { + // Build the traineddata file. + TessdataManager traineddata; + if (!version_str.empty()) { + traineddata.SetVersionString(traineddata.VersionString() + ":" + + version_str); + } + // Unicharset and recoder. + if (!WriteUnicharset(unicharset, output_dir, lang, writer, &traineddata)) { + tprintf("Error writing unicharset!!\n"); + return EXIT_FAILURE; + } + // If there is a config file, read it and add to traineddata. + string config_filename = script_dir + "/" + lang + "/" + lang + ".config"; + STRING config_file = ReadFile(config_filename, reader); + if (config_file.length() > 0) { + traineddata.OverwriteEntry(TESSDATA_LANG_CONFIG, &config_file[0], + config_file.length()); + } + string radical_filename = script_dir + "/radical-stroke.txt"; + STRING radical_data = ReadFile(radical_filename, reader); + if (radical_data.length() == 0) { + tprintf("Error reading radical code table %s\n", radical_filename.c_str()); + return EXIT_FAILURE; + } + if (!WriteRecoder(unicharset, pass_through_recoder, output_dir, lang, writer, + &radical_data, &traineddata)) { + tprintf("Error writing recoder!!\n"); + } + if (!words.empty() || !puncs.empty() || !numbers.empty()) { + if (!WriteDawgs(words, puncs, numbers, lang_is_rtl, unicharset, + &traineddata)) { + tprintf("Error during conversion of wordlists to DAWGs!!\n"); + return EXIT_FAILURE; + } + } + + // Traineddata file. + GenericVector traineddata_data; + traineddata.Serialize(&traineddata_data); + if (!WriteFile(output_dir, lang, ".traineddata", traineddata_data, writer)) { + tprintf("Error writing output traineddata file!!\n"); + return EXIT_FAILURE; + } + return EXIT_SUCCESS; +} + +} // namespace tesseract diff --git a/training/lang_model_helpers.h b/training/lang_model_helpers.h new file mode 100644 index 00000000..e674be64 --- /dev/null +++ b/training/lang_model_helpers.h @@ -0,0 +1,84 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +// Purpose: Collection of convenience functions to simplify creation of the +// unicharset, recoder, and dawgs for an LSTM model. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef TESSERACT_TRAINING_LANG_MODEL_HELPERS_H_ +#define TESSERACT_TRAINING_LANG_MODEL_HELPERS_H_ + +#include +#include "genericvector.h" +#include "serialis.h" +#include "strngs.h" +#include "tessdatamanager.h" +#include "unicharset.h" + +namespace tesseract { + +// Helper makes a filename (//) and writes data +// to the file, using writer if not null, otherwise, a default writer. +// Default writer will overwrite any existing file, but a supplied writer +// can do its own thing. If lang is empty, returns true but does nothing. +// NOTE that suffix should contain any required . for the filename. +bool WriteFile(const string& output_dir, const string& lang, + const string& suffix, const GenericVector& data, + FileWriter writer); +// Helper reads a file with optional reader and returns a STRING. +// On failure emits a warning message and returns and empty STRING. +STRING ReadFile(const string& filename, FileReader reader); + +// Helper writes the unicharset to file and to the traineddata. +bool WriteUnicharset(const UNICHARSET& unicharset, const string& output_dir, + const string& lang, FileWriter writer, + TessdataManager* traineddata); +// Helper creates the recoder from the unicharset and writes it to the +// traineddata, with a human-readable form to file at: +// //.charset_size= for some num being the size +// of the re-encoded character set. The charset_size file is written using +// writer if not null, or using a default file writer otherwise, overwriting +// any existing content. +// If pass_through is true, then the recoder will be a no-op, passing the +// unicharset codes through unchanged. Otherwise, the recoder will "compress" +// the unicharset by encoding Hangul in Jamos, decomposing multi-unicode +// symbols into sequences of unicodes, and encoding Han using the data in the +// radical_table_data, which must be the content of the file: +// langdata/radical-stroke.txt. +bool WriteRecoder(const UNICHARSET& unicharset, bool pass_through, + const string& output_dir, const string& lang, + FileWriter writer, STRING* radical_table_data, + TessdataManager* traineddata); + +// The main function for combine_lang_model.cpp. +// Returns EXIT_SUCCESS or EXIT_FAILURE for error. +// unicharset: can be a hand-created file with incomplete fields. Its basic +// and script properties will be set before it is used. +// script_dir: should point to the langdata (github repo) directory. +// version_str: arbitrary version label. +// Output files will be written to //.* +// If pass_through_recoder is true, the unicharset will be used unchanged as +// labels in the classifier, otherwise, the unicharset will be "compressed" to +// make the recognition task simpler and faster. +// The words/puncs/numbers lists may be all empty. If any are non-empty then +// puncs must be non-empty. +// lang_is_rtl indicates that the language is generally written from right +// to left (eg Arabic/Hebrew). +int CombineLangModel(const UNICHARSET& unicharset, const string& script_dir, + const string& version_str, const string& output_dir, + const string& lang, bool pass_through_recoder, + const GenericVector& words, + const GenericVector& puncs, + const GenericVector& numbers, bool lang_is_rtl, + FileReader reader, FileWriter writer); + +} // namespace tesseract + +#endif // TESSERACT_TRAINING_LANG_MODEL_HELPERS_H_ diff --git a/training/lstmeval.cpp b/training/lstmeval.cpp index 130c2c4a..197f15e1 100644 --- a/training/lstmeval.cpp +++ b/training/lstmeval.cpp @@ -32,6 +32,8 @@ STRING_PARAM_FLAG(traineddata, "", STRING_PARAM_FLAG(eval_listfile, "", "File listing sample files in lstmf training format."); INT_PARAM_FLAG(max_image_MB, 2000, "Max memory to use for images."); +INT_PARAM_FLAG(verbosity, 1, + "Amount of diagnosting information to output (0-2)."); int main(int argc, char **argv) { ParseArguments(&argc, &argv); @@ -45,6 +47,10 @@ int main(int argc, char **argv) { } tesseract::TessdataManager mgr; if (!mgr.Init(FLAGS_model.c_str())) { + if (FLAGS_traineddata.empty()) { + tprintf("Must supply --traineddata to eval a training checkpoint!\n"); + return 1; + } tprintf("%s is not a recognition model, trying training checkpoint...\n", FLAGS_model.c_str()); if (!mgr.Init(FLAGS_traineddata.c_str())) { @@ -67,7 +73,9 @@ int main(int argc, char **argv) { return 1; } double errs = 0.0; - STRING result = tester.RunEvalSync(0, &errs, mgr, 0); + STRING result = + tester.RunEvalSync(0, &errs, mgr, + /*training_stage (irrelevant)*/ 0, FLAGS_verbosity); tprintf("%s\n", result.string()); return 0; } /* main */ diff --git a/training/lstmtester.cpp b/training/lstmtester.cpp index 50e2c562..b837c9cb 100644 --- a/training/lstmtester.cpp +++ b/training/lstmtester.cpp @@ -81,7 +81,7 @@ STRING LSTMTester::RunEvalAsync(int iteration, const double* training_errors, // describing the results. STRING LSTMTester::RunEvalSync(int iteration, const double* training_errors, const TessdataManager& model_mgr, - int training_stage) { + int training_stage, int verbosity) { LSTMTrainer trainer; trainer.InitCharSet(model_mgr); TFile fp; @@ -97,11 +97,20 @@ STRING LSTMTester::RunEvalSync(int iteration, const double* training_errors, const ImageData* trainingdata = test_data_.GetPageBySerial(eval_iteration); trainer.SetIteration(++eval_iteration); NetworkIO fwd_outputs, targets; - if (trainer.PrepareForBackward(trainingdata, &fwd_outputs, &targets) != - UNENCODABLE) { + Trainability result = + trainer.PrepareForBackward(trainingdata, &fwd_outputs, &targets); + if (result != UNENCODABLE) { char_error += trainer.NewSingleError(tesseract::ET_CHAR_ERROR); word_error += trainer.NewSingleError(tesseract::ET_WORD_RECERR); ++error_count; + if (verbosity > 1 || (verbosity > 0 && result != PERFECT)) { + tprintf("Truth:%s\n", trainingdata->transcription().string()); + GenericVector ocr_labels; + GenericVector xcoords; + trainer.LabelsFromOutputs(fwd_outputs, &ocr_labels, &xcoords); + STRING ocr_text = trainer.DecodeLabels(ocr_labels); + tprintf("OCR :%s\n", ocr_text.string()); + } } } char_error *= 100.0 / total_pages_; @@ -125,7 +134,8 @@ void* LSTMTester::ThreadFunc(void* lstmtester_void) { LSTMTester* lstmtester = static_cast(lstmtester_void); lstmtester->test_result_ = lstmtester->RunEvalSync( lstmtester->test_iteration_, lstmtester->test_training_errors_, - lstmtester->test_model_mgr_, lstmtester->test_training_stage_); + lstmtester->test_model_mgr_, lstmtester->test_training_stage_, + /*verbosity*/ 0); lstmtester->UnlockRunning(); return lstmtester_void; } diff --git a/training/lstmtester.h b/training/lstmtester.h index e43dd26e..5937eb7b 100644 --- a/training/lstmtester.h +++ b/training/lstmtester.h @@ -55,9 +55,11 @@ class LSTMTester { STRING RunEvalAsync(int iteration, const double* training_errors, const TessdataManager& model_mgr, int training_stage); // Runs an evaluation synchronously on the stored eval data and returns a - // string describing the results. Args as RunEvalAsync. + // string describing the results. Args as RunEvalAsync, except verbosity, + // which outputs errors, if 1, or all results if 2. STRING RunEvalSync(int iteration, const double* training_errors, - const TessdataManager& model_mgr, int training_stage); + const TessdataManager& model_mgr, int training_stage, + int verbosity); private: // Static helper thread function for RunEvalAsync, with a specific signature diff --git a/training/lstmtraining.cpp b/training/lstmtraining.cpp index 15635ca1..4d639b03 100644 --- a/training/lstmtraining.cpp +++ b/training/lstmtraining.cpp @@ -29,9 +29,8 @@ INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment."); STRING_PARAM_FLAG(net_spec, "", "Network specification"); -INT_PARAM_FLAG(train_mode, 80, "Controls gross training behavior."); INT_PARAM_FLAG(net_mode, 192, "Controls network behavior."); -INT_PARAM_FLAG(perfect_sample_delay, 4, +INT_PARAM_FLAG(perfect_sample_delay, 0, "How many imperfect samples between perfect ones."); DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent."); DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights."); @@ -40,21 +39,23 @@ DOUBLE_PARAM_FLAG(momentum, 0.9, "Decay factor for repeating deltas."); INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images."); STRING_PARAM_FLAG(continue_from, "", "Existing model to extend"); STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models"); -STRING_PARAM_FLAG(script_dir, "", - "Required to set unicharset properties or" - " use unicharset compression."); STRING_PARAM_FLAG(train_listfile, "", "File listing training files in lstmf training format."); STRING_PARAM_FLAG(eval_listfile, "", "File listing eval files in lstmf training format."); BOOL_PARAM_FLAG(stop_training, false, "Just convert the training model to a runtime model."); +BOOL_PARAM_FLAG(convert_to_int, false, + "Convert the recognition model to an integer model."); +BOOL_PARAM_FLAG(sequential_training, false, + "Use the training files sequentially instead of round-robin."); INT_PARAM_FLAG(append_index, -1, "Index in continue_from Network at which to" " attach the new network defined by net_spec"); BOOL_PARAM_FLAG(debug_network, false, "Get info on distribution of weight values"); INT_PARAM_FLAG(max_iterations, 0, "If set, exit after this many iterations"); -DECLARE_STRING_PARAM_FLAG(U); +STRING_PARAM_FLAG(traineddata, "", + "Combined Dawgs/Unicharset/Recoder for language model"); // Number of training images to train between calls to MaintainCheckpoints. const int kNumPagesPerBatch = 100; @@ -85,6 +86,7 @@ int main(int argc, char **argv) { nullptr, nullptr, nullptr, nullptr, FLAGS_model_output.c_str(), checkpoint_file.c_str(), FLAGS_debug_interval, static_cast(FLAGS_max_image_MB) * 1048576); + trainer.InitCharSet(FLAGS_traineddata.c_str()); // Reading something from an existing model doesn't require many flags, // so do it now and exit. @@ -97,12 +99,8 @@ int main(int argc, char **argv) { if (FLAGS_debug_network) { trainer.DebugNetwork(); } else { - if (FLAGS_train_mode & tesseract::TF_INT_MODE) - trainer.ConvertToInt(); - GenericVector recognizer_data; - trainer.SaveRecognitionDump(&recognizer_data); - if (!tesseract::SaveDataToFile(recognizer_data, - FLAGS_model_output.c_str())) { + if (FLAGS_convert_to_int) trainer.ConvertToInt(); + if (!trainer.SaveTraineddata(FLAGS_model_output.c_str())) { tprintf("Failed to write recognition model : %s\n", FLAGS_model_output.c_str()); } @@ -123,7 +121,6 @@ int main(int argc, char **argv) { return 1; } - UNICHARSET unicharset; // Checkpoints always take priority if they are available. if (trainer.TryLoadingCheckpoint(checkpoint_file.string()) || trainer.TryLoadingCheckpoint(checkpoint_bak.string())) { @@ -140,14 +137,6 @@ int main(int argc, char **argv) { trainer.InitIterations(); } if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) { - // We need a unicharset to start from scratch or append. - string unicharset_str; - // Character coding to be used by the classifier. - if (!unicharset.load_from_file(FLAGS_U.c_str())) { - tprintf("Error: must provide a -U unicharset!\n"); - return 1; - } - tesseract::SetupBasicProperties(true, &unicharset); if (FLAGS_append_index >= 0) { tprintf("Appending a new network to an old one!!"); if (FLAGS_continue_from.empty()) { @@ -156,8 +145,6 @@ int main(int argc, char **argv) { } } // We are initializing from scratch. - trainer.InitCharSet(unicharset, FLAGS_script_dir.c_str(), - FLAGS_train_mode); if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index, FLAGS_net_mode, FLAGS_weight_range, FLAGS_learning_rate, FLAGS_momentum)) { @@ -168,7 +155,9 @@ int main(int argc, char **argv) { trainer.set_perfect_delay(FLAGS_perfect_sample_delay); } } - if (!trainer.LoadAllTrainingData(filenames)) { + if (!trainer.LoadAllTrainingData( + filenames, FLAGS_sequential_training ? tesseract::CS_SEQUENTIAL + : tesseract::CS_ROUND_ROBIN)) { tprintf("Load of images failed!!\n"); return 1; } diff --git a/training/tesstrain.sh b/training/tesstrain.sh index c55b646f..18a57b00 100755 --- a/training/tesstrain.sh +++ b/training/tesstrain.sh @@ -60,11 +60,11 @@ initialize_fontconfig phase_I_generate_image 8 phase_UP_generate_unicharset -phase_D_generate_dawg if ((LINEDATA)); then phase_E_extract_features "lstm.train" 8 "lstmf" make__lstmdata else + phase_D_generate_dawg phase_E_extract_features "box.train" 8 "tr" phase_C_cluster_prototypes "${TRAINING_DIR}/${LANG_CODE}.normproto" if [[ "${ENABLE_SHAPE_CLUSTERING}" == "y" ]]; then diff --git a/training/tesstrain_utils.sh b/training/tesstrain_utils.sh index b319bbc4..de52a841 100755 --- a/training/tesstrain_utils.sh +++ b/training/tesstrain_utils.sh @@ -44,11 +44,19 @@ err_exit() { run_command() { local cmd=$(which $1) if [[ -z ${cmd} ]]; then - err_exit "$1 not found" + for d in api training; do + cmd=$(which $d/$1) + if [[ ! -z ${cmd} ]]; then + break + fi + done + if [[ -z ${cmd} ]]; then + err_exit "$1 not found" + fi fi shift tlog "[$(date)] ${cmd} $@" - ${cmd} "$@" 2>&1 1>&2 | tee -a ${LOG_FILE} + "${cmd}" "$@" 2>&1 1>&2 | tee -a ${LOG_FILE} # check completion status if [[ $? -gt 0 ]]; then err_exit "Program $(basename ${cmd}) failed. Abort." @@ -204,7 +212,7 @@ generate_font_image() { common_args+=" --fonts_dir=${FONTS_DIR} --strip_unrenderable_words" common_args+=" --leading=${LEADING}" common_args+=" --char_spacing=${CHAR_SPACING} --exposure=${EXPOSURE}" - common_args+=" --outputbase=${outbase}" + common_args+=" --outputbase=${outbase} --max_pages=3" # add --writing_mode=vertical-upright to common_args if the font is # specified to be rendered vertically. @@ -490,36 +498,43 @@ phase_B_generate_ambiguities() { make__lstmdata() { tlog "\n=== Constructing LSTM training data ===" - local lang_prefix=${LANGDATA_ROOT}/${LANG_CODE}/${LANG_CODE} - if [[ ! -d ${OUTPUT_DIR} ]]; then + local lang_prefix="${LANGDATA_ROOT}/${LANG_CODE}/${LANG_CODE}" + if [[ ! -d "${OUTPUT_DIR}" ]]; then tlog "Creating new directory ${OUTPUT_DIR}" - mkdir -p ${OUTPUT_DIR} + mkdir -p "${OUTPUT_DIR}" fi + local lang_is_rtl="" + # TODO(rays) set using script lang lists. + case "${LANG_CODE}" in + ara | div| fas | pus | snd | syr | uig | urd | kur_ara | heb | yid ) + lang_is_rtl="--lang_is_rtl" ;; + * ) ;; + esac + local pass_through="" + # TODO(rays) set using script lang lists. + case "${LANG_CODE}" in + asm | ben | bih | hin | mar | nep | guj | kan | mal | tam | tel | pan | \ + dzo | sin | san | bod | ori | khm | mya | tha | lao | heb | yid | ara | \ + fas | pus | snd | urd | div | syr | uig | kur_ara ) + pass_through="--pass_through_recoder" ;; + * ) ;; + esac - # Copy available files for this language from the langdata dir. - if [[ -r ${lang_prefix}.config ]]; then - tlog "Copying ${lang_prefix}.config to ${OUTPUT_DIR}" - cp ${lang_prefix}.config ${OUTPUT_DIR} - chmod u+w ${OUTPUT_DIR}/${LANG_CODE}.config - fi - if [[ -r "${TRAINING_DIR}/${LANG_CODE}.unicharset" ]]; then - tlog "Moving ${TRAINING_DIR}/${LANG_CODE}.unicharset to ${OUTPUT_DIR}" - mv "${TRAINING_DIR}/${LANG_CODE}.unicharset" "${OUTPUT_DIR}" - fi - for ext in number-dawg punc-dawg word-dawg; do - local src="${TRAINING_DIR}/${LANG_CODE}.${ext}" - if [[ -r "${src}" ]]; then - dest="${OUTPUT_DIR}/${LANG_CODE}.lstm-${ext}" - tlog "Moving ${src} to ${dest}" - mv "${src}" "${dest}" - fi - done + # Build the starter traineddata from the inputs. + run_command combine_lang_model \ + --input_unicharset "${TRAINING_DIR}/${LANG_CODE}.unicharset" \ + --script_dir "${LANGDATA_ROOT}" \ + --words "${lang_prefix}.wordlist" \ + --numbers "${lang_prefix}.numbers" \ + --puncs "${lang_prefix}.punc" \ + --output_dir "${OUTPUT_DIR}" --lang "${LANG_CODE}" \ + "${pass_through}" "${lang_is_rtl}" for f in "${TRAINING_DIR}/${LANG_CODE}".*.lstmf; do tlog "Moving ${f} to ${OUTPUT_DIR}" mv "${f}" "${OUTPUT_DIR}" done local lstm_list="${OUTPUT_DIR}/${LANG_CODE}.training_files.txt" - ls -1 "${OUTPUT_DIR}"/*.lstmf > "${lstm_list}" + ls -1 "${OUTPUT_DIR}/${LANG_CODE}".*.lstmf > "${lstm_list}" } make__traineddata() { diff --git a/training/text2image.cpp b/training/text2image.cpp index 0858d480..70496d68 100644 --- a/training/text2image.cpp +++ b/training/text2image.cpp @@ -79,6 +79,9 @@ INT_PARAM_FLAG(xsize, 3600, "Width of output image"); // Max height of output image (in pixels). INT_PARAM_FLAG(ysize, 4800, "Height of output image"); +// Max number of pages to produce. +INT_PARAM_FLAG(max_pages, 0, "Maximum number of pages to output (0=unlimited)"); + // Margin around text (in pixels). INT_PARAM_FLAG(margin, 100, "Margin round edges of image"); @@ -579,7 +582,10 @@ int Main() { for (int pass = 0; pass < num_pass; ++pass) { int page_num = 0; string font_used; - for (size_t offset = 0; offset < strlen(to_render_utf8); ++im, ++page_num) { + for (size_t offset = 0; + offset < strlen(to_render_utf8) && + (FLAGS_max_pages == 0 || page_num < FLAGS_max_pages); + ++im, ++page_num) { tlog(1, "Starting page %d\n", im); Pix* pix = nullptr; if (FLAGS_find_fonts) { diff --git a/training/unicharset_training_utils.cpp b/training/unicharset_training_utils.cpp index 9a720329..f8e38353 100644 --- a/training/unicharset_training_utils.cpp +++ b/training/unicharset_training_utils.cpp @@ -139,6 +139,42 @@ void SetupBasicProperties(bool report_errors, bool decompose, unicharset->post_load_setup(); } +// Helper sets the properties from universal script unicharsets, if found. +void SetScriptProperties(const string& script_dir, UNICHARSET* unicharset) { + for (int s = 0; s < unicharset->get_script_table_size(); ++s) { + // Load the unicharset for the script if available. + string filename = script_dir + "/" + + unicharset->get_script_from_script_id(s) + ".unicharset"; + UNICHARSET script_set; + if (script_set.load_from_file(filename.c_str())) { + unicharset->SetPropertiesFromOther(script_set); + } else if (s != unicharset->common_sid() && s != unicharset->null_sid()) { + tprintf("Failed to load script unicharset from:%s\n", filename.c_str()); + } + } + for (int c = SPECIAL_UNICHAR_CODES_COUNT; c < unicharset->size(); ++c) { + if (unicharset->PropertiesIncomplete(c)) { + tprintf("Warning: properties incomplete for index %d = %s\n", c, + unicharset->id_to_unichar(c)); + } + } +} + +// Helper gets the combined x-heights string. +string GetXheightString(const string& script_dir, + const UNICHARSET& unicharset) { + string xheights_str; + for (int s = 0; s < unicharset.get_script_table_size(); ++s) { + // Load the xheights for the script if available. + string filename = script_dir + "/" + + unicharset.get_script_from_script_id(s) + ".xheights"; + string script_heights; + if (File::ReadFileToString(filename, &script_heights)) + xheights_str += script_heights; + } + return xheights_str; +} + // Helper to set the properties for an input unicharset file, writes to the // output file. If an appropriate script unicharset can be found in the // script_dir directory, then the tops and bottoms are expanded using the @@ -158,29 +194,11 @@ void SetPropertiesForInputFile(const string& script_dir, // Set unichar properties tprintf("Setting unichar properties\n"); SetupBasicProperties(true, false, &unicharset); - string xheights_str; - for (int s = 0; s < unicharset.get_script_table_size(); ++s) { - // Load the unicharset for the script if available. - string filename = script_dir + "/" + - unicharset.get_script_from_script_id(s) + ".unicharset"; - UNICHARSET script_set; - if (script_set.load_from_file(filename.c_str())) { - unicharset.SetPropertiesFromOther(script_set); - } - // Load the xheights for the script if available. - filename = script_dir + "/" + unicharset.get_script_from_script_id(s) + - ".xheights"; - string script_heights; - if (File::ReadFileToString(filename, &script_heights)) - xheights_str += script_heights; - } - if (!output_xheights_file.empty()) + tprintf("Setting script properties\n"); + SetScriptProperties(script_dir, &unicharset); + if (!output_xheights_file.empty()) { + string xheights_str = GetXheightString(script_dir, unicharset); File::WriteStringToFileOrDie(xheights_str, output_xheights_file); - for (int c = SPECIAL_UNICHAR_CODES_COUNT; c < unicharset.size(); ++c) { - if (unicharset.PropertiesIncomplete(c)) { - tprintf("Warning: properties incomplete for index %d = %s\n", - c, unicharset.id_to_unichar(c)); - } } // Write the output unicharset diff --git a/training/unicharset_training_utils.h b/training/unicharset_training_utils.h index 0e42c35a..2f1a3986 100644 --- a/training/unicharset_training_utils.h +++ b/training/unicharset_training_utils.h @@ -38,6 +38,10 @@ void SetupBasicProperties(bool report_errors, bool decompose, inline void SetupBasicProperties(bool report_errors, UNICHARSET* unicharset) { SetupBasicProperties(report_errors, false, unicharset); } +// Helper sets the properties from universal script unicharsets, if found. +void SetScriptProperties(const string& script_dir, UNICHARSET* unicharset); +// Helper gets the combined x-heights string. +string GetXheightString(const string& script_dir, const UNICHARSET& unicharset); // Helper to set the properties for an input unicharset file, writes to the // output file. If an appropriate script unicharset can be found in the