/////////////////////////////////////////////////////////////////////// // File: lstmrecognizer.h // Description: Top-level line recognizer class for LSTM-based networks. // Author: Ray Smith // Created: Thu May 02 08:57:06 PST 2013 // // (C) Copyright 2013, Google Inc. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /////////////////////////////////////////////////////////////////////// #ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_ #define TESSERACT_LSTM_LSTMRECOGNIZER_H_ #include "ccutil.h" #include "helpers.h" #include "imagedata.h" #include "matrix.h" #include "network.h" #include "networkscratch.h" #include "recodebeam.h" #include "series.h" #include "strngs.h" #include "unicharcompress.h" class BLOB_CHOICE_IT; struct Pix; class ROW_RES; class ScrollView; class TBOX; class WERD_RES; namespace tesseract { class Dict; class ImageData; // Enum indicating training mode control flags. enum TrainingFlags { TF_INT_MODE = 1, TF_AUTO_HARDEN = 2, TF_ROUND_ROBIN_TRAINING = 16, TF_COMPRESS_UNICHARSET = 64, }; // Top-level line recognizer class for LSTM-based networks. // Note that a sub-class, LSTMTrainer is used for training. class LSTMRecognizer { public: LSTMRecognizer(); ~LSTMRecognizer(); int NumOutputs() const { return network_->NumOutputs(); } int training_iteration() const { return training_iteration_; } int sample_iteration() const { return sample_iteration_; } double learning_rate() const { return learning_rate_; } bool IsHardening() const { return (training_flags_ & TF_AUTO_HARDEN) != 0; } LossType OutputLossType() const { if (network_ == nullptr) return LT_NONE; StaticShape shape; shape = network_->OutputShape(shape); return shape.loss_type(); } bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; } bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; } // True if recoder_ is active to re-encode text to a smaller space. bool IsRecoding() const { return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0; } // Returns the cache strategy for the DocumentCache. CachingStrategy CacheStrategy() const { return training_flags_ & TF_ROUND_ROBIN_TRAINING ? CS_ROUND_ROBIN : CS_SEQUENTIAL; } // Returns true if the network is a TensorFlow network. bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; } // Returns a vector of layer ids that can be passed to other layer functions // to access a specific layer. GenericVector EnumerateLayers() const { ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES); Series* series = reinterpret_cast(network_); GenericVector layers; series->EnumerateLayers(NULL, &layers); return layers; } // Returns a specific layer from its id (from EnumerateLayers). Network* GetLayer(const STRING& id) const { ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES); ASSERT_HOST(id.length() > 1 && id[0] == ':'); Series* series = reinterpret_cast(network_); return series->GetLayer(&id[1]); } // Returns the learning rate of the layer from its id. float GetLayerLearningRate(const STRING& id) const { ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES); if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) { ASSERT_HOST(id.length() > 1 && id[0] == ':'); Series* series = reinterpret_cast(network_); return series->LayerLearningRate(&id[1]); } else { return learning_rate_; } } // Multiplies the all the learning rate(s) by the given factor. void ScaleLearningRate(double factor) { ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES); learning_rate_ *= factor; if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) { GenericVector layers = EnumerateLayers(); for (int i = 0; i < layers.size(); ++i) { ScaleLayerLearningRate(layers[i], factor); } } } // Multiplies the learning rate of the layer with id, by the given factor. void ScaleLayerLearningRate(const STRING& id, double factor) { ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES); ASSERT_HOST(id.length() > 1 && id[0] == ':'); Series* series = reinterpret_cast(network_); series->ScaleLayerLearningRate(&id[1], factor); } // True if the network is using adagrad to train. bool IsUsingAdaGrad() const { return network_->TestFlag(NF_ADA_GRAD); } // Provides access to the UNICHARSET that this classifier works with. const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; } // Provides access to the Dict that this classifier works with. const Dict* GetDict() const { return dict_; } // Sets the sample iteration to the given value. The sample_iteration_ // determines the seed for the random number generator. The training // iteration is incremented only by a successful training iteration. void SetIteration(int iteration) { sample_iteration_ = iteration; } // Accessors for textline image normalization. int NumInputs() const { return network_->NumInputs(); } int null_char() const { return null_char_; } // Writes to the given file. Returns false in case of error. bool Serialize(TFile* fp) const; // Reads from the given file. Returns false in case of error. // If swap is true, assumes a big/little-endian swap is needed. bool DeSerialize(bool swap, TFile* fp); // Loads the dictionary if possible from the traineddata file. // Prints a warning message, and returns false but otherwise fails silently // and continues to work without it if loading fails. // Note that dictionary load is independent from DeSerialize, but dependent // on the unicharset matching. This enables training to deserialize a model // from checkpoint or restore without having to go back and reload the // dictionary. bool LoadDictionary(const char* lang, TessdataManager* mgr); // Recognizes the line image, contained within image_data, returning the // ratings matrix and matching box_word for each WERD_RES in the output. // If invert, tries inverted as well if the normal interpretation doesn't // produce a good enough result. If use_alternates, the ratings matrix is // filled with segmentation and classifier alternatives that may be searched // using the standard beam search, otherwise, just a diagonal and prebuilt // best_choice. The line_box is used for computing the box_word in the // output words. Score_ratio is used to determine the classifier alternates. // If one_word, then a single WERD_RES is formed, regardless of the spaces // found during recognition. // If not NULL, we attempt to translate the output to target_unicharset, but // do not guarantee success, due to mismatches. In that case the output words // are marked with our UNICHARSET, not the caller's. void RecognizeLine(const ImageData& image_data, bool invert, bool debug, double worst_dict_cert, bool use_alternates, const UNICHARSET* target_unicharset, const TBOX& line_box, float score_ratio, bool one_word, PointerVector* words); // Builds a set of tesseract-compatible WERD_RESs aligned to line_box, // corresponding to the network output in outputs, labels, label_coords. // one_word generates a single word output, that may include spaces inside. // use_alternates generates alternative BLOB_CHOICEs and segmentation paths, // with cut-offs determined by scale_factor. // If not NULL, we attempt to translate the output to target_unicharset, but // do not guarantee success, due to mismatches. In that case the output words // are marked with our UNICHARSET, not the caller's. void WordsFromOutputs(const NetworkIO& outputs, const GenericVector& labels, const GenericVector label_coords, const TBOX& line_box, bool debug, bool use_alternates, bool one_word, float score_ratio, float scale_factor, const UNICHARSET* target_unicharset, PointerVector* words); // Helper computes min and mean best results in the output. void OutputStats(const NetworkIO& outputs, float* min_output, float* mean_output, float* sd); // Recognizes the image_data, returning the labels, // scores, and corresponding pairs of start, end x-coords in coords. // If label_threshold is positive, uses it for making the labels, otherwise // uses standard ctc. Returned in scale_factor is the reduction factor // between the image and the output coords, for computing bounding boxes. // If re_invert is true, the input is inverted back to its original // photometric interpretation if inversion is attempted but fails to // improve the results. This ensures that outputs contains the correct // forward outputs for the best photometric interpretation. // inputs is filled with the used inputs to the network, and if not null, // target boxes is filled with scaled truth boxes if present in image_data. bool RecognizeLine(const ImageData& image_data, bool invert, bool debug, bool re_invert, float label_threshold, float* scale_factor, NetworkIO* inputs, NetworkIO* outputs); // Returns a tesseract-compatible WERD_RES from the line recognizer outputs. // line_box should be the bounding box of the line image in the main image, // outputs the output of the network, // [word_start, word_end) the interval over which to convert, // score_ratio for choosing alternate classifier choices, // use_alternates to control generation of alternative segmentations, // labels, label_coords, scale_factor from RecognizeLine above. // If target_unicharset is not NULL, attempts to translate the internal // unichar_ids to the target_unicharset, but falls back to untranslated ids // if the translation should fail. WERD_RES* WordFromOutput(const TBOX& line_box, const NetworkIO& outputs, int word_start, int word_end, float score_ratio, float space_certainty, bool debug, bool use_alternates, const UNICHARSET* target_unicharset, const GenericVector& labels, const GenericVector& label_coords, float scale_factor); // Sets up a word with the ratings matrix and fake blobs with boxes in the // right places. WERD_RES* InitializeWord(const TBOX& line_box, int word_start, int word_end, float space_certainty, bool use_alternates, const UNICHARSET* target_unicharset, const GenericVector& labels, const GenericVector& label_coords, float scale_factor); // Converts an array of labels to utf-8, whether or not the labels are // augmented with character boundaries. STRING DecodeLabels(const GenericVector& labels); // Displays the forward results in a window with the characters and // boundaries as determined by the labels and label_coords. void DisplayForward(const NetworkIO& inputs, const GenericVector& labels, const GenericVector& label_coords, const char* window_name, ScrollView** window); protected: // Sets the random seed from the sample_iteration_; void SetRandomSeed() { inT64 seed = static_cast(sample_iteration_) * 0x10000001; randomizer_.set_seed(seed); randomizer_.IntRand(); } // Displays the labels and cuts at the corresponding xcoords. // Size of labels should match xcoords. void DisplayLSTMOutput(const GenericVector& labels, const GenericVector& xcoords, int height, ScrollView* window); // Prints debug output detailing the activation path that is implied by the // xcoords. void DebugActivationPath(const NetworkIO& outputs, const GenericVector& labels, const GenericVector& xcoords); // Prints debug output detailing activations and 2nd choice over a range // of positions. void DebugActivationRange(const NetworkIO& outputs, const char* label, int best_choice, int x_start, int x_end); // Converts the network output to a sequence of labels. Outputs labels, scores // and start xcoords of each char, and each null_char_, with an additional // final xcoord for the end of the output. // The conversion method is determined by internal state. void LabelsFromOutputs(const NetworkIO& outputs, float null_thr, GenericVector* labels, GenericVector* xcoords); // Converts the network output to a sequence of labels, using a threshold // on the null_char_ to determine character boundaries. Outputs labels, scores // and start xcoords of each char, and each null_char_, with an additional // final xcoord for the end of the output. // The label output is the one with the highest score in the interval between // null_chars_. void LabelsViaThreshold(const NetworkIO& output, float null_threshold, GenericVector* labels, GenericVector* xcoords); // Converts the network output to a sequence of labels, with scores and // start x-coords of the character labels. Retains the null_char_ character as // the end x-coord, where already present, otherwise the start of the next // character is the end. // The number of labels, scores, and xcoords is always matched, except that // there is always an additional xcoord for the last end position. void LabelsViaCTC(const NetworkIO& output, GenericVector* labels, GenericVector* xcoords); // As LabelsViaCTC except that this function constructs the best path that // contains only legal sequences of subcodes for recoder_. void LabelsViaReEncode(const NetworkIO& output, GenericVector* labels, GenericVector* xcoords); // Converts the network output to a sequence of labels, with scores, using // the simple character model (each position is a char, and the null_char_ is // mainly intended for tail padding.) void LabelsViaSimpleText(const NetworkIO& output, GenericVector* labels, GenericVector* xcoords); // Helper returns a BLOB_CHOICE_LIST for the choices in a given x-range. // Handles either LSTM labels or direct unichar-ids. // Score ratio determines the worst ratio between top choice and remainder. // If target_unicharset is not NULL, attempts to translate to the target // unicharset, returning NULL on failure. BLOB_CHOICE_LIST* GetBlobChoices(int col, int row, bool debug, const NetworkIO& output, const UNICHARSET* target_unicharset, int x_start, int x_end, float score_ratio); // Adds to the given iterator, the blob choices for the target_unicharset // that correspond to the given LSTM unichar_id. // Returns false if unicharset translation failed. bool AddBlobChoices(int unichar_id, float rating, float certainty, int col, int row, const UNICHARSET* target_unicharset, BLOB_CHOICE_IT* bc_it); // Returns a string corresponding to the label starting at start. Sets *end // to the next start and if non-null, *decoded to the unichar id. const char* DecodeLabel(const GenericVector& labels, int start, int* end, int* decoded); // Returns a string corresponding to a given single label id, falling back to // a default of ".." for part of a multi-label unichar-id. const char* DecodeSingleLabel(int label); protected: // The network hierarchy. Network* network_; // The unicharset. Only the unicharset element is serialized. // Has to be a CCUtil, so Dict can point to it. CCUtil ccutil_; // For backward compatibility, recoder_ is serialized iff // training_flags_ & TF_COMPRESS_UNICHARSET. // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset. UnicharCompress recoder_; // ==Training parameters that are serialized to provide a record of them.== STRING network_str_; // Flags used to determine the training method of the network. // See enum TrainingFlags above. inT32 training_flags_; // Number of actual backward training steps used. inT32 training_iteration_; // Index into training sample set. sample_iteration >= training_iteration_. inT32 sample_iteration_; // Index in softmax of null character. May take the value UNICHAR_BROKEN or // ccutil_.unicharset.size(). inT32 null_char_; // Range used for the initial random numbers in the weights. float weight_range_; // Learning rate and momentum multipliers of deltas in backprop. float learning_rate_; float momentum_; // === NOT SERIALIZED. TRand randomizer_; NetworkScratch scratch_space_; // Language model (optional) to use with the beam search. Dict* dict_; // Beam search held between uses to optimize memory allocation/use. RecodeBeamSearch* search_; // == Debugging parameters.== // Recognition debug display window. ScrollView* debug_win_; }; } // namespace tesseract. #endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_