2016-11-08 07:38:07 +08:00
|
|
|
///////////////////////////////////////////////////////////////////////
|
|
|
|
// File: lstmrecognizer.h
|
|
|
|
// Description: Top-level line recognizer class for LSTM-based networks.
|
|
|
|
// Author: Ray Smith
|
|
|
|
// Created: Thu May 02 08:57:06 PST 2013
|
|
|
|
//
|
|
|
|
// (C) Copyright 2013, Google Inc.
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
///////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
#ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
|
|
|
|
#define TESSERACT_LSTM_LSTMRECOGNIZER_H_
|
|
|
|
|
|
|
|
#include "ccutil.h"
|
|
|
|
#include "helpers.h"
|
|
|
|
#include "imagedata.h"
|
|
|
|
#include "matrix.h"
|
|
|
|
#include "network.h"
|
|
|
|
#include "networkscratch.h"
|
|
|
|
#include "recodebeam.h"
|
|
|
|
#include "series.h"
|
|
|
|
#include "strngs.h"
|
|
|
|
#include "unicharcompress.h"
|
|
|
|
|
|
|
|
class BLOB_CHOICE_IT;
|
|
|
|
struct Pix;
|
|
|
|
class ROW_RES;
|
|
|
|
class ScrollView;
|
|
|
|
class TBOX;
|
|
|
|
class WERD_RES;
|
|
|
|
|
|
|
|
namespace tesseract {
|
|
|
|
|
|
|
|
class Dict;
|
|
|
|
class ImageData;
|
|
|
|
|
|
|
|
// Enum indicating training mode control flags.
|
|
|
|
enum TrainingFlags {
|
|
|
|
TF_INT_MODE = 1,
|
|
|
|
TF_COMPRESS_UNICHARSET = 64,
|
|
|
|
};
|
|
|
|
|
|
|
|
// Top-level line recognizer class for LSTM-based networks.
|
|
|
|
// Note that a sub-class, LSTMTrainer is used for training.
|
|
|
|
class LSTMRecognizer {
|
|
|
|
public:
|
|
|
|
LSTMRecognizer();
|
|
|
|
~LSTMRecognizer();
|
|
|
|
|
|
|
|
int NumOutputs() const {
|
|
|
|
return network_->NumOutputs();
|
|
|
|
}
|
|
|
|
int training_iteration() const {
|
|
|
|
return training_iteration_;
|
|
|
|
}
|
|
|
|
int sample_iteration() const {
|
|
|
|
return sample_iteration_;
|
|
|
|
}
|
|
|
|
double learning_rate() const {
|
|
|
|
return learning_rate_;
|
|
|
|
}
|
|
|
|
LossType OutputLossType() const {
|
|
|
|
if (network_ == nullptr) return LT_NONE;
|
|
|
|
StaticShape shape;
|
|
|
|
shape = network_->OutputShape(shape);
|
|
|
|
return shape.loss_type();
|
|
|
|
}
|
|
|
|
bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; }
|
|
|
|
bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; }
|
|
|
|
// True if recoder_ is active to re-encode text to a smaller space.
|
|
|
|
bool IsRecoding() const {
|
|
|
|
return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0;
|
|
|
|
}
|
|
|
|
// Returns true if the network is a TensorFlow network.
|
|
|
|
bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
|
|
|
|
// Returns a vector of layer ids that can be passed to other layer functions
|
|
|
|
// to access a specific layer.
|
|
|
|
GenericVector<STRING> EnumerateLayers() const {
|
2018-03-25 23:19:27 +08:00
|
|
|
ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
|
2017-05-11 06:40:31 +08:00
|
|
|
Series* series = static_cast<Series*>(network_);
|
2016-11-08 07:38:07 +08:00
|
|
|
GenericVector<STRING> layers;
|
2018-03-25 23:19:27 +08:00
|
|
|
series->EnumerateLayers(nullptr, &layers);
|
2016-11-08 07:38:07 +08:00
|
|
|
return layers;
|
|
|
|
}
|
|
|
|
// Returns a specific layer from its id (from EnumerateLayers).
|
|
|
|
Network* GetLayer(const STRING& id) const {
|
2018-03-25 23:19:27 +08:00
|
|
|
ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
|
2016-11-08 07:38:07 +08:00
|
|
|
ASSERT_HOST(id.length() > 1 && id[0] == ':');
|
2017-05-11 06:40:31 +08:00
|
|
|
Series* series = static_cast<Series*>(network_);
|
2016-11-08 07:38:07 +08:00
|
|
|
return series->GetLayer(&id[1]);
|
|
|
|
}
|
|
|
|
// Returns the learning rate of the layer from its id.
|
|
|
|
float GetLayerLearningRate(const STRING& id) const {
|
2018-03-25 23:19:27 +08:00
|
|
|
ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
|
2016-11-08 07:38:07 +08:00
|
|
|
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
|
|
|
|
ASSERT_HOST(id.length() > 1 && id[0] == ':');
|
2017-05-11 06:40:31 +08:00
|
|
|
Series* series = static_cast<Series*>(network_);
|
2016-11-08 07:38:07 +08:00
|
|
|
return series->LayerLearningRate(&id[1]);
|
|
|
|
} else {
|
|
|
|
return learning_rate_;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Multiplies the all the learning rate(s) by the given factor.
|
|
|
|
void ScaleLearningRate(double factor) {
|
2018-03-25 23:19:27 +08:00
|
|
|
ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
|
2016-11-08 07:38:07 +08:00
|
|
|
learning_rate_ *= factor;
|
|
|
|
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
|
|
|
|
GenericVector<STRING> layers = EnumerateLayers();
|
|
|
|
for (int i = 0; i < layers.size(); ++i) {
|
|
|
|
ScaleLayerLearningRate(layers[i], factor);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Multiplies the learning rate of the layer with id, by the given factor.
|
|
|
|
void ScaleLayerLearningRate(const STRING& id, double factor) {
|
2018-03-25 23:19:27 +08:00
|
|
|
ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
|
2016-11-08 07:38:07 +08:00
|
|
|
ASSERT_HOST(id.length() > 1 && id[0] == ':');
|
2017-05-11 06:40:31 +08:00
|
|
|
Series* series = static_cast<Series*>(network_);
|
2016-11-08 07:38:07 +08:00
|
|
|
series->ScaleLayerLearningRate(&id[1], factor);
|
|
|
|
}
|
|
|
|
|
2017-08-03 05:53:07 +08:00
|
|
|
// Converts the network to int if not already.
|
|
|
|
void ConvertToInt() {
|
|
|
|
if ((training_flags_ & TF_INT_MODE) == 0) {
|
|
|
|
network_->ConvertToInt();
|
|
|
|
training_flags_ |= TF_INT_MODE;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-11-08 07:38:07 +08:00
|
|
|
// Provides access to the UNICHARSET that this classifier works with.
|
|
|
|
const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
|
2017-08-03 05:03:50 +08:00
|
|
|
// Provides access to the UnicharCompress that this classifier works with.
|
|
|
|
const UnicharCompress& GetRecoder() const { return recoder_; }
|
2016-12-06 06:41:43 +08:00
|
|
|
// Provides access to the Dict that this classifier works with.
|
|
|
|
const Dict* GetDict() const { return dict_; }
|
2016-11-08 07:38:07 +08:00
|
|
|
// Sets the sample iteration to the given value. The sample_iteration_
|
|
|
|
// determines the seed for the random number generator. The training
|
|
|
|
// iteration is incremented only by a successful training iteration.
|
|
|
|
void SetIteration(int iteration) {
|
|
|
|
sample_iteration_ = iteration;
|
|
|
|
}
|
|
|
|
// Accessors for textline image normalization.
|
|
|
|
int NumInputs() const {
|
|
|
|
return network_->NumInputs();
|
|
|
|
}
|
|
|
|
int null_char() const { return null_char_; }
|
|
|
|
|
2017-07-15 02:14:23 +08:00
|
|
|
// Loads a model from mgr, including the dictionary only if lang is not null.
|
|
|
|
bool Load(const char* lang, TessdataManager* mgr);
|
|
|
|
|
2016-11-08 07:38:07 +08:00
|
|
|
// Writes to the given file. Returns false in case of error.
|
2017-07-15 02:14:23 +08:00
|
|
|
// If mgr contains a unicharset and recoder, then they are not encoded to fp.
|
|
|
|
bool Serialize(const TessdataManager* mgr, TFile* fp) const;
|
2016-11-08 07:38:07 +08:00
|
|
|
// Reads from the given file. Returns false in case of error.
|
2017-07-15 02:14:23 +08:00
|
|
|
// If mgr contains a unicharset and recoder, then they are taken from there,
|
|
|
|
// otherwise, they are part of the serialization in fp.
|
|
|
|
bool DeSerialize(const TessdataManager* mgr, TFile* fp);
|
|
|
|
// Loads the charsets from mgr.
|
|
|
|
bool LoadCharsets(const TessdataManager* mgr);
|
|
|
|
// Loads the Recoder.
|
|
|
|
bool LoadRecoder(TFile* fp);
|
2016-11-08 07:38:07 +08:00
|
|
|
// Loads the dictionary if possible from the traineddata file.
|
|
|
|
// Prints a warning message, and returns false but otherwise fails silently
|
|
|
|
// and continues to work without it if loading fails.
|
|
|
|
// Note that dictionary load is independent from DeSerialize, but dependent
|
|
|
|
// on the unicharset matching. This enables training to deserialize a model
|
|
|
|
// from checkpoint or restore without having to go back and reload the
|
|
|
|
// dictionary.
|
2017-04-28 06:48:23 +08:00
|
|
|
bool LoadDictionary(const char* lang, TessdataManager* mgr);
|
2016-11-08 07:38:07 +08:00
|
|
|
|
|
|
|
// Recognizes the line image, contained within image_data, returning the
|
2017-07-15 01:58:21 +08:00
|
|
|
// recognized tesseract WERD_RES for the words.
|
2016-11-08 07:38:07 +08:00
|
|
|
// If invert, tries inverted as well if the normal interpretation doesn't
|
2017-07-15 01:58:21 +08:00
|
|
|
// produce a good enough result. The line_box is used for computing the
|
|
|
|
// box_word in the output words. worst_dict_cert is the worst certainty that
|
|
|
|
// will be used in a dictionary word.
|
2016-11-08 07:38:07 +08:00
|
|
|
void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
|
2017-07-15 01:58:21 +08:00
|
|
|
double worst_dict_cert, const TBOX& line_box,
|
2016-11-08 07:38:07 +08:00
|
|
|
PointerVector<WERD_RES>* words);
|
|
|
|
|
|
|
|
// Helper computes min and mean best results in the output.
|
|
|
|
void OutputStats(const NetworkIO& outputs,
|
|
|
|
float* min_output, float* mean_output, float* sd);
|
|
|
|
// Recognizes the image_data, returning the labels,
|
|
|
|
// scores, and corresponding pairs of start, end x-coords in coords.
|
2017-07-15 01:58:21 +08:00
|
|
|
// Returned in scale_factor is the reduction factor
|
2016-11-08 07:38:07 +08:00
|
|
|
// between the image and the output coords, for computing bounding boxes.
|
2016-11-22 15:20:05 +08:00
|
|
|
// If re_invert is true, the input is inverted back to its original
|
2016-11-08 07:38:07 +08:00
|
|
|
// photometric interpretation if inversion is attempted but fails to
|
|
|
|
// improve the results. This ensures that outputs contains the correct
|
|
|
|
// forward outputs for the best photometric interpretation.
|
2017-07-15 01:58:21 +08:00
|
|
|
// inputs is filled with the used inputs to the network.
|
2016-11-08 07:38:07 +08:00
|
|
|
bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
|
2017-09-08 19:42:57 +08:00
|
|
|
bool re_invert, bool upside_down, float* scale_factor,
|
|
|
|
NetworkIO* inputs, NetworkIO* outputs);
|
2016-11-08 07:38:07 +08:00
|
|
|
|
|
|
|
// Converts an array of labels to utf-8, whether or not the labels are
|
|
|
|
// augmented with character boundaries.
|
|
|
|
STRING DecodeLabels(const GenericVector<int>& labels);
|
|
|
|
|
|
|
|
// Displays the forward results in a window with the characters and
|
|
|
|
// boundaries as determined by the labels and label_coords.
|
|
|
|
void DisplayForward(const NetworkIO& inputs,
|
|
|
|
const GenericVector<int>& labels,
|
|
|
|
const GenericVector<int>& label_coords,
|
|
|
|
const char* window_name,
|
|
|
|
ScrollView** window);
|
2017-08-03 05:03:50 +08:00
|
|
|
// Converts the network output to a sequence of labels. Outputs labels, scores
|
|
|
|
// and start xcoords of each char, and each null_char_, with an additional
|
|
|
|
// final xcoord for the end of the output.
|
|
|
|
// The conversion method is determined by internal state.
|
|
|
|
void LabelsFromOutputs(const NetworkIO& outputs, GenericVector<int>* labels,
|
|
|
|
GenericVector<int>* xcoords);
|
2016-11-08 07:38:07 +08:00
|
|
|
|
|
|
|
protected:
|
|
|
|
// Sets the random seed from the sample_iteration_;
|
|
|
|
void SetRandomSeed() {
|
2018-03-14 02:01:40 +08:00
|
|
|
int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001;
|
2016-11-08 07:38:07 +08:00
|
|
|
randomizer_.set_seed(seed);
|
|
|
|
randomizer_.IntRand();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Displays the labels and cuts at the corresponding xcoords.
|
|
|
|
// Size of labels should match xcoords.
|
|
|
|
void DisplayLSTMOutput(const GenericVector<int>& labels,
|
|
|
|
const GenericVector<int>& xcoords,
|
|
|
|
int height, ScrollView* window);
|
|
|
|
|
|
|
|
// Prints debug output detailing the activation path that is implied by the
|
|
|
|
// xcoords.
|
|
|
|
void DebugActivationPath(const NetworkIO& outputs,
|
|
|
|
const GenericVector<int>& labels,
|
|
|
|
const GenericVector<int>& xcoords);
|
|
|
|
|
|
|
|
// Prints debug output detailing activations and 2nd choice over a range
|
|
|
|
// of positions.
|
|
|
|
void DebugActivationRange(const NetworkIO& outputs, const char* label,
|
|
|
|
int best_choice, int x_start, int x_end);
|
|
|
|
|
|
|
|
// As LabelsViaCTC except that this function constructs the best path that
|
|
|
|
// contains only legal sequences of subcodes for recoder_.
|
|
|
|
void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
|
|
|
|
GenericVector<int>* xcoords);
|
|
|
|
// Converts the network output to a sequence of labels, with scores, using
|
|
|
|
// the simple character model (each position is a char, and the null_char_ is
|
|
|
|
// mainly intended for tail padding.)
|
|
|
|
void LabelsViaSimpleText(const NetworkIO& output,
|
|
|
|
GenericVector<int>* labels,
|
|
|
|
GenericVector<int>* xcoords);
|
|
|
|
|
|
|
|
// Returns a string corresponding to the label starting at start. Sets *end
|
|
|
|
// to the next start and if non-null, *decoded to the unichar id.
|
|
|
|
const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
|
|
|
|
int* decoded);
|
|
|
|
|
|
|
|
// Returns a string corresponding to a given single label id, falling back to
|
|
|
|
// a default of ".." for part of a multi-label unichar-id.
|
|
|
|
const char* DecodeSingleLabel(int label);
|
|
|
|
|
|
|
|
protected:
|
|
|
|
// The network hierarchy.
|
|
|
|
Network* network_;
|
|
|
|
// The unicharset. Only the unicharset element is serialized.
|
|
|
|
// Has to be a CCUtil, so Dict can point to it.
|
|
|
|
CCUtil ccutil_;
|
2016-11-22 15:20:05 +08:00
|
|
|
// For backward compatibility, recoder_ is serialized iff
|
2016-11-08 07:38:07 +08:00
|
|
|
// training_flags_ & TF_COMPRESS_UNICHARSET.
|
|
|
|
// Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
|
|
|
|
UnicharCompress recoder_;
|
|
|
|
|
|
|
|
// ==Training parameters that are serialized to provide a record of them.==
|
|
|
|
STRING network_str_;
|
|
|
|
// Flags used to determine the training method of the network.
|
|
|
|
// See enum TrainingFlags above.
|
2018-03-14 02:01:40 +08:00
|
|
|
int32_t training_flags_;
|
2016-11-08 07:38:07 +08:00
|
|
|
// Number of actual backward training steps used.
|
2018-03-14 02:01:40 +08:00
|
|
|
int32_t training_iteration_;
|
2016-11-08 07:38:07 +08:00
|
|
|
// Index into training sample set. sample_iteration >= training_iteration_.
|
2018-03-14 02:01:40 +08:00
|
|
|
int32_t sample_iteration_;
|
2016-11-08 07:38:07 +08:00
|
|
|
// Index in softmax of null character. May take the value UNICHAR_BROKEN or
|
|
|
|
// ccutil_.unicharset.size().
|
2018-03-14 02:01:40 +08:00
|
|
|
int32_t null_char_;
|
2016-11-08 07:38:07 +08:00
|
|
|
// Learning rate and momentum multipliers of deltas in backprop.
|
|
|
|
float learning_rate_;
|
|
|
|
float momentum_;
|
2017-08-03 05:03:50 +08:00
|
|
|
// Smoothing factor for 2nd moment of gradients.
|
|
|
|
float adam_beta_;
|
2016-11-08 07:38:07 +08:00
|
|
|
|
|
|
|
// === NOT SERIALIZED.
|
|
|
|
TRand randomizer_;
|
|
|
|
NetworkScratch scratch_space_;
|
|
|
|
// Language model (optional) to use with the beam search.
|
|
|
|
Dict* dict_;
|
|
|
|
// Beam search held between uses to optimize memory allocation/use.
|
|
|
|
RecodeBeamSearch* search_;
|
|
|
|
|
|
|
|
// == Debugging parameters.==
|
|
|
|
// Recognition debug display window.
|
|
|
|
ScrollView* debug_win_;
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace tesseract.
|
|
|
|
|
|
|
|
#endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_
|