tesseract/lstm/lstmrecognizer.h

313 lines
13 KiB
C
Raw Normal View History

///////////////////////////////////////////////////////////////////////
// File: lstmrecognizer.h
// Description: Top-level line recognizer class for LSTM-based networks.
// Author: Ray Smith
// Created: Thu May 02 08:57:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
#define TESSERACT_LSTM_LSTMRECOGNIZER_H_
#include "ccutil.h"
#include "helpers.h"
#include "imagedata.h"
#include "matrix.h"
#include "network.h"
#include "networkscratch.h"
#include "recodebeam.h"
#include "series.h"
#include "strngs.h"
#include "unicharcompress.h"
class BLOB_CHOICE_IT;
struct Pix;
class ROW_RES;
class ScrollView;
class TBOX;
class WERD_RES;
namespace tesseract {
class Dict;
class ImageData;
// Enum indicating training mode control flags.
enum TrainingFlags {
TF_INT_MODE = 1,
TF_COMPRESS_UNICHARSET = 64,
};
// Top-level line recognizer class for LSTM-based networks.
// Note that a sub-class, LSTMTrainer is used for training.
class LSTMRecognizer {
public:
LSTMRecognizer();
~LSTMRecognizer();
int NumOutputs() const {
return network_->NumOutputs();
}
int training_iteration() const {
return training_iteration_;
}
int sample_iteration() const {
return sample_iteration_;
}
double learning_rate() const {
return learning_rate_;
}
LossType OutputLossType() const {
if (network_ == nullptr) return LT_NONE;
StaticShape shape;
shape = network_->OutputShape(shape);
return shape.loss_type();
}
bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; }
bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; }
// True if recoder_ is active to re-encode text to a smaller space.
bool IsRecoding() const {
return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0;
}
// Returns true if the network is a TensorFlow network.
bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
// Returns a vector of layer ids that can be passed to other layer functions
// to access a specific layer.
GenericVector<STRING> EnumerateLayers() const {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
2017-05-11 06:40:31 +08:00
Series* series = static_cast<Series*>(network_);
GenericVector<STRING> layers;
series->EnumerateLayers(NULL, &layers);
return layers;
}
// Returns a specific layer from its id (from EnumerateLayers).
Network* GetLayer(const STRING& id) const {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
ASSERT_HOST(id.length() > 1 && id[0] == ':');
2017-05-11 06:40:31 +08:00
Series* series = static_cast<Series*>(network_);
return series->GetLayer(&id[1]);
}
// Returns the learning rate of the layer from its id.
float GetLayerLearningRate(const STRING& id) const {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
ASSERT_HOST(id.length() > 1 && id[0] == ':');
2017-05-11 06:40:31 +08:00
Series* series = static_cast<Series*>(network_);
return series->LayerLearningRate(&id[1]);
} else {
return learning_rate_;
}
}
// Multiplies the all the learning rate(s) by the given factor.
void ScaleLearningRate(double factor) {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
learning_rate_ *= factor;
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
GenericVector<STRING> layers = EnumerateLayers();
for (int i = 0; i < layers.size(); ++i) {
ScaleLayerLearningRate(layers[i], factor);
}
}
}
// Multiplies the learning rate of the layer with id, by the given factor.
void ScaleLayerLearningRate(const STRING& id, double factor) {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
ASSERT_HOST(id.length() > 1 && id[0] == ':');
2017-05-11 06:40:31 +08:00
Series* series = static_cast<Series*>(network_);
series->ScaleLayerLearningRate(&id[1], factor);
}
// Converts the network to int if not already.
void ConvertToInt() {
if ((training_flags_ & TF_INT_MODE) == 0) {
network_->ConvertToInt();
training_flags_ |= TF_INT_MODE;
}
}
// Provides access to the UNICHARSET that this classifier works with.
const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
// Provides access to the UnicharCompress that this classifier works with.
const UnicharCompress& GetRecoder() const { return recoder_; }
// Provides access to the Dict that this classifier works with.
const Dict* GetDict() const { return dict_; }
// Sets the sample iteration to the given value. The sample_iteration_
// determines the seed for the random number generator. The training
// iteration is incremented only by a successful training iteration.
void SetIteration(int iteration) {
sample_iteration_ = iteration;
}
// Accessors for textline image normalization.
int NumInputs() const {
return network_->NumInputs();
}
int null_char() const { return null_char_; }
// Loads a model from mgr, including the dictionary only if lang is not null.
bool Load(const char* lang, TessdataManager* mgr);
// Writes to the given file. Returns false in case of error.
// If mgr contains a unicharset and recoder, then they are not encoded to fp.
bool Serialize(const TessdataManager* mgr, TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If mgr contains a unicharset and recoder, then they are taken from there,
// otherwise, they are part of the serialization in fp.
bool DeSerialize(const TessdataManager* mgr, TFile* fp);
// Loads the charsets from mgr.
bool LoadCharsets(const TessdataManager* mgr);
// Loads the Recoder.
bool LoadRecoder(TFile* fp);
// Loads the dictionary if possible from the traineddata file.
// Prints a warning message, and returns false but otherwise fails silently
// and continues to work without it if loading fails.
// Note that dictionary load is independent from DeSerialize, but dependent
// on the unicharset matching. This enables training to deserialize a model
// from checkpoint or restore without having to go back and reload the
// dictionary.
bool LoadDictionary(const char* lang, TessdataManager* mgr);
// Recognizes the line image, contained within image_data, returning the
// recognized tesseract WERD_RES for the words.
// If invert, tries inverted as well if the normal interpretation doesn't
// produce a good enough result. The line_box is used for computing the
// box_word in the output words. worst_dict_cert is the worst certainty that
// will be used in a dictionary word.
void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
double worst_dict_cert, const TBOX& line_box,
PointerVector<WERD_RES>* words);
// Helper computes min and mean best results in the output.
void OutputStats(const NetworkIO& outputs,
float* min_output, float* mean_output, float* sd);
// Recognizes the image_data, returning the labels,
// scores, and corresponding pairs of start, end x-coords in coords.
// Returned in scale_factor is the reduction factor
// between the image and the output coords, for computing bounding boxes.
// If re_invert is true, the input is inverted back to its original
// photometric interpretation if inversion is attempted but fails to
// improve the results. This ensures that outputs contains the correct
// forward outputs for the best photometric interpretation.
// inputs is filled with the used inputs to the network.
bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
bool re_invert, bool upside_down, float* scale_factor,
NetworkIO* inputs, NetworkIO* outputs);
// Converts an array of labels to utf-8, whether or not the labels are
// augmented with character boundaries.
STRING DecodeLabels(const GenericVector<int>& labels);
// Displays the forward results in a window with the characters and
// boundaries as determined by the labels and label_coords.
void DisplayForward(const NetworkIO& inputs,
const GenericVector<int>& labels,
const GenericVector<int>& label_coords,
const char* window_name,
ScrollView** window);
// Converts the network output to a sequence of labels. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The conversion method is determined by internal state.
void LabelsFromOutputs(const NetworkIO& outputs, GenericVector<int>* labels,
GenericVector<int>* xcoords);
protected:
// Sets the random seed from the sample_iteration_;
void SetRandomSeed() {
inT64 seed = static_cast<inT64>(sample_iteration_) * 0x10000001;
randomizer_.set_seed(seed);
randomizer_.IntRand();
}
// Displays the labels and cuts at the corresponding xcoords.
// Size of labels should match xcoords.
void DisplayLSTMOutput(const GenericVector<int>& labels,
const GenericVector<int>& xcoords,
int height, ScrollView* window);
// Prints debug output detailing the activation path that is implied by the
// xcoords.
void DebugActivationPath(const NetworkIO& outputs,
const GenericVector<int>& labels,
const GenericVector<int>& xcoords);
// Prints debug output detailing activations and 2nd choice over a range
// of positions.
void DebugActivationRange(const NetworkIO& outputs, const char* label,
int best_choice, int x_start, int x_end);
// As LabelsViaCTC except that this function constructs the best path that
// contains only legal sequences of subcodes for recoder_.
void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
GenericVector<int>* xcoords);
// Converts the network output to a sequence of labels, with scores, using
// the simple character model (each position is a char, and the null_char_ is
// mainly intended for tail padding.)
void LabelsViaSimpleText(const NetworkIO& output,
GenericVector<int>* labels,
GenericVector<int>* xcoords);
// Returns a string corresponding to the label starting at start. Sets *end
// to the next start and if non-null, *decoded to the unichar id.
const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
int* decoded);
// Returns a string corresponding to a given single label id, falling back to
// a default of ".." for part of a multi-label unichar-id.
const char* DecodeSingleLabel(int label);
protected:
// The network hierarchy.
Network* network_;
// The unicharset. Only the unicharset element is serialized.
// Has to be a CCUtil, so Dict can point to it.
CCUtil ccutil_;
// For backward compatibility, recoder_ is serialized iff
// training_flags_ & TF_COMPRESS_UNICHARSET.
// Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
UnicharCompress recoder_;
// ==Training parameters that are serialized to provide a record of them.==
STRING network_str_;
// Flags used to determine the training method of the network.
// See enum TrainingFlags above.
inT32 training_flags_;
// Number of actual backward training steps used.
inT32 training_iteration_;
// Index into training sample set. sample_iteration >= training_iteration_.
inT32 sample_iteration_;
// Index in softmax of null character. May take the value UNICHAR_BROKEN or
// ccutil_.unicharset.size().
inT32 null_char_;
// Learning rate and momentum multipliers of deltas in backprop.
float learning_rate_;
float momentum_;
// Smoothing factor for 2nd moment of gradients.
float adam_beta_;
// === NOT SERIALIZED.
TRand randomizer_;
NetworkScratch scratch_space_;
// Language model (optional) to use with the beam search.
Dict* dict_;
// Beam search held between uses to optimize memory allocation/use.
RecodeBeamSearch* search_;
// == Debugging parameters.==
// Recognition debug display window.
ScrollView* debug_win_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_