/////////////////////////////////////////////////////////////////////// // File: lstmrecognizer.cpp // Description: Top-level line recognizer class for LSTM-based networks. // Author: Ray Smith // Created: Thu May 02 10:59:06 PST 2013 // // (C) Copyright 2013, Google Inc. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /////////////////////////////////////////////////////////////////////// // Include automatically generated configuration file if running autoconf. #ifdef HAVE_CONFIG_H #include "config_auto.h" #endif #include "lstmrecognizer.h" #include "allheaders.h" #include "callcpp.h" #include "dict.h" #include "genericheap.h" #include "helpers.h" #include "imagedata.h" #include "input.h" #include "lstm.h" #include "normalis.h" #include "pageres.h" #include "ratngs.h" #include "recodebeam.h" #include "scrollview.h" #include "shapetable.h" #include "statistc.h" #include "tprintf.h" namespace tesseract { // Max number of blob choices to return in any given position. const int kMaxChoices = 4; // Default ratio between dict and non-dict words. const double kDictRatio = 2.25; // Default certainty offset to give the dictionary a chance. const double kCertOffset = -0.085; LSTMRecognizer::LSTMRecognizer() : network_(nullptr), training_flags_(0), training_iteration_(0), sample_iteration_(0), null_char_(UNICHAR_BROKEN), learning_rate_(0.0f), momentum_(0.0f), adam_beta_(0.0f), dict_(nullptr), search_(nullptr), debug_win_(nullptr) {} LSTMRecognizer::~LSTMRecognizer() { delete network_; delete dict_; delete search_; } // Loads a model from mgr, including the dictionary only if lang is not null. bool LSTMRecognizer::Load(const char* lang, TessdataManager* mgr) { TFile fp; if (!mgr->GetComponent(TESSDATA_LSTM, &fp)) return false; if (!DeSerialize(mgr, &fp)) return false; if (lang == nullptr) return true; // Allow it to run without a dictionary. LoadDictionary(lang, mgr); return true; } // Writes to the given file. Returns false in case of error. bool LSTMRecognizer::Serialize(const TessdataManager* mgr, TFile* fp) const { bool include_charsets = mgr == nullptr || !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) || !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET); if (!network_->Serialize(fp)) return false; if (include_charsets && !GetUnicharset().save_to_file(fp)) return false; if (!network_str_.Serialize(fp)) return false; if (fp->FWrite(&training_flags_, sizeof(training_flags_), 1) != 1) return false; if (fp->FWrite(&training_iteration_, sizeof(training_iteration_), 1) != 1) return false; if (fp->FWrite(&sample_iteration_, sizeof(sample_iteration_), 1) != 1) return false; if (fp->FWrite(&null_char_, sizeof(null_char_), 1) != 1) return false; if (fp->FWrite(&adam_beta_, sizeof(adam_beta_), 1) != 1) return false; if (fp->FWrite(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false; if (fp->FWrite(&momentum_, sizeof(momentum_), 1) != 1) return false; if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) return false; return true; } // Reads from the given file. Returns false in case of error. bool LSTMRecognizer::DeSerialize(const TessdataManager* mgr, TFile* fp) { delete network_; network_ = Network::CreateFromFile(fp); if (network_ == nullptr) return false; bool include_charsets = mgr == nullptr || !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) || !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET); if (include_charsets && !ccutil_.unicharset.load_from_file(fp, false)) return false; if (!network_str_.DeSerialize(fp)) return false; if (fp->FReadEndian(&training_flags_, sizeof(training_flags_), 1) != 1) return false; if (fp->FReadEndian(&training_iteration_, sizeof(training_iteration_), 1) != 1) return false; if (fp->FReadEndian(&sample_iteration_, sizeof(sample_iteration_), 1) != 1) return false; if (fp->FReadEndian(&null_char_, sizeof(null_char_), 1) != 1) return false; if (fp->FReadEndian(&adam_beta_, sizeof(adam_beta_), 1) != 1) return false; if (fp->FReadEndian(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false; if (fp->FReadEndian(&momentum_, sizeof(momentum_), 1) != 1) return false; if (include_charsets && !LoadRecoder(fp)) return false; if (!include_charsets && !LoadCharsets(mgr)) return false; network_->SetRandomizer(&randomizer_); network_->CacheXScaleFactor(network_->XScaleFactor()); return true; } // Loads the charsets from mgr. bool LSTMRecognizer::LoadCharsets(const TessdataManager* mgr) { TFile fp; if (!mgr->GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) return false; if (!ccutil_.unicharset.load_from_file(&fp, false)) return false; if (!mgr->GetComponent(TESSDATA_LSTM_RECODER, &fp)) return false; if (!LoadRecoder(&fp)) return false; return true; } // Loads the Recoder. bool LSTMRecognizer::LoadRecoder(TFile* fp) { if (IsRecoding()) { if (!recoder_.DeSerialize(fp)) return false; RecodedCharID code; recoder_.EncodeUnichar(UNICHAR_SPACE, &code); if (code(0) != UNICHAR_SPACE) { tprintf("Space was garbled in recoding!!\n"); return false; } } else { recoder_.SetupPassThrough(GetUnicharset()); training_flags_ |= TF_COMPRESS_UNICHARSET; } return true; } // Loads the dictionary if possible from the traineddata file. // Prints a warning message, and returns false but otherwise fails silently // and continues to work without it if loading fails. // Note that dictionary load is independent from DeSerialize, but dependent // on the unicharset matching. This enables training to deserialize a model // from checkpoint or restore without having to go back and reload the // dictionary. bool LSTMRecognizer::LoadDictionary(const char* lang, TessdataManager* mgr) { delete dict_; dict_ = new Dict(&ccutil_); dict_->SetupForLoad(Dict::GlobalDawgCache()); dict_->LoadLSTM(lang, mgr); if (dict_->FinishLoad()) return true; // Success. tprintf("Failed to load any lstm-specific dictionaries for lang %s!!\n", lang); delete dict_; dict_ = nullptr; return false; } // Recognizes the line image, contained within image_data, returning the // ratings matrix and matching box_word for each WERD_RES in the output. void LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert, bool debug, double worst_dict_cert, const TBOX& line_box, PointerVector* words) { NetworkIO outputs; float scale_factor; NetworkIO inputs; if (!RecognizeLine(image_data, invert, debug, false, false, &scale_factor, &inputs, &outputs)) return; if (search_ == nullptr) { search_ = new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_); } search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, nullptr); search_->ExtractBestPathAsWords(line_box, scale_factor, debug, &GetUnicharset(), words); } // Helper computes min and mean best results in the output. void LSTMRecognizer::OutputStats(const NetworkIO& outputs, float* min_output, float* mean_output, float* sd) { const int kOutputScale = INT8_MAX; STATS stats(0, kOutputScale + 1); for (int t = 0; t < outputs.Width(); ++t) { int best_label = outputs.BestLabel(t, nullptr); if (best_label != null_char_) { float best_output = outputs.f(t)[best_label]; stats.add(static_cast(kOutputScale * best_output), 1); } } // If the output is all nulls it could be that the photometric interpretation // is wrong, so make it look bad, so the other way can win, even if not great. if (stats.get_total() == 0) { *min_output = 0.0f; *mean_output = 0.0f; *sd = 1.0f; } else { *min_output = static_cast(stats.min_bucket()) / kOutputScale; *mean_output = stats.mean() / kOutputScale; *sd = stats.sd() / kOutputScale; } } // Recognizes the image_data, returning the labels, // scores, and corresponding pairs of start, end x-coords in coords. bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert, bool debug, bool re_invert, bool upside_down, float* scale_factor, NetworkIO* inputs, NetworkIO* outputs) { // Maximum width of image to train on. const int kMaxImageWidth = 2560; // This ensures consistent recognition results. SetRandomSeed(); int min_width = network_->XScaleFactor(); Pix* pix = Input::PrepareLSTMInputs(image_data, network_, min_width, &randomizer_, scale_factor); if (pix == nullptr) { tprintf("Line cannot be recognized!!\n"); return false; } if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) { tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix), pixGetHeight(pix)); pixDestroy(&pix); return false; } if (upside_down) pixRotate180(pix, pix); // Reduction factor from image to coords. *scale_factor = min_width / *scale_factor; inputs->set_int_mode(IsIntMode()); SetRandomSeed(); Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, inputs); network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs); // Check for auto inversion. float pos_min, pos_mean, pos_sd; OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd); if (invert && pos_min < 0.5) { // Run again inverted and see if it is any better. NetworkIO inv_inputs, inv_outputs; inv_inputs.set_int_mode(IsIntMode()); SetRandomSeed(); pixInvert(pix, pix); Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, &inv_inputs); network_->Forward(debug, inv_inputs, nullptr, &scratch_space_, &inv_outputs); float inv_min, inv_mean, inv_sd; OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd); if (inv_min > pos_min && inv_mean > pos_mean && inv_sd < pos_sd) { // Inverted did better. Use inverted data. if (debug) { tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n", pos_min, pos_mean, pos_sd, inv_min, inv_mean, inv_sd); } *outputs = inv_outputs; *inputs = inv_inputs; } else if (re_invert) { // Inverting was not an improvement, so undo and run again, so the // outputs match the best forward result. SetRandomSeed(); network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs); } } pixDestroy(&pix); if (debug) { GenericVector labels, coords; LabelsFromOutputs(*outputs, &labels, &coords); DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_); DebugActivationPath(*outputs, labels, coords); } return true; } // Converts an array of labels to utf-8, whether or not the labels are // augmented with character boundaries. STRING LSTMRecognizer::DecodeLabels(const GenericVector& labels) { STRING result; int end = 1; for (int start = 0; start < labels.size(); start = end) { if (labels[start] == null_char_) { end = start + 1; } else { result += DecodeLabel(labels, start, &end, nullptr); } } return result; } // Displays the forward results in a window with the characters and // boundaries as determined by the labels and label_coords. void LSTMRecognizer::DisplayForward(const NetworkIO& inputs, const GenericVector& labels, const GenericVector& label_coords, const char* window_name, ScrollView** window) { #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics Pix* input_pix = inputs.ToPix(); Network::ClearWindow(false, window_name, pixGetWidth(input_pix), pixGetHeight(input_pix), window); int line_height = Network::DisplayImage(input_pix, *window); DisplayLSTMOutput(labels, label_coords, line_height, *window); #endif // GRAPHICS_DISABLED } // Displays the labels and cuts at the corresponding xcoords. // Size of labels should match xcoords. void LSTMRecognizer::DisplayLSTMOutput(const GenericVector& labels, const GenericVector& xcoords, int height, ScrollView* window) { #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics int x_scale = network_->XScaleFactor(); window->TextAttributes("Arial", height / 4, false, false, false); int end = 1; for (int start = 0; start < labels.size(); start = end) { int xpos = xcoords[start] * x_scale; if (labels[start] == null_char_) { end = start + 1; window->Pen(ScrollView::RED); } else { window->Pen(ScrollView::GREEN); const char* str = DecodeLabel(labels, start, &end, nullptr); if (*str == '\\') str = "\\\\"; xpos = xcoords[(start + end) / 2] * x_scale; window->Text(xpos, height, str); } window->Line(xpos, 0, xpos, height * 3 / 2); } window->Update(); #endif // GRAPHICS_DISABLED } // Prints debug output detailing the activation path that is implied by the // label_coords. void LSTMRecognizer::DebugActivationPath(const NetworkIO& outputs, const GenericVector& labels, const GenericVector& xcoords) { if (xcoords[0] > 0) DebugActivationRange(outputs, "", null_char_, 0, xcoords[0]); int end = 1; for (int start = 0; start < labels.size(); start = end) { if (labels[start] == null_char_) { end = start + 1; DebugActivationRange(outputs, "", null_char_, xcoords[start], xcoords[end]); continue; } else { int decoded; const char* label = DecodeLabel(labels, start, &end, &decoded); DebugActivationRange(outputs, label, labels[start], xcoords[start], xcoords[start + 1]); for (int i = start + 1; i < end; ++i) { DebugActivationRange(outputs, DecodeSingleLabel(labels[i]), labels[i], xcoords[i], xcoords[i + 1]); } } } } // Prints debug output detailing activations and 2nd choice over a range // of positions. void LSTMRecognizer::DebugActivationRange(const NetworkIO& outputs, const char* label, int best_choice, int x_start, int x_end) { tprintf("%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end); double max_score = 0.0; double mean_score = 0.0; int width = x_end - x_start; for (int x = x_start; x < x_end; ++x) { const float* line = outputs.f(x); double score = line[best_choice] * 100.0; if (score > max_score) max_score = score; mean_score += score / width; int best_c = 0; double best_score = 0.0; for (int c = 0; c < outputs.NumFeatures(); ++c) { if (c != best_choice && line[c] > best_score) { best_c = c; best_score = line[c]; } } tprintf(" %.3g(%s=%d=%.3g)", score, DecodeSingleLabel(best_c), best_c, best_score * 100.0); } tprintf(", Mean=%g, max=%g\n", mean_score, max_score); } // Helper returns true if the null_char is the winner at t, and it beats the // null_threshold, or the next choice is space, in which case we will use the // null anyway. static bool NullIsBest(const NetworkIO& output, float null_thr, int null_char, int t) { if (output.f(t)[null_char] >= null_thr) return true; if (output.BestLabel(t, null_char, null_char, nullptr) != UNICHAR_SPACE) return false; return output.f(t)[null_char] > output.f(t)[UNICHAR_SPACE]; } // Converts the network output to a sequence of labels. Outputs labels, scores // and start xcoords of each char, and each null_char_, with an additional // final xcoord for the end of the output. // The conversion method is determined by internal state. void LSTMRecognizer::LabelsFromOutputs(const NetworkIO& outputs, GenericVector* labels, GenericVector* xcoords) { if (SimpleTextOutput()) { LabelsViaSimpleText(outputs, labels, xcoords); } else { LabelsViaReEncode(outputs, labels, xcoords); } } // As LabelsViaCTC except that this function constructs the best path that // contains only legal sequences of subcodes for CJK. void LSTMRecognizer::LabelsViaReEncode(const NetworkIO& output, GenericVector* labels, GenericVector* xcoords) { if (search_ == nullptr) { search_ = new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_); } search_->Decode(output, 1.0, 0.0, RecodeBeamSearch::kMinCertainty, nullptr); search_->ExtractBestPathAsLabels(labels, xcoords); } // Converts the network output to a sequence of labels, with scores, using // the simple character model (each position is a char, and the null_char_ is // mainly intended for tail padding.) void LSTMRecognizer::LabelsViaSimpleText(const NetworkIO& output, GenericVector* labels, GenericVector* xcoords) { labels->truncate(0); xcoords->truncate(0); int width = output.Width(); for (int t = 0; t < width; ++t) { float score = 0.0f; int label = output.BestLabel(t, &score); if (label != null_char_) { labels->push_back(label); xcoords->push_back(t); } } xcoords->push_back(width); } // Returns a string corresponding to the label starting at start. Sets *end // to the next start and if non-null, *decoded to the unichar id. const char* LSTMRecognizer::DecodeLabel(const GenericVector& labels, int start, int* end, int* decoded) { *end = start + 1; if (IsRecoding()) { // Decode labels via recoder_. RecodedCharID code; if (labels[start] == null_char_) { if (decoded != nullptr) { code.Set(0, null_char_); *decoded = recoder_.DecodeUnichar(code); } return ""; } int index = start; while (index < labels.size() && code.length() < RecodedCharID::kMaxCodeLen) { code.Set(code.length(), labels[index++]); while (index < labels.size() && labels[index] == null_char_) ++index; int uni_id = recoder_.DecodeUnichar(code); // If the next label isn't a valid first code, then we need to continue // extending even if we have a valid uni_id from this prefix. if (uni_id != INVALID_UNICHAR_ID && (index == labels.size() || code.length() == RecodedCharID::kMaxCodeLen || recoder_.IsValidFirstCode(labels[index]))) { *end = index; if (decoded != nullptr) *decoded = uni_id; if (uni_id == UNICHAR_SPACE) return " "; return GetUnicharset().get_normed_unichar(uni_id); } } return ""; } else { if (decoded != nullptr) *decoded = labels[start]; if (labels[start] == null_char_) return ""; if (labels[start] == UNICHAR_SPACE) return " "; return GetUnicharset().get_normed_unichar(labels[start]); } } // Returns a string corresponding to a given single label id, falling back to // a default of ".." for part of a multi-label unichar-id. const char* LSTMRecognizer::DecodeSingleLabel(int label) { if (label == null_char_) return ""; if (IsRecoding()) { // Decode label via recoder_. RecodedCharID code; code.Set(0, label); label = recoder_.DecodeUnichar(code); if (label == INVALID_UNICHAR_ID) return ".."; // Part of a bigger code. } if (label == UNICHAR_SPACE) return " "; return GetUnicharset().get_normed_unichar(label); } } // namespace tesseract.