Added new LSTM-based neural network line recognizer

This commit is contained in:
Ray Smith 2016-11-07 15:38:07 -08:00
parent 5d21ecfad3
commit c1c1e426b3
107 changed files with 15414 additions and 358 deletions

View File

@ -16,7 +16,7 @@ endif
.PHONY: install-langs ScrollView.jar install-jars training
SUBDIRS = ccutil viewer cutil opencl ccstruct dict classify wordrec textord
SUBDIRS = arch ccutil viewer cutil opencl ccstruct dict classify wordrec textord lstm
if !NO_CUBE_BUILD
SUBDIRS += neural_networks/runtime cube
endif

View File

@ -1,5 +1,6 @@
AM_CPPFLAGS += -DLOCALEDIR=\"$(localedir)\"\
-DUSE_STD_NAMESPACE \
-I$(top_srcdir)/arch -I$(top_srcdir)/lstm \
-I$(top_srcdir)/ccutil -I$(top_srcdir)/ccstruct -I$(top_srcdir)/cube \
-I$(top_srcdir)/viewer \
-I$(top_srcdir)/textord -I$(top_srcdir)/dict \
@ -27,6 +28,9 @@ libtesseract_api_la_LIBADD = \
../wordrec/libtesseract_wordrec.la \
../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \
@ -57,6 +61,9 @@ libtesseract_la_LIBADD = \
../wordrec/libtesseract_wordrec.la \
../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \

View File

@ -121,7 +121,6 @@ TessBaseAPI::TessBaseAPI()
block_list_(NULL),
page_res_(NULL),
input_file_(NULL),
input_image_(NULL),
output_file_(NULL),
datapath_(NULL),
language_(NULL),
@ -515,9 +514,7 @@ void TessBaseAPI::ClearAdaptiveClassifier() {
/**
* Provide an image for Tesseract to recognize. Format is as
* TesseractRect above. Does not copy the image buffer, or take
* ownership. The source image may be destroyed after Recognize is called,
* either explicitly or implicitly via one of the Get*Text functions.
* TesseractRect above. Copies the image buffer and converts to Pix.
* SetImage clears all recognition results, and sets the rectangle to the
* full image, so it may be followed immediately by a GetUTF8Text, and it
* will automatically perform recognition.
@ -525,9 +522,11 @@ void TessBaseAPI::ClearAdaptiveClassifier() {
void TessBaseAPI::SetImage(const unsigned char* imagedata,
int width, int height,
int bytes_per_pixel, int bytes_per_line) {
if (InternalSetImage())
if (InternalSetImage()) {
thresholder_->SetImage(imagedata, width, height,
bytes_per_pixel, bytes_per_line);
SetInputImage(thresholder_->GetPixRect());
}
}
void TessBaseAPI::SetSourceResolution(int ppi) {
@ -539,18 +538,17 @@ void TessBaseAPI::SetSourceResolution(int ppi) {
/**
* Provide an image for Tesseract to recognize. As with SetImage above,
* Tesseract doesn't take a copy or ownership or pixDestroy the image, so
* it must persist until after Recognize.
* Tesseract takes its own copy of the image, so it need not persist until
* after Recognize.
* Pix vs raw, which to use?
* Use Pix where possible. A future version of Tesseract may choose to use Pix
* as its internal representation and discard IMAGE altogether.
* Because of that, an implementation that sources and targets Pix may end up
* with less copies than an implementation that does not.
* Use Pix where possible. Tesseract uses Pix as its internal representation
* and it is therefore more efficient to provide a Pix directly.
*/
void TessBaseAPI::SetImage(Pix* pix) {
if (InternalSetImage())
if (InternalSetImage()) {
thresholder_->SetImage(pix);
SetInputImage(pix);
SetInputImage(thresholder_->GetPixRect());
}
}
/**
@ -693,8 +691,8 @@ Boxa* TessBaseAPI::GetComponentImages(PageIteratorLevel level,
if (pixa != NULL) {
Pix* pix = NULL;
if (raw_image) {
pix = page_it->GetImage(level, raw_padding, input_image_,
&left, &top);
pix = page_it->GetImage(level, raw_padding, GetInputImage(), &left,
&top);
} else {
pix = page_it->GetBinaryImage(level);
}
@ -849,13 +847,17 @@ int TessBaseAPI::Recognize(ETEXT_DESC* monitor) {
} else if (tesseract_->tessedit_resegment_from_boxes) {
page_res_ = tesseract_->ApplyBoxes(*input_file_, false, block_list_);
} else {
// TODO(rays) LSTM here.
page_res_ = new PAGE_RES(false,
page_res_ = new PAGE_RES(tesseract_->AnyLSTMLang(),
block_list_, &tesseract_->prev_word_best_choice_);
}
if (page_res_ == NULL) {
return -1;
}
if (tesseract_->tessedit_train_line_recognizer) {
tesseract_->TrainLineRecognizer(*input_file_, *output_file_, block_list_);
tesseract_->CorrectClassifyWords(page_res_);
return 0;
}
if (tesseract_->tessedit_make_boxes_from_boxes) {
tesseract_->CorrectClassifyWords(page_res_);
return 0;
@ -938,17 +940,10 @@ int TessBaseAPI::RecognizeForChopTest(ETEXT_DESC* monitor) {
return 0;
}
void TessBaseAPI::SetInputImage(Pix *pix) {
if (input_image_)
pixDestroy(&input_image_);
input_image_ = NULL;
if (pix)
input_image_ = pixCopy(NULL, pix);
}
// Takes ownership of the input pix.
void TessBaseAPI::SetInputImage(Pix* pix) { tesseract_->set_pix_original(pix); }
Pix* TessBaseAPI::GetInputImage() {
return input_image_;
}
Pix* TessBaseAPI::GetInputImage() { return tesseract_->pix_original(); }
const char * TessBaseAPI::GetInputName() {
if (input_file_)
@ -992,8 +987,7 @@ bool TessBaseAPI::ProcessPagesFileList(FILE *flist,
}
// Begin producing output
const char* kUnknownTitle = "";
if (renderer && !renderer->BeginDocument(kUnknownTitle)) {
if (renderer && !renderer->BeginDocument(unknown_title_)) {
return false;
}
@ -1105,7 +1099,6 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename,
const char* retry_config,
int timeout_millisec,
TessResultRenderer* renderer) {
#ifndef ANDROID_BUILD
PERF_COUNT_START("ProcessPages")
bool stdInput = !strcmp(filename, "stdin") || !strcmp(filename, "-");
if (stdInput) {
@ -1162,8 +1155,7 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename,
}
// Begin the output
const char* kUnknownTitle = "";
if (renderer && !renderer->BeginDocument(kUnknownTitle)) {
if (renderer && !renderer->BeginDocument(unknown_title_)) {
pixDestroy(&pix);
return false;
}
@ -1185,9 +1177,6 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename,
}
PERF_COUNT_END
return true;
#else
return false;
#endif
}
bool TessBaseAPI::ProcessPage(Pix* pix, int page_index, const char* filename,
@ -2107,10 +2096,6 @@ void TessBaseAPI::End() {
delete input_file_;
input_file_ = NULL;
}
if (input_image_ != NULL) {
pixDestroy(&input_image_);
input_image_ = NULL;
}
if (output_file_ != NULL) {
delete output_file_;
output_file_ = NULL;

View File

@ -20,8 +20,8 @@
#ifndef TESSERACT_API_BASEAPI_H__
#define TESSERACT_API_BASEAPI_H__
#define TESSERACT_VERSION_STR "3.05.00dev"
#define TESSERACT_VERSION 0x030500
#define TESSERACT_VERSION_STR "4.00.00alpha"
#define TESSERACT_VERSION 0x040000
#define MAKE_VERSION(major, minor, patch) (((major) << 16) | ((minor) << 8) | \
(patch))
@ -142,6 +142,7 @@ class TESS_API TessBaseAPI {
* is stored in the PDF so we need that as well.
*/
const char* GetInputName();
// Takes ownership of the input pix.
void SetInputImage(Pix *pix);
Pix* GetInputImage();
int GetSourceYResolution();
@ -333,9 +334,7 @@ class TESS_API TessBaseAPI {
/**
* Provide an image for Tesseract to recognize. Format is as
* TesseractRect above. Does not copy the image buffer, or take
* ownership. The source image may be destroyed after Recognize is called,
* either explicitly or implicitly via one of the Get*Text functions.
* TesseractRect above. Copies the image buffer and converts to Pix.
* SetImage clears all recognition results, and sets the rectangle to the
* full image, so it may be followed immediately by a GetUTF8Text, and it
* will automatically perform recognition.
@ -345,13 +344,11 @@ class TESS_API TessBaseAPI {
/**
* Provide an image for Tesseract to recognize. As with SetImage above,
* Tesseract doesn't take a copy or ownership or pixDestroy the image, so
* it must persist until after Recognize.
* Tesseract takes its own copy of the image, so it need not persist until
* after Recognize.
* Pix vs raw, which to use?
* Use Pix where possible. A future version of Tesseract may choose to use Pix
* as its internal representation and discard IMAGE altogether.
* Because of that, an implementation that sources and targets Pix may end up
* with less copies than an implementation that does not.
* Use Pix where possible. Tesseract uses Pix as its internal representation
* and it is therefore more efficient to provide a Pix directly.
*/
void SetImage(Pix* pix);
@ -866,7 +863,6 @@ class TESS_API TessBaseAPI {
BLOCK_LIST* block_list_; ///< The page layout.
PAGE_RES* page_res_; ///< The page-level data.
STRING* input_file_; ///< Name used by training code.
Pix* input_image_; ///< Image used for searchable PDF
STRING* output_file_; ///< Name used by debug code.
STRING* datapath_; ///< Current location of tessdata.
STRING* language_; ///< Last initialized language.
@ -902,6 +898,12 @@ class TESS_API TessBaseAPI {
int timeout_millisec,
TessResultRenderer* renderer,
int tessedit_page_number);
// There's currently no way to pass a document title from the
// Tesseract command line, and we have multiple places that choose
// to set the title to an empty string. Using a single named
// variable will hopefully reduce confusion if the situation changes
// in the future.
const char *unknown_title_ = "";
}; // class TessBaseAPI.
/** Escape a char string - remove &<>"' with HTML codes. */

View File

@ -620,7 +620,6 @@ bool TessPDFRenderer::BeginDocumentHandler() {
AppendPDFObject(buf);
// FONT DESCRIPTOR
const int kCharHeight = 2; // Effect: highlights are half height
n = snprintf(buf, sizeof(buf),
"7 0 obj\n"
"<<\n"
@ -636,10 +635,10 @@ bool TessPDFRenderer::BeginDocumentHandler() {
" /Type /FontDescriptor\n"
">>\n"
"endobj\n",
1000 / kCharHeight,
1000 / kCharHeight,
1000,
1000,
1000 / kCharWidth,
1000 / kCharHeight,
1000,
8L // Font data
);
if (n >= sizeof(buf)) return false;

View File

@ -77,7 +77,7 @@ class TESS_API TessResultRenderer {
bool EndDocument();
const char* file_extension() const { return file_extension_; }
const char* title() const { return title_; }
const char* title() const { return title_.c_str(); }
/**
* Returns the index of the last image given to AddImage
@ -126,7 +126,7 @@ class TESS_API TessResultRenderer {
private:
const char* file_extension_; // standard extension for generated output
const char* title_; // title of document being renderered
STRING title_; // title of document being renderered
int imagenum_; // index of last image added
FILE* fout_; // output file pointer

29
arch/Makefile.am Normal file
View File

@ -0,0 +1,29 @@
AM_CPPFLAGS += -I$(top_srcdir)/ccutil
AUTOMAKE_OPTIONS = subdir-objects
SUBDIRS =
AM_CXXFLAGS =
if VISIBILITY
AM_CXXFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden
AM_CPPFLAGS += -DTESS_EXPORTS
endif
include_HEADERS = \
dotproductavx.h dotproductsse.h
noinst_HEADERS =
if !USING_MULTIPLELIBS
noinst_LTLIBRARIES = libtesseract_avx.la libtesseract_sse.la
else
lib_LTLIBRARIES = libtesseract_avx.la libtesseract_sse.la
libtesseract_avx_la_LDFLAGS = -version-info $(GENERIC_LIBRARY_VERSION)
libtesseract_sse_la_LDFLAGS = -version-info $(GENERIC_LIBRARY_VERSION)
endif
libtesseract_avx_la_CXXFLAGS = -mavx
libtesseract_sse_la_CXXFLAGS = -msse4.1
libtesseract_avx_la_SOURCES = dotproductavx.cpp
libtesseract_sse_la_SOURCES = dotproductsse.cpp

103
arch/dotproductavx.cpp Normal file
View File

@ -0,0 +1,103 @@
///////////////////////////////////////////////////////////////////////
// File: dotproductavx.cpp
// Description: Architecture-specific dot-product function.
// Author: Ray Smith
// Created: Wed Jul 22 10:48:05 PDT 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#if !defined(__AVX__)
// Implementation for non-avx archs.
#include "dotproductavx.h"
#include <stdio.h>
#include <stdlib.h>
namespace tesseract {
double DotProductAVX(const double* u, const double* v, int n) {
fprintf(stderr, "DotProductAVX can't be used on Android\n");
abort();
}
} // namespace tesseract
#else // !defined(__AVX__)
// Implementation for avx capable archs.
#include <immintrin.h>
#include <stdint.h>
#include "dotproductavx.h"
#include "host.h"
namespace tesseract {
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel AVX intrinsics to access the SIMD instruction set.
double DotProductAVX(const double* u, const double* v, int n) {
int max_offset = n - 4;
int offset = 0;
// Accumulate a set of 4 sums in sum, by loading pairs of 4 values from u and
// v, and multiplying them together in parallel.
__m256d sum = _mm256_setzero_pd();
if (offset <= max_offset) {
offset = 4;
// Aligned load is reputedly faster but requires 32 byte aligned input.
if ((reinterpret_cast<const uintptr_t>(u) & 31) == 0 &&
(reinterpret_cast<const uintptr_t>(v) & 31) == 0) {
// Use aligned load.
__m256d floats1 = _mm256_load_pd(u);
__m256d floats2 = _mm256_load_pd(v);
// Multiply.
sum = _mm256_mul_pd(floats1, floats2);
while (offset <= max_offset) {
floats1 = _mm256_load_pd(u + offset);
floats2 = _mm256_load_pd(v + offset);
offset += 4;
__m256d product = _mm256_mul_pd(floats1, floats2);
sum = _mm256_add_pd(sum, product);
}
} else {
// Use unaligned load.
__m256d floats1 = _mm256_loadu_pd(u);
__m256d floats2 = _mm256_loadu_pd(v);
// Multiply.
sum = _mm256_mul_pd(floats1, floats2);
while (offset <= max_offset) {
floats1 = _mm256_loadu_pd(u + offset);
floats2 = _mm256_loadu_pd(v + offset);
offset += 4;
__m256d product = _mm256_mul_pd(floats1, floats2);
sum = _mm256_add_pd(sum, product);
}
}
}
// Add the 4 product sums together horizontally. Not so easy as with sse, as
// there is no add across the upper/lower 128 bit boundary, so permute to
// move the upper 128 bits to lower in another register.
__m256d sum2 = _mm256_permute2f128_pd(sum, sum, 1);
sum = _mm256_hadd_pd(sum, sum2);
sum = _mm256_hadd_pd(sum, sum);
double result;
// _mm256_extract_f64 doesn't exist, but resist the temptation to use an sse
// instruction, as that introduces a 70 cycle delay. All this casting is to
// fool the instrinsics into thinking we are extracting the bottom int64.
*(reinterpret_cast<inT64*>(&result)) =
_mm256_extract_epi64(_mm256_castpd_si256(sum), 0);
while (offset < n) {
result += u[offset] * v[offset];
++offset;
}
return result;
}
} // namespace tesseract.
#endif // ANDROID_BUILD

30
arch/dotproductavx.h Normal file
View File

@ -0,0 +1,30 @@
///////////////////////////////////////////////////////////////////////
// File: dotproductavx.h
// Description: Architecture-specific dot-product function.
// Author: Ray Smith
// Created: Wed Jul 22 10:51:05 PDT 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_ARCH_DOTPRODUCTAVX_H_
#define TESSERACT_ARCH_DOTPRODUCTAVX_H_
namespace tesseract {
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel AVX intrinsics to access the SIMD instruction set.
double DotProductAVX(const double* u, const double* v, int n);
} // namespace tesseract.
#endif // TESSERACT_ARCH_DOTPRODUCTAVX_H_

141
arch/dotproductsse.cpp Normal file
View File

@ -0,0 +1,141 @@
///////////////////////////////////////////////////////////////////////
// File: dotproductsse.cpp
// Description: Architecture-specific dot-product function.
// Author: Ray Smith
// Created: Wed Jul 22 10:57:45 PDT 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#if !defined(__SSE4_1__)
// This code can't compile with "-msse4.1", so use dummy stubs.
#include "dotproductsse.h"
#include <stdio.h>
#include <stdlib.h>
namespace tesseract {
double DotProductSSE(const double* u, const double* v, int n) {
fprintf(stderr, "DotProductSSE can't be used on Android\n");
abort();
}
inT32 IntDotProductSSE(const inT8* u, const inT8* v, int n) {
fprintf(stderr, "IntDotProductSSE can't be used on Android\n");
abort();
}
} // namespace tesseract
#else // !defined(__SSE4_1__)
// Non-Android code here
#include <emmintrin.h>
#include <smmintrin.h>
#include <stdint.h>
#include "dotproductsse.h"
#include "host.h"
namespace tesseract {
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel SSE intrinsics to access the SIMD instruction set.
double DotProductSSE(const double* u, const double* v, int n) {
int max_offset = n - 2;
int offset = 0;
// Accumulate a set of 2 sums in sum, by loading pairs of 2 values from u and
// v, and multiplying them together in parallel.
__m128d sum = _mm_setzero_pd();
if (offset <= max_offset) {
offset = 2;
// Aligned load is reputedly faster but requires 16 byte aligned input.
if ((reinterpret_cast<const uintptr_t>(u) & 15) == 0 &&
(reinterpret_cast<const uintptr_t>(v) & 15) == 0) {
// Use aligned load.
sum = _mm_load_pd(u);
__m128d floats2 = _mm_load_pd(v);
// Multiply.
sum = _mm_mul_pd(sum, floats2);
while (offset <= max_offset) {
__m128d floats1 = _mm_load_pd(u + offset);
floats2 = _mm_load_pd(v + offset);
offset += 2;
floats1 = _mm_mul_pd(floats1, floats2);
sum = _mm_add_pd(sum, floats1);
}
} else {
// Use unaligned load.
sum = _mm_loadu_pd(u);
__m128d floats2 = _mm_loadu_pd(v);
// Multiply.
sum = _mm_mul_pd(sum, floats2);
while (offset <= max_offset) {
__m128d floats1 = _mm_loadu_pd(u + offset);
floats2 = _mm_loadu_pd(v + offset);
offset += 2;
floats1 = _mm_mul_pd(floats1, floats2);
sum = _mm_add_pd(sum, floats1);
}
}
}
// Add the 2 sums in sum horizontally.
sum = _mm_hadd_pd(sum, sum);
// Extract the low result.
double result = _mm_cvtsd_f64(sum);
// Add on any left-over products.
while (offset < n) {
result += u[offset] * v[offset];
++offset;
}
return result;
}
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel SSE intrinsics to access the SIMD instruction set.
inT32 IntDotProductSSE(const inT8* u, const inT8* v, int n) {
int max_offset = n - 8;
int offset = 0;
// Accumulate a set of 4 32-bit sums in sum, by loading 8 pairs of 8-bit
// values, extending to 16 bit, multiplying to make 32 bit results.
__m128i sum = _mm_setzero_si128();
if (offset <= max_offset) {
offset = 8;
__m128i packed1 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(u));
__m128i packed2 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(v));
sum = _mm_cvtepi8_epi16(packed1);
packed2 = _mm_cvtepi8_epi16(packed2);
// The magic _mm_add_epi16 is perfect here. It multiplies 8 pairs of 16 bit
// ints to make 32 bit results, which are then horizontally added in pairs
// to make 4 32 bit results that still fit in a 128 bit register.
sum = _mm_madd_epi16(sum, packed2);
while (offset <= max_offset) {
packed1 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(u + offset));
packed2 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(v + offset));
offset += 8;
packed1 = _mm_cvtepi8_epi16(packed1);
packed2 = _mm_cvtepi8_epi16(packed2);
packed1 = _mm_madd_epi16(packed1, packed2);
sum = _mm_add_epi32(sum, packed1);
}
}
// Sum the 4 packed 32 bit sums and extract the low result.
sum = _mm_hadd_epi32(sum, sum);
sum = _mm_hadd_epi32(sum, sum);
inT32 result = _mm_cvtsi128_si32(sum);
while (offset < n) {
result += u[offset] * v[offset];
++offset;
}
return result;
}
} // namespace tesseract.
#endif // ANDROID_BUILD

35
arch/dotproductsse.h Normal file
View File

@ -0,0 +1,35 @@
///////////////////////////////////////////////////////////////////////
// File: dotproductsse.h
// Description: Architecture-specific dot-product function.
// Author: Ray Smith
// Created: Wed Jul 22 10:57:05 PDT 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_ARCH_DOTPRODUCTSSE_H_
#define TESSERACT_ARCH_DOTPRODUCTSSE_H_
#include "host.h"
namespace tesseract {
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel SSE intrinsics to access the SIMD instruction set.
double DotProductSSE(const double* u, const double* v, int n);
// Computes and returns the dot product of the n-vectors u and v.
// Uses Intel SSE intrinsics to access the SIMD instruction set.
inT32 IntDotProductSSE(const inT8* u, const inT8* v, int n);
} // namespace tesseract.
#endif // TESSERACT_ARCH_DOTPRODUCTSSE_H_

View File

@ -1,6 +1,7 @@
AM_CPPFLAGS += \
-DUSE_STD_NAMESPACE \
-I$(top_srcdir)/ccutil -I$(top_srcdir)/ccstruct \
-I$(top_srcdir)/arch -I$(top_srcdir)/lstm \
-I$(top_srcdir)/viewer \
-I$(top_srcdir)/classify -I$(top_srcdir)/dict \
-I$(top_srcdir)/wordrec -I$(top_srcdir)/cutil \
@ -33,6 +34,9 @@ libtesseract_main_la_LIBADD = \
../ccstruct/libtesseract_ccstruct.la \
../viewer/libtesseract_viewer.la \
../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../classify/libtesseract_classify.la \
../cutil/libtesseract_cutil.la \
../opencl/libtesseract_opencl.la
@ -44,7 +48,7 @@ endif
libtesseract_main_la_SOURCES = \
adaptions.cpp applybox.cpp control.cpp \
docqual.cpp equationdetect.cpp fixspace.cpp fixxht.cpp \
ltrresultiterator.cpp \
linerec.cpp ltrresultiterator.cpp \
osdetect.cpp output.cpp pageiterator.cpp pagesegmain.cpp \
pagewalk.cpp par_control.cpp paragraphs.cpp paramsd.cpp pgedit.cpp recogtraining.cpp \
reject.cpp resultiterator.cpp superscript.cpp \

View File

@ -84,7 +84,12 @@ BOOL8 Tesseract::recog_interactive(PAGE_RES_IT* pr_it) {
WordData word_data(*pr_it);
SetupWordPassN(2, &word_data);
// LSTM doesn't run on pass2, but we want to run pass2 for tesseract.
if (lstm_recognizer_ == NULL) {
classify_word_and_language(2, pr_it, &word_data);
} else {
classify_word_and_language(1, pr_it, &word_data);
}
if (tessedit_debug_quality_metrics) {
WERD_RES* word_res = pr_it->word();
word_char_quality(word_res, pr_it->row()->row, &char_qual, &good_char_qual);
@ -219,15 +224,13 @@ bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC* monitor,
monitor->progress = 70 * w / words->size();
if (monitor->progress_callback != NULL) {
TBOX box = pr_it->word()->word->bounding_box();
(*monitor->progress_callback)(monitor->progress,
box.left(), box.right(),
box.top(), box.bottom());
(*monitor->progress_callback)(monitor->progress, box.left(),
box.right(), box.top(), box.bottom());
}
} else {
monitor->progress = 70 + 30 * w / words->size();
if (monitor->progress_callback!=NULL) {
(*monitor->progress_callback)(monitor->progress,
0, 0, 0, 0);
if (monitor->progress_callback != NULL) {
(*monitor->progress_callback)(monitor->progress, 0, 0, 0, 0);
}
}
if (monitor->deadline_exceeded() ||
@ -252,7 +255,8 @@ bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC* monitor,
pr_it->forward();
ASSERT_HOST(pr_it->word() != NULL);
bool make_next_word_fuzzy = false;
if (ReassignDiacritics(pass_n, pr_it, &make_next_word_fuzzy)) {
if (!AnyLSTMLang() &&
ReassignDiacritics(pass_n, pr_it, &make_next_word_fuzzy)) {
// Needs to be setup again to see the new outlines in the chopped_word.
SetupWordPassN(pass_n, word);
}
@ -297,6 +301,16 @@ bool Tesseract::recog_all_words(PAGE_RES* page_res,
const TBOX* target_word_box,
const char* word_config,
int dopasses) {
// PSM_RAW_LINE is a special-case mode in which the layout analysis is
// completely ignored and LSTM is run on the raw image. There is no hope
// of running normal tesseract in this situation or of integrating output.
#ifndef ANDROID_BUILD
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY &&
tessedit_pageseg_mode == PSM_RAW_LINE) {
RecogRawLine(page_res);
return true;
}
#endif
PAGE_RES_IT page_res_it(page_res);
if (tessedit_minimal_rej_pass1) {
@ -385,7 +399,7 @@ bool Tesseract::recog_all_words(PAGE_RES* page_res,
// The next passes can only be run if tesseract has been used, as cube
// doesn't set all the necessary outputs in WERD_RES.
if (AnyTessLang()) {
if (AnyTessLang() && !AnyLSTMLang()) {
// ****************** Pass 3 *******************
// Fix fuzzy spaces.
set_global_loc_code(LOC_FUZZY_SPACE);
@ -1362,6 +1376,19 @@ void Tesseract::classify_word_pass1(const WordData& word_data,
cube_word_pass1(block, row, *in_word);
return;
}
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
if (!(*in_word)->odd_size) {
LSTMRecognizeWord(*block, row, *in_word, out_words);
if (!out_words->empty())
return; // Successful lstm recognition.
}
// Fall back to tesseract for failed words or odd words.
(*in_word)->SetupForRecognition(unicharset, this, BestPix(),
OEM_TESSERACT_ONLY, NULL,
classify_bln_numeric_mode,
textord_use_cjk_fp_model,
poly_allow_detailed_fx, row, block);
}
#endif
WERD_RES* word = *in_word;
match_word_pass_n(1, word, row, block);
@ -1496,10 +1523,6 @@ void Tesseract::classify_word_pass2(const WordData& word_data,
WERD_RES** in_word,
PointerVector<WERD_RES>* out_words) {
// Return if we do not want to run Tesseract.
if (tessedit_ocr_engine_mode != OEM_TESSERACT_ONLY &&
tessedit_ocr_engine_mode != OEM_TESSERACT_CUBE_COMBINED &&
word_data.word->best_choice != NULL)
return;
if (tessedit_ocr_engine_mode == OEM_CUBE_ONLY) {
return;
}

332
ccmain/linerec.cpp Normal file
View File

@ -0,0 +1,332 @@
///////////////////////////////////////////////////////////////////////
// File: linerec.cpp
// Description: Top-level line-based recognition module for Tesseract.
// Author: Ray Smith
// Created: Thu May 02 09:47:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "tesseractclass.h"
#include "allheaders.h"
#include "boxread.h"
#include "imagedata.h"
#ifndef ANDROID_BUILD
#include "lstmrecognizer.h"
#include "recodebeam.h"
#endif
#include "ndminx.h"
#include "pageres.h"
#include "tprintf.h"
namespace tesseract {
// Arbitarary penalty for non-dictionary words.
// TODO(rays) How to learn this?
const float kNonDictionaryPenalty = 5.0f;
// Scale factor to make certainty more comparable to Tesseract.
const float kCertaintyScale = 7.0f;
// Worst acceptable certainty for a dictionary word.
const float kWorstDictCertainty = -25.0f;
// Generates training data for training a line recognizer, eg LSTM.
// Breaks the page into lines, according to the boxes, and writes them to a
// serialized DocumentData based on output_basename.
void Tesseract::TrainLineRecognizer(const STRING& input_imagename,
const STRING& output_basename,
BLOCK_LIST *block_list) {
STRING lstmf_name = output_basename + ".lstmf";
DocumentData images(lstmf_name);
if (applybox_page > 0) {
// Load existing document for the previous pages.
if (!images.LoadDocument(lstmf_name.string(), "eng", 0, 0, NULL)) {
tprintf("Failed to read training data from %s!\n", lstmf_name.string());
return;
}
}
GenericVector<TBOX> boxes;
GenericVector<STRING> texts;
// Get the boxes for this page, if there are any.
if (!ReadAllBoxes(applybox_page, false, input_imagename, &boxes, &texts, NULL,
NULL) ||
boxes.empty()) {
tprintf("Failed to read boxes from %s\n", input_imagename.string());
return;
}
TrainFromBoxes(boxes, texts, block_list, &images);
if (!images.SaveDocument(lstmf_name.string(), NULL)) {
tprintf("Failed to write training data to %s!\n", lstmf_name.string());
}
}
// Generates training data for training a line recognizer, eg LSTM.
// Breaks the boxes into lines, normalizes them, converts to ImageData and
// appends them to the given training_data.
void Tesseract::TrainFromBoxes(const GenericVector<TBOX>& boxes,
const GenericVector<STRING>& texts,
BLOCK_LIST *block_list,
DocumentData* training_data) {
int box_count = boxes.size();
// Process all the text lines in this page, as defined by the boxes.
int end_box = 0;
for (int start_box = 0; start_box < box_count; start_box = end_box) {
// Find the textline of boxes starting at start and their bounding box.
TBOX line_box = boxes[start_box];
STRING line_str = texts[start_box];
for (end_box = start_box + 1; end_box < box_count && texts[end_box] != "\t";
++end_box) {
line_box += boxes[end_box];
line_str += texts[end_box];
}
// Find the most overlapping block.
BLOCK* best_block = NULL;
int best_overlap = 0;
BLOCK_IT b_it(block_list);
for (b_it.mark_cycle_pt(); !b_it.cycled_list(); b_it.forward()) {
BLOCK* block = b_it.data();
if (block->poly_block() != NULL && !block->poly_block()->IsText())
continue; // Not a text block.
TBOX block_box = block->bounding_box();
block_box.rotate(block->re_rotation());
if (block_box.major_overlap(line_box)) {
TBOX overlap_box = line_box.intersection(block_box);
if (overlap_box.area() > best_overlap) {
best_overlap = overlap_box.area();
best_block = block;
}
}
}
ImageData* imagedata = NULL;
if (best_block == NULL) {
tprintf("No block overlapping textline: %s\n", line_str.string());
} else {
imagedata = GetLineData(line_box, boxes, texts, start_box, end_box,
*best_block);
}
if (imagedata != NULL)
training_data->AddPageToDocument(imagedata);
if (end_box < texts.size() && texts[end_box] == "\t") ++end_box;
}
}
// Returns an Imagedata containing the image of the given box,
// and ground truth boxes/truth text if available in the input.
// The image is not normalized in any way.
ImageData* Tesseract::GetLineData(const TBOX& line_box,
const GenericVector<TBOX>& boxes,
const GenericVector<STRING>& texts,
int start_box, int end_box,
const BLOCK& block) {
TBOX revised_box;
ImageData* image_data = GetRectImage(line_box, block, kImagePadding,
&revised_box);
if (image_data == NULL) return NULL;
image_data->set_page_number(applybox_page);
// Copy the boxes and shift them so they are relative to the image.
FCOORD block_rotation(block.re_rotation().x(), -block.re_rotation().y());
ICOORD shift = -revised_box.botleft();
GenericVector<TBOX> line_boxes;
GenericVector<STRING> line_texts;
for (int b = start_box; b < end_box; ++b) {
TBOX box = boxes[b];
box.rotate(block_rotation);
box.move(shift);
line_boxes.push_back(box);
line_texts.push_back(texts[b]);
}
GenericVector<int> page_numbers;
page_numbers.init_to_size(line_boxes.size(), applybox_page);
image_data->AddBoxes(line_boxes, line_texts, page_numbers);
return image_data;
}
// Helper gets the image of a rectangle, using the block.re_rotation() if
// needed to get to the image, and rotating the result back to horizontal
// layout. (CJK characters will be on their left sides) The vertical text flag
// is set in the returned ImageData if the text was originally vertical, which
// can be used to invoke a different CJK recognition engine. The revised_box
// is also returned to enable calculation of output bounding boxes.
ImageData* Tesseract::GetRectImage(const TBOX& box, const BLOCK& block,
int padding, TBOX* revised_box) const {
TBOX wbox = box;
wbox.pad(padding, padding);
*revised_box = wbox;
// Number of clockwise 90 degree rotations needed to get back to tesseract
// coords from the clipped image.
int num_rotations = 0;
if (block.re_rotation().y() > 0.0f)
num_rotations = 1;
else if (block.re_rotation().x() < 0.0f)
num_rotations = 2;
else if (block.re_rotation().y() < 0.0f)
num_rotations = 3;
// Handle two cases automatically: 1 the box came from the block, 2 the box
// came from a box file, and refers to the image, which the block may not.
if (block.bounding_box().major_overlap(*revised_box))
revised_box->rotate(block.re_rotation());
// Now revised_box always refers to the image.
// BestPix is never colormapped, but may be of any depth.
Pix* pix = BestPix();
int width = pixGetWidth(pix);
int height = pixGetHeight(pix);
TBOX image_box(0, 0, width, height);
// Clip to image bounds;
*revised_box &= image_box;
if (revised_box->null_box()) return NULL;
Box* clip_box = boxCreate(revised_box->left(), height - revised_box->top(),
revised_box->width(), revised_box->height());
Pix* box_pix = pixClipRectangle(pix, clip_box, NULL);
if (box_pix == NULL) return NULL;
boxDestroy(&clip_box);
if (num_rotations > 0) {
Pix* rot_pix = pixRotateOrth(box_pix, num_rotations);
pixDestroy(&box_pix);
box_pix = rot_pix;
}
// Convert sub-8-bit images to 8 bit.
int depth = pixGetDepth(box_pix);
if (depth < 8) {
Pix* grey;
grey = pixConvertTo8(box_pix, false);
pixDestroy(&box_pix);
box_pix = grey;
}
bool vertical_text = false;
if (num_rotations > 0) {
// Rotated the clipped revised box back to internal coordinates.
FCOORD rotation(block.re_rotation().x(), -block.re_rotation().y());
revised_box->rotate(rotation);
if (num_rotations != 2)
vertical_text = true;
}
return new ImageData(vertical_text, box_pix);
}
#ifndef ANDROID_BUILD
// Top-level function recognizes a single raw line.
void Tesseract::RecogRawLine(PAGE_RES* page_res) {
PAGE_RES_IT it(page_res);
PointerVector<WERD_RES> words;
LSTMRecognizeWord(*it.block()->block, it.row()->row, it.word(), &words);
if (getDict().stopper_debug_level >= 1) {
for (int w = 0; w < words.size(); ++w) {
words[w]->DebugWordChoices(true, NULL);
}
}
it.ReplaceCurrentWord(&words);
}
// Recognizes a word or group of words, converting to WERD_RES in *words.
// Analogous to classify_word_pass1, but can handle a group of words as well.
void Tesseract::LSTMRecognizeWord(const BLOCK& block, ROW *row, WERD_RES *word,
PointerVector<WERD_RES>* words) {
TBOX word_box = word->word->bounding_box();
// Get the word image - no frills.
if (tessedit_pageseg_mode == PSM_SINGLE_WORD ||
tessedit_pageseg_mode == PSM_RAW_LINE) {
// In single word mode, use the whole image without any other row/word
// interpretation.
word_box = TBOX(0, 0, ImageWidth(), ImageHeight());
} else {
float baseline = row->base_line((word_box.left() + word_box.right()) / 2);
if (baseline + row->descenders() < word_box.bottom())
word_box.set_bottom(baseline + row->descenders());
if (baseline + row->x_height() + row->ascenders() > word_box.top())
word_box.set_top(baseline + row->x_height() + row->ascenders());
}
ImageData* im_data = GetRectImage(word_box, block, kImagePadding, &word_box);
if (im_data == NULL) return;
lstm_recognizer_->RecognizeLine(*im_data, true, classify_debug_level > 0,
kWorstDictCertainty / kCertaintyScale,
lstm_use_matrix, &unicharset, word_box, 2.0,
false, words);
delete im_data;
SearchWords(words);
}
// Apply segmentation search to the given set of words, within the constraints
// of the existing ratings matrix. If there is already a best_choice on a word
// leaves it untouched and just sets the done/accepted etc flags.
void Tesseract::SearchWords(PointerVector<WERD_RES>* words) {
// Run the segmentation search on the network outputs and make a BoxWord
// for each of the output words.
// If we drop a word as junk, then there is always a space in front of the
// next.
bool deleted_prev = false;
for (int w = 0; w < words->size(); ++w) {
WERD_RES* word = (*words)[w];
if (word->best_choice == NULL) {
// If we are using the beam search, the unicharset had better match!
word->SetupWordScript(unicharset);
WordSearch(word);
} else if (word->best_choice->unicharset() == &unicharset &&
!lstm_recognizer_->IsRecoding()) {
// We set up the word without using the dictionary, so set the permuter
// now, but we can only do it because the unicharsets match.
word->best_choice->set_permuter(
getDict().valid_word(*word->best_choice, true));
}
if (word->best_choice == NULL) {
// It is a dud.
words->remove(w);
--w;
deleted_prev = true;
} else {
// Set the best state.
for (int i = 0; i < word->best_choice->length(); ++i) {
int length = word->best_choice->state(i);
word->best_state.push_back(length);
}
word->tess_failed = false;
word->tess_accepted = true;
word->tess_would_adapt = false;
word->done = true;
word->tesseract = this;
float word_certainty = MIN(word->space_certainty,
word->best_choice->certainty());
word_certainty *= kCertaintyScale;
// Arbitrary ding factor for non-dictionary words.
if (!lstm_recognizer_->IsRecoding() &&
!Dict::valid_word_permuter(word->best_choice->permuter(), true))
word_certainty -= kNonDictionaryPenalty;
if (getDict().stopper_debug_level >= 1) {
tprintf("Best choice certainty=%g, space=%g, scaled=%g, final=%g\n",
word->best_choice->certainty(), word->space_certainty,
MIN(word->space_certainty, word->best_choice->certainty()) *
kCertaintyScale,
word_certainty);
word->best_choice->print();
}
// Discard words that are impossibly bad, but allow a bit more for
// dictionary words.
if (word_certainty >= RecodeBeamSearch::kMinCertainty ||
(word_certainty >= kWorstDictCertainty &&
Dict::valid_word_permuter(word->best_choice->permuter(), true))) {
word->best_choice->set_certainty(word_certainty);
if (deleted_prev) word->word->set_blanks(1);
} else {
if (getDict().stopper_debug_level >= 1) {
tprintf("Deleting word with certainty %g\n", word_certainty);
word->best_choice->print();
}
// It is a dud.
words->remove(w);
--w;
deleted_prev = true;
}
}
}
}
#endif // ANDROID_BUILD
} // namespace tesseract.

View File

@ -220,6 +220,12 @@ bool LTRResultIterator::WordIsFromDictionary() const {
permuter == USER_DAWG_PERM;
}
// Returns the number of blanks before the current word.
int LTRResultIterator::BlanksBeforeWord() const {
if (it_->word() == NULL) return 1;
return it_->word()->word->space();
}
// Returns true if the current word is numeric.
bool LTRResultIterator::WordIsNumeric() const {
if (it_->word() == NULL) return false; // Already at the end!

View File

@ -124,6 +124,9 @@ class TESS_API LTRResultIterator : public PageIterator {
// Returns true if the current word was found in a dictionary.
bool WordIsFromDictionary() const;
// Returns the number of blanks before the current word.
int BlanksBeforeWord() const;
// Returns true if the current word is numeric.
bool WordIsNumeric() const;

View File

@ -40,6 +40,9 @@
#include "efio.h"
#include "danerror.h"
#include "globals.h"
#ifndef ANDROID_BUILD
#include "lstmrecognizer.h"
#endif
#include "tesseractclass.h"
#include "params.h"
@ -214,6 +217,18 @@ bool Tesseract::init_tesseract_lang_data(
ASSERT_HOST(init_cube_objects(true, &tessdata_manager));
if (tessdata_manager_debug_level)
tprintf("Loaded Cube with combiner\n");
} else if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
if (tessdata_manager.SeekToStart(TESSDATA_LSTM)) {
lstm_recognizer_ = new LSTMRecognizer;
TFile fp;
fp.Open(tessdata_manager.GetDataFilePtr(), -1);
ASSERT_HOST(lstm_recognizer_->DeSerialize(tessdata_manager.swap(), &fp));
if (lstm_use_matrix)
lstm_recognizer_->LoadDictionary(tessdata_path.string(), language);
} else {
tprintf("Error: LSTM requested, but not present!! Loading tesseract.\n");
tessedit_ocr_engine_mode.set_value(OEM_TESSERACT_ONLY);
}
}
#endif
// Init ParamsModel.
@ -409,8 +424,7 @@ int Tesseract::init_tesseract_internal(
// If only Cube will be used, skip loading Tesseract classifier's
// pre-trained templates.
bool init_tesseract_classifier =
(tessedit_ocr_engine_mode == OEM_TESSERACT_ONLY ||
tessedit_ocr_engine_mode == OEM_TESSERACT_CUBE_COMBINED);
tessedit_ocr_engine_mode != OEM_CUBE_ONLY;
// If only Cube will be used and if it has its own Unicharset,
// skip initializing permuter and loading Tesseract Dawgs.
bool init_dict =
@ -468,7 +482,9 @@ int Tesseract::init_tesseract_lm(const char *arg0,
if (!init_tesseract_lang_data(arg0, textbase, language, OEM_TESSERACT_ONLY,
NULL, 0, NULL, NULL, false))
return -1;
getDict().Load(Dict::GlobalDawgCache());
getDict().SetupForLoad(Dict::GlobalDawgCache());
getDict().Load(tessdata_manager.GetDataFileName().string(), lang);
getDict().FinishLoad();
tessdata_manager.End();
return 0;
}

View File

@ -49,6 +49,7 @@
#include "equationdetect.h"
#include "globals.h"
#ifndef NO_CUBE_BUILD
#include "lstmrecognizer.h"
#include "tesseract_cube_combiner.h"
#endif
@ -65,6 +66,9 @@ Tesseract::Tesseract()
"Generate training data from boxed chars", this->params()),
BOOL_MEMBER(tessedit_make_boxes_from_boxes, false,
"Generate more boxes from boxed chars", this->params()),
BOOL_MEMBER(tessedit_train_line_recognizer, false,
"Break input into lines and remap boxes if present",
this->params()),
BOOL_MEMBER(tessedit_dump_pageseg_images, false,
"Dump intermediate images made during page segmentation",
this->params()),
@ -222,6 +226,8 @@ Tesseract::Tesseract()
"(more accurate)",
this->params()),
INT_MEMBER(cube_debug_level, 0, "Print cube debug info.", this->params()),
BOOL_MEMBER(lstm_use_matrix, 1,
"Use ratings matrix/beam search with lstm", this->params()),
STRING_MEMBER(outlines_odd, "%| ", "Non standard number of outlines",
this->params()),
STRING_MEMBER(outlines_2, "ij!?%\":;", "Non standard number of outlines",
@ -605,6 +611,7 @@ Tesseract::Tesseract()
pix_binary_(NULL),
cube_binary_(NULL),
pix_grey_(NULL),
pix_original_(NULL),
pix_thresholds_(NULL),
source_resolution_(0),
textord_(this),
@ -619,11 +626,16 @@ Tesseract::Tesseract()
cube_cntxt_(NULL),
tess_cube_combiner_(NULL),
#endif
equ_detect_(NULL) {
equ_detect_(NULL),
#ifndef ANDROID_BUILD
lstm_recognizer_(NULL),
#endif
train_line_page_num_(0) {
}
Tesseract::~Tesseract() {
Clear();
pixDestroy(&pix_original_);
end_tesseract();
sub_langs_.delete_data_pointers();
#ifndef NO_CUBE_BUILD
@ -636,6 +648,8 @@ Tesseract::~Tesseract() {
delete tess_cube_combiner_;
tess_cube_combiner_ = NULL;
}
delete lstm_recognizer_;
lstm_recognizer_ = NULL;
#endif
}

View File

@ -102,7 +102,10 @@ class CubeLineObject;
class CubeObject;
class CubeRecoContext;
#endif
class DocumentData;
class EquationDetect;
class ImageData;
class LSTMRecognizer;
class Tesseract;
#ifndef NO_CUBE_BUILD
class TesseractCubeCombiner;
@ -189,7 +192,7 @@ class Tesseract : public Wordrec {
}
// Destroy any existing pix and return a pointer to the pointer.
Pix** mutable_pix_binary() {
Clear();
pixDestroy(&pix_binary_);
return &pix_binary_;
}
Pix* pix_binary() const {
@ -202,16 +205,20 @@ class Tesseract : public Wordrec {
pixDestroy(&pix_grey_);
pix_grey_ = grey_pix;
}
// Returns a pointer to a Pix representing the best available image of the
// page. The image will be 8-bit grey if the input was grey or color. Note
// that in grey 0 is black and 255 is white. If the input was binary, then
// the returned Pix will be binary. Note that here black is 1 and white is 0.
// To tell the difference pixGetDepth() will return 8 or 1.
// In either case, the return value is a borrowed Pix, and should not be
// deleted or pixDestroyed.
Pix* BestPix() const {
return pix_grey_ != NULL ? pix_grey_ : pix_binary_;
Pix* pix_original() const { return pix_original_; }
// Takes ownership of the given original_pix.
void set_pix_original(Pix* original_pix) {
pixDestroy(&pix_original_);
pix_original_ = original_pix;
}
// Returns a pointer to a Pix representing the best available (original) image
// of the page. Can be of any bit depth, but never color-mapped, as that has
// always been dealt with. Note that in grey and color, 0 is black and 255 is
// white. If the input was binary, then black is 1 and white is 0.
// To tell the difference pixGetDepth() will return 32, 8 or 1.
// In any case, the return value is a borrowed Pix, and should not be
// deleted or pixDestroyed.
Pix* BestPix() const { return pix_original_; }
void set_pix_thresholds(Pix* thresholds) {
pixDestroy(&pix_thresholds_);
pix_thresholds_ = thresholds;
@ -263,6 +270,15 @@ class Tesseract : public Wordrec {
}
return false;
}
// Returns true if any language uses the LSTM.
bool AnyLSTMLang() const {
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) return true;
for (int i = 0; i < sub_langs_.size(); ++i) {
if (sub_langs_[i]->tessedit_ocr_engine_mode == OEM_LSTM_ONLY)
return true;
}
return false;
}
void SetBlackAndWhitelist();
@ -293,6 +309,48 @@ class Tesseract : public Wordrec {
// par_control.cpp
void PrerecAllWordsPar(const GenericVector<WordData>& words);
//// linerec.cpp
// Generates training data for training a line recognizer, eg LSTM.
// Breaks the page into lines, according to the boxes, and writes them to a
// serialized DocumentData based on output_basename.
void TrainLineRecognizer(const STRING& input_imagename,
const STRING& output_basename,
BLOCK_LIST *block_list);
// Generates training data for training a line recognizer, eg LSTM.
// Breaks the boxes into lines, normalizes them, converts to ImageData and
// appends them to the given training_data.
void TrainFromBoxes(const GenericVector<TBOX>& boxes,
const GenericVector<STRING>& texts,
BLOCK_LIST *block_list,
DocumentData* training_data);
// Returns an Imagedata containing the image of the given textline,
// and ground truth boxes/truth text if available in the input.
// The image is not normalized in any way.
ImageData* GetLineData(const TBOX& line_box,
const GenericVector<TBOX>& boxes,
const GenericVector<STRING>& texts,
int start_box, int end_box,
const BLOCK& block);
// Helper gets the image of a rectangle, using the block.re_rotation() if
// needed to get to the image, and rotating the result back to horizontal
// layout. (CJK characters will be on their left sides) The vertical text flag
// is set in the returned ImageData if the text was originally vertical, which
// can be used to invoke a different CJK recognition engine. The revised_box
// is also returned to enable calculation of output bounding boxes.
ImageData* GetRectImage(const TBOX& box, const BLOCK& block, int padding,
TBOX* revised_box) const;
// Top-level function recognizes a single raw line.
void RecogRawLine(PAGE_RES* page_res);
// Recognizes a word or group of words, converting to WERD_RES in *words.
// Analogous to classify_word_pass1, but can handle a group of words as well.
void LSTMRecognizeWord(const BLOCK& block, ROW *row, WERD_RES *word,
PointerVector<WERD_RES>* words);
// Apply segmentation search to the given set of words, within the constraints
// of the existing ratings matrix. If there is already a best_choice on a word
// leaves it untouched and just sets the done/accepted etc flags.
void SearchWords(PointerVector<WERD_RES>* words);
//// control.h /////////////////////////////////////////////////////////
bool ProcessTargetWord(const TBOX& word_box, const TBOX& target_word_box,
const char* word_config, int pass);
@ -783,6 +841,8 @@ class Tesseract : public Wordrec {
"Generate training data from boxed chars");
BOOL_VAR_H(tessedit_make_boxes_from_boxes, false,
"Generate more boxes from boxed chars");
BOOL_VAR_H(tessedit_train_line_recognizer, false,
"Break input into lines and remap boxes if present");
BOOL_VAR_H(tessedit_dump_pageseg_images, false,
"Dump intermediate images made during page segmentation");
INT_VAR_H(tessedit_pageseg_mode, PSM_SINGLE_BLOCK,
@ -891,6 +951,7 @@ class Tesseract : public Wordrec {
"Run paragraph detection on the post-text-recognition "
"(more accurate)");
INT_VAR_H(cube_debug_level, 1, "Print cube debug info.");
BOOL_VAR_H(lstm_use_matrix, 1, "Use ratings matrix/beam searct with lstm");
STRING_VAR_H(outlines_odd, "%| ", "Non standard number of outlines");
STRING_VAR_H(outlines_2, "ij!?%\":;", "Non standard number of outlines");
BOOL_VAR_H(docqual_excuse_outline_errs, false,
@ -1174,6 +1235,8 @@ class Tesseract : public Wordrec {
Pix* cube_binary_;
// Grey-level input image if the input was not binary, otherwise NULL.
Pix* pix_grey_;
// Original input image. Color if the input was color.
Pix* pix_original_;
// Thresholds that were used to generate the thresholded image from grey.
Pix* pix_thresholds_;
// Input image resolution after any scaling. The resolution is not well
@ -1205,6 +1268,10 @@ class Tesseract : public Wordrec {
#endif
// Equation detector. Note: this pointer is NOT owned by the class.
EquationDetect* equ_detect_;
// LSTM recognizer, if available.
LSTMRecognizer* lstm_recognizer_;
// Output "page" number (actually line number) using TrainLineRecognizer.
int train_line_page_num_;
};
} // namespace tesseract

View File

@ -152,19 +152,27 @@ void ImageThresholder::SetImage(const Pix* pix) {
int depth;
pixGetDimensions(src, &image_width_, &image_height_, &depth);
// Convert the image as necessary so it is one of binary, plain RGB, or
// 8 bit with no colormap.
// 8 bit with no colormap. Guarantee that we always end up with our own copy,
// not just a clone of the input.
if (pixGetColormap(src)) {
Pix* tmp = pixRemoveColormap(src, REMOVE_CMAP_BASED_ON_SRC);
depth = pixGetDepth(tmp);
if (depth > 1 && depth < 8) {
pix_ = pixConvertTo8(src, false);
} else if (pixGetColormap(src)) {
pix_ = pixRemoveColormap(src, REMOVE_CMAP_BASED_ON_SRC);
pix_ = pixConvertTo8(tmp, false);
pixDestroy(&tmp);
} else {
pix_ = pixClone(src);
pix_ = tmp;
}
} else if (depth > 1 && depth < 8) {
pix_ = pixConvertTo8(src, false);
} else {
pix_ = pixCopy(NULL, src);
}
depth = pixGetDepth(pix_);
pix_channels_ = depth / 8;
pix_wpl_ = pixGetWpl(pix_);
scale_ = 1;
estimated_res_ = yres_ = pixGetYRes(src);
estimated_res_ = yres_ = pixGetYRes(pix_);
Init();
}

View File

@ -24,12 +24,18 @@
#include "imagedata.h"
#include <unistd.h>
#include "allheaders.h"
#include "boxread.h"
#include "callcpp.h"
#include "helpers.h"
#include "tprintf.h"
// Number of documents to read ahead while training. Doesn't need to be very
// large.
const int kMaxReadAhead = 8;
namespace tesseract {
WordFeature::WordFeature() : x_(0), y_(0), dir_(0) {
@ -182,6 +188,19 @@ bool ImageData::DeSerialize(bool swap, TFile* fp) {
return true;
}
// As DeSerialize, but only seeks past the data - hence a static method.
bool ImageData::SkipDeSerialize(bool swap, TFile* fp) {
if (!STRING::SkipDeSerialize(swap, fp)) return false;
inT32 page_number;
if (fp->FRead(&page_number, sizeof(page_number), 1) != 1) return false;
if (!GenericVector<char>::SkipDeSerialize(swap, fp)) return false;
if (!STRING::SkipDeSerialize(swap, fp)) return false;
if (!GenericVector<TBOX>::SkipDeSerialize(swap, fp)) return false;
if (!GenericVector<STRING>::SkipDeSerializeClasses(swap, fp)) return false;
inT8 vertical = 0;
return fp->FRead(&vertical, sizeof(vertical), 1) == 1;
}
// Saves the given Pix as a PNG-encoded string and destroys it.
void ImageData::SetPix(Pix* pix) {
SetPixInternal(pix, &image_data_);
@ -195,9 +214,10 @@ Pix* ImageData::GetPix() const {
// Gets anything and everything with a non-NULL pointer, prescaled to a
// given target_height (if 0, then the original image height), and aligned.
// Also returns (if not NULL) the width and height of the scaled image.
// The return value is the scale factor that was applied to the image to
// achieve the target_height.
float ImageData::PreScale(int target_height, Pix** pix,
// The return value is the scaled Pix, which must be pixDestroyed after use,
// and scale_factor (if not NULL) is set to the scale factor that was applied
// to the image to achieve the target_height.
Pix* ImageData::PreScale(int target_height, float* scale_factor,
int* scaled_width, int* scaled_height,
GenericVector<TBOX>* boxes) const {
int input_width = 0;
@ -213,19 +233,14 @@ float ImageData::PreScale(int target_height, Pix** pix,
*scaled_width = IntCastRounded(im_factor * input_width);
if (scaled_height != NULL)
*scaled_height = target_height;
if (pix != NULL) {
// Get the scaled image.
pixDestroy(pix);
*pix = pixScale(src_pix, im_factor, im_factor);
if (*pix == NULL) {
Pix* pix = pixScale(src_pix, im_factor, im_factor);
if (pix == NULL) {
tprintf("Scaling pix of size %d, %d by factor %g made null pix!!\n",
input_width, input_height, im_factor);
}
if (scaled_width != NULL)
*scaled_width = pixGetWidth(*pix);
if (scaled_height != NULL)
*scaled_height = pixGetHeight(*pix);
}
if (scaled_width != NULL) *scaled_width = pixGetWidth(pix);
if (scaled_height != NULL) *scaled_height = pixGetHeight(pix);
pixDestroy(&src_pix);
if (boxes != NULL) {
// Get the boxes.
@ -241,7 +256,8 @@ float ImageData::PreScale(int target_height, Pix** pix,
boxes->push_back(box);
}
}
return im_factor;
if (scale_factor != NULL) *scale_factor = im_factor;
return pix;
}
int ImageData::MemoryUsed() const {
@ -266,19 +282,20 @@ void ImageData::Display() const {
// Draw the boxes.
win->Pen(ScrollView::RED);
win->Brush(ScrollView::NONE);
win->TextAttributes("Arial", kTextSize, false, false, false);
int text_size = kTextSize;
if (!boxes_.empty() && boxes_[0].height() * 2 < text_size)
text_size = boxes_[0].height() * 2;
win->TextAttributes("Arial", text_size, false, false, false);
if (!boxes_.empty()) {
for (int b = 0; b < boxes_.size(); ++b) {
boxes_[b].plot(win);
win->Text(boxes_[b].left(), height + kTextSize, box_texts_[b].string());
TBOX scaled(boxes_[b]);
scaled.scale(256.0 / height);
scaled.plot(win);
}
} else {
// The full transcription.
win->Pen(ScrollView::CYAN);
win->Text(0, height + kTextSize * 2, transcription_.string());
// Add the features.
win->Pen(ScrollView::GREEN);
}
win->Update();
window_wait(win);
#endif
@ -340,27 +357,51 @@ bool ImageData::AddBoxes(const char* box_text) {
return false;
}
DocumentData::DocumentData(const STRING& name)
: document_name_(name), pages_offset_(0), total_pages_(0),
memory_used_(0), max_memory_(0), reader_(NULL) {}
// Thread function to call ReCachePages.
void* ReCachePagesFunc(void* data) {
DocumentData* document_data = reinterpret_cast<DocumentData*>(data);
document_data->ReCachePages();
return NULL;
}
DocumentData::~DocumentData() {}
DocumentData::DocumentData(const STRING& name)
: document_name_(name),
pages_offset_(-1),
total_pages_(-1),
memory_used_(0),
max_memory_(0),
reader_(NULL) {}
DocumentData::~DocumentData() {
SVAutoLock lock_p(&pages_mutex_);
SVAutoLock lock_g(&general_mutex_);
}
// Reads all the pages in the given lstmf filename to the cache. The reader
// is used to read the file.
bool DocumentData::LoadDocument(const char* filename, const char* lang,
int start_page, inT64 max_memory,
FileReader reader) {
SetDocument(filename, lang, max_memory, reader);
pages_offset_ = start_page;
return ReCachePages();
}
// Sets up the document, without actually loading it.
void DocumentData::SetDocument(const char* filename, const char* lang,
inT64 max_memory, FileReader reader) {
SVAutoLock lock_p(&pages_mutex_);
SVAutoLock lock(&general_mutex_);
document_name_ = filename;
lang_ = lang;
pages_offset_ = start_page;
pages_offset_ = -1;
max_memory_ = max_memory;
reader_ = reader;
return ReCachePages();
}
// Writes all the pages to the given filename. Returns false on error.
bool DocumentData::SaveDocument(const char* filename, FileWriter writer) {
SVAutoLock lock(&pages_mutex_);
TFile fp;
fp.OpenWrite(NULL);
if (!pages_.Serialize(&fp) || !fp.CloseWrite(filename, writer)) {
@ -370,112 +411,166 @@ bool DocumentData::SaveDocument(const char* filename, FileWriter writer) {
return true;
}
bool DocumentData::SaveToBuffer(GenericVector<char>* buffer) {
SVAutoLock lock(&pages_mutex_);
TFile fp;
fp.OpenWrite(buffer);
return pages_.Serialize(&fp);
}
// Returns a pointer to the page with the given index, modulo the total
// number of pages, recaching if needed.
const ImageData* DocumentData::GetPage(int index) {
index = Modulo(index, total_pages_);
if (index < pages_offset_ || index >= pages_offset_ + pages_.size()) {
pages_offset_ = index;
if (!ReCachePages()) return NULL;
}
return pages_[index - pages_offset_];
// Adds the given page data to this document, counting up memory.
void DocumentData::AddPageToDocument(ImageData* page) {
SVAutoLock lock(&pages_mutex_);
pages_.push_back(page);
set_memory_used(memory_used() + page->MemoryUsed());
}
// Loads as many pages can fit in max_memory_ starting at index pages_offset_.
// If the given index is not currently loaded, loads it using a separate
// thread.
void DocumentData::LoadPageInBackground(int index) {
ImageData* page = NULL;
if (IsPageAvailable(index, &page)) return;
SVAutoLock lock(&pages_mutex_);
if (pages_offset_ == index) return;
pages_offset_ = index;
pages_.clear();
SVSync::StartThread(ReCachePagesFunc, this);
}
// Returns a pointer to the page with the given index, modulo the total
// number of pages. Blocks until the background load is completed.
const ImageData* DocumentData::GetPage(int index) {
ImageData* page = NULL;
while (!IsPageAvailable(index, &page)) {
// If there is no background load scheduled, schedule one now.
pages_mutex_.Lock();
bool needs_loading = pages_offset_ != index;
pages_mutex_.Unlock();
if (needs_loading) LoadPageInBackground(index);
// We can't directly load the page, or the background load will delete it
// while the caller is using it, so give it a chance to work.
sleep(1);
}
return page;
}
// Returns true if the requested page is available, and provides a pointer,
// which may be NULL if the document is empty. May block, even though it
// doesn't guarantee to return true.
bool DocumentData::IsPageAvailable(int index, ImageData** page) {
SVAutoLock lock(&pages_mutex_);
int num_pages = NumPages();
if (num_pages == 0 || index < 0) {
*page = NULL; // Empty Document.
return true;
}
if (num_pages > 0) {
index = Modulo(index, num_pages);
if (pages_offset_ <= index && index < pages_offset_ + pages_.size()) {
*page = pages_[index - pages_offset_]; // Page is available already.
return true;
}
}
return false;
}
// Removes all pages from memory and frees the memory, but does not forget
// the document metadata.
inT64 DocumentData::UnCache() {
SVAutoLock lock(&pages_mutex_);
inT64 memory_saved = memory_used();
pages_.clear();
pages_offset_ = -1;
set_total_pages(-1);
set_memory_used(0);
tprintf("Unloaded document %s, saving %d memory\n", document_name_.string(),
memory_saved);
return memory_saved;
}
// Locks the pages_mutex_ and Loads as many pages can fit in max_memory_
// starting at index pages_offset_.
bool DocumentData::ReCachePages() {
SVAutoLock lock(&pages_mutex_);
// Read the file.
TFile fp;
if (!fp.Open(document_name_, reader_)) return false;
memory_used_ = 0;
if (!pages_.DeSerialize(false, &fp)) {
tprintf("Deserialize failed: %s\n", document_name_.string());
set_total_pages(0);
set_memory_used(0);
int loaded_pages = 0;
pages_.truncate(0);
TFile fp;
if (!fp.Open(document_name_, reader_) ||
!PointerVector<ImageData>::DeSerializeSize(false, &fp, &loaded_pages) ||
loaded_pages <= 0) {
tprintf("Deserialize header failed: %s\n", document_name_.string());
return false;
}
total_pages_ = pages_.size();
pages_offset_ %= total_pages_;
// Delete pages before the first one we want, and relocate the rest.
pages_offset_ %= loaded_pages;
// Skip pages before the first one we want, and load the rest until max
// memory and skip the rest after that.
int page;
for (page = 0; page < pages_.size(); ++page) {
if (page < pages_offset_) {
delete pages_[page];
pages_[page] = NULL;
for (page = 0; page < loaded_pages; ++page) {
if (page < pages_offset_ ||
(max_memory_ > 0 && memory_used() > max_memory_)) {
if (!PointerVector<ImageData>::DeSerializeSkip(false, &fp)) break;
} else {
ImageData* image_data = pages_[page];
if (max_memory_ > 0 && page > pages_offset_ &&
memory_used_ + image_data->MemoryUsed() > max_memory_)
break; // Don't go over memory quota unless the first image.
if (!pages_.DeSerializeElement(false, &fp)) break;
ImageData* image_data = pages_.back();
if (image_data->imagefilename().length() == 0) {
image_data->set_imagefilename(document_name_);
image_data->set_page_number(page);
}
image_data->set_language(lang_);
memory_used_ += image_data->MemoryUsed();
if (pages_offset_ != 0) {
pages_[page - pages_offset_] = image_data;
pages_[page] = NULL;
set_memory_used(memory_used() + image_data->MemoryUsed());
}
}
if (page < loaded_pages) {
tprintf("Deserialize failed: %s read %d/%d pages\n",
document_name_.string(), page, loaded_pages);
pages_.truncate(0);
} else {
tprintf("Loaded %d/%d pages (%d-%d) of document %s\n", pages_.size(),
loaded_pages, pages_offset_, pages_offset_ + pages_.size(),
document_name_.string());
}
pages_.truncate(page - pages_offset_);
tprintf("Loaded %d/%d pages (%d-%d) of document %s\n",
pages_.size(), total_pages_, pages_offset_,
pages_offset_ + pages_.size(), document_name_.string());
set_total_pages(loaded_pages);
return !pages_.empty();
}
// Adds the given page data to this document, counting up memory.
void DocumentData::AddPageToDocument(ImageData* page) {
pages_.push_back(page);
memory_used_ += page->MemoryUsed();
}
// A collection of DocumentData that knows roughly how much memory it is using.
DocumentCache::DocumentCache(inT64 max_memory)
: total_pages_(0), memory_used_(0), max_memory_(max_memory) {}
: num_pages_per_doc_(0), max_memory_(max_memory) {}
DocumentCache::~DocumentCache() {}
// Adds all the documents in the list of filenames, counting memory.
// The reader is used to read the files.
bool DocumentCache::LoadDocuments(const GenericVector<STRING>& filenames,
const char* lang, FileReader reader) {
inT64 fair_share_memory = max_memory_ / filenames.size();
const char* lang,
CachingStrategy cache_strategy,
FileReader reader) {
cache_strategy_ = cache_strategy;
inT64 fair_share_memory = 0;
// In the round-robin case, each DocumentData handles restricting its content
// to its fair share of memory. In the sequential case, DocumentCache
// determines which DocumentDatas are held entirely in memory.
if (cache_strategy_ == CS_ROUND_ROBIN)
fair_share_memory = max_memory_ / filenames.size();
for (int arg = 0; arg < filenames.size(); ++arg) {
STRING filename = filenames[arg];
DocumentData* document = new DocumentData(filename);
if (document->LoadDocument(filename.string(), lang, 0,
fair_share_memory, reader)) {
document->SetDocument(filename.string(), lang, fair_share_memory, reader);
AddToCache(document);
} else {
tprintf("Failed to load image %s!\n", filename.string());
delete document;
}
if (!documents_.empty()) {
// Try to get the first page now to verify the list of filenames.
if (GetPageBySerial(0) != NULL) return true;
tprintf("Load of page 0 failed!\n");
}
tprintf("Loaded %d pages, total %gMB\n",
total_pages_, memory_used_ / 1048576.0);
return total_pages_ > 0;
return false;
}
// Adds document to the cache, throwing out other documents if needed.
// Adds document to the cache.
bool DocumentCache::AddToCache(DocumentData* data) {
inT64 new_memory = data->memory_used();
memory_used_ += new_memory;
documents_.push_back(data);
total_pages_ += data->NumPages();
// Delete the first item in the array, and other pages of the same name
// while memory is full.
while (memory_used_ >= max_memory_ && max_memory_ > 0) {
tprintf("Memory used=%lld vs max=%lld, discarding doc of size %lld\n",
memory_used_ , max_memory_, documents_[0]->memory_used());
memory_used_ -= documents_[0]->memory_used();
total_pages_ -= documents_[0]->NumPages();
documents_.remove(0);
}
return true;
}
@ -488,11 +583,104 @@ DocumentData* DocumentCache::FindDocument(const STRING& document_name) const {
return NULL;
}
// Returns the total number of pages in an epoch. For CS_ROUND_ROBIN cache
// strategy, could take a long time.
int DocumentCache::TotalPages() {
if (cache_strategy_ == CS_SEQUENTIAL) {
// In sequential mode, we assume each doc has the same number of pages
// whether it is true or not.
if (num_pages_per_doc_ == 0) GetPageSequential(0);
return num_pages_per_doc_ * documents_.size();
}
int total_pages = 0;
int num_docs = documents_.size();
for (int d = 0; d < num_docs; ++d) {
// We have to load a page to make NumPages() valid.
documents_[d]->GetPage(0);
total_pages += documents_[d]->NumPages();
}
return total_pages;
}
// Returns a page by serial number, selecting them in a round-robin fashion
// from all the documents.
const ImageData* DocumentCache::GetPageBySerial(int serial) {
int document_index = serial % documents_.size();
return documents_[document_index]->GetPage(serial / documents_.size());
// from all the documents. Highly disk-intensive, but doesn't need samples
// to be shuffled between files to begin with.
const ImageData* DocumentCache::GetPageRoundRobin(int serial) {
int num_docs = documents_.size();
int doc_index = serial % num_docs;
const ImageData* doc = documents_[doc_index]->GetPage(serial / num_docs);
for (int offset = 1; offset <= kMaxReadAhead && offset < num_docs; ++offset) {
doc_index = (serial + offset) % num_docs;
int page = (serial + offset) / num_docs;
documents_[doc_index]->LoadPageInBackground(page);
}
return doc;
}
// Returns a page by serial number, selecting them in sequence from each file.
// Requires the samples to be shuffled between the files to give a random or
// uniform distribution of data. Less disk-intensive than GetPageRoundRobin.
const ImageData* DocumentCache::GetPageSequential(int serial) {
int num_docs = documents_.size();
ASSERT_HOST(num_docs > 0);
if (num_pages_per_doc_ == 0) {
// Use the pages in the first doc as the number of pages in each doc.
documents_[0]->GetPage(0);
num_pages_per_doc_ = documents_[0]->NumPages();
if (num_pages_per_doc_ == 0) {
tprintf("First document cannot be empty!!\n");
ASSERT_HOST(num_pages_per_doc_ > 0);
}
// Get rid of zero now if we don't need it.
if (serial / num_pages_per_doc_ % num_docs > 0) documents_[0]->UnCache();
}
int doc_index = serial / num_pages_per_doc_ % num_docs;
const ImageData* doc =
documents_[doc_index]->GetPage(serial % num_pages_per_doc_);
// Count up total memory. Background loading makes it more complicated to
// keep a running count.
inT64 total_memory = 0;
for (int d = 0; d < num_docs; ++d) {
total_memory += documents_[d]->memory_used();
}
if (total_memory >= max_memory_) {
// Find something to un-cache.
// If there are more than 3 in front, then serial is from the back reader
// of a pair of readers. If we un-cache from in-front-2 to 2-ahead, then
// we create a hole between them and then un-caching the backmost occupied
// will work for both.
int num_in_front = CountNeighbourDocs(doc_index, 1);
for (int offset = num_in_front - 2;
offset > 1 && total_memory >= max_memory_; --offset) {
int next_index = (doc_index + offset) % num_docs;
total_memory -= documents_[next_index]->UnCache();
}
// If that didn't work, the best solution is to un-cache from the back. If
// we take away the document that a 2nd reader is using, it will put it
// back and make a hole between.
int num_behind = CountNeighbourDocs(doc_index, -1);
for (int offset = num_behind; offset < 0 && total_memory >= max_memory_;
++offset) {
int next_index = (doc_index + offset + num_docs) % num_docs;
total_memory -= documents_[next_index]->UnCache();
}
}
int next_index = (doc_index + 1) % num_docs;
if (!documents_[next_index]->IsCached() && total_memory < max_memory_) {
documents_[next_index]->LoadPageInBackground(0);
}
return doc;
}
// Helper counts the number of adjacent cached neighbours of index looking in
// direction dir, ie index+dir, index+2*dir etc.
int DocumentCache::CountNeighbourDocs(int index, int dir) {
int num_docs = documents_.size();
for (int offset = dir; abs(offset) < num_docs; offset += dir) {
int offset_index = (index + offset + num_docs) % num_docs;
if (!documents_[offset_index]->IsCached()) return offset - dir;
}
return num_docs;
}
} // namespace tesseract.

View File

@ -25,6 +25,7 @@
#include "normalis.h"
#include "rect.h"
#include "strngs.h"
#include "svutil.h"
struct Pix;
@ -34,8 +35,22 @@ namespace tesseract {
const int kFeaturePadding = 2;
// Number of pixels to pad around text boxes.
const int kImagePadding = 4;
// Number of training images to combine into a mini-batch for training.
const int kNumPagesPerMiniBatch = 100;
// Enum to determine the caching and data sequencing strategy.
enum CachingStrategy {
// Reads all of one file before moving on to the next. Requires samples to be
// shuffled across files. Uses the count of samples in the first file as
// the count in all the files to achieve high-speed random access. As a
// consequence, if subsequent files are smaller, they get entries used more
// than once, and if subsequent files are larger, some entries are not used.
// Best for larger data sets that don't fit in memory.
CS_SEQUENTIAL,
// Reads one sample from each file in rotation. Does not require shuffled
// samples, but is extremely disk-intensive. Samples in smaller files also
// get used more often than samples in larger files.
// Best for smaller data sets that mostly fit in memory.
CS_ROUND_ROBIN,
};
class WordFeature {
public:
@ -103,6 +118,8 @@ class ImageData {
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, TFile* fp);
// As DeSerialize, but only seeks past the data - hence a static method.
static bool SkipDeSerialize(bool swap, tesseract::TFile* fp);
// Other accessors.
const STRING& imagefilename() const {
@ -145,11 +162,11 @@ class ImageData {
// Gets anything and everything with a non-NULL pointer, prescaled to a
// given target_height (if 0, then the original image height), and aligned.
// Also returns (if not NULL) the width and height of the scaled image.
// The return value is the scale factor that was applied to the image to
// achieve the target_height.
float PreScale(int target_height, Pix** pix,
int* scaled_width, int* scaled_height,
GenericVector<TBOX>* boxes) const;
// The return value is the scaled Pix, which must be pixDestroyed after use,
// and scale_factor (if not NULL) is set to the scale factor that was applied
// to the image to achieve the target_height.
Pix* PreScale(int target_height, float* scale_factor, int* scaled_width,
int* scaled_height, GenericVector<TBOX>* boxes) const;
int MemoryUsed() const;
@ -184,6 +201,8 @@ class ImageData {
// A collection of ImageData that knows roughly how much memory it is using.
class DocumentData {
friend void* ReCachePagesFunc(void* data);
public:
explicit DocumentData(const STRING& name);
~DocumentData();
@ -192,6 +211,9 @@ class DocumentData {
// is used to read the file.
bool LoadDocument(const char* filename, const char* lang, int start_page,
inT64 max_memory, FileReader reader);
// Sets up the document, without actually loading it.
void SetDocument(const char* filename, const char* lang, inT64 max_memory,
FileReader reader);
// Writes all the pages to the given filename. Returns false on error.
bool SaveDocument(const char* filename, FileWriter writer);
bool SaveToBuffer(GenericVector<char>* buffer);
@ -200,26 +222,62 @@ class DocumentData {
void AddPageToDocument(ImageData* page);
const STRING& document_name() const {
SVAutoLock lock(&general_mutex_);
return document_name_;
}
int NumPages() const {
SVAutoLock lock(&general_mutex_);
return total_pages_;
}
inT64 memory_used() const {
SVAutoLock lock(&general_mutex_);
return memory_used_;
}
// If the given index is not currently loaded, loads it using a separate
// thread. Note: there are 4 cases:
// Document uncached: IsCached() returns false, total_pages_ < 0.
// Required page is available: IsPageAvailable returns true. In this case,
// total_pages_ > 0 and
// pages_offset_ <= index%total_pages_ <= pages_offset_+pages_.size()
// Pages are loaded, but the required one is not.
// The requested page is being loaded by LoadPageInBackground. In this case,
// index == pages_offset_. Once the loading starts, the pages lock is held
// until it completes, at which point IsPageAvailable will unblock and return
// true.
void LoadPageInBackground(int index);
// Returns a pointer to the page with the given index, modulo the total
// number of pages, recaching if needed.
// number of pages. Blocks until the background load is completed.
const ImageData* GetPage(int index);
// Returns true if the requested page is available, and provides a pointer,
// which may be NULL if the document is empty. May block, even though it
// doesn't guarantee to return true.
bool IsPageAvailable(int index, ImageData** page);
// Takes ownership of the given page index. The page is made NULL in *this.
ImageData* TakePage(int index) {
SVAutoLock lock(&pages_mutex_);
ImageData* page = pages_[index];
pages_[index] = NULL;
return page;
}
// Returns true if the document is currently loaded or in the process of
// loading.
bool IsCached() const { return NumPages() >= 0; }
// Removes all pages from memory and frees the memory, but does not forget
// the document metadata. Returns the memory saved.
inT64 UnCache();
private:
// Loads as many pages can fit in max_memory_ starting at index pages_offset_.
// Sets the value of total_pages_ behind a mutex.
void set_total_pages(int total) {
SVAutoLock lock(&general_mutex_);
total_pages_ = total;
}
void set_memory_used(inT64 memory_used) {
SVAutoLock lock(&general_mutex_);
memory_used_ = memory_used;
}
// Locks the pages_mutex_ and Loads as many pages can fit in max_memory_
// starting at index pages_offset_.
bool ReCachePages();
private:
@ -239,43 +297,77 @@ class DocumentData {
inT64 max_memory_;
// Saved reader from LoadDocument to allow re-caching.
FileReader reader_;
// Mutex that protects pages_ and pages_offset_ against multiple parallel
// loads, and provides a wait for page.
SVMutex pages_mutex_;
// Mutex that protects other data members that callers want to access without
// waiting for a load operation.
mutable SVMutex general_mutex_;
};
// A collection of DocumentData that knows roughly how much memory it is using.
// Note that while it supports background read-ahead, it assumes that a single
// thread is accessing documents, ie it is not safe for multiple threads to
// access different documents in parallel, as one may de-cache the other's
// content.
class DocumentCache {
public:
explicit DocumentCache(inT64 max_memory);
~DocumentCache();
// Deletes all existing documents from the cache.
void Clear() {
documents_.clear();
num_pages_per_doc_ = 0;
}
// Adds all the documents in the list of filenames, counting memory.
// The reader is used to read the files.
bool LoadDocuments(const GenericVector<STRING>& filenames, const char* lang,
FileReader reader);
CachingStrategy cache_strategy, FileReader reader);
// Adds document to the cache, throwing out other documents if needed.
// Adds document to the cache.
bool AddToCache(DocumentData* data);
// Finds and returns a document by name.
DocumentData* FindDocument(const STRING& document_name) const;
// Returns a page by serial number, selecting them in a round-robin fashion
// from all the documents.
const ImageData* GetPageBySerial(int serial);
// Returns a page by serial number using the current cache_strategy_ to
// determine the mapping from serial number to page.
const ImageData* GetPageBySerial(int serial) {
if (cache_strategy_ == CS_SEQUENTIAL)
return GetPageSequential(serial);
else
return GetPageRoundRobin(serial);
}
const PointerVector<DocumentData>& documents() const {
return documents_;
}
int total_pages() const {
return total_pages_;
}
// Returns the total number of pages in an epoch. For CS_ROUND_ROBIN cache
// strategy, could take a long time.
int TotalPages();
private:
// Returns a page by serial number, selecting them in a round-robin fashion
// from all the documents. Highly disk-intensive, but doesn't need samples
// to be shuffled between files to begin with.
const ImageData* GetPageRoundRobin(int serial);
// Returns a page by serial number, selecting them in sequence from each file.
// Requires the samples to be shuffled between the files to give a random or
// uniform distribution of data. Less disk-intensive than GetPageRoundRobin.
const ImageData* GetPageSequential(int serial);
// Helper counts the number of adjacent cached neighbour documents_ of index
// looking in direction dir, ie index+dir, index+2*dir etc.
int CountNeighbourDocs(int index, int dir);
// A group of pages that corresponds in some loose way to a document.
PointerVector<DocumentData> documents_;
// Total of all pages.
int total_pages_;
// Total of all memory used by the cache.
inT64 memory_used_;
// Strategy to use for caching and serializing data samples.
CachingStrategy cache_strategy_;
// Number of pages in the first document, used as a divisor in
// GetPageSequential to determine the document index.
int num_pages_per_doc_;
// Max memory allowed in this cache.
inT64 max_memory_;
};

View File

@ -1,8 +1,12 @@
/* -*-C-*-
******************************************************************************
*
* File: matrix.h (Formerly matrix.h)
* Description: Ratings matrix code. (Used by associator)
* Description: Generic 2-d array/matrix and banded triangular matrix class.
* Author: Ray Smith
* TODO(rays) Separate from ratings matrix, which it also contains:
*
* Descrition: Ratings matrix class (specialization of banded matrix).
* Segmentation search matrix of lists of BLOB_CHOICE.
* Author: Mark Seaman, OCR Technology
* Created: Wed May 16 13:22:06 1990
* Modified: Tue Mar 19 16:00:20 1991 (Mark Seaman) marks@hpgrlt
@ -25,9 +29,13 @@
#ifndef TESSERACT_CCSTRUCT_MATRIX_H__
#define TESSERACT_CCSTRUCT_MATRIX_H__
#include <math.h>
#include "kdpair.h"
#include "points.h"
#include "serialis.h"
#include "unicharset.h"
class BLOB_CHOICE;
class BLOB_CHOICE_LIST;
#define NOT_CLASSIFIED reinterpret_cast<BLOB_CHOICE_LIST*>(0)
@ -44,34 +52,60 @@ class GENERIC_2D_ARRAY {
// either pass the memory in, or allocate after by calling Resize().
GENERIC_2D_ARRAY(int dim1, int dim2, const T& empty, T* array)
: empty_(empty), dim1_(dim1), dim2_(dim2), array_(array) {
size_allocated_ = dim1 * dim2;
}
// Original constructor for a full rectangular matrix DOES allocate memory
// and initialize it to empty.
GENERIC_2D_ARRAY(int dim1, int dim2, const T& empty)
: empty_(empty), dim1_(dim1), dim2_(dim2) {
array_ = new T[dim1_ * dim2_];
for (int x = 0; x < dim1_; x++)
for (int y = 0; y < dim2_; y++)
this->put(x, y, empty_);
int new_size = dim1 * dim2;
array_ = new T[new_size];
size_allocated_ = new_size;
for (int i = 0; i < size_allocated_; ++i)
array_[i] = empty_;
}
// Default constructor for array allocation. Use Resize to set the size.
GENERIC_2D_ARRAY()
: array_(NULL), empty_(static_cast<T>(0)), dim1_(0), dim2_(0),
size_allocated_(0) {
}
GENERIC_2D_ARRAY(const GENERIC_2D_ARRAY<T>& src)
: array_(NULL), empty_(static_cast<T>(0)), dim1_(0), dim2_(0),
size_allocated_(0) {
*this = src;
}
virtual ~GENERIC_2D_ARRAY() { delete[] array_; }
void operator=(const GENERIC_2D_ARRAY<T>& src) {
ResizeNoInit(src.dim1(), src.dim2());
memcpy(array_, src.array_, num_elements() * sizeof(array_[0]));
}
// Reallocate the array to the given size. Does not keep old data, but does
// not initialize the array either.
void ResizeNoInit(int size1, int size2) {
int new_size = size1 * size2;
if (new_size > size_allocated_) {
delete [] array_;
array_ = new T[new_size];
size_allocated_ = new_size;
}
dim1_ = size1;
dim2_ = size2;
}
// Reallocate the array to the given size. Does not keep old data.
void Resize(int size1, int size2, const T& empty) {
empty_ = empty;
if (size1 != dim1_ || size2 != dim2_) {
dim1_ = size1;
dim2_ = size2;
delete [] array_;
array_ = new T[dim1_ * dim2_];
}
ResizeNoInit(size1, size2);
Clear();
}
// Reallocate the array to the given size, keeping old data.
void ResizeWithCopy(int size1, int size2) {
if (size1 != dim1_ || size2 != dim2_) {
T* new_array = new T[size1 * size2];
int new_size = size1 * size2;
T* new_array = new T[new_size];
for (int col = 0; col < size1; ++col) {
for (int row = 0; row < size2; ++row) {
int old_index = col * dim2() + row;
@ -87,6 +121,7 @@ class GENERIC_2D_ARRAY {
array_ = new_array;
dim1_ = size1;
dim2_ = size2;
size_allocated_ = new_size;
}
}
@ -106,9 +141,16 @@ class GENERIC_2D_ARRAY {
if (fwrite(array_, sizeof(*array_), size, fp) != size) return false;
return true;
}
bool Serialize(tesseract::TFile* fp) const {
if (!SerializeSize(fp)) return false;
if (fp->FWrite(&empty_, sizeof(empty_), 1) != 1) return false;
int size = num_elements();
if (fp->FWrite(array_, sizeof(*array_), size) != size) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// Only works with bitwise-serializeable typ
// Only works with bitwise-serializeable types!
// If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, FILE* fp) {
if (!DeSerializeSize(swap, fp)) return false;
@ -122,6 +164,18 @@ class GENERIC_2D_ARRAY {
}
return true;
}
bool DeSerialize(bool swap, tesseract::TFile* fp) {
if (!DeSerializeSize(swap, fp)) return false;
if (fp->FRead(&empty_, sizeof(empty_), 1) != 1) return false;
if (swap) ReverseN(&empty_, sizeof(empty_));
int size = num_elements();
if (fp->FRead(array_, sizeof(*array_), size) != size) return false;
if (swap) {
for (int i = 0; i < size; ++i)
ReverseN(&array_[i], sizeof(array_[i]));
}
return true;
}
// Writes to the given file. Returns false in case of error.
// Assumes a T::Serialize(FILE*) const function.
@ -163,11 +217,17 @@ class GENERIC_2D_ARRAY {
}
// Put a list element into the matrix at a specific location.
void put(ICOORD pos, const T& thing) {
array_[this->index(pos.x(), pos.y())] = thing;
}
void put(int column, int row, const T& thing) {
array_[this->index(column, row)] = thing;
}
// Get the item at a specified location from the matrix.
T get(ICOORD pos) const {
return array_[this->index(pos.x(), pos.y())];
}
T get(int column, int row) const {
return array_[this->index(column, row)];
}
@ -187,6 +247,207 @@ class GENERIC_2D_ARRAY {
return &array_[this->index(column, 0)];
}
// Adds addend to *this, element-by-element.
void operator+=(const GENERIC_2D_ARRAY<T>& addend) {
if (dim2_ == addend.dim2_) {
// Faster if equal size in the major dimension.
int size = MIN(num_elements(), addend.num_elements());
for (int i = 0; i < size; ++i) {
array_[i] += addend.array_[i];
}
} else {
for (int x = 0; x < dim1_; x++) {
for (int y = 0; y < dim2_; y++) {
(*this)(x, y) += addend(x, y);
}
}
}
}
// Subtracts minuend from *this, element-by-element.
void operator-=(const GENERIC_2D_ARRAY<T>& minuend) {
if (dim2_ == minuend.dim2_) {
// Faster if equal size in the major dimension.
int size = MIN(num_elements(), minuend.num_elements());
for (int i = 0; i < size; ++i) {
array_[i] -= minuend.array_[i];
}
} else {
for (int x = 0; x < dim1_; x++) {
for (int y = 0; y < dim2_; y++) {
(*this)(x, y) -= minuend(x, y);
}
}
}
}
// Adds addend to all elements.
void operator+=(const T& addend) {
int size = num_elements();
for (int i = 0; i < size; ++i) {
array_[i] += addend;
}
}
// Multiplies *this by factor, element-by-element.
void operator*=(const T& factor) {
int size = num_elements();
for (int i = 0; i < size; ++i) {
array_[i] *= factor;
}
}
// Clips *this to the given range.
void Clip(const T& rangemin, const T& rangemax) {
int size = num_elements();
for (int i = 0; i < size; ++i) {
array_[i] = ClipToRange(array_[i], rangemin, rangemax);
}
}
// Returns true if all elements of *this are within the given range.
// Only uses operator<
bool WithinBounds(const T& rangemin, const T& rangemax) const {
int size = num_elements();
for (int i = 0; i < size; ++i) {
const T& value = array_[i];
if (value < rangemin || rangemax < value)
return false;
}
return true;
}
// Normalize the whole array.
double Normalize() {
int size = num_elements();
if (size <= 0) return 0.0;
// Compute the mean.
double mean = 0.0;
for (int i = 0; i < size; ++i) {
mean += array_[i];
}
mean /= size;
// Subtract the mean and compute the standard deviation.
double sd = 0.0;
for (int i = 0; i < size; ++i) {
double normed = array_[i] - mean;
array_[i] = normed;
sd += normed * normed;
}
sd = sqrt(sd / size);
if (sd > 0.0) {
// Divide by the sd.
for (int i = 0; i < size; ++i) {
array_[i] /= sd;
}
}
return sd;
}
// Returns the maximum value of the array.
T Max() const {
int size = num_elements();
if (size <= 0) return empty_;
// Compute the max.
T max_value = array_[0];
for (int i = 1; i < size; ++i) {
const T& value = array_[i];
if (value > max_value) max_value = value;
}
return max_value;
}
// Returns the maximum absolute value of the array.
T MaxAbs() const {
int size = num_elements();
if (size <= 0) return empty_;
// Compute the max.
T max_abs = static_cast<T>(0);
for (int i = 0; i < size; ++i) {
T value = static_cast<T>(fabs(array_[i]));
if (value > max_abs) max_abs = value;
}
return max_abs;
}
// Accumulates the element-wise sums of squares of src into *this.
void SumSquares(const GENERIC_2D_ARRAY<T>& src) {
int size = num_elements();
for (int i = 0; i < size; ++i) {
array_[i] += src.array_[i] * src.array_[i];
}
}
// Scales each element using the ada-grad algorithm, ie array_[i] by
// sqrt(num_samples/max(1,sqsum[i])).
void AdaGradScaling(const GENERIC_2D_ARRAY<T>& sqsum, int num_samples) {
int size = num_elements();
for (int i = 0; i < size; ++i) {
array_[i] *= sqrt(num_samples / MAX(1.0, sqsum.array_[i]));
}
}
void AssertFinite() const {
int size = num_elements();
for (int i = 0; i < size; ++i) {
ASSERT_HOST(isfinite(array_[i]));
}
}
// REGARDLESS OF THE CURRENT DIMENSIONS, treats the data as a
// num_dims-dimensional array/tensor with dimensions given by dims, (ordered
// from most significant to least significant, the same as standard C arrays)
// and moves src_dim to dest_dim, with the initial dest_dim and any dimensions
// in between shifted towards the hole left by src_dim. Example:
// Current data content: array_=[0, 1, 2, ....119]
// perhaps *this may be of dim[40, 3], with values [[0, 1, 2][3, 4, 5]...
// but the current dimensions are irrelevant.
// num_dims = 4, dims=[5, 4, 3, 2]
// src_dim=3, dest_dim=1
// tensor=[[[[0, 1][2, 3][4, 5]]
// [[6, 7][8, 9][10, 11]]
// [[12, 13][14, 15][16, 17]]
// [[18, 19][20, 21][22, 23]]]
// [[[24, 25]...
// output dims =[5, 2, 4, 3]
// output tensor=[[[[0, 2, 4][6, 8, 10][12, 14, 16][18, 20, 22]]
// [[1, 3, 5][7, 9, 11][13, 15, 17][19, 21, 23]]]
// [[[24, 26, 28]...
// which is stored in the array_ as:
// [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 1, 3, 5, 7, 9, 11, 13...]
// NOTE: the 2 stored matrix dimensions are simply copied from *this. To
// change the dimensions after the transpose, use ResizeNoInit.
// Higher dimensions above 2 are strictly the responsibility of the caller.
void RotatingTranspose(const int* dims, int num_dims, int src_dim,
int dest_dim, GENERIC_2D_ARRAY<T>* result) const {
int max_d = MAX(src_dim, dest_dim);
int min_d = MIN(src_dim, dest_dim);
// In a tensor of shape [d0, d1... min_d, ... max_d, ... dn-2, dn-1], the
// ends outside of min_d and max_d are unaffected, with [max_d +1, dn-1]
// being contiguous blocks of data that will move together, and
// [d0, min_d -1] being replicas of the transpose operation.
// num_replicas represents the large dimensions unchanged by the operation.
// move_size represents the small dimensions unchanged by the operation.
// src_step represents the stride in the src between each adjacent group
// in the destination.
int num_replicas = 1, move_size = 1, src_step = 1;
for (int d = 0; d < min_d; ++d) num_replicas *= dims[d];
for (int d = max_d + 1; d < num_dims; ++d) move_size *= dims[d];
for (int d = src_dim + 1; d < num_dims; ++d) src_step *= dims[d];
if (src_dim > dest_dim) src_step *= dims[src_dim];
// wrap_size is the size of a single replica, being the amount that is
// handled num_replicas times.
int wrap_size = move_size;
for (int d = min_d; d <= max_d; ++d) wrap_size *= dims[d];
result->ResizeNoInit(dim1_, dim2_);
result->empty_ = empty_;
const T* src = array_;
T* dest = result->array_;
for (int replica = 0; replica < num_replicas; ++replica) {
for (int start = 0; start < src_step; start += move_size) {
for (int pos = start; pos < wrap_size; pos += src_step) {
memcpy(dest, src + pos, sizeof(*dest) * move_size);
dest += move_size;
}
}
src += wrap_size;
}
}
// Delete objects pointed to by array_[i].
void delete_matrix_pointers() {
int size = num_elements();
@ -206,6 +467,13 @@ class GENERIC_2D_ARRAY {
if (fwrite(&size, sizeof(size), 1, fp) != 1) return false;
return true;
}
bool SerializeSize(tesseract::TFile* fp) const {
inT32 size = dim1_;
if (fp->FWrite(&size, sizeof(size), 1) != 1) return false;
size = dim2_;
if (fp->FWrite(&size, sizeof(size), 1) != 1) return false;
return true;
}
// Factored helper to deserialize the size.
// If swap is true, assumes a big/little-endian swap is needed.
bool DeSerializeSize(bool swap, FILE* fp) {
@ -219,11 +487,26 @@ class GENERIC_2D_ARRAY {
Resize(size1, size2, empty_);
return true;
}
bool DeSerializeSize(bool swap, tesseract::TFile* fp) {
inT32 size1, size2;
if (fp->FRead(&size1, sizeof(size1), 1) != 1) return false;
if (fp->FRead(&size2, sizeof(size2), 1) != 1) return false;
if (swap) {
ReverseN(&size1, sizeof(size1));
ReverseN(&size2, sizeof(size2));
}
Resize(size1, size2, empty_);
return true;
}
T* array_;
T empty_; // The unused cell.
int dim1_; // Size of the 1st dimension in indexing functions.
int dim2_; // Size of the 2nd dimension in indexing functions.
// The total size to which the array can be expanded before a realloc is
// needed. If Resize is used, memory is retained so it can be re-expanded
// without a further alloc, and this stores the allocated size.
int size_allocated_;
};
// A generic class to store a banded triangular matrix with entries of type T.

View File

@ -304,6 +304,7 @@ bool WERD_RES::SetupForRecognition(const UNICHARSET& unicharset_in,
tesseract = tess;
POLY_BLOCK* pb = block != NULL ? block->poly_block() : NULL;
if ((norm_mode_hint != tesseract::OEM_CUBE_ONLY &&
norm_mode_hint != tesseract::OEM_LSTM_ONLY &&
word->cblob_list()->empty()) || (pb != NULL && !pb->IsText())) {
// Empty words occur when all the blobs have been moved to the rej_blobs
// list, which seems to occur frequently in junk.
@ -882,17 +883,17 @@ void WERD_RES::FakeClassifyWord(int blob_count, BLOB_CHOICE** choices) {
choice_it.add_after_then_move(choices[c]);
ratings->put(c, c, choice_list);
}
FakeWordFromRatings();
FakeWordFromRatings(TOP_CHOICE_PERM);
reject_map.initialise(blob_count);
done = true;
}
// Creates a WERD_CHOICE for the word using the top choices from the leading
// diagonal of the ratings matrix.
void WERD_RES::FakeWordFromRatings() {
void WERD_RES::FakeWordFromRatings(PermuterType permuter) {
int num_blobs = ratings->dimension();
WERD_CHOICE* word_choice = new WERD_CHOICE(uch_set, num_blobs);
word_choice->set_permuter(TOP_CHOICE_PERM);
word_choice->set_permuter(permuter);
for (int b = 0; b < num_blobs; ++b) {
UNICHAR_ID unichar_id = UNICHAR_SPACE;
float rating = MAX_INT32;
@ -1105,6 +1106,7 @@ void WERD_RES::InitNonPointers() {
x_height = 0.0;
caps_height = 0.0;
baseline_shift = 0.0f;
space_certainty = 0.0f;
guessed_x_ht = TRUE;
guessed_caps_ht = TRUE;
combination = FALSE;

View File

@ -295,6 +295,9 @@ class WERD_RES : public ELIST_LINK {
float x_height; // post match estimate
float caps_height; // post match estimate
float baseline_shift; // post match estimate.
// Certainty score for the spaces either side of this word (LSTM mode).
// MIN this value with the actual word certainty.
float space_certainty;
/*
To deal with fuzzy spaces we need to be able to combine "words" to form
@ -590,7 +593,7 @@ class WERD_RES : public ELIST_LINK {
// Creates a WERD_CHOICE for the word using the top choices from the leading
// diagonal of the ratings matrix.
void FakeWordFromRatings();
void FakeWordFromRatings(PermuterType permuter);
// Copies the best_choice strings to the correct_text for adaption/training.
void BestChoiceToCorrectText();

View File

@ -257,13 +257,21 @@ enum OcrEngineMode {
OEM_TESSERACT_ONLY, // Run Tesseract only - fastest
OEM_CUBE_ONLY, // Run Cube only - better accuracy, but slower
OEM_TESSERACT_CUBE_COMBINED, // Run both and combine results - best accuracy
OEM_DEFAULT // Specify this mode when calling init_*(),
OEM_DEFAULT, // Specify this mode when calling init_*(),
// to indicate that any of the above modes
// should be automatically inferred from the
// variables in the language-specific config,
// command-line configs, or if not specified
// in any of the above should be set to the
// default OEM_TESSERACT_ONLY.
// OEM_LSTM_ONLY will fall back (with a warning) to OEM_TESSERACT_ONLY where
// there is no network model available. This allows use of a mix of languages,
// some of which contain a network model, and some of which do not. Since the
// tesseract model is required for the LSTM to fall back to for "difficult"
// words anyway, this seems like a reasonable approach, but leaves the danger
// of not noticing that it is using the wrong engine if the warning is
// ignored.
OEM_LSTM_ONLY, // Run just the LSTM line recognizer.
};
} // namespace tesseract.

View File

@ -14,7 +14,7 @@ endif
include_HEADERS = \
basedir.h errcode.h fileerr.h genericvector.h helpers.h host.h memry.h \
ndminx.h params.h ocrclass.h platform.h serialis.h strngs.h \
tesscallback.h unichar.h unicharmap.h unicharset.h
tesscallback.h unichar.h unicharcompress.h unicharmap.h unicharset.h
noinst_HEADERS = \
ambigs.h bits16.h bitvector.h ccutil.h clst.h doubleptr.h elst2.h \
@ -38,7 +38,7 @@ libtesseract_ccutil_la_SOURCES = \
mainblk.cpp memry.cpp \
serialis.cpp strngs.cpp scanutils.cpp \
tessdatamanager.cpp tprintf.cpp \
unichar.cpp unicharmap.cpp unicharset.cpp unicodes.cpp \
unichar.cpp unicharcompress.cpp unicharmap.cpp unicharset.cpp unicodes.cpp \
params.cpp universalambigs.cpp
if T_WIN

View File

@ -108,6 +108,8 @@ class GenericHeap {
const Pair& PeekTop() const {
return heap_[0];
}
// Get the value of the worst (largest, defined by operator< ) element.
const Pair& PeekWorst() const { return heap_[IndexOfWorst()]; }
// Removes the top element of the heap. If entry is not NULL, the element
// is copied into *entry, otherwise it is discarded.
@ -136,22 +138,12 @@ class GenericHeap {
// not NULL, the element is copied into *entry, otherwise it is discarded.
// Time = O(n). Returns false if the heap was already empty.
bool PopWorst(Pair* entry) {
int heap_size = heap_.size();
if (heap_size == 0) return false; // It cannot be empty!
// Find the maximum element. Its index is guaranteed to be greater than
// the index of the parent of the last element, since by the heap invariant
// the parent must be less than or equal to the children.
int worst_index = heap_size - 1;
int end_parent = ParentNode(worst_index);
for (int i = worst_index - 1; i > end_parent; --i) {
if (heap_[worst_index] < heap_[i])
worst_index = i;
}
int worst_index = IndexOfWorst();
if (worst_index < 0) return false; // It cannot be empty!
// Extract the worst element from the heap, leaving a hole at worst_index.
if (entry != NULL)
*entry = heap_[worst_index];
--heap_size;
int heap_size = heap_.size() - 1;
if (heap_size > 0) {
// Sift the hole upwards to match the last element of the heap_
Pair hole_pair = heap_[heap_size];
@ -162,6 +154,22 @@ class GenericHeap {
return true;
}
// Returns the index of the worst element. Time = O(n/2).
int IndexOfWorst() const {
int heap_size = heap_.size();
if (heap_size == 0) return -1; // It cannot be empty!
// Find the maximum element. Its index is guaranteed to be greater than
// the index of the parent of the last element, since by the heap invariant
// the parent must be less than or equal to the children.
int worst_index = heap_size - 1;
int end_parent = ParentNode(worst_index);
for (int i = worst_index - 1; i > end_parent; --i) {
if (heap_[worst_index] < heap_[i]) worst_index = i;
}
return worst_index;
}
// The pointed-to Pair has changed its key value, so the location of pair
// is reshuffled to maintain the heap invariant.
// Must be a valid pointer to an element of the heap_!

View File

@ -174,6 +174,8 @@ class GenericVector {
// If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, FILE* fp);
bool DeSerialize(bool swap, tesseract::TFile* fp);
// Skips the deserialization of the vector.
static bool SkipDeSerialize(bool swap, tesseract::TFile* fp);
// Writes a vector of classes to the given file. Assumes the existence of
// bool T::Serialize(FILE* fp) const that returns false in case of error.
// Returns false in case of error.
@ -186,6 +188,8 @@ class GenericVector {
// If swap is true, assumes a big/little-endian swap is needed.
bool DeSerializeClasses(bool swap, FILE* fp);
bool DeSerializeClasses(bool swap, tesseract::TFile* fp);
// Calls SkipDeSerialize on the elements of the vector.
static bool SkipDeSerializeClasses(bool swap, tesseract::TFile* fp);
// Allocates a new array of double the current_size, copies over the
// information from data to the new location, deletes data and returns
@ -238,14 +242,13 @@ class GenericVector {
int binary_search(const T& target) const {
int bottom = 0;
int top = size_used_;
do {
while (top - bottom > 1) {
int middle = (bottom + top) / 2;
if (data_[middle] > target)
top = middle;
else
bottom = middle;
}
while (top - bottom > 1);
return bottom;
}
@ -361,7 +364,7 @@ inline bool LoadDataFromFile(const STRING& filename,
size_t size = ftell(fp);
fseek(fp, 0, SEEK_SET);
// Pad with a 0, just in case we treat the result as a string.
data->init_to_size((int)size + 1, 0);
data->init_to_size(static_cast<int>(size) + 1, 0);
bool result = fread(&(*data)[0], 1, size, fp) == size;
fclose(fp);
return result;
@ -556,11 +559,25 @@ class PointerVector : public GenericVector<T*> {
}
bool DeSerialize(bool swap, TFile* fp) {
inT32 reserved;
if (fp->FRead(&reserved, sizeof(reserved), 1) != 1) return false;
if (swap) Reverse32(&reserved);
if (!DeSerializeSize(swap, fp, &reserved)) return false;
GenericVector<T*>::reserve(reserved);
truncate(0);
for (int i = 0; i < reserved; ++i) {
if (!DeSerializeElement(swap, fp)) return false;
}
return true;
}
// Enables deserialization of a selection of elements. Note that in order to
// retain the integrity of the stream, the caller must call some combination
// of DeSerializeElement and DeSerializeSkip of the exact number returned in
// *size, assuming a true return.
static bool DeSerializeSize(bool swap, TFile* fp, inT32* size) {
if (fp->FRead(size, sizeof(*size), 1) != 1) return false;
if (swap) Reverse32(size);
return true;
}
// Reads and appends to the vector the next element of the serialization.
bool DeSerializeElement(bool swap, TFile* fp) {
inT8 non_null;
if (fp->FRead(&non_null, sizeof(non_null), 1) != 1) return false;
T* item = NULL;
@ -575,15 +592,21 @@ class PointerVector : public GenericVector<T*> {
// Null elements should keep their place in the vector.
this->push_back(NULL);
}
return true;
}
// Skips the next element of the serialization.
static bool DeSerializeSkip(bool swap, TFile* fp) {
inT8 non_null;
if (fp->FRead(&non_null, sizeof(non_null), 1) != 1) return false;
if (non_null) {
if (!T::SkipDeSerialize(swap, fp)) return false;
}
return true;
}
// Sorts the items pointed to by the members of this vector using
// t::operator<().
void sort() {
sort(&sort_ptr_cmp<T>);
}
void sort() { this->GenericVector<T*>::sort(&sort_ptr_cmp<T>); }
};
} // namespace tesseract
@ -926,6 +949,13 @@ bool GenericVector<T>::DeSerialize(bool swap, tesseract::TFile* fp) {
}
return true;
}
template <typename T>
bool GenericVector<T>::SkipDeSerialize(bool swap, tesseract::TFile* fp) {
inT32 reserved;
if (fp->FRead(&reserved, sizeof(reserved), 1) != 1) return false;
if (swap) Reverse32(&reserved);
return fp->FRead(NULL, sizeof(T), reserved) == reserved;
}
// Writes a vector of classes to the given file. Assumes the existence of
// bool T::Serialize(FILE* fp) const that returns false in case of error.
@ -976,6 +1006,16 @@ bool GenericVector<T>::DeSerializeClasses(bool swap, tesseract::TFile* fp) {
}
return true;
}
template <typename T>
bool GenericVector<T>::SkipDeSerializeClasses(bool swap, tesseract::TFile* fp) {
uinT32 reserved;
if (fp->FRead(&reserved, sizeof(reserved), 1) != 1) return false;
if (swap) Reverse32(&reserved);
for (int i = 0; i < reserved; ++i) {
if (!T::SkipDeSerialize(swap, fp)) return false;
}
return true;
}
// This method clear the current object, then, does a shallow copy of
// its argument, and finally invalidates its argument.

View File

@ -95,7 +95,7 @@ int TFile::FRead(void* buffer, int size, int count) {
char* char_buffer = reinterpret_cast<char*>(buffer);
if (data_->size() - offset_ < required_size)
required_size = data_->size() - offset_;
if (required_size > 0)
if (required_size > 0 && char_buffer != NULL)
memcpy(char_buffer, &(*data_)[offset_], required_size);
offset_ += required_size;
return required_size / size;

View File

@ -181,6 +181,14 @@ bool STRING::DeSerialize(bool swap, TFile* fp) {
return true;
}
// As DeSerialize, but only seeks past the data - hence a static method.
bool STRING::SkipDeSerialize(bool swap, tesseract::TFile* fp) {
inT32 len;
if (fp->FRead(&len, sizeof(len), 1) != 1) return false;
if (swap) ReverseN(&len, sizeof(len));
return fp->FRead(NULL, 1, len) == len;
}
BOOL8 STRING::contains(const char c) const {
return (c != '\0') && (strchr (GetCStr(), c) != NULL);
}

View File

@ -60,6 +60,8 @@ class TESS_API STRING
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, tesseract::TFile* fp);
// As DeSerialize, but only seeks past the data - hence a static method.
static bool SkipDeSerialize(bool swap, tesseract::TFile* fp);
BOOL8 contains(const char c) const;
inT32 length() const;

View File

@ -47,6 +47,10 @@ static const char kShapeTableFileSuffix[] = "shapetable";
static const char kBigramDawgFileSuffix[] = "bigram-dawg";
static const char kUnambigDawgFileSuffix[] = "unambig-dawg";
static const char kParamsModelFileSuffix[] = "params-model";
static const char kLSTMModelFileSuffix[] = "lstm";
static const char kLSTMPuncDawgFileSuffix[] = "lstm-punc-dawg";
static const char kLSTMSystemDawgFileSuffix[] = "lstm-word-dawg";
static const char kLSTMNumberDawgFileSuffix[] = "lstm-number-dawg";
namespace tesseract {
@ -68,6 +72,10 @@ enum TessdataType {
TESSDATA_BIGRAM_DAWG, // 14
TESSDATA_UNAMBIG_DAWG, // 15
TESSDATA_PARAMS_MODEL, // 16
TESSDATA_LSTM, // 17
TESSDATA_LSTM_PUNC_DAWG, // 18
TESSDATA_LSTM_SYSTEM_DAWG, // 19
TESSDATA_LSTM_NUMBER_DAWG, // 20
TESSDATA_NUM_ENTRIES
};
@ -76,7 +84,7 @@ enum TessdataType {
* kTessdataFileSuffixes[i] indicates the file suffix for
* tessdata of type i (from TessdataType enum).
*/
static const char * const kTessdataFileSuffixes[] = {
static const char *const kTessdataFileSuffixes[] = {
kLangConfigFileSuffix, // 0
kUnicharsetFileSuffix, // 1
kAmbigsFileSuffix, // 2
@ -94,6 +102,10 @@ static const char * const kTessdataFileSuffixes[] = {
kBigramDawgFileSuffix, // 14
kUnambigDawgFileSuffix, // 15
kParamsModelFileSuffix, // 16
kLSTMModelFileSuffix, // 17
kLSTMPuncDawgFileSuffix, // 18
kLSTMSystemDawgFileSuffix, // 19
kLSTMNumberDawgFileSuffix, // 20
};
/**
@ -118,6 +130,10 @@ static const bool kTessdataFileIsText[] = {
false, // 14
false, // 15
true, // 16
false, // 17
false, // 18
false, // 19
false, // 20
};
/**

439
ccutil/unicharcompress.cpp Normal file
View File

@ -0,0 +1,439 @@
///////////////////////////////////////////////////////////////////////
// File: unicharcompress.cpp
// Description: Unicode re-encoding using a sequence of smaller numbers in
// place of a single large code for CJK, similarly for Indic,
// and dissection of ligatures for other scripts.
// Author: Ray Smith
// Created: Wed Mar 04 14:45:01 PST 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
///////////////////////////////////////////////////////////////////////
#include "unicharcompress.h"
#include "tprintf.h"
namespace tesseract {
// String used to represent the null_id in direct_set.
const char* kNullChar = "<nul>";
// Local struct used only for processing the radical-stroke table.
struct RadicalStroke {
RadicalStroke() : num_strokes(0) {}
RadicalStroke(const STRING& r, int s) : radical(r), num_strokes(s) {}
bool operator==(const RadicalStroke& other) const {
return radical == other.radical && num_strokes == other.num_strokes;
}
// The radical is encoded as a string because its format is of an int with
// an optional ' mark to indicate a simplified shape. To treat these as
// distinct, we use a string and a UNICHARSET to do the integer mapping.
STRING radical;
// The number of strokes we treat as dense and just take the face value from
// the table.
int num_strokes;
};
// Hash functor for RadicalStroke.
struct RadicalStrokedHash {
size_t operator()(const RadicalStroke& rs) const {
size_t result = rs.num_strokes;
for (int i = 0; i < rs.radical.length(); ++i) {
result ^= rs.radical[i] << (6 * i + 8);
}
return result;
}
};
// A hash map to convert unicodes to radical,stroke pair.
typedef TessHashMap<int, RadicalStroke> RSMap;
// A hash map to count occurrences of each radical,stroke pair.
typedef TessHashMap<RadicalStroke, int, RadicalStrokedHash> RSCounts;
// Helper function builds the RSMap from the radical-stroke file, which has
// already been read into a STRING. Returns false on error.
// The radical_stroke_table is non-const because it gets split and the caller
// is unlikely to want to use it again.
static bool DecodeRadicalStrokeTable(STRING* radical_stroke_table,
RSMap* radical_map) {
GenericVector<STRING> lines;
radical_stroke_table->split('\n', &lines);
for (int i = 0; i < lines.size(); ++i) {
if (lines[i].length() == 0 || lines[i][0] == '#') continue;
int unicode, radical, strokes;
STRING str_radical;
if (sscanf(lines[i].string(), "%x\t%d.%d", &unicode, &radical, &strokes) ==
3) {
str_radical.add_str_int("", radical);
} else if (sscanf(lines[i].string(), "%x\t%d'.%d", &unicode, &radical,
&strokes) == 3) {
str_radical.add_str_int("'", radical);
} else {
tprintf("Invalid format in radical stroke table at line %d: %s\n", i,
lines[i].string());
return false;
}
(*radical_map)[unicode] = RadicalStroke(str_radical, strokes);
}
return true;
}
UnicharCompress::UnicharCompress() : code_range_(0) {}
UnicharCompress::UnicharCompress(const UnicharCompress& src) { *this = src; }
UnicharCompress::~UnicharCompress() { Cleanup(); }
UnicharCompress& UnicharCompress::operator=(const UnicharCompress& src) {
Cleanup();
encoder_ = src.encoder_;
code_range_ = src.code_range_;
SetupDecoder();
return *this;
}
// Computes the encoding for the given unicharset. It is a requirement that
// the file training/langdata/radical-stroke.txt have been read into the
// input string radical_stroke_table.
// Returns false if the encoding cannot be constructed.
bool UnicharCompress::ComputeEncoding(const UNICHARSET& unicharset, int null_id,
STRING* radical_stroke_table) {
RSMap radical_map;
if (!DecodeRadicalStrokeTable(radical_stroke_table, &radical_map))
return false;
encoder_.clear();
UNICHARSET direct_set;
UNICHARSET radicals;
// To avoid unused codes, clear the special codes from the unicharsets.
direct_set.clear();
radicals.clear();
// Always keep space as 0;
direct_set.unichar_insert(" ");
// Null char is next if we have one.
if (null_id >= 0) {
direct_set.unichar_insert(kNullChar);
}
RSCounts radical_counts;
// In the initial map, codes [0, unicharset.size()) are
// reserved for non-han/hangul sequences of 1 or more unicodes.
int hangul_offset = unicharset.size();
// Hangul takes the next range [hangul_offset, hangul_offset + kTotalJamos).
const int kTotalJamos = kLCount + kVCount + kTCount;
// Han takes the codes beyond hangul_offset + kTotalJamos. Since it is hard
// to measure the number of radicals and strokes, initially we use the same
// code range for all 3 Han code positions, and fix them after.
int han_offset = hangul_offset + kTotalJamos;
int max_num_strokes = -1;
for (int u = 0; u <= unicharset.size(); ++u) {
bool self_normalized = false;
// We special-case allow null_id to be equal to unicharset.size() in case
// there is no space in unicharset for it.
if (u == unicharset.size()) {
if (u == null_id) {
self_normalized = true;
} else {
break; // Finished.
}
} else {
self_normalized = strcmp(unicharset.id_to_unichar(u),
unicharset.get_normed_unichar(u)) == 0;
}
RecodedCharID code;
// Convert to unicodes.
GenericVector<int> unicodes;
if (u < unicharset.size() &&
UNICHAR::UTF8ToUnicode(unicharset.get_normed_unichar(u), &unicodes) &&
unicodes.size() == 1) {
// Check single unicodes for Hangul/Han and encode if so.
int unicode = unicodes[0];
int leading, vowel, trailing;
auto it = radical_map.find(unicode);
if (it != radical_map.end()) {
// This is Han. Convert to radical, stroke, index.
if (!radicals.contains_unichar(it->second.radical.string())) {
radicals.unichar_insert(it->second.radical.string());
}
int radical = radicals.unichar_to_id(it->second.radical.string());
int num_strokes = it->second.num_strokes;
int num_samples = radical_counts[it->second]++;
if (num_strokes > max_num_strokes) max_num_strokes = num_strokes;
code.Set3(radical + han_offset, num_strokes + han_offset,
num_samples + han_offset);
} else if (DecomposeHangul(unicode, &leading, &vowel, &trailing)) {
// This is Hangul. Since we know the exact size of each part at compile
// time, it gets the bottom set of codes.
code.Set3(leading + hangul_offset, vowel + kLCount + hangul_offset,
trailing + kLCount + kVCount + hangul_offset);
}
}
// If the code is still empty, it wasn't Han or Hangul.
if (code.length() == 0) {
// Special cases.
if (u == UNICHAR_SPACE) {
code.Set(0, 0); // Space.
} else if (u == null_id || (unicharset.has_special_codes() &&
u < SPECIAL_UNICHAR_CODES_COUNT)) {
code.Set(0, direct_set.unichar_to_id(kNullChar));
} else {
// Add the direct_set unichar-ids of the unicodes in sequence to the
// code.
for (int i = 0; i < unicodes.size(); ++i) {
int position = code.length();
if (position >= RecodedCharID::kMaxCodeLen) {
tprintf("Unichar %d=%s->%s is too long to encode!!\n", u,
unicharset.id_to_unichar(u),
unicharset.get_normed_unichar(u));
return false;
}
int uni = unicodes[i];
UNICHAR unichar(uni);
char* utf8 = unichar.utf8_str();
if (!direct_set.contains_unichar(utf8))
direct_set.unichar_insert(utf8);
code.Set(position, direct_set.unichar_to_id(utf8));
delete[] utf8;
if (direct_set.size() > unicharset.size()) {
// Code space got bigger!
tprintf("Code space expanded from original unicharset!!\n");
return false;
}
}
}
}
code.set_self_normalized(self_normalized);
encoder_.push_back(code);
}
// Now renumber Han to make all codes unique. We already added han_offset to
// all Han. Now separate out the radical, stroke, and count codes for Han.
// In the uniqued Han encoding, the 1st code uses the next radical_map.size()
// values, the 2nd code uses the next max_num_strokes+1 values, and the 3rd
// code uses the rest for the max number of duplicated radical/stroke combos.
int num_radicals = radicals.size();
for (int u = 0; u < unicharset.size(); ++u) {
RecodedCharID* code = &encoder_[u];
if ((*code)(0) >= han_offset) {
code->Set(1, (*code)(1) + num_radicals);
code->Set(2, (*code)(2) + num_radicals + max_num_strokes + 1);
}
}
DefragmentCodeValues(null_id >= 0 ? 1 : -1);
SetupDecoder();
return true;
}
// Sets up an encoder that doesn't change the unichars at all, so it just
// passes them through unchanged.
void UnicharCompress::SetupPassThrough(const UNICHARSET& unicharset) {
GenericVector<RecodedCharID> codes;
for (int u = 0; u < unicharset.size(); ++u) {
RecodedCharID code;
code.Set(0, u);
codes.push_back(code);
}
SetupDirect(codes);
}
// Sets up an encoder directly using the given encoding vector, which maps
// unichar_ids to the given codes.
void UnicharCompress::SetupDirect(const GenericVector<RecodedCharID>& codes) {
encoder_ = codes;
ComputeCodeRange();
SetupDecoder();
}
// Renumbers codes to eliminate unused values.
void UnicharCompress::DefragmentCodeValues(int encoded_null) {
// There may not be any Hangul, but even if there is, it is possible that not
// all codes are used. Likewise with the Han encoding, it is possible that not
// all numbers of strokes are used.
ComputeCodeRange();
GenericVector<int> offsets;
offsets.init_to_size(code_range_, 0);
// Find which codes are used
for (int c = 0; c < encoder_.size(); ++c) {
const RecodedCharID& code = encoder_[c];
for (int i = 0; i < code.length(); ++i) {
offsets[code(i)] = 1;
}
}
// Compute offsets based on code use.
int offset = 0;
for (int i = 0; i < offsets.size(); ++i) {
// If not used, decrement everything above here.
// We are moving encoded_null to the end, so it is not "used".
if (offsets[i] == 0 || i == encoded_null) {
--offset;
} else {
offsets[i] = offset;
}
}
if (encoded_null >= 0) {
// The encoded_null is moving to the end, for the benefit of TensorFlow,
// which is offsets.size() + offsets.back().
offsets[encoded_null] = offsets.size() + offsets.back() - encoded_null;
}
// Now apply the offsets.
for (int c = 0; c < encoder_.size(); ++c) {
RecodedCharID* code = &encoder_[c];
for (int i = 0; i < code->length(); ++i) {
int value = (*code)(i);
code->Set(i, value + offsets[value]);
}
}
ComputeCodeRange();
}
// Encodes a single unichar_id. Returns the length of the code, or zero if
// invalid input, and the encoding itself
int UnicharCompress::EncodeUnichar(int unichar_id, RecodedCharID* code) const {
if (unichar_id < 0 || unichar_id >= encoder_.size()) return 0;
*code = encoder_[unichar_id];
return code->length();
}
// Decodes code, returning the original unichar-id, or
// INVALID_UNICHAR_ID if the input is invalid.
int UnicharCompress::DecodeUnichar(const RecodedCharID& code) const {
int len = code.length();
if (len <= 0 || len > RecodedCharID::kMaxCodeLen) return INVALID_UNICHAR_ID;
auto it = decoder_.find(code);
if (it == decoder_.end()) return INVALID_UNICHAR_ID;
return it->second;
}
// Writes to the given file. Returns false in case of error.
bool UnicharCompress::Serialize(TFile* fp) const {
return encoder_.SerializeClasses(fp);
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool UnicharCompress::DeSerialize(bool swap, TFile* fp) {
if (!encoder_.DeSerializeClasses(swap, fp)) return false;
ComputeCodeRange();
SetupDecoder();
return true;
}
// Returns a STRING containing a text file that describes the encoding thus:
// <index>[,<index>]*<tab><UTF8-str><newline>
// In words, a comma-separated list of one or more indices, followed by a tab
// and the UTF-8 string that the code represents per line. Most simple scripts
// will encode a single index to a UTF8-string, but Chinese, Japanese, Korean
// and the Indic scripts will contain a many-to-many mapping.
// See the class comment above for details.
STRING UnicharCompress::GetEncodingAsString(
const UNICHARSET& unicharset) const {
STRING encoding;
for (int c = 0; c < encoder_.size(); ++c) {
const RecodedCharID& code = encoder_[c];
if (0 < c && c < SPECIAL_UNICHAR_CODES_COUNT && code == encoder_[c - 1]) {
// Don't show the duplicate entry.
continue;
}
encoding.add_str_int("", code(0));
for (int i = 1; i < code.length(); ++i) {
encoding.add_str_int(",", code(i));
}
encoding += "\t";
if (c >= unicharset.size() || (0 < c && c < SPECIAL_UNICHAR_CODES_COUNT &&
unicharset.has_special_codes())) {
encoding += kNullChar;
} else {
encoding += unicharset.id_to_unichar(c);
}
encoding += "\n";
}
return encoding;
}
// Helper decomposes a Hangul unicode to 3 parts, leading, vowel, trailing.
// Note that the returned values are 0-based indices, NOT unicode Jamo.
// Returns false if the input is not in the Hangul unicode range.
/* static */
bool UnicharCompress::DecomposeHangul(int unicode, int* leading, int* vowel,
int* trailing) {
if (unicode < kFirstHangul) return false;
int offset = unicode - kFirstHangul;
if (offset >= kNumHangul) return false;
const int kNCount = kVCount * kTCount;
*leading = offset / kNCount;
*vowel = (offset % kNCount) / kTCount;
*trailing = offset % kTCount;
return true;
}
// Computes the value of code_range_ from the encoder_.
void UnicharCompress::ComputeCodeRange() {
code_range_ = -1;
for (int c = 0; c < encoder_.size(); ++c) {
const RecodedCharID& code = encoder_[c];
for (int i = 0; i < code.length(); ++i) {
if (code(i) > code_range_) code_range_ = code(i);
}
}
++code_range_;
}
// Initializes the decoding hash_map from the encoding array.
void UnicharCompress::SetupDecoder() {
Cleanup();
is_valid_start_.init_to_size(code_range_, false);
for (int c = 0; c < encoder_.size(); ++c) {
const RecodedCharID& code = encoder_[c];
if (code.self_normalized() || decoder_.find(code) == decoder_.end())
decoder_[code] = c;
is_valid_start_[code(0)] = true;
RecodedCharID prefix = code;
int len = code.length() - 1;
prefix.Truncate(len);
auto final_it = final_codes_.find(prefix);
if (final_it == final_codes_.end()) {
GenericVectorEqEq<int>* code_list = new GenericVectorEqEq<int>;
code_list->push_back(code(len));
final_codes_[prefix] = code_list;
while (--len >= 0) {
prefix.Truncate(len);
auto next_it = next_codes_.find(prefix);
if (next_it == next_codes_.end()) {
GenericVectorEqEq<int>* code_list = new GenericVectorEqEq<int>;
code_list->push_back(code(len));
next_codes_[prefix] = code_list;
} else {
// We still have to search the list as we may get here via multiple
// lengths of code.
if (!next_it->second->contains(code(len)))
next_it->second->push_back(code(len));
break; // This prefix has been processed.
}
}
} else {
if (!final_it->second->contains(code(len)))
final_it->second->push_back(code(len));
}
}
}
// Frees allocated memory.
void UnicharCompress::Cleanup() {
decoder_.clear();
is_valid_start_.clear();
for (auto it = next_codes_.begin(); it != next_codes_.end(); ++it) {
delete it->second;
}
for (auto it = final_codes_.begin(); it != final_codes_.end(); ++it) {
delete it->second;
}
next_codes_.clear();
final_codes_.clear();
}
} // namespace tesseract.

258
ccutil/unicharcompress.h Normal file
View File

@ -0,0 +1,258 @@
///////////////////////////////////////////////////////////////////////
// File: unicharcompress.h
// Description: Unicode re-encoding using a sequence of smaller numbers in
// place of a single large code for CJK, similarly for Indic,
// and dissection of ligatures for other scripts.
// Author: Ray Smith
// Created: Wed Mar 04 14:45:01 PST 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_CCUTIL_UNICHARCOMPRESS_H_
#define TESSERACT_CCUTIL_UNICHARCOMPRESS_H_
#include "hashfn.h"
#include "serialis.h"
#include "strngs.h"
#include "unicharset.h"
namespace tesseract {
// Trivial class to hold the code for a recoded unichar-id.
class RecodedCharID {
public:
// The maximum length of a code.
static const int kMaxCodeLen = 9;
RecodedCharID() : self_normalized_(0), length_(0) {
memset(code_, 0, sizeof(code_));
}
void Truncate(int length) { length_ = length; }
// Sets the code value at the given index in the code.
void Set(int index, int value) {
code_[index] = value;
if (length_ <= index) length_ = index + 1;
}
// Shorthand for setting codes of length 3, as all Hangul and Han codes are
// length 3.
void Set3(int code0, int code1, int code2) {
length_ = 3;
code_[0] = code0;
code_[1] = code1;
code_[2] = code2;
}
// Accessors
bool self_normalized() const { return self_normalized_ != 0; }
void set_self_normalized(bool value) { self_normalized_ = value; }
int length() const { return length_; }
int operator()(int index) const { return code_[index]; }
// Writes to the given file. Returns false in case of error.
bool Serialize(TFile* fp) const {
if (fp->FWrite(&self_normalized_, sizeof(self_normalized_), 1) != 1)
return false;
if (fp->FWrite(&length_, sizeof(length_), 1) != 1) return false;
if (fp->FWrite(code_, sizeof(code_[0]), length_) != length_) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, TFile* fp) {
if (fp->FRead(&self_normalized_, sizeof(self_normalized_), 1) != 1)
return false;
if (fp->FRead(&length_, sizeof(length_), 1) != 1) return false;
if (swap) ReverseN(&length_, sizeof(length_));
if (fp->FRead(code_, sizeof(code_[0]), length_) != length_) return false;
if (swap) {
for (int i = 0; i < length_; ++i) {
ReverseN(&code_[i], sizeof(code_[i]));
}
}
return true;
}
bool operator==(const RecodedCharID& other) const {
if (length_ != other.length_) return false;
for (int i = 0; i < length_; ++i) {
if (code_[i] != other.code_[i]) return false;
}
return true;
}
// Hash functor for RecodedCharID.
struct RecodedCharIDHash {
size_t operator()(const RecodedCharID& code) const {
size_t result = 0;
for (int i = 0; i < code.length_; ++i) {
result ^= code(i) << (7 * i);
}
return result;
}
};
private:
// True if this code is self-normalizing, ie is the master entry for indices
// that map to the same code. Has boolean value, but inT8 for serialization.
inT8 self_normalized_;
// The number of elements in use in code_;
inT32 length_;
// The re-encoded form of the unichar-id to which this RecodedCharID relates.
inT32 code_[kMaxCodeLen];
};
// Class holds a "compression" of a unicharset to simplify the learning problem
// for a neural-network-based classifier.
// Objectives:
// 1 (CJK): Ids of a unicharset with a large number of classes are expressed as
// a sequence of 3 codes with much fewer values.
// This is achieved using the Jamo coding for Hangul and the Unicode
// Radical-Stroke-index for Han.
// 2 (Indic): Instead of thousands of codes with one for each grapheme, re-code
// as the unicode sequence (but coded in a more compact space).
// 3 (the rest): Eliminate multi-path problems with ligatures and fold confusing
// and not significantly distinct shapes (quotes) togther, ie
// represent the fi ligature as the f-i pair, and fold u+2019 and
// friends all onto ascii single '
// 4 The null character and mapping to target activations:
// To save horizontal coding space, the compressed codes are generally mapped
// to target network activations without intervening null characters, BUT
// in the case of ligatures, such as ff, null characters have to be included
// so existence of repeated codes is detected at codebook-building time, and
// null characters are embedded directly into the codes, so the rest of the
// system doesn't need to worry about the problem (much). There is still an
// effect on the range of ways in which the target activations can be
// generated.
//
// The computed code values are compact (no unused values), and, for CJK,
// unique (each code position uses a disjoint set of values from each other code
// position). For non-CJK, the same code value CAN be used in multiple
// positions, eg the ff ligature is converted to <f> <nullchar> <f>, where <f>
// is the same code as is used for the single f.
// NOTE that an intended consequence of using the normalized text from the
// unicharset is that the fancy quotes all map to a single code, so round-trip
// conversion doesn't work for all unichar-ids.
class UnicharCompress {
public:
UnicharCompress();
UnicharCompress(const UnicharCompress& src);
~UnicharCompress();
UnicharCompress& operator=(const UnicharCompress& src);
// The 1st Hangul unicode.
static const int kFirstHangul = 0xac00;
// The number of Hangul unicodes.
static const int kNumHangul = 11172;
// The number of Jamos for each of the 3 parts of a Hangul character, being
// the Leading consonant, Vowel and Trailing consonant.
static const int kLCount = 19;
static const int kVCount = 21;
static const int kTCount = 28;
// Computes the encoding for the given unicharset. It is a requirement that
// the file training/langdata/radical-stroke.txt have been read into the
// input string radical_stroke_table.
// Returns false if the encoding cannot be constructed.
bool ComputeEncoding(const UNICHARSET& unicharset, int null_id,
STRING* radical_stroke_table);
// Sets up an encoder that doesn't change the unichars at all, so it just
// passes them through unchanged.
void SetupPassThrough(const UNICHARSET& unicharset);
// Sets up an encoder directly using the given encoding vector, which maps
// unichar_ids to the given codes.
void SetupDirect(const GenericVector<RecodedCharID>& codes);
// Returns the number of different values that can be used in a code, ie
// 1 + the maximum value that will ever be used by an RecodedCharID code in
// any position in its array.
int code_range() const { return code_range_; }
// Encodes a single unichar_id. Returns the length of the code, (or zero if
// invalid input), and the encoding itself in code.
int EncodeUnichar(int unichar_id, RecodedCharID* code) const;
// Decodes code, returning the original unichar-id, or
// INVALID_UNICHAR_ID if the input is invalid. Note that this is not a perfect
// inverse of EncodeUnichar, since the unichar-id of U+2019 (curly single
// quote), for example, will have the same encoding as the unichar-id of
// U+0027 (ascii '). The foldings are obtained from the input unicharset,
// which in turn obtains them from NormalizeUTF8String in normstrngs.cpp,
// and include NFKC normalization plus others like quote and dash folding.
int DecodeUnichar(const RecodedCharID& code) const;
// Returns true if the given code is a valid start or single code.
bool IsValidFirstCode(int code) const { return is_valid_start_[code]; }
// Returns a list of valid non-final next codes for a given prefix code,
// which may be empty.
const GenericVector<int>* GetNextCodes(const RecodedCharID& code) const {
auto it = next_codes_.find(code);
return it == next_codes_.end() ? NULL : it->second;
}
// Returns a list of valid final codes for a given prefix code, which may
// be empty.
const GenericVector<int>* GetFinalCodes(const RecodedCharID& code) const {
auto it = final_codes_.find(code);
return it == final_codes_.end() ? NULL : it->second;
}
// Writes to the given file. Returns false in case of error.
bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, TFile* fp);
// Returns a STRING containing a text file that describes the encoding thus:
// <index>[,<index>]*<tab><UTF8-str><newline>
// In words, a comma-separated list of one or more indices, followed by a tab
// and the UTF-8 string that the code represents per line. Most simple scripts
// will encode a single index to a UTF8-string, but Chinese, Japanese, Korean
// and the Indic scripts will contain a many-to-many mapping.
// See the class comment above for details.
STRING GetEncodingAsString(const UNICHARSET& unicharset) const;
// Helper decomposes a Hangul unicode to 3 parts, leading, vowel, trailing.
// Note that the returned values are 0-based indices, NOT unicode Jamo.
// Returns false if the input is not in the Hangul unicode range.
static bool DecomposeHangul(int unicode, int* leading, int* vowel,
int* trailing);
private:
// Renumbers codes to eliminate unused values.
void DefragmentCodeValues(int encoded_null);
// Computes the value of code_range_ from the encoder_.
void ComputeCodeRange();
// Initializes the decoding hash_map from the encoder_ array.
void SetupDecoder();
// Frees allocated memory.
void Cleanup();
// The encoder that maps a unichar-id to a sequence of small codes.
// encoder_ is the only part that is serialized. The rest is computed on load.
GenericVector<RecodedCharID> encoder_;
// Decoder converts the output of encoder back to a unichar-id.
TessHashMap<RecodedCharID, int, RecodedCharID::RecodedCharIDHash> decoder_;
// True if the index is a valid single or start code.
GenericVector<bool> is_valid_start_;
// Maps a prefix code to a list of valid next codes.
// The map owns the vectors.
TessHashMap<RecodedCharID, GenericVectorEqEq<int>*,
RecodedCharID::RecodedCharIDHash>
next_codes_;
// Maps a prefix code to a list of valid final codes.
// The map owns the vectors.
TessHashMap<RecodedCharID, GenericVectorEqEq<int>*,
RecodedCharID::RecodedCharIDHash>
final_codes_;
// Max of any value in encoder_ + 1.
int code_range_;
};
} // namespace tesseract.
#endif // TESSERACT_CCUTIL_UNICHARCOMPRESS_H_

View File

@ -906,6 +906,8 @@ void UNICHARSET::post_load_setup() {
han_sid_ = get_script_id_from_name("Han");
hiragana_sid_ = get_script_id_from_name("Hiragana");
katakana_sid_ = get_script_id_from_name("Katakana");
thai_sid_ = get_script_id_from_name("Thai");
hangul_sid_ = get_script_id_from_name("Hangul");
// Compute default script. Use the highest-counting alpha script, that is
// not the common script, as that still contains some "alphas".

View File

@ -290,6 +290,8 @@ class UNICHARSET {
han_sid_ = 0;
hiragana_sid_ = 0;
katakana_sid_ = 0;
thai_sid_ = 0;
hangul_sid_ = 0;
}
// Return the size of the set (the number of different UNICHAR it holds).
@ -604,6 +606,16 @@ class UNICHARSET {
return unichars[unichar_id].properties.AnyRangeEmpty();
}
// Returns true if the script of the given id is space delimited.
// Returns false for Han and Thai scripts.
bool IsSpaceDelimited(UNICHAR_ID unichar_id) const {
if (INVALID_UNICHAR_ID == unichar_id) return true;
int script_id = get_script(unichar_id);
return script_id != han_sid_ && script_id != thai_sid_ &&
script_id != hangul_sid_ && script_id != hiragana_sid_ &&
script_id != katakana_sid_;
}
// Return the script name of the given unichar.
// The returned pointer will always be the same for the same script, it's
// managed by unicharset and thus MUST NOT be deleted
@ -773,7 +785,7 @@ class UNICHARSET {
// Returns normalized version of unichar with the given unichar_id.
const char *get_normed_unichar(UNICHAR_ID unichar_id) const {
if (unichar_id == UNICHAR_SPACE && has_special_codes()) return " ";
if (unichar_id == UNICHAR_SPACE) return " ";
return unichars[unichar_id].properties.normed.string();
}
// Returns a vector of UNICHAR_IDs that represent the ids of the normalized
@ -835,6 +847,8 @@ class UNICHARSET {
int han_sid() const { return han_sid_; }
int hiragana_sid() const { return hiragana_sid_; }
int katakana_sid() const { return katakana_sid_; }
int thai_sid() const { return thai_sid_; }
int hangul_sid() const { return hangul_sid_; }
int default_sid() const { return default_sid_; }
// Returns true if the unicharset has the concept of upper/lower case.
@ -977,6 +991,8 @@ class UNICHARSET {
int han_sid_;
int hiragana_sid_;
int katakana_sid_;
int thai_sid_;
int hangul_sid_;
// The most frequently occurring script in the charset.
int default_sid_;
};

View File

@ -6,7 +6,7 @@
# Initialization
# ----------------------------------------
AC_PREREQ([2.50])
AC_INIT([tesseract], [3.05.00dev], [https://github.com/tesseract-ocr/tesseract/issues])
AC_INIT([tesseract], [4.00.00dev], [https://github.com/tesseract-ocr/tesseract/issues])
AC_PROG_CXX([g++ clang++])
AC_LANG([C++])
AC_LANG_COMPILER_REQUIRE
@ -18,8 +18,8 @@ AC_PREFIX_DEFAULT([/usr/local])
# Define date of package, etc. Could be useful in auto-generated
# documentation.
PACKAGE_YEAR=2015
PACKAGE_DATE="07/11"
PACKAGE_YEAR=2016
PACKAGE_DATE="11/11"
abs_top_srcdir=`AS_DIRNAME([$0])`
gitrev="`git --git-dir=${abs_top_srcdir}/.git --work-tree=${abs_top_srcdir} describe --always --tags`"
@ -42,8 +42,8 @@ AC_SUBST([PACKAGE_DATE])
GENERIC_LIBRARY_NAME=tesseract
# Release versioning
GENERIC_MAJOR_VERSION=3
GENERIC_MINOR_VERSION=4
GENERIC_MAJOR_VERSION=4
GENERIC_MINOR_VERSION=0
GENERIC_MICRO_VERSION=0
# API version (often = GENERIC_MAJOR_VERSION.GENERIC_MINOR_VERSION)
@ -520,6 +520,7 @@ fi
# Output files
AC_CONFIG_FILES([Makefile tesseract.pc])
AC_CONFIG_FILES([api/Makefile])
AC_CONFIG_FILES([arch/Makefile])
AC_CONFIG_FILES([ccmain/Makefile])
AC_CONFIG_FILES([opencl/Makefile])
AC_CONFIG_FILES([ccstruct/Makefile])
@ -528,6 +529,7 @@ AC_CONFIG_FILES([classify/Makefile])
AC_CONFIG_FILES([cube/Makefile])
AC_CONFIG_FILES([cutil/Makefile])
AC_CONFIG_FILES([dict/Makefile])
AC_CONFIG_FILES([lstm/Makefile])
AC_CONFIG_FILES([neural_networks/runtime/Makefile])
AC_CONFIG_FILES([textord/Makefile])
AC_CONFIG_FILES([viewer/Makefile])

View File

@ -401,7 +401,6 @@ LIST s_adjoin(LIST var_list, void *variable, int_compare compare) {
return (push_last (var_list, variable));
}
/**********************************************************************
* s e a r c h
*

View File

@ -69,14 +69,17 @@ Dawg *DawgLoader::Load() {
PermuterType perm_type;
switch (tessdata_dawg_type_) {
case TESSDATA_PUNC_DAWG:
case TESSDATA_LSTM_PUNC_DAWG:
dawg_type = DAWG_TYPE_PUNCTUATION;
perm_type = PUNC_PERM;
break;
case TESSDATA_SYSTEM_DAWG:
case TESSDATA_LSTM_SYSTEM_DAWG:
dawg_type = DAWG_TYPE_WORD;
perm_type = SYSTEM_DAWG_PERM;
break;
case TESSDATA_NUMBER_DAWG:
case TESSDATA_LSTM_NUMBER_DAWG:
dawg_type = DAWG_TYPE_NUMBER;
perm_type = NUMBER_PERM;
break;

View File

@ -202,10 +202,8 @@ DawgCache *Dict::GlobalDawgCache() {
return &cache;
}
void Dict::Load(DawgCache *dawg_cache) {
STRING name;
STRING &lang = getCCUtil()->lang;
// Sets up ready for a Load or LoadLSTM.
void Dict::SetupForLoad(DawgCache *dawg_cache) {
if (dawgs_.length() != 0) this->End();
apostrophe_unichar_id_ = getUnicharset().unichar_to_id(kApostropheSymbol);
@ -220,10 +218,10 @@ void Dict::Load(DawgCache *dawg_cache) {
dawg_cache_ = new DawgCache();
dawg_cache_is_ours_ = true;
}
}
TessdataManager &tessdata_manager = getCCUtil()->tessdata_manager;
const char *data_file_name = tessdata_manager.GetDataFileName().string();
// Loads the dawgs needed by Tesseract. Call FinishLoad() after.
void Dict::Load(const char *data_file_name, const STRING &lang) {
// Load dawgs_.
if (load_punc_dawg) {
punc_dawg_ = dawg_cache_->GetSquishedDawg(
@ -255,6 +253,7 @@ void Dict::Load(DawgCache *dawg_cache) {
if (unambig_dawg_) dawgs_ += unambig_dawg_;
}
STRING name;
if (((STRING &)user_words_suffix).length() > 0 ||
((STRING &)user_words_file).length() > 0) {
Trie *trie_ptr = new Trie(DAWG_TYPE_WORD, lang, USER_DAWG_PERM,
@ -300,8 +299,33 @@ void Dict::Load(DawgCache *dawg_cache) {
// This dawg is temporary and should not be searched by letter_is_ok.
pending_words_ = new Trie(DAWG_TYPE_WORD, lang, NO_PERM,
getUnicharset().size(), dawg_debug_level);
}
// Construct a list of corresponding successors for each dawg. Each entry i
// Loads the dawgs needed by the LSTM model. Call FinishLoad() after.
void Dict::LoadLSTM(const char *data_file_name, const STRING &lang) {
// Load dawgs_.
if (load_punc_dawg) {
punc_dawg_ = dawg_cache_->GetSquishedDawg(
lang, data_file_name, TESSDATA_LSTM_PUNC_DAWG, dawg_debug_level);
if (punc_dawg_) dawgs_ += punc_dawg_;
}
if (load_system_dawg) {
Dawg *system_dawg = dawg_cache_->GetSquishedDawg(
lang, data_file_name, TESSDATA_LSTM_SYSTEM_DAWG, dawg_debug_level);
if (system_dawg) dawgs_ += system_dawg;
}
if (load_number_dawg) {
Dawg *number_dawg = dawg_cache_->GetSquishedDawg(
lang, data_file_name, TESSDATA_LSTM_NUMBER_DAWG, dawg_debug_level);
if (number_dawg) dawgs_ += number_dawg;
}
}
// Completes the loading process after Load() and/or LoadLSTM().
// Returns false if no dictionaries were loaded.
bool Dict::FinishLoad() {
if (dawgs_.empty()) return false;
// Construct a list of corresponding successors for each dawg. Each entry, i,
// in the successors_ vector is a vector of integers that represent the
// indices into the dawgs_ vector of the successors for dawg i.
successors_.reserve(dawgs_.length());
@ -316,6 +340,7 @@ void Dict::Load(DawgCache *dawg_cache) {
}
successors_ += lst;
}
return true;
}
void Dict::End() {
@ -368,6 +393,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
// Initialization.
PermuterType curr_perm = NO_PERM;
dawg_args->updated_dawgs->clear();
dawg_args->valid_end = false;
// Go over the active_dawgs vector and insert DawgPosition records
// with the updated ref (an edge with the corresponding unichar id) into
@ -405,6 +431,9 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
dawg_debug_level > 0,
"Append transition from punc dawg to current dawgs: ");
if (sdawg->permuter() > curr_perm) curr_perm = sdawg->permuter();
if (sdawg->end_of_word(dawg_edge) &&
punc_dawg->end_of_word(punc_transition_edge))
dawg_args->valid_end = true;
}
}
}
@ -419,6 +448,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
dawg_debug_level > 0,
"Extend punctuation dawg: ");
if (PUNC_PERM > curr_perm) curr_perm = PUNC_PERM;
if (punc_dawg->end_of_word(punc_edge)) dawg_args->valid_end = true;
}
continue;
}
@ -436,6 +466,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
dawg_debug_level > 0,
"Return to punctuation dawg: ");
if (dawg->permuter() > curr_perm) curr_perm = dawg->permuter();
if (punc_dawg->end_of_word(punc_edge)) dawg_args->valid_end = true;
}
}
@ -445,8 +476,8 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
// possible edges, not only for the exact unichar_id, but also
// for all its character classes (alpha, digit, etc).
if (dawg->type() == DAWG_TYPE_PATTERN) {
ProcessPatternEdges(dawg, pos, unichar_id, word_end,
dawg_args->updated_dawgs, &curr_perm);
ProcessPatternEdges(dawg, pos, unichar_id, word_end, dawg_args,
&curr_perm);
// There can't be any successors to dawg that is of type
// DAWG_TYPE_PATTERN, so we are done examining this DawgPosition.
continue;
@ -473,6 +504,9 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
continue;
}
if (dawg->permuter() > curr_perm) curr_perm = dawg->permuter();
if (dawg->end_of_word(edge) &&
(punc_dawg == NULL || punc_dawg->end_of_word(pos.punc_ref)))
dawg_args->valid_end = true;
dawg_args->updated_dawgs->add_unique(
DawgPosition(pos.dawg_index, edge, pos.punc_index, pos.punc_ref,
false),
@ -497,7 +531,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
void Dict::ProcessPatternEdges(const Dawg *dawg, const DawgPosition &pos,
UNICHAR_ID unichar_id, bool word_end,
DawgPositionVector *updated_dawgs,
DawgArgs *dawg_args,
PermuterType *curr_perm) const {
NODE_REF node = GetStartingNode(dawg, pos.dawg_ref);
// Try to find the edge corresponding to the exact unichar_id and to all the
@ -520,7 +554,8 @@ void Dict::ProcessPatternEdges(const Dawg *dawg, const DawgPosition &pos,
tprintf("Letter found in pattern dawg %d\n", pos.dawg_index);
}
if (dawg->permuter() > *curr_perm) *curr_perm = dawg->permuter();
updated_dawgs->add_unique(
if (dawg->end_of_word(edge)) dawg_args->valid_end = true;
dawg_args->updated_dawgs->add_unique(
DawgPosition(pos.dawg_index, edge, pos.punc_index, pos.punc_ref,
pos.back_to_punc),
dawg_debug_level > 0,
@ -816,5 +851,13 @@ bool Dict::valid_punctuation(const WERD_CHOICE &word) {
return false;
}
/// Returns true if the language is space-delimited (not CJ, or T).
bool Dict::IsSpaceDelimitedLang() const {
const UNICHARSET &u_set = getUnicharset();
if (u_set.han_sid() > 0) return false;
if (u_set.katakana_sid() > 0) return false;
if (u_set.thai_sid() > 0) return false;
return true;
}
} // namespace tesseract

View File

@ -23,7 +23,6 @@
#include "dawg.h"
#include "dawg_cache.h"
#include "host.h"
#include "oldlist.h"
#include "ratngs.h"
#include "stopper.h"
#include "trie.h"
@ -76,11 +75,13 @@ enum XHeightConsistencyEnum {XH_GOOD, XH_SUBNORMAL, XH_INCONSISTENT};
struct DawgArgs {
DawgArgs(DawgPositionVector *d, DawgPositionVector *up, PermuterType p)
: active_dawgs(d), updated_dawgs(up), permuter(p) {}
: active_dawgs(d), updated_dawgs(up), permuter(p), valid_end(false) {}
DawgPositionVector *active_dawgs;
DawgPositionVector *updated_dawgs;
PermuterType permuter;
// True if the current position is a valid word end.
bool valid_end;
};
class Dict {
@ -294,7 +295,15 @@ class Dict {
/// Initialize Dict class - load dawgs from [lang].traineddata and
/// user-specified wordlist and parttern list.
static DawgCache *GlobalDawgCache();
void Load(DawgCache *dawg_cache);
// Sets up ready for a Load or LoadLSTM.
void SetupForLoad(DawgCache *dawg_cache);
// Loads the dawgs needed by Tesseract. Call FinishLoad() after.
void Load(const char *data_file_name, const STRING &lang);
// Loads the dawgs needed by the LSTM model. Call FinishLoad() after.
void LoadLSTM(const char *data_file_name, const STRING &lang);
// Completes the loading process after Load() and/or LoadLSTM().
// Returns false if no dictionaries were loaded.
bool FinishLoad();
void End();
// Resets the document dictionary analogous to ResetAdaptiveClassifier.
@ -397,9 +406,7 @@ class Dict {
}
inline void SetWildcardID(UNICHAR_ID id) { wildcard_unichar_id_ = id; }
inline UNICHAR_ID WildcardID() const {
return wildcard_unichar_id_;
}
inline UNICHAR_ID WildcardID() const { return wildcard_unichar_id_; }
/// Return the number of dawgs in the dawgs_ vector.
inline int NumDawgs() const { return dawgs_.size(); }
/// Return i-th dawg pointer recorded in the dawgs_ vector.
@ -436,7 +443,7 @@ class Dict {
/// edges were found.
void ProcessPatternEdges(const Dawg *dawg, const DawgPosition &info,
UNICHAR_ID unichar_id, bool word_end,
DawgPositionVector *updated_dawgs,
DawgArgs *dawg_args,
PermuterType *current_permuter) const;
/// Read/Write/Access special purpose dawgs which contain words
@ -483,6 +490,8 @@ class Dict {
inline void SetWordsegRatingAdjustFactor(float f) {
wordseg_rating_adjust_factor_ = f;
}
/// Returns true if the language is space-delimited (not CJ, or T).
bool IsSpaceDelimitedLang() const;
private:
/** Private member variables. */

39
lstm/Makefile.am Normal file
View File

@ -0,0 +1,39 @@
AM_CPPFLAGS += \
-I$(top_srcdir)/ccutil -I$(top_srcdir)/cutil -I$(top_srcdir)/ccstruct \
-I$(top_srcdir)/arch -I$(top_srcdir)/viewer -I$(top_srcdir)/classify \
-I$(top_srcdir)/dict -I$(top_srcdir)/lstm
AUTOMAKE_OPTIONS = subdir-objects
SUBDIRS =
AM_CXXFLAGS = -fopenmp
if !NO_TESSDATA_PREFIX
AM_CXXFLAGS += -DTESSDATA_PREFIX=@datadir@/
endif
if VISIBILITY
AM_CXXFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden
AM_CPPFLAGS += -DTESS_EXPORTS
endif
include_HEADERS = \
convolve.h ctc.h fullyconnected.h functions.h input.h \
lstm.h lstmrecognizer.h lstmtrainer.h maxpool.h \
networkbuilder.h network.h networkio.h networkscratch.h \
parallel.h plumbing.h recodebeam.h reconfig.h reversed.h \
series.h static_shape.h stridemap.h tfnetwork.h weightmatrix.h
noinst_HEADERS =
if !USING_MULTIPLELIBS
noinst_LTLIBRARIES = libtesseract_lstm.la
else
lib_LTLIBRARIES = libtesseract_lstm.la
libtesseract_lstm_la_LDFLAGS = -version-info $(GENERIC_LIBRARY_VERSION)
endif
libtesseract_lstm_la_SOURCES = \
convolve.cpp ctc.cpp fullyconnected.cpp functions.cpp input.cpp \
lstm.cpp lstmrecognizer.cpp lstmtrainer.cpp maxpool.cpp \
networkbuilder.cpp network.cpp networkio.cpp \
parallel.cpp plumbing.cpp recodebeam.cpp reconfig.cpp reversed.cpp \
series.cpp stridemap.cpp tfnetwork.cpp weightmatrix.cpp

124
lstm/convolve.cpp Normal file
View File

@ -0,0 +1,124 @@
///////////////////////////////////////////////////////////////////////
// File: convolve.cpp
// Description: Convolutional layer that stacks the inputs over its rectangle
// and pulls in random data to fill out-of-input inputs.
// Output is therefore same size as its input, but deeper.
// Author: Ray Smith
// Created: Tue Mar 18 16:56:06 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "convolve.h"
#include "networkscratch.h"
#include "serialis.h"
namespace tesseract {
Convolve::Convolve(const STRING& name, int ni, int half_x, int half_y)
: Network(NT_CONVOLVE, name, ni, ni * (2*half_x + 1) * (2*half_y + 1)),
half_x_(half_x), half_y_(half_y) {
}
Convolve::~Convolve() {
}
// Writes to the given file. Returns false in case of error.
bool Convolve::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
if (fp->FWrite(&half_x_, sizeof(half_x_), 1) != 1) return false;
if (fp->FWrite(&half_y_, sizeof(half_y_), 1) != 1) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool Convolve::DeSerialize(bool swap, TFile* fp) {
if (fp->FRead(&half_x_, sizeof(half_x_), 1) != 1) return false;
if (fp->FRead(&half_y_, sizeof(half_y_), 1) != 1) return false;
if (swap) {
ReverseN(&half_x_, sizeof(half_x_));
ReverseN(&half_y_, sizeof(half_y_));
}
no_ = ni_ * (2*half_x_ + 1) * (2*half_y_ + 1);
return true;
}
// Runs forward propagation of activations on the input line.
// See NetworkCpp for a detailed discussion of the arguments.
void Convolve::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
output->Resize(input, no_);
int y_scale = 2 * half_y_ + 1;
StrideMap::Index dest_index(output->stride_map());
do {
// Stack x_scale groups of y_scale * ni_ inputs together.
int t = dest_index.t();
int out_ix = 0;
for (int x = -half_x_; x <= half_x_; ++x, out_ix += y_scale * ni_) {
StrideMap::Index x_index(dest_index);
if (!x_index.AddOffset(x, FD_WIDTH)) {
// This x is outside the image.
output->Randomize(t, out_ix, y_scale * ni_, randomizer_);
} else {
int out_iy = out_ix;
for (int y = -half_y_; y <= half_y_; ++y, out_iy += ni_) {
StrideMap::Index y_index(x_index);
if (!y_index.AddOffset(y, FD_HEIGHT)) {
// This y is outside the image.
output->Randomize(t, out_iy, ni_, randomizer_);
} else {
output->CopyTimeStepGeneral(t, out_iy, ni_, input, y_index.t(), 0);
}
}
}
}
} while (dest_index.Increment());
if (debug) DisplayForward(*output);
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool Convolve::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
back_deltas->Resize(fwd_deltas, ni_);
NetworkScratch::IO delta_sum;
delta_sum.ResizeFloat(fwd_deltas, ni_, scratch);
delta_sum->Zero();
int y_scale = 2 * half_y_ + 1;
StrideMap::Index src_index(fwd_deltas.stride_map());
do {
// Stack x_scale groups of y_scale * ni_ inputs together.
int t = src_index.t();
int out_ix = 0;
for (int x = -half_x_; x <= half_x_; ++x, out_ix += y_scale * ni_) {
StrideMap::Index x_index(src_index);
if (x_index.AddOffset(x, FD_WIDTH)) {
int out_iy = out_ix;
for (int y = -half_y_; y <= half_y_; ++y, out_iy += ni_) {
StrideMap::Index y_index(x_index);
if (y_index.AddOffset(y, FD_HEIGHT)) {
fwd_deltas.AddTimeStepPart(t, out_iy, ni_,
delta_sum->f(y_index.t()));
}
}
}
}
} while (src_index.Increment());
back_deltas->CopyWithNormalization(*delta_sum, fwd_deltas);
return true;
}
} // namespace tesseract.

74
lstm/convolve.h Normal file
View File

@ -0,0 +1,74 @@
///////////////////////////////////////////////////////////////////////
// File: convolve.h
// Description: Convolutional layer that stacks the inputs over its rectangle
// and pulls in random data to fill out-of-input inputs.
// Output is therefore same size as its input, but deeper.
// Author: Ray Smith
// Created: Tue Mar 18 16:45:34 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_CONVOLVE_H_
#define TESSERACT_LSTM_CONVOLVE_H_
#include "genericvector.h"
#include "matrix.h"
#include "network.h"
namespace tesseract {
// Makes each time-step deeper by stacking inputs over its rectangle. Does not
// affect the size of its input. Achieves this by bringing in random values in
// out-of-input areas.
class Convolve : public Network {
public:
// The area of convolution is 2*half_x + 1 by 2*half_y + 1, forcing it to
// always be odd, so the center is the current pixel.
Convolve(const STRING& name, int ni, int half_x, int half_y);
virtual ~Convolve();
virtual STRING spec() const {
STRING spec;
spec.add_str_int("C", half_x_ * 2 + 1);
spec.add_str_int(",", half_y_ * 2 + 1);
return spec;
}
// Writes to the given file. Returns false in case of error.
virtual bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
virtual bool DeSerialize(bool swap, TFile* fp);
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output);
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas);
protected:
// Serialized data.
inT32 half_x_;
inT32 half_y_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_SUBSAMPLE_H_

412
lstm/ctc.cpp Normal file
View File

@ -0,0 +1,412 @@
///////////////////////////////////////////////////////////////////////
// File: ctc.cpp
// Description: Slightly improved standard CTC to compute the targets.
// Author: Ray Smith
// Created: Wed Jul 13 15:50:06 PDT 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "ctc.h"
#include <memory>
#include "genericvector.h"
#include "host.h"
#include "matrix.h"
#include "networkio.h"
#include "network.h"
#include "scrollview.h"
namespace tesseract {
// Magic constants that keep CTC stable.
// Minimum probability limit for softmax input to ctc_loss.
const float CTC::kMinProb_ = 1e-12;
// Maximum absolute argument to exp().
const double CTC::kMaxExpArg_ = 80.0;
// Minimum probability for total prob in time normalization.
const double CTC::kMinTotalTimeProb_ = 1e-8;
// Minimum probability for total prob in final normalization.
const double CTC::kMinTotalFinalProb_ = 1e-6;
// Builds a target using CTC. Slightly improved as follows:
// Includes normalizations and clipping for stability.
// labels should be pre-padded with nulls everywhere.
// labels can be longer than the time sequence, but the total number of
// essential labels (non-null plus nulls between equal labels) must not exceed
// the number of timesteps in outputs.
// outputs is the output of the network, and should have already been
// normalized with NormalizeProbs.
// On return targets is filled with the computed targets.
// Returns false if there is insufficient time for the labels.
/* static */
bool CTC::ComputeCTCTargets(const GenericVector<int>& labels, int null_char,
const GENERIC_2D_ARRAY<float>& outputs,
NetworkIO* targets) {
std::unique_ptr<CTC> ctc(new CTC(labels, null_char, outputs));
if (!ctc->ComputeLabelLimits()) {
return false; // Not enough time.
}
// Generate simple targets purely from the truth labels by spreading them
// evenly over time.
GENERIC_2D_ARRAY<float> simple_targets;
ctc->ComputeSimpleTargets(&simple_targets);
// Add the simple targets as a starter bias to the network outputs.
float bias_fraction = ctc->CalculateBiasFraction();
simple_targets *= bias_fraction;
ctc->outputs_ += simple_targets;
NormalizeProbs(&ctc->outputs_);
// Run regular CTC on the biased outputs.
// Run forward and backward
GENERIC_2D_ARRAY<double> log_alphas, log_betas;
ctc->Forward(&log_alphas);
ctc->Backward(&log_betas);
// Normalize and come out of log space with a clipped softmax over time.
log_alphas += log_betas;
ctc->NormalizeSequence(&log_alphas);
ctc->LabelsToClasses(log_alphas, targets);
NormalizeProbs(targets);
return true;
}
CTC::CTC(const GenericVector<int>& labels, int null_char,
const GENERIC_2D_ARRAY<float>& outputs)
: labels_(labels), outputs_(outputs), null_char_(null_char) {
num_timesteps_ = outputs.dim1();
num_classes_ = outputs.dim2();
num_labels_ = labels_.size();
}
// Computes vectors of min and max label index for each timestep, based on
// whether skippability of nulls makes it possible to complete a valid path.
bool CTC::ComputeLabelLimits() {
min_labels_.init_to_size(num_timesteps_, 0);
max_labels_.init_to_size(num_timesteps_, 0);
int min_u = num_labels_ - 1;
if (labels_[min_u] == null_char_) --min_u;
for (int t = num_timesteps_ - 1; t >= 0; --t) {
min_labels_[t] = min_u;
if (min_u > 0) {
--min_u;
if (labels_[min_u] == null_char_ && min_u > 0 &&
labels_[min_u + 1] != labels_[min_u - 1]) {
--min_u;
}
}
}
int max_u = labels_[0] == null_char_;
for (int t = 0; t < num_timesteps_; ++t) {
max_labels_[t] = max_u;
if (max_labels_[t] < min_labels_[t]) return false; // Not enough room.
if (max_u + 1 < num_labels_) {
++max_u;
if (labels_[max_u] == null_char_ && max_u + 1 < num_labels_ &&
labels_[max_u + 1] != labels_[max_u - 1]) {
++max_u;
}
}
}
return true;
}
// Computes targets based purely on the labels by spreading the labels evenly
// over the available timesteps.
void CTC::ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const {
// Initialize all targets to zero.
targets->Resize(num_timesteps_, num_classes_, 0.0f);
GenericVector<float> half_widths;
GenericVector<int> means;
ComputeWidthsAndMeans(&half_widths, &means);
for (int l = 0; l < num_labels_; ++l) {
int label = labels_[l];
float left_half_width = half_widths[l];
float right_half_width = left_half_width;
int mean = means[l];
if (label == null_char_) {
if (!NeededNull(l)) {
if ((l > 0 && mean == means[l - 1]) ||
(l + 1 < num_labels_ && mean == means[l + 1])) {
continue; // Drop overlapping null.
}
}
// Make sure that no space is left unoccupied and that non-nulls always
// peak at 1 by stretching nulls to meet their neighbors.
if (l > 0) left_half_width = mean - means[l - 1];
if (l + 1 < num_labels_) right_half_width = means[l + 1] - mean;
}
if (mean >= 0 && mean < num_timesteps_) targets->put(mean, label, 1.0f);
for (int offset = 1; offset < left_half_width && mean >= offset; ++offset) {
float prob = 1.0f - offset / left_half_width;
if (mean - offset < num_timesteps_ &&
prob > targets->get(mean - offset, label)) {
targets->put(mean - offset, label, prob);
}
}
for (int offset = 1;
offset < right_half_width && mean + offset < num_timesteps_;
++offset) {
float prob = 1.0f - offset / right_half_width;
if (mean + offset >= 0 && prob > targets->get(mean + offset, label)) {
targets->put(mean + offset, label, prob);
}
}
}
}
// Computes mean positions and half widths of the simple targets by spreading
// the labels evenly over the available timesteps.
void CTC::ComputeWidthsAndMeans(GenericVector<float>* half_widths,
GenericVector<int>* means) const {
// Count the number of labels of each type, in regexp terms, counts plus
// (non-null or necessary null, which must occur at least once) and star
// (optional null).
int num_plus = 0, num_star = 0;
for (int i = 0; i < num_labels_; ++i) {
if (labels_[i] != null_char_ || NeededNull(i))
++num_plus;
else
++num_star;
}
// Compute the size for each type. If there is enough space for everything
// to have size>=1, then all are equal, otherwise plus_size=1 and star gets
// whatever is left-over.
float plus_size = 1.0f, star_size = 0.0f;
float total_floating = num_plus + num_star;
if (total_floating <= num_timesteps_) {
plus_size = star_size = num_timesteps_ / total_floating;
} else if (num_star > 0) {
star_size = static_cast<float>(num_timesteps_ - num_plus) / num_star;
}
// Set the width and compute the mean of each.
float mean_pos = 0.0f;
for (int i = 0; i < num_labels_; ++i) {
float half_width;
if (labels_[i] != null_char_ || NeededNull(i)) {
half_width = plus_size / 2.0f;
} else {
half_width = star_size / 2.0f;
}
mean_pos += half_width;
means->push_back(static_cast<int>(mean_pos));
mean_pos += half_width;
half_widths->push_back(half_width);
}
}
// Helper returns the index of the highest probability label at timestep t.
static int BestLabel(const GENERIC_2D_ARRAY<float>& outputs, int t) {
int result = 0;
int num_classes = outputs.dim2();
const float* outputs_t = outputs[t];
for (int c = 1; c < num_classes; ++c) {
if (outputs_t[c] > outputs_t[result]) result = c;
}
return result;
}
// Calculates and returns a suitable fraction of the simple targets to add
// to the network outputs.
float CTC::CalculateBiasFraction() {
// Compute output labels via basic decoding.
GenericVector<int> output_labels;
for (int t = 0; t < num_timesteps_; ++t) {
int label = BestLabel(outputs_, t);
while (t + 1 < num_timesteps_ && BestLabel(outputs_, t + 1) == label) ++t;
if (label != null_char_) output_labels.push_back(label);
}
// Simple bag of labels error calculation.
GenericVector<int> truth_counts(num_classes_, 0);
GenericVector<int> output_counts(num_classes_, 0);
for (int l = 0; l < num_labels_; ++l) {
++truth_counts[labels_[l]];
}
for (int l = 0; l < output_labels.size(); ++l) {
++output_counts[output_labels[l]];
}
// Count the number of true and false positive non-nulls and truth labels.
int true_pos = 0, false_pos = 0, total_labels = 0;
for (int c = 0; c < num_classes_; ++c) {
if (c == null_char_) continue;
int truth_count = truth_counts[c];
int ocr_count = output_counts[c];
if (truth_count > 0) {
total_labels += truth_count;
if (ocr_count > truth_count) {
true_pos += truth_count;
false_pos += ocr_count - truth_count;
} else {
true_pos += ocr_count;
}
}
// We don't need to count classes that don't exist in the truth as
// false positives, because they don't affect CTC at all.
}
if (total_labels == 0) return 0.0f;
return exp(MAX(true_pos - false_pos, 1) * log(kMinProb_) / total_labels);
}
// Given ln(x) and ln(y), returns ln(x + y), using:
// ln(x + y) = ln(y) + ln(1 + exp(ln(y) - ln(x)), ensuring that ln(x) is the
// bigger number to maximize precision.
static double LogSumExp(double ln_x, double ln_y) {
if (ln_x >= ln_y) {
return ln_x + log1p(exp(ln_y - ln_x));
} else {
return ln_y + log1p(exp(ln_x - ln_y));
}
}
// Runs the forward CTC pass, filling in log_probs.
void CTC::Forward(GENERIC_2D_ARRAY<double>* log_probs) const {
log_probs->Resize(num_timesteps_, num_labels_, -MAX_FLOAT32);
log_probs->put(0, 0, log(outputs_(0, labels_[0])));
if (labels_[0] == null_char_)
log_probs->put(0, 1, log(outputs_(0, labels_[1])));
for (int t = 1; t < num_timesteps_; ++t) {
const float* outputs_t = outputs_[t];
for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
// Continuing the same label.
double log_sum = log_probs->get(t - 1, u);
// Change from previous label.
if (u > 0) {
log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 1));
}
// Skip the null if allowed.
if (u >= 2 && labels_[u - 1] == null_char_ &&
labels_[u] != labels_[u - 2]) {
log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 2));
}
// Add in the log prob of the current label.
double label_prob = outputs_t[labels_[u]];
log_sum += log(label_prob);
log_probs->put(t, u, log_sum);
}
}
}
// Runs the backward CTC pass, filling in log_probs.
void CTC::Backward(GENERIC_2D_ARRAY<double>* log_probs) const {
log_probs->Resize(num_timesteps_, num_labels_, -MAX_FLOAT32);
log_probs->put(num_timesteps_ - 1, num_labels_ - 1, 0.0);
if (labels_[num_labels_ - 1] == null_char_)
log_probs->put(num_timesteps_ - 1, num_labels_ - 2, 0.0);
for (int t = num_timesteps_ - 2; t >= 0; --t) {
const float* outputs_tp1 = outputs_[t + 1];
for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
// Continuing the same label.
double log_sum = log_probs->get(t + 1, u) + log(outputs_tp1[labels_[u]]);
// Change from previous label.
if (u + 1 < num_labels_) {
double prev_prob = outputs_tp1[labels_[u + 1]];
log_sum =
LogSumExp(log_sum, log_probs->get(t + 1, u + 1) + log(prev_prob));
}
// Skip the null if allowed.
if (u + 2 < num_labels_ && labels_[u + 1] == null_char_ &&
labels_[u] != labels_[u + 2]) {
double skip_prob = outputs_tp1[labels_[u + 2]];
log_sum =
LogSumExp(log_sum, log_probs->get(t + 1, u + 2) + log(skip_prob));
}
log_probs->put(t, u, log_sum);
}
}
}
// Normalizes and brings probs out of log space with a softmax over time.
void CTC::NormalizeSequence(GENERIC_2D_ARRAY<double>* probs) const {
double max_logprob = probs->Max();
for (int u = 0; u < num_labels_; ++u) {
double total = 0.0;
for (int t = 0; t < num_timesteps_; ++t) {
// Separate impossible path from unlikely probs.
double prob = probs->get(t, u);
if (prob > -MAX_FLOAT32)
prob = ClippedExp(prob - max_logprob);
else
prob = 0.0;
total += prob;
probs->put(t, u, prob);
}
// Note that although this is a probability distribution over time and
// therefore should sum to 1, it is important to allow some labels to be
// all zero, (or at least tiny) as it is necessary to skip some blanks.
if (total < kMinTotalTimeProb_) total = kMinTotalTimeProb_;
for (int t = 0; t < num_timesteps_; ++t)
probs->put(t, u, probs->get(t, u) / total);
}
}
// For each timestep computes the max prob for each class over all
// instances of the class in the labels_, and sets the targets to
// the max observed prob.
void CTC::LabelsToClasses(const GENERIC_2D_ARRAY<double>& probs,
NetworkIO* targets) const {
// For each timestep compute the max prob for each class over all
// instances of the class in the labels_.
GenericVector<double> class_probs;
for (int t = 0; t < num_timesteps_; ++t) {
float* targets_t = targets->f(t);
class_probs.init_to_size(num_classes_, 0.0);
for (int u = 0; u < num_labels_; ++u) {
double prob = probs(t, u);
// Note that although Graves specifies sum over all labels of the same
// class, we need to allow skipped blanks to go to zero, so they don't
// interfere with the non-blanks, so max is better than sum.
if (prob > class_probs[labels_[u]]) class_probs[labels_[u]] = prob;
// class_probs[labels_[u]] += prob;
}
int best_class = 0;
for (int c = 0; c < num_classes_; ++c) {
targets_t[c] = class_probs[c];
if (class_probs[c] > class_probs[best_class]) best_class = c;
}
}
}
// Normalizes the probabilities such that no target has a prob below min_prob,
// and, provided that the initial total is at least min_total_prob, then all
// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
// probability is thus 1 - (num_classes-1)*min_prob.
/* static */
void CTC::NormalizeProbs(GENERIC_2D_ARRAY<float>* probs) {
int num_timesteps = probs->dim1();
int num_classes = probs->dim2();
for (int t = 0; t < num_timesteps; ++t) {
float* probs_t = (*probs)[t];
// Compute the total and clip that to prevent amplification of noise.
double total = 0.0;
for (int c = 0; c < num_classes; ++c) total += probs_t[c];
if (total < kMinTotalFinalProb_) total = kMinTotalFinalProb_;
// Compute the increased total as a result of clipping.
double increment = 0.0;
for (int c = 0; c < num_classes; ++c) {
double prob = probs_t[c] / total;
if (prob < kMinProb_) increment += kMinProb_ - prob;
}
// Now normalize with clipping. Any additional clipping is negligible.
total += increment;
for (int c = 0; c < num_classes; ++c) {
float prob = probs_t[c] / total;
probs_t[c] = MAX(prob, kMinProb_);
}
}
}
// Returns true if the label at index is a needed null.
bool CTC::NeededNull(int index) const {
return labels_[index] == null_char_ && index > 0 && index + 1 < num_labels_ &&
labels_[index + 1] == labels_[index - 1];
}
} // namespace tesseract

130
lstm/ctc.h Normal file
View File

@ -0,0 +1,130 @@
///////////////////////////////////////////////////////////////////////
// File: ctc.h
// Description: Slightly improved standard CTC to compute the targets.
// Author: Ray Smith
// Created: Wed Jul 13 15:17:06 PDT 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_CTC_H_
#define TESSERACT_LSTM_CTC_H_
#include "genericvector.h"
#include "network.h"
#include "networkio.h"
#include "scrollview.h"
namespace tesseract {
// Class to encapsulate CTC and simple target generation.
class CTC {
public:
// Normalizes the probabilities such that no target has a prob below min_prob,
// and, provided that the initial total is at least min_total_prob, then all
// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
// probability is thus 1 - (num_classes-1)*min_prob.
static void NormalizeProbs(NetworkIO* probs) {
NormalizeProbs(probs->mutable_float_array());
}
// Builds a target using CTC. Slightly improved as follows:
// Includes normalizations and clipping for stability.
// labels should be pre-padded with nulls wherever desired, but they don't
// have to be between all labels. Allows for multi-label codes with no
// nulls between.
// labels can be longer than the time sequence, but the total number of
// essential labels (non-null plus nulls between equal labels) must not exceed
// the number of timesteps in outputs.
// outputs is the output of the network, and should have already been
// normalized with NormalizeProbs.
// On return targets is filled with the computed targets.
// Returns false if there is insufficient time for the labels.
static bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
int null_char,
const GENERIC_2D_ARRAY<float>& outputs,
NetworkIO* targets);
private:
// Constructor is private as the instance only holds information specific to
// the current labels, outputs etc, and is built by the static function.
CTC(const GenericVector<int>& labels, int null_char,
const GENERIC_2D_ARRAY<float>& outputs);
// Computes vectors of min and max label index for each timestep, based on
// whether skippability of nulls makes it possible to complete a valid path.
bool ComputeLabelLimits();
// Computes targets based purely on the labels by spreading the labels evenly
// over the available timesteps.
void ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const;
// Computes mean positions and half widths of the simple targets by spreading
// the labels even over the available timesteps.
void ComputeWidthsAndMeans(GenericVector<float>* half_widths,
GenericVector<int>* means) const;
// Calculates and returns a suitable fraction of the simple targets to add
// to the network outputs.
float CalculateBiasFraction();
// Runs the forward CTC pass, filling in log_probs.
void Forward(GENERIC_2D_ARRAY<double>* log_probs) const;
// Runs the backward CTC pass, filling in log_probs.
void Backward(GENERIC_2D_ARRAY<double>* log_probs) const;
// Normalizes and brings probs out of log space with a softmax over time.
void NormalizeSequence(GENERIC_2D_ARRAY<double>* probs) const;
// For each timestep computes the max prob for each class over all
// instances of the class in the labels_, and sets the targets to
// the max observed prob.
void LabelsToClasses(const GENERIC_2D_ARRAY<double>& probs,
NetworkIO* targets) const;
// Normalizes the probabilities such that no target has a prob below min_prob,
// and, provided that the initial total is at least min_total_prob, then all
// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
// probability is thus 1 - (num_classes-1)*min_prob.
static void NormalizeProbs(GENERIC_2D_ARRAY<float>* probs);
// Returns true if the label at index is a needed null.
bool NeededNull(int index) const;
// Returns exp(clipped(x)), clipping x to a reasonable range to prevent over/
// underflow.
static double ClippedExp(double x) {
if (x < -kMaxExpArg_) return exp(-kMaxExpArg_);
if (x > kMaxExpArg_) return exp(kMaxExpArg_);
return exp(x);
}
// Minimum probability limit for softmax input to ctc_loss.
static const float kMinProb_;
// Maximum absolute argument to exp().
static const double kMaxExpArg_;
// Minimum probability for total prob in time normalization.
static const double kMinTotalTimeProb_;
// Minimum probability for total prob in final normalization.
static const double kMinTotalFinalProb_;
// The truth label indices that are to be matched to outputs_.
const GenericVector<int>& labels_;
// The network outputs.
GENERIC_2D_ARRAY<float> outputs_;
// The null or "blank" label.
int null_char_;
// Number of timesteps in outputs_.
int num_timesteps_;
// Number of classes in outputs_.
int num_classes_;
// Number of labels in labels_.
int num_labels_;
// Min and max valid label indices for each timestep.
GenericVector<int> min_labels_;
GenericVector<int> max_labels_;
};
} // namespace tesseract
#endif // TESSERACT_LSTM_CTC_H_

285
lstm/fullyconnected.cpp Normal file
View File

@ -0,0 +1,285 @@
///////////////////////////////////////////////////////////////////////
// File: fullyconnected.cpp
// Description: Simple feed-forward layer with various non-linearities.
// Author: Ray Smith
// Created: Wed Feb 26 14:49:15 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "fullyconnected.h"
#ifdef _OPENMP
#include <omp.h>
#endif
#include <stdio.h>
#include <stdlib.h>
#include "functions.h"
#include "networkscratch.h"
// Number of threads to use for parallel calculation of Forward and Backward.
const int kNumThreads = 4;
namespace tesseract {
FullyConnected::FullyConnected(const STRING& name, int ni, int no,
NetworkType type)
: Network(type, name, ni, no), external_source_(NULL), int_mode_(false) {
}
FullyConnected::~FullyConnected() {
}
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
StaticShape FullyConnected::OutputShape(const StaticShape& input_shape) const {
LossType loss_type = LT_NONE;
if (type_ == NT_SOFTMAX)
loss_type = LT_CTC;
else if (type_ == NT_SOFTMAX_NO_CTC)
loss_type = LT_SOFTMAX;
else if (type_ == NT_LOGISTIC)
loss_type = LT_LOGISTIC;
StaticShape result(input_shape);
result.set_depth(no_);
result.set_loss_type(loss_type);
return result;
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
int FullyConnected::InitWeights(float range, TRand* randomizer) {
Network::SetRandomizer(randomizer);
num_weights_ = weights_.InitWeightsFloat(no_, ni_ + 1, TestFlag(NF_ADA_GRAD),
range, randomizer);
return num_weights_;
}
// Converts a float network to an int network.
void FullyConnected::ConvertToInt() {
weights_.ConvertToInt();
}
// Provides debug output on the weights.
void FullyConnected::DebugWeights() {
weights_.Debug2D(name_.string());
}
// Writes to the given file. Returns false in case of error.
bool FullyConnected::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
if (!weights_.Serialize(training_, fp)) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool FullyConnected::DeSerialize(bool swap, TFile* fp) {
if (!weights_.DeSerialize(training_, swap, fp)) return false;
return true;
}
// Runs forward propagation of activations on the input line.
// See NetworkCpp for a detailed discussion of the arguments.
void FullyConnected::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
int width = input.Width();
if (type_ == NT_SOFTMAX)
output->ResizeFloat(input, no_);
else
output->Resize(input, no_);
SetupForward(input, input_transpose);
GenericVector<NetworkScratch::FloatVec> temp_lines;
temp_lines.init_to_size(kNumThreads, NetworkScratch::FloatVec());
GenericVector<NetworkScratch::FloatVec> curr_input;
curr_input.init_to_size(kNumThreads, NetworkScratch::FloatVec());
for (int i = 0; i < temp_lines.size(); ++i) {
temp_lines[i].Init(no_, scratch);
curr_input[i].Init(ni_, scratch);
}
#ifdef _OPENMP
#pragma omp parallel for num_threads(kNumThreads)
for (int t = 0; t < width; ++t) {
// Thread-local pointer to temporary storage.
int thread_id = omp_get_thread_num();
#else
for (int t = 0; t < width; ++t) {
// Thread-local pointer to temporary storage.
int thread_id = 0;
#endif
double* temp_line = temp_lines[thread_id];
const double* d_input = NULL;
const inT8* i_input = NULL;
if (input.int_mode()) {
i_input = input.i(t);
} else {
input.ReadTimeStep(t, curr_input[thread_id]);
d_input = curr_input[thread_id];
}
ForwardTimeStep(d_input, i_input, t, temp_line);
output->WriteTimeStep(t, temp_line);
if (training() && type_ != NT_SOFTMAX) {
acts_.CopyTimeStepFrom(t, *output, t);
}
}
// Zero all the elements that are in the padding around images that allows
// multiple different-sized images to exist in a single array.
// acts_ is only used if this is not a softmax op.
if (training() && type_ != NT_SOFTMAX) {
acts_.ZeroInvalidElements();
}
output->ZeroInvalidElements();
#if DEBUG_DETAIL > 0
tprintf("F Output:%s\n", name_.string());
output->Print(10);
#endif
if (debug) DisplayForward(*output);
}
// Components of Forward so FullyConnected can be reused inside LSTM.
void FullyConnected::SetupForward(const NetworkIO& input,
const TransposedArray* input_transpose) {
// Softmax output is always float, so save the input type.
int_mode_ = input.int_mode();
if (training()) {
acts_.Resize(input, no_);
// Source_ is a transposed copy of input. It isn't needed if provided.
external_source_ = input_transpose;
if (external_source_ == NULL) source_t_.ResizeNoInit(ni_, input.Width());
}
}
void FullyConnected::ForwardTimeStep(const double* d_input, const inT8* i_input,
int t, double* output_line) {
// input is copied to source_ line-by-line for cache coherency.
if (training() && external_source_ == NULL && d_input != NULL)
source_t_.WriteStrided(t, d_input);
if (d_input != NULL)
weights_.MatrixDotVector(d_input, output_line);
else
weights_.MatrixDotVector(i_input, output_line);
if (type_ == NT_TANH) {
FuncInplace<GFunc>(no_, output_line);
} else if (type_ == NT_LOGISTIC) {
FuncInplace<FFunc>(no_, output_line);
} else if (type_ == NT_POSCLIP) {
FuncInplace<ClipFFunc>(no_, output_line);
} else if (type_ == NT_SYMCLIP) {
FuncInplace<ClipGFunc>(no_, output_line);
} else if (type_ == NT_RELU) {
FuncInplace<Relu>(no_, output_line);
} else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) {
SoftmaxInPlace(no_, output_line);
} else if (type_ != NT_LINEAR) {
ASSERT_HOST("Invalid fully-connected type!" == NULL);
}
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
if (debug) DisplayBackward(fwd_deltas);
back_deltas->Resize(fwd_deltas, ni_);
GenericVector<NetworkScratch::FloatVec> errors;
errors.init_to_size(kNumThreads, NetworkScratch::FloatVec());
for (int i = 0; i < errors.size(); ++i) errors[i].Init(no_, scratch);
GenericVector<NetworkScratch::FloatVec> temp_backprops;
if (needs_to_backprop_) {
temp_backprops.init_to_size(kNumThreads, NetworkScratch::FloatVec());
for (int i = 0; i < kNumThreads; ++i) temp_backprops[i].Init(ni_, scratch);
}
int width = fwd_deltas.Width();
NetworkScratch::GradientStore errors_t;
errors_t.Init(no_, width, scratch);
#ifdef _OPENMP
#pragma omp parallel for num_threads(kNumThreads)
for (int t = 0; t < width; ++t) {
int thread_id = omp_get_thread_num();
#else
for (int t = 0; t < width; ++t) {
int thread_id = 0;
#endif
double* backprop = NULL;
if (needs_to_backprop_) backprop = temp_backprops[thread_id];
double* curr_errors = errors[thread_id];
BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
if (backprop != NULL) {
back_deltas->WriteTimeStep(t, backprop);
}
}
FinishBackward(*errors_t.get());
if (needs_to_backprop_) {
back_deltas->ZeroInvalidElements();
back_deltas->CopyWithNormalization(*back_deltas, fwd_deltas);
#if DEBUG_DETAIL > 0
tprintf("F Backprop:%s\n", name_.string());
back_deltas->Print(10);
#endif
return true;
}
return false; // No point going further back.
}
void FullyConnected::BackwardTimeStep(const NetworkIO& fwd_deltas, int t,
double* curr_errors,
TransposedArray* errors_t,
double* backprop) {
if (type_ == NT_TANH)
acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
else if (type_ == NT_LOGISTIC)
acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors);
else if (type_ == NT_POSCLIP)
acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors);
else if (type_ == NT_SYMCLIP)
acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors);
else if (type_ == NT_RELU)
acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors);
else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC ||
type_ == NT_LINEAR)
fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors.
else
ASSERT_HOST("Invalid fully-connected type!" == NULL);
// Generate backprop only if needed by the lower layer.
if (backprop != NULL) weights_.VectorDotMatrix(curr_errors, backprop);
errors_t->WriteStrided(t, curr_errors);
}
void FullyConnected::FinishBackward(const TransposedArray& errors_t) {
if (external_source_ == NULL)
weights_.SumOuterTransposed(errors_t, source_t_, true);
else
weights_.SumOuterTransposed(errors_t, *external_source_, true);
}
// Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the adagrad computation iff
// use_ada_grad_ is true.
void FullyConnected::Update(float learning_rate, float momentum,
int num_samples) {
weights_.Update(learning_rate, momentum, num_samples);
}
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// *changed.
void FullyConnected::CountAlternators(const Network& other, double* same,
double* changed) const {
ASSERT_HOST(other.type() == type_);
const FullyConnected* fc = reinterpret_cast<const FullyConnected*>(&other);
weights_.CountAlternators(fc->weights_, same, changed);
}
} // namespace tesseract.

130
lstm/fullyconnected.h Normal file
View File

@ -0,0 +1,130 @@
///////////////////////////////////////////////////////////////////////
// File: fullyconnected.h
// Description: Simple feed-forward layer with various non-linearities.
// Author: Ray Smith
// Created: Wed Feb 26 14:46:06 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_FULLYCONNECTED_H_
#define TESSERACT_LSTM_FULLYCONNECTED_H_
#include "network.h"
#include "networkscratch.h"
namespace tesseract {
// C++ Implementation of the Softmax (output) class from lstm.py.
class FullyConnected : public Network {
public:
FullyConnected(const STRING& name, int ni, int no, NetworkType type);
virtual ~FullyConnected();
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
virtual StaticShape OutputShape(const StaticShape& input_shape) const;
virtual STRING spec() const {
STRING spec;
if (type_ == NT_TANH)
spec.add_str_int("Ft", no_);
else if (type_ == NT_LOGISTIC)
spec.add_str_int("Fs", no_);
else if (type_ == NT_RELU)
spec.add_str_int("Fr", no_);
else if (type_ == NT_LINEAR)
spec.add_str_int("Fl", no_);
else if (type_ == NT_POSCLIP)
spec.add_str_int("Fp", no_);
else if (type_ == NT_SYMCLIP)
spec.add_str_int("Fs", no_);
else if (type_ == NT_SOFTMAX)
spec.add_str_int("Fc", no_);
else
spec.add_str_int("Fm", no_);
return spec;
}
// Changes the type to the given type. Used to commute a softmax to a
// non-output type for adding on other networks.
void ChangeType(NetworkType type) {
type_ = type;
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
virtual int InitWeights(float range, TRand* randomizer);
// Converts a float network to an int network.
virtual void ConvertToInt();
// Provides debug output on the weights.
virtual void DebugWeights();
// Writes to the given file. Returns false in case of error.
virtual bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
virtual bool DeSerialize(bool swap, TFile* fp);
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output);
// Components of Forward so FullyConnected can be reused inside LSTM.
void SetupForward(const NetworkIO& input,
const TransposedArray* input_transpose);
void ForwardTimeStep(const double* d_input, const inT8* i_input, int t,
double* output_line);
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas);
// Components of Backward so FullyConnected can be reused inside LSTM.
void BackwardTimeStep(const NetworkIO& fwd_deltas, int t, double* curr_errors,
TransposedArray* errors_t, double* backprop);
void FinishBackward(const TransposedArray& errors_t);
// Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the adagrad computation iff
// use_ada_grad_ is true.
virtual void Update(float learning_rate, float momentum, int num_samples);
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// *changed.
virtual void CountAlternators(const Network& other, double* same,
double* changed) const;
protected:
// Weight arrays of size [no, ni + 1].
WeightMatrix weights_;
// Transposed copy of input used during training of size [ni, width].
TransposedArray source_t_;
// Pointer to transposed input stored elsewhere. If not null, this is used
// in preference to calculating the transpose and storing it in source_t_.
const TransposedArray* external_source_;
// Activations from forward pass of size [width, no].
NetworkIO acts_;
// Memory of the integer mode input to forward as softmax always outputs
// float, so the information is otherwise lost.
bool int_mode_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_FULLYCONNECTED_H_

26
lstm/functions.cpp Normal file
View File

@ -0,0 +1,26 @@
///////////////////////////////////////////////////////////////////////
// File: functions.cpp
// Description: Static initialize-on-first-use non-linearity functions.
// Author: Ray Smith
// Created: Tue Jul 17 14:02:59 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "functions.h"
namespace tesseract {
double TanhTable[kTableSize];
double LogisticTable[kTableSize];
} // namespace tesseract.

249
lstm/functions.h Normal file
View File

@ -0,0 +1,249 @@
///////////////////////////////////////////////////////////////////////
// File: functions.h
// Description: Collection of function-objects used by the network layers.
// Author: Ray Smith
// Created: Fri Jun 20 10:45:37 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_FUNCTIONS_H_
#define TESSERACT_LSTM_FUNCTIONS_H_
#include <cmath>
#include "helpers.h"
#include "tprintf.h"
// Setting this to 1 or more causes massive dumps of debug data: weights,
// updates, internal calculations etc, and reduces the number of test iterations
// to a small number, so outputs can be diffed.
#define DEBUG_DETAIL 0
#if DEBUG_DETAIL > 0
#undef _OPENMP // Disable open mp to get the outputs in sync.
#endif
namespace tesseract {
// Size of static tables.
const int kTableSize = 4096;
// Scale factor for float arg to int index.
const double kScaleFactor = 256.0;
extern double TanhTable[];
extern double LogisticTable[];
// Non-linearity (sigmoid) functions with cache tables and clipping.
inline double Tanh(double x) {
if (x < 0.0) return -Tanh(-x);
if (x >= (kTableSize - 1) / kScaleFactor) return 1.0;
x *= kScaleFactor;
int index = static_cast<int>(floor(x));
if (TanhTable[index] == 0.0 && index > 0) {
// Generate the entry.
TanhTable[index] = tanh(index / kScaleFactor);
}
if (index == kTableSize - 1) return TanhTable[kTableSize - 1];
if (TanhTable[index + 1] == 0.0) {
// Generate the entry.
TanhTable[index + 1] = tanh((index + 1) / kScaleFactor);
}
double offset = x - index;
return TanhTable[index] * (1.0 - offset) + TanhTable[index + 1] * offset;
}
inline double Logistic(double x) {
if (x < 0.0) return 1.0 - Logistic(-x);
if (x >= (kTableSize - 1) / kScaleFactor) return 1.0;
x *= kScaleFactor;
int index = static_cast<int>(floor(x));
if (LogisticTable[index] == 0.0) {
// Generate the entry.
LogisticTable[index] = 1.0 / (1.0 + exp(-index / kScaleFactor));
}
if (index == kTableSize - 1) return LogisticTable[kTableSize - 1];
if (LogisticTable[index + 1] == 0.0) {
// Generate the entry.
LogisticTable[index + 1] = 1.0 / (1.0 + exp(-(index + 1) / kScaleFactor));
}
double offset = x - index;
return LogisticTable[index] * (1.0 - offset) +
LogisticTable[index + 1] * offset;
}
// Non-linearity (sigmoid) functions and their derivatives.
struct FFunc {
inline double operator()(double x) const { return Logistic(x); }
};
struct FPrime {
inline double operator()(double y) const { return y * (1.0 - y); }
};
struct ClipFFunc {
inline double operator()(double x) const {
if (x <= 0.0) return 0.0;
if (x >= 1.0) return 1.0;
return x;
}
};
struct ClipFPrime {
inline double operator()(double y) const {
return 0.0 < y && y < 1.0 ? 1.0 : 0.0;
}
};
struct Relu {
inline double operator()(double x) const {
if (x <= 0.0) return 0.0;
return x;
}
};
struct ReluPrime {
inline double operator()(double y) const { return 0.0 < y ? 1.0 : 0.0; }
};
struct GFunc {
inline double operator()(double x) const { return Tanh(x); }
};
struct GPrime {
inline double operator()(double y) const { return 1.0 - y * y; }
};
struct ClipGFunc {
inline double operator()(double x) const {
if (x <= -1.0) return -1.0;
if (x >= 1.0) return 1.0;
return x;
}
};
struct ClipGPrime {
inline double operator()(double y) const {
return -1.0 < y && y < 1.0 ? 1.0 : 0.0;
}
};
struct HFunc {
inline double operator()(double x) const { return Tanh(x); }
};
struct HPrime {
inline double operator()(double y) const {
double u = Tanh(y);
return 1.0 - u * u;
}
};
struct UnityFunc {
inline double operator()(double x) const { return 1.0; }
};
struct IdentityFunc {
inline double operator()(double x) const { return x; }
};
// Applies Func in-place to inout, of size n.
template <class Func>
inline void FuncInplace(int n, double* inout) {
Func f;
for (int i = 0; i < n; ++i) {
inout[i] = f(inout[i]);
}
}
// Applies Func to u and multiplies the result by v component-wise,
// putting the product in out, all of size n.
template <class Func>
inline void FuncMultiply(const double* u, const double* v, int n, double* out) {
Func f;
for (int i = 0; i < n; ++i) {
out[i] = f(u[i]) * v[i];
}
}
// Applies the Softmax function in-place to inout, of size n.
template <typename T>
inline void SoftmaxInPlace(int n, T* inout) {
if (n <= 0) return;
// A limit on the negative range input to exp to guarantee non-zero output.
const T kMaxSoftmaxActivation = 86.0f;
T max_output = inout[0];
for (int i = 1; i < n; i++) {
T output = inout[i];
if (output > max_output) max_output = output;
}
T prob_total = 0.0;
for (int i = 0; i < n; i++) {
T prob = inout[i] - max_output;
prob = exp(ClipToRange(prob, -kMaxSoftmaxActivation, static_cast<T>(0)));
prob_total += prob;
inout[i] = prob;
}
if (prob_total > 0.0) {
for (int i = 0; i < n; i++) inout[i] /= prob_total;
}
}
// Copies n values of the given src vector to dest.
inline void CopyVector(int n, const double* src, double* dest) {
memcpy(dest, src, n * sizeof(dest[0]));
}
// Adds n values of the given src vector to dest.
inline void AccumulateVector(int n, const double* src, double* dest) {
for (int i = 0; i < n; ++i) dest[i] += src[i];
}
// Multiplies n values of inout in-place element-wise by the given src vector.
inline void MultiplyVectorsInPlace(int n, const double* src, double* inout) {
for (int i = 0; i < n; ++i) inout[i] *= src[i];
}
// Multiplies n values of u by v, element-wise, accumulating to out.
inline void MultiplyAccumulate(int n, const double* u, const double* v,
double* out) {
for (int i = 0; i < n; i++) {
out[i] += u[i] * v[i];
}
}
// Sums the given 5 n-vectors putting the result into sum.
inline void SumVectors(int n, const double* v1, const double* v2,
const double* v3, const double* v4, const double* v5,
double* sum) {
for (int i = 0; i < n; ++i) {
sum[i] = v1[i] + v2[i] + v3[i] + v4[i] + v5[i];
}
}
// Sets the given n-vector vec to 0.
template <typename T>
inline void ZeroVector(int n, T* vec) {
memset(vec, 0, n * sizeof(*vec));
}
// Clips the given vector vec, of size n to [lower, upper].
template <typename T>
inline void ClipVector(int n, T lower, T upper, T* vec) {
for (int i = 0; i < n; ++i) vec[i] = ClipToRange(vec[i], lower, upper);
}
// Converts the given n-vector to a binary encoding of the maximum value,
// encoded as vector of nf binary values.
inline void CodeInBinary(int n, int nf, double* vec) {
if (nf <= 0 || n < nf) return;
int index = 0;
double best_score = vec[0];
for (int i = 1; i < n; ++i) {
if (vec[i] > best_score) {
best_score = vec[i];
index = i;
}
}
int mask = 1;
for (int i = 0; i < nf; ++i, mask *= 2) {
vec[i] = (index & mask) ? 1.0 : 0.0;
}
}
} // namespace tesseract.
#endif // TESSERACT_LSTM_FUNCTIONS_H_

154
lstm/input.cpp Normal file
View File

@ -0,0 +1,154 @@
///////////////////////////////////////////////////////////////////////
// File: input.cpp
// Description: Input layer class for neural network implementations.
// Author: Ray Smith
// Created: Thu Mar 13 09:10:34 PDT 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "input.h"
#include "allheaders.h"
#include "imagedata.h"
#include "pageres.h"
#include "scrollview.h"
namespace tesseract {
Input::Input(const STRING& name, int ni, int no)
: Network(NT_INPUT, name, ni, no), cached_x_scale_(1) {}
Input::Input(const STRING& name, const StaticShape& shape)
: Network(NT_INPUT, name, shape.height(), shape.depth()),
shape_(shape),
cached_x_scale_(1) {
if (shape.height() == 1) ni_ = shape.depth();
}
Input::~Input() {
}
// Writes to the given file. Returns false in case of error.
bool Input::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
if (fp->FWrite(&shape_, sizeof(shape_), 1) != 1) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool Input::DeSerialize(bool swap, TFile* fp) {
if (fp->FRead(&shape_, sizeof(shape_), 1) != 1) return false;
// TODO(rays) swaps!
return true;
}
// Returns an integer reduction factor that the network applies to the
// time sequence. Assumes that any 2-d is already eliminated. Used for
// scaling bounding boxes of truth data.
int Input::XScaleFactor() const {
return 1;
}
// Provides the (minimum) x scale factor to the network (of interest only to
// input units) so they can determine how to scale bounding boxes.
void Input::CacheXScaleFactor(int factor) {
cached_x_scale_ = factor;
}
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
void Input::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
*output = input;
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool Input::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
tprintf("Input::Backward should not be called!!\n");
return false;
}
// Creates and returns a Pix of appropriate size for the network from the
// image_data. If non-null, *image_scale returns the image scale factor used.
// Returns nullptr on error.
/* static */
Pix* Input::PrepareLSTMInputs(const ImageData& image_data,
const Network* network, int min_width,
TRand* randomizer, float* image_scale) {
// Note that NumInputs() is defined as input image height.
int target_height = network->NumInputs();
int width, height;
Pix* pix =
image_data.PreScale(target_height, image_scale, &width, &height, nullptr);
if (pix == nullptr) {
tprintf("Bad pix from ImageData!\n");
return nullptr;
}
if (width <= min_width) {
tprintf("Image too small to scale!! (%dx%d vs min width of %d)\n", width,
height, min_width);
pixDestroy(&pix);
return nullptr;
}
return pix;
}
// Converts the given pix to a NetworkIO of height and depth appropriate to the
// given StaticShape:
// If depth == 3, convert to 24 bit color, otherwise normalized grey.
// Scale to target height, if the shape's height is > 1, or its depth if the
// height == 1. If height == 0 then no scaling.
// NOTE: It isn't safe for multiple threads to call this on the same pix.
/* static */
void Input::PreparePixInput(const StaticShape& shape, const Pix* pix,
TRand* randomizer, NetworkIO* input) {
bool color = shape.depth() == 3;
Pix* var_pix = const_cast<Pix*>(pix);
int depth = pixGetDepth(var_pix);
Pix* normed_pix = nullptr;
// On input to BaseAPI, an image is forced to be 1, 8 or 24 bit, without
// colormap, so we just have to deal with depth conversion here.
if (color) {
// Force RGB.
if (depth == 32)
normed_pix = pixClone(var_pix);
else
normed_pix = pixConvertTo32(var_pix);
} else {
// Convert non-8-bit images to 8 bit.
if (depth == 8)
normed_pix = pixClone(var_pix);
else
normed_pix = pixConvertTo8(var_pix, false);
}
int width = pixGetWidth(normed_pix);
int height = pixGetHeight(normed_pix);
int target_height = shape.height();
if (target_height == 1) target_height = shape.depth();
if (target_height == 0) target_height = height;
float im_factor = static_cast<float>(target_height) / height;
if (im_factor != 1.0f) {
// Get the scaled image.
Pix* scaled_pix = pixScale(normed_pix, im_factor, im_factor);
pixDestroy(&normed_pix);
normed_pix = scaled_pix;
}
input->FromPix(shape, normed_pix, randomizer);
pixDestroy(&normed_pix);
}
} // namespace tesseract.

107
lstm/input.h Normal file
View File

@ -0,0 +1,107 @@
///////////////////////////////////////////////////////////////////////
// File: input.h
// Description: Input layer class for neural network implementations.
// Author: Ray Smith
// Created: Thu Mar 13 08:56:26 PDT 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_INPUT_H_
#define TESSERACT_LSTM_INPUT_H_
#include "network.h"
class ScrollView;
namespace tesseract {
class Input : public Network {
public:
Input(const STRING& name, int ni, int no);
Input(const STRING& name, const StaticShape& shape);
virtual ~Input();
virtual STRING spec() const {
STRING spec;
spec.add_str_int("", shape_.batch());
spec.add_str_int(",", shape_.height());
spec.add_str_int(",", shape_.width());
spec.add_str_int(",", shape_.depth());
return spec;
}
// Returns the required shape input to the network.
virtual StaticShape InputShape() const { return shape_; }
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
virtual StaticShape OutputShape(const StaticShape& input_shape) const {
return shape_;
}
// Writes to the given file. Returns false in case of error.
// Should be overridden by subclasses, but called by their Serialize.
virtual bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
// Should be overridden by subclasses, but NOT called by their DeSerialize.
virtual bool DeSerialize(bool swap, TFile* fp);
// Returns an integer reduction factor that the network applies to the
// time sequence. Assumes that any 2-d is already eliminated. Used for
// scaling bounding boxes of truth data.
// WARNING: if GlobalMinimax is used to vary the scale, this will return
// the last used scale factor. Call it before any forward, and it will return
// the minimum scale factor of the paths through the GlobalMinimax.
virtual int XScaleFactor() const;
// Provides the (minimum) x scale factor to the network (of interest only to
// input units) so they can determine how to scale bounding boxes.
virtual void CacheXScaleFactor(int factor);
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output);
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas);
// Creates and returns a Pix of appropriate size for the network from the
// image_data. If non-null, *image_scale returns the image scale factor used.
// Returns nullptr on error.
/* static */
static Pix* PrepareLSTMInputs(const ImageData& image_data,
const Network* network, int min_width,
TRand* randomizer, float* image_scale);
// Converts the given pix to a NetworkIO of height and depth appropriate to
// the given StaticShape:
// If depth == 3, convert to 24 bit color, otherwise normalized grey.
// Scale to target height, if the shape's height is > 1, or its depth if the
// height == 1. If height == 0 then no scaling.
// NOTE: It isn't safe for multiple threads to call this on the same pix.
static void PreparePixInput(const StaticShape& shape, const Pix* pix,
TRand* randomizer, NetworkIO* input);
private:
// Input shape determines how images are dealt with.
StaticShape shape_;
// Cached total network x scale factor for scaling bounding boxes.
int cached_x_scale_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_INPUT_H_

710
lstm/lstm.cpp Normal file
View File

@ -0,0 +1,710 @@
///////////////////////////////////////////////////////////////////////
// File: lstm.cpp
// Description: Long-term-short-term-memory Recurrent neural network.
// Author: Ray Smith
// Created: Wed May 01 17:43:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "lstm.h"
#ifndef ANDROID_BUILD
#include <omp.h>
#endif
#include <stdio.h>
#include <stdlib.h>
#include "fullyconnected.h"
#include "functions.h"
#include "networkscratch.h"
#include "tprintf.h"
// Macros for openmp code if it is available, otherwise empty macros.
#ifdef _OPENMP
#define PARALLEL_IF_OPENMP(__num_threads) \
PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \
PRAGMA(omp sections nowait) { \
PRAGMA(omp section) {
#define SECTION_IF_OPENMP \
} \
PRAGMA(omp section) \
{
#define END_PARALLEL_IF_OPENMP \
} \
} /* end of sections */ \
} /* end of parallel section */
// Define the portable PRAGMA macro.
#ifdef _MSC_VER // Different _Pragma
#define PRAGMA(x) __pragma(x)
#else
#define PRAGMA(x) _Pragma(#x)
#endif // _MSC_VER
#else // _OPENMP
#define PARALLEL_IF_OPENMP(__num_threads)
#define SECTION_IF_OPENMP
#define END_PARALLEL_IF_OPENMP
#endif // _OPENMP
namespace tesseract {
// Max absolute value of state_. It is reasonably high to enable the state
// to count things.
const double kStateClip = 100.0;
// Max absolute value of gate_errors (the gradients).
const double kErrClip = 1.0f;
LSTM::LSTM(const STRING& name, int ni, int ns, int no, bool two_dimensional,
NetworkType type)
: Network(type, name, ni, no),
na_(ni + ns),
ns_(ns),
nf_(0),
is_2d_(two_dimensional),
softmax_(NULL) {
if (two_dimensional) na_ += ns_;
if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
nf_ = 0;
// networkbuilder ensures this is always true.
ASSERT_HOST(no == ns);
} else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : IntCastRounded(ceil(log2(no_)));
softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
} else {
tprintf("%d is invalid type of LSTM!\n", type);
ASSERT_HOST(false);
}
na_ += nf_;
}
LSTM::~LSTM() { delete softmax_; }
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
StaticShape result = input_shape;
result.set_depth(no_);
if (type_ == NT_LSTM_SUMMARY) result.set_width(1);
if (softmax_ != NULL) return softmax_->OutputShape(result);
return result;
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
int LSTM::InitWeights(float range, TRand* randomizer) {
Network::SetRandomizer(randomizer);
num_weights_ = 0;
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
num_weights_ += gate_weights_[w].InitWeightsFloat(
ns_, na_ + 1, TestFlag(NF_ADA_GRAD), range, randomizer);
}
if (softmax_ != NULL) {
num_weights_ += softmax_->InitWeights(range, randomizer);
}
return num_weights_;
}
// Converts a float network to an int network.
void LSTM::ConvertToInt() {
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].ConvertToInt();
}
if (softmax_ != NULL) {
softmax_->ConvertToInt();
}
}
// Sets up the network for training using the given weight_range.
void LSTM::DebugWeights() {
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
STRING msg = name_;
msg.add_str_int(" Gate weights ", w);
gate_weights_[w].Debug2D(msg.string());
}
if (softmax_ != NULL) {
softmax_->DebugWeights();
}
}
// Writes to the given file. Returns false in case of error.
bool LSTM::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
if (fp->FWrite(&na_, sizeof(na_), 1) != 1) return false;
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
if (!gate_weights_[w].Serialize(training_, fp)) return false;
}
if (softmax_ != NULL && !softmax_->Serialize(fp)) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool LSTM::DeSerialize(bool swap, TFile* fp) {
if (fp->FRead(&na_, sizeof(na_), 1) != 1) return false;
if (swap) ReverseN(&na_, sizeof(na_));
if (type_ == NT_LSTM_SOFTMAX) {
nf_ = no_;
} else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
nf_ = IntCastRounded(ceil(log2(no_)));
} else {
nf_ = 0;
}
is_2d_ = false;
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
if (!gate_weights_[w].DeSerialize(training_, swap, fp)) return false;
if (w == CI) {
ns_ = gate_weights_[CI].NumOutputs();
is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
}
}
delete softmax_;
if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
softmax_ =
reinterpret_cast<FullyConnected*>(Network::CreateFromFile(swap, fp));
if (softmax_ == NULL) return false;
} else {
softmax_ = NULL;
}
return true;
}
// Runs forward propagation of activations on the input line.
// See NetworkCpp for a detailed discussion of the arguments.
void LSTM::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
input_map_ = input.stride_map();
input_width_ = input.Width();
if (softmax_ != NULL)
output->ResizeFloat(input, no_);
else if (type_ == NT_LSTM_SUMMARY)
output->ResizeXTo1(input, no_);
else
output->Resize(input, no_);
ResizeForward(input);
// Temporary storage of forward computation for each gate.
NetworkScratch::FloatVec temp_lines[WT_COUNT];
for (int i = 0; i < WT_COUNT; ++i) temp_lines[i].Init(ns_, scratch);
// Single timestep buffers for the current/recurrent output and state.
NetworkScratch::FloatVec curr_state, curr_output;
curr_state.Init(ns_, scratch);
ZeroVector<double>(ns_, curr_state);
curr_output.Init(ns_, scratch);
ZeroVector<double>(ns_, curr_output);
// Rotating buffers of width buf_width allow storage of the state and output
// for the other dimension, used only when working in true 2D mode. The width
// is enough to hold an entire strip of the major direction.
int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
GenericVector<NetworkScratch::FloatVec> states, outputs;
if (Is2D()) {
states.init_to_size(buf_width, NetworkScratch::FloatVec());
outputs.init_to_size(buf_width, NetworkScratch::FloatVec());
for (int i = 0; i < buf_width; ++i) {
states[i].Init(ns_, scratch);
ZeroVector<double>(ns_, states[i]);
outputs[i].Init(ns_, scratch);
ZeroVector<double>(ns_, outputs[i]);
}
}
// Used only if a softmax LSTM.
NetworkScratch::FloatVec softmax_output;
NetworkScratch::IO int_output;
if (softmax_ != NULL) {
softmax_output.Init(no_, scratch);
ZeroVector<double>(no_, softmax_output);
if (input.int_mode()) int_output.Resize2d(true, 1, ns_, scratch);
softmax_->SetupForward(input, NULL);
}
NetworkScratch::FloatVec curr_input;
curr_input.Init(na_, scratch);
StrideMap::Index src_index(input_map_);
// Used only by NT_LSTM_SUMMARY.
StrideMap::Index dest_index(output->stride_map());
do {
int t = src_index.t();
// True if there is a valid old state for the 2nd dimension.
bool valid_2d = Is2D();
if (valid_2d) {
StrideMap::Index dim_index(src_index);
if (!dim_index.AddOffset(-1, FD_HEIGHT)) valid_2d = false;
}
// Index of the 2-D revolving buffers (outputs, states).
int mod_t = Modulo(t, buf_width); // Current timestep.
// Setup the padded input in source.
source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
if (softmax_ != NULL) {
source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
}
source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
if (Is2D())
source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
if (!source_.int_mode()) source_.ReadTimeStep(t, curr_input);
// Matrix multiply the inputs with the source.
PARALLEL_IF_OPENMP(GFS)
// It looks inefficient to create the threads on each t iteration, but the
// alternative of putting the parallel outside the t loop, a single around
// the t-loop and then tasks in place of the sections is a *lot* slower.
// Cell inputs.
if (source_.int_mode())
gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
else
gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
FuncInplace<GFunc>(ns_, temp_lines[CI]);
SECTION_IF_OPENMP
// Input Gates.
if (source_.int_mode())
gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
else
gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
FuncInplace<FFunc>(ns_, temp_lines[GI]);
SECTION_IF_OPENMP
// 1-D forget gates.
if (source_.int_mode())
gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
else
gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
FuncInplace<FFunc>(ns_, temp_lines[GF1]);
// 2-D forget gates.
if (Is2D()) {
if (source_.int_mode())
gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
else
gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
FuncInplace<FFunc>(ns_, temp_lines[GFS]);
}
SECTION_IF_OPENMP
// Output gates.
if (source_.int_mode())
gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
else
gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
FuncInplace<FFunc>(ns_, temp_lines[GO]);
END_PARALLEL_IF_OPENMP
// Apply forget gate to state.
MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
if (Is2D()) {
// Max-pool the forget gates (in 2-d) instead of blindly adding.
inT8* which_fg_col = which_fg_[t];
memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
if (valid_2d) {
const double* stepped_state = states[mod_t];
for (int i = 0; i < ns_; ++i) {
if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
which_fg_col[i] = 2;
}
}
}
}
MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
// Clip curr_state to a sane range.
ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
if (training_) {
// Save the gate node values.
node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
}
FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
if (training_) state_.WriteTimeStep(t, curr_state);
if (softmax_ != NULL) {
if (input.int_mode()) {
int_output->WriteTimeStep(0, curr_output);
softmax_->ForwardTimeStep(NULL, int_output->i(0), t, softmax_output);
} else {
softmax_->ForwardTimeStep(curr_output, NULL, t, softmax_output);
}
output->WriteTimeStep(t, softmax_output);
if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
CodeInBinary(no_, nf_, softmax_output);
}
} else if (type_ == NT_LSTM_SUMMARY) {
// Output only at the end of a row.
if (src_index.IsLast(FD_WIDTH)) {
output->WriteTimeStep(dest_index.t(), curr_output);
dest_index.Increment();
}
} else {
output->WriteTimeStep(t, curr_output);
}
// Save states for use by the 2nd dimension only if needed.
if (Is2D()) {
CopyVector(ns_, curr_state, states[mod_t]);
CopyVector(ns_, curr_output, outputs[mod_t]);
}
// Always zero the states at the end of every row, but only for the major
// direction. The 2-D state remains intact.
if (src_index.IsLast(FD_WIDTH)) {
ZeroVector<double>(ns_, curr_state);
ZeroVector<double>(ns_, curr_output);
}
} while (src_index.Increment());
#if DEBUG_DETAIL > 0
tprintf("Source:%s\n", name_.string());
source_.Print(10);
tprintf("State:%s\n", name_.string());
state_.Print(10);
tprintf("Output:%s\n", name_.string());
output->Print(10);
#endif
if (debug) DisplayForward(*output);
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool LSTM::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
if (debug) DisplayBackward(fwd_deltas);
back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
// ======Scratch space.======
// Output errors from deltas with recurrence from sourceerr.
NetworkScratch::FloatVec outputerr;
outputerr.Init(ns_, scratch);
// Recurrent error in the state/source.
NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
curr_stateerr.Init(ns_, scratch);
curr_sourceerr.Init(na_, scratch);
ZeroVector<double>(ns_, curr_stateerr);
ZeroVector<double>(na_, curr_sourceerr);
// Errors in the gates.
NetworkScratch::FloatVec gate_errors[WT_COUNT];
for (int g = 0; g < WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
// Rotating buffers of width buf_width allow storage of the recurrent time-
// steps used only for true 2-D. Stores one full strip of the major direction.
int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
GenericVector<NetworkScratch::FloatVec> stateerr, sourceerr;
if (Is2D()) {
stateerr.init_to_size(buf_width, NetworkScratch::FloatVec());
sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec());
for (int t = 0; t < buf_width; ++t) {
stateerr[t].Init(ns_, scratch);
sourceerr[t].Init(na_, scratch);
ZeroVector<double>(ns_, stateerr[t]);
ZeroVector<double>(na_, sourceerr[t]);
}
}
// Parallel-generated sourceerr from each of the gates.
NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
for (int w = 0; w < WT_COUNT; ++w)
sourceerr_temps[w].Init(na_, scratch);
int width = input_width_;
// Transposed gate errors stored over all timesteps for sum outer.
NetworkScratch::GradientStore gate_errors_t[WT_COUNT];
for (int w = 0; w < WT_COUNT; ++w) {
gate_errors_t[w].Init(ns_, width, scratch);
}
// Used only if softmax_ != NULL.
NetworkScratch::FloatVec softmax_errors;
NetworkScratch::GradientStore softmax_errors_t;
if (softmax_ != NULL) {
softmax_errors.Init(no_, scratch);
softmax_errors_t.Init(no_, width, scratch);
}
double state_clip = Is2D() ? 9.0 : 4.0;
#if DEBUG_DETAIL > 1
tprintf("fwd_deltas:%s\n", name_.string());
fwd_deltas.Print(10);
#endif
StrideMap::Index dest_index(input_map_);
dest_index.InitToLast();
// Used only by NT_LSTM_SUMMARY.
StrideMap::Index src_index(fwd_deltas.stride_map());
src_index.InitToLast();
do {
int t = dest_index.t();
bool at_last_x = dest_index.IsLast(FD_WIDTH);
// up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
// valid if >= 0, which is true if 2d and not on the top/bottom.
int up_pos = -1;
int down_pos = -1;
if (Is2D()) {
if (dest_index.index(FD_HEIGHT) > 0) {
StrideMap::Index up_index(dest_index);
if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t();
}
if (!dest_index.IsLast(FD_HEIGHT)) {
StrideMap::Index down_index(dest_index);
if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t();
}
}
// Index of the 2-D revolving buffers (sourceerr, stateerr).
int mod_t = Modulo(t, buf_width); // Current timestep.
// Zero the state in the major direction only at the end of every row.
if (at_last_x) {
ZeroVector<double>(na_, curr_sourceerr);
ZeroVector<double>(ns_, curr_stateerr);
}
// Setup the outputerr.
if (type_ == NT_LSTM_SUMMARY) {
if (dest_index.IsLast(FD_WIDTH)) {
fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
src_index.Decrement();
} else {
ZeroVector<double>(ns_, outputerr);
}
} else if (softmax_ == NULL) {
fwd_deltas.ReadTimeStep(t, outputerr);
} else {
softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors,
softmax_errors_t.get(), outputerr);
}
if (!at_last_x)
AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
if (down_pos >= 0)
AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
// Apply the 1-d forget gates.
if (!at_last_x) {
const float* next_node_gf1 = node_values_[GF1].f(t + 1);
for (int i = 0; i < ns_; ++i) {
curr_stateerr[i] *= next_node_gf1[i];
}
}
if (Is2D() && t + 1 < width) {
for (int i = 0; i < ns_; ++i) {
if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
}
if (down_pos >= 0) {
const float* right_node_gfs = node_values_[GFS].f(down_pos);
const double* right_stateerr = stateerr[mod_t];
for (int i = 0; i < ns_; ++i) {
if (which_fg_[down_pos][i] == 2) {
curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
}
}
}
}
state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr,
curr_stateerr);
// Clip stateerr_ to a sane range.
ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
#if DEBUG_DETAIL > 1
if (t + 10 > width) {
tprintf("t=%d, stateerr=", t);
for (int i = 0; i < ns_; ++i)
tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i],
curr_sourceerr[ni_ + nf_ + i]);
tprintf("\n");
}
#endif
// Matrix multiply to get the source errors.
PARALLEL_IF_OPENMP(GFS)
// Cell inputs.
node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t,
curr_stateerr, gate_errors[CI]);
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
SECTION_IF_OPENMP
// Input Gates.
node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t,
curr_stateerr, gate_errors[GI]);
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
SECTION_IF_OPENMP
// 1-D forget Gates.
if (t > 0) {
node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
gate_errors[GF1]);
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1],
sourceerr_temps[GF1]);
} else {
memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
}
gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
// 2-D forget Gates.
if (up_pos >= 0) {
node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
gate_errors[GFS]);
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS],
sourceerr_temps[GFS]);
} else {
memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
}
if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
SECTION_IF_OPENMP
// Output gates.
state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr,
gate_errors[GO]);
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
END_PARALLEL_IF_OPENMP
SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
curr_sourceerr);
back_deltas->WriteTimeStep(t, curr_sourceerr);
// Save states for use by the 2nd dimension only if needed.
if (Is2D()) {
CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
}
} while (dest_index.Decrement());
#if DEBUG_DETAIL > 2
for (int w = 0; w < WT_COUNT; ++w) {
tprintf("%s gate errors[%d]\n", name_.string(), w);
gate_errors_t[w].get()->PrintUnTransposed(10);
}
#endif
// Transposed source_ used to speed-up SumOuter.
NetworkScratch::GradientStore source_t, state_t;
source_t.Init(na_, width, scratch);
source_.Transpose(source_t.get());
state_t.Init(ns_, width, scratch);
state_.Transpose(state_t.get());
#ifdef _OPENMP
#pragma omp parallel for num_threads(GFS) if (!Is2D())
#endif
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
}
if (softmax_ != NULL) {
softmax_->FinishBackward(*softmax_errors_t);
}
if (needs_to_backprop_) {
// Normalize the inputerr in back_deltas.
back_deltas->CopyWithNormalization(*back_deltas, fwd_deltas);
return true;
}
return false;
}
// Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the adagrad computation iff
// use_ada_grad_ is true.
void LSTM::Update(float learning_rate, float momentum, int num_samples) {
#if DEBUG_DETAIL > 3
PrintW();
#endif
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].Update(learning_rate, momentum, num_samples);
}
if (softmax_ != NULL) {
softmax_->Update(learning_rate, momentum, num_samples);
}
#if DEBUG_DETAIL > 3
PrintDW();
#endif
}
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// *changed.
void LSTM::CountAlternators(const Network& other, double* same,
double* changed) const {
ASSERT_HOST(other.type() == type_);
const LSTM* lstm = reinterpret_cast<const LSTM*>(&other);
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
}
if (softmax_ != NULL) {
softmax_->CountAlternators(*lstm->softmax_, same, changed);
}
}
// Prints the weights for debug purposes.
void LSTM::PrintW() {
tprintf("Weight state:%s\n", name_.string());
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
tprintf("Gate %d, inputs\n", w);
for (int i = 0; i < ni_; ++i) {
tprintf("Row %d:", i);
for (int s = 0; s < ns_; ++s)
tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
tprintf("\n");
}
tprintf("Gate %d, outputs\n", w);
for (int i = ni_; i < ni_ + ns_; ++i) {
tprintf("Row %d:", i - ni_);
for (int s = 0; s < ns_; ++s)
tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
tprintf("\n");
}
tprintf("Gate %d, bias\n", w);
for (int s = 0; s < ns_; ++s)
tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
tprintf("\n");
}
}
// Prints the weight deltas for debug purposes.
void LSTM::PrintDW() {
tprintf("Delta state:%s\n", name_.string());
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
tprintf("Gate %d, inputs\n", w);
for (int i = 0; i < ni_; ++i) {
tprintf("Row %d:", i);
for (int s = 0; s < ns_; ++s)
tprintf(" %g", gate_weights_[w].GetDW(s, i));
tprintf("\n");
}
tprintf("Gate %d, outputs\n", w);
for (int i = ni_; i < ni_ + ns_; ++i) {
tprintf("Row %d:", i - ni_);
for (int s = 0; s < ns_; ++s)
tprintf(" %g", gate_weights_[w].GetDW(s, i));
tprintf("\n");
}
tprintf("Gate %d, bias\n", w);
for (int s = 0; s < ns_; ++s)
tprintf(" %g", gate_weights_[w].GetDW(s, na_));
tprintf("\n");
}
}
// Resizes forward data to cope with an input image of the given width.
void LSTM::ResizeForward(const NetworkIO& input) {
source_.Resize(input, na_);
which_fg_.ResizeNoInit(input.Width(), ns_);
if (training_) {
state_.ResizeFloat(input, ns_);
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
node_values_[w].ResizeFloat(input, ns_);
}
}
}
} // namespace tesseract.

157
lstm/lstm.h Normal file
View File

@ -0,0 +1,157 @@
///////////////////////////////////////////////////////////////////////
// File: lstm.h
// Description: Long-term-short-term-memory Recurrent neural network.
// Author: Ray Smith
// Created: Wed May 01 17:33:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_LSTM_H_
#define TESSERACT_LSTM_LSTM_H_
#include "network.h"
#include "fullyconnected.h"
namespace tesseract {
// C++ Implementation of the LSTM class from lstm.py.
class LSTM : public Network {
public:
// Enum for the different weights in LSTM, to reduce some of the I/O and
// setup code to loops. The elements of the enum correspond to elements of an
// array of WeightMatrix or a corresponding array of NetworkIO.
enum WeightType {
CI, // Cell Inputs.
GI, // Gate at the input.
GF1, // Forget gate at the memory (1-d or looking back 1 timestep).
GO, // Gate at the output.
GFS, // Forget gate at the memory, looking back in the other dimension.
WT_COUNT // Number of WeightTypes.
};
// Constructor for NT_LSTM (regular 1 or 2-d LSTM), NT_LSTM_SOFTMAX (LSTM with
// additional softmax layer included and fed back into the input at the next
// timestep), or NT_LSTM_SOFTMAX_ENCODED (as LSTM_SOFTMAX, but the feedback
// is binary encoded instead of categorical) only.
// 2-d and bidi softmax LSTMs are not rejected, but are impossible to build
// in the conventional way because the output feedback both forwards and
// backwards in time does become impossible.
LSTM(const STRING& name, int num_inputs, int num_states, int num_outputs,
bool two_dimensional, NetworkType type);
virtual ~LSTM();
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
virtual StaticShape OutputShape(const StaticShape& input_shape) const;
virtual STRING spec() const {
STRING spec;
if (type_ == NT_LSTM)
spec.add_str_int("Lfx", ns_);
else if (type_ == NT_LSTM_SUMMARY)
spec.add_str_int("Lfxs", ns_);
else if (type_ == NT_LSTM_SOFTMAX)
spec.add_str_int("LS", ns_);
else if (type_ == NT_LSTM_SOFTMAX_ENCODED)
spec.add_str_int("LE", ns_);
if (softmax_ != NULL) spec += softmax_->spec();
return spec;
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
virtual int InitWeights(float range, TRand* randomizer);
// Converts a float network to an int network.
virtual void ConvertToInt();
// Provides debug output on the weights.
virtual void DebugWeights();
// Writes to the given file. Returns false in case of error.
virtual bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
virtual bool DeSerialize(bool swap, TFile* fp);
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output);
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas);
// Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the adagrad computation iff
// use_ada_grad_ is true.
virtual void Update(float learning_rate, float momentum, int num_samples);
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// *changed.
virtual void CountAlternators(const Network& other, double* same,
double* changed) const;
// Prints the weights for debug purposes.
void PrintW();
// Prints the weight deltas for debug purposes.
void PrintDW();
// Returns true of this is a 2-d lstm.
bool Is2D() const {
return is_2d_;
}
private:
// Resizes forward data to cope with an input image of the given width.
void ResizeForward(const NetworkIO& input);
private:
// Size of padded input to weight matrices = ni_ + no_ for 1-D operation
// and ni_ + 2 * no_ for 2-D operation. Note that there is a phantom 1 input
// for the bias that makes the weight matrices of size [na + 1][no].
inT32 na_;
// Number of internal states. Equal to no_ except for a softmax LSTM.
// ns_ is NOT serialized, but is calculated from gate_weights_.
inT32 ns_;
// Number of additional feedback states. The softmax types feed back
// additional output information on top of the ns_ internal states.
// In the case of a binary-coded (EMBEDDED) softmax, nf_ < no_.
inT32 nf_;
// Flag indicating 2-D operation.
bool is_2d_;
// Gate weight arrays of size [na + 1, no].
WeightMatrix gate_weights_[WT_COUNT];
// Used only if this is a softmax LSTM.
FullyConnected* softmax_;
// Input padded with previous output of size [width, na].
NetworkIO source_;
// Internal state used during forward operation, of size [width, ns].
NetworkIO state_;
// State of the 2-d maxpool, generated during forward, used during backward.
GENERIC_2D_ARRAY<inT8> which_fg_;
// Internal state saved from forward, but used only during backward.
NetworkIO node_values_[WT_COUNT];
// Preserved input stride_map used for Backward when NT_LSTM_SQUASHED.
StrideMap input_map_;
int input_width_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_LSTM_H_

816
lstm/lstmrecognizer.cpp Normal file
View File

@ -0,0 +1,816 @@
///////////////////////////////////////////////////////////////////////
// File: lstmrecognizer.cpp
// Description: Top-level line recognizer class for LSTM-based networks.
// Author: Ray Smith
// Created: Thu May 02 10:59:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "lstmrecognizer.h"
#include "allheaders.h"
#include "callcpp.h"
#include "dict.h"
#include "genericheap.h"
#include "helpers.h"
#include "imagedata.h"
#include "input.h"
#include "lstm.h"
#include "normalis.h"
#include "pageres.h"
#include "ratngs.h"
#include "recodebeam.h"
#include "scrollview.h"
#include "shapetable.h"
#include "statistc.h"
#include "tprintf.h"
namespace tesseract {
// Max number of blob choices to return in any given position.
const int kMaxChoices = 4;
// Default ratio between dict and non-dict words.
const double kDictRatio = 2.25;
// Default certainty offset to give the dictionary a chance.
const double kCertOffset = -0.085;
LSTMRecognizer::LSTMRecognizer()
: network_(NULL),
training_flags_(0),
training_iteration_(0),
sample_iteration_(0),
null_char_(UNICHAR_BROKEN),
weight_range_(0.0f),
learning_rate_(0.0f),
momentum_(0.0f),
dict_(NULL),
search_(NULL),
debug_win_(NULL) {}
LSTMRecognizer::~LSTMRecognizer() {
delete network_;
delete dict_;
delete search_;
}
// Writes to the given file. Returns false in case of error.
bool LSTMRecognizer::Serialize(TFile* fp) const {
if (!network_->Serialize(fp)) return false;
if (!GetUnicharset().save_to_file(fp)) return false;
if (!network_str_.Serialize(fp)) return false;
if (fp->FWrite(&training_flags_, sizeof(training_flags_), 1) != 1)
return false;
if (fp->FWrite(&training_iteration_, sizeof(training_iteration_), 1) != 1)
return false;
if (fp->FWrite(&sample_iteration_, sizeof(sample_iteration_), 1) != 1)
return false;
if (fp->FWrite(&null_char_, sizeof(null_char_), 1) != 1) return false;
if (fp->FWrite(&weight_range_, sizeof(weight_range_), 1) != 1) return false;
if (fp->FWrite(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false;
if (fp->FWrite(&momentum_, sizeof(momentum_), 1) != 1) return false;
if (IsRecoding() && !recoder_.Serialize(fp)) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool LSTMRecognizer::DeSerialize(bool swap, TFile* fp) {
delete network_;
network_ = Network::CreateFromFile(swap, fp);
if (network_ == NULL) return false;
if (!ccutil_.unicharset.load_from_file(fp, false)) return false;
if (!network_str_.DeSerialize(swap, fp)) return false;
if (fp->FRead(&training_flags_, sizeof(training_flags_), 1) != 1)
return false;
if (fp->FRead(&training_iteration_, sizeof(training_iteration_), 1) != 1)
return false;
if (fp->FRead(&sample_iteration_, sizeof(sample_iteration_), 1) != 1)
return false;
if (fp->FRead(&null_char_, sizeof(null_char_), 1) != 1) return false;
if (fp->FRead(&weight_range_, sizeof(weight_range_), 1) != 1) return false;
if (fp->FRead(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false;
if (fp->FRead(&momentum_, sizeof(momentum_), 1) != 1) return false;
if (IsRecoding()) {
if (!recoder_.DeSerialize(swap, fp)) return false;
RecodedCharID code;
recoder_.EncodeUnichar(UNICHAR_SPACE, &code);
if (code(0) != UNICHAR_SPACE) {
tprintf("Space was garbled in recoding!!\n");
return false;
}
}
// TODO(rays) swaps!
network_->SetRandomizer(&randomizer_);
network_->CacheXScaleFactor(network_->XScaleFactor());
return true;
}
// Loads the dictionary if possible from the traineddata file.
// Prints a warning message, and returns false but otherwise fails silently
// and continues to work without it if loading fails.
// Note that dictionary load is independent from DeSerialize, but dependent
// on the unicharset matching. This enables training to deserialize a model
// from checkpoint or restore without having to go back and reload the
// dictionary.
bool LSTMRecognizer::LoadDictionary(const char* data_file_name,
const char* lang) {
delete dict_;
dict_ = new Dict(&ccutil_);
dict_->SetupForLoad(Dict::GlobalDawgCache());
dict_->LoadLSTM(data_file_name, lang);
if (dict_->FinishLoad()) return true; // Success.
tprintf("Failed to load any lstm-specific dictionaries for lang %s!!\n",
lang);
delete dict_;
dict_ = NULL;
return false;
}
// Recognizes the line image, contained within image_data, returning the
// ratings matrix and matching box_word for each WERD_RES in the output.
void LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
bool debug, double worst_dict_cert,
bool use_alternates,
const UNICHARSET* target_unicharset,
const TBOX& line_box, float score_ratio,
bool one_word,
PointerVector<WERD_RES>* words) {
NetworkIO outputs;
float label_threshold = use_alternates ? 0.75f : 0.0f;
float scale_factor;
NetworkIO inputs;
if (!RecognizeLine(image_data, invert, debug, false, label_threshold,
&scale_factor, &inputs, &outputs))
return;
if (IsRecoding()) {
if (search_ == NULL) {
search_ =
new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
}
search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, NULL);
search_->ExtractBestPathAsWords(line_box, scale_factor, debug,
&GetUnicharset(), words);
} else {
GenericVector<int> label_coords;
GenericVector<int> labels;
LabelsFromOutputs(outputs, label_threshold, &labels, &label_coords);
WordsFromOutputs(outputs, labels, label_coords, line_box, debug,
use_alternates, one_word, score_ratio, scale_factor,
target_unicharset, words);
}
}
// Builds a set of tesseract-compatible WERD_RESs aligned to line_box,
// corresponding to the network output in outputs, labels, label_coords.
// one_word generates a single word output, that may include spaces inside.
// use_alternates generates alternative BLOB_CHOICEs and segmentation paths.
// If not NULL, we attempt to translate the output to target_unicharset, but do
// not guarantee success, due to mismatches. In that case the output words are
// marked with our UNICHARSET, not the caller's.
void LSTMRecognizer::WordsFromOutputs(
const NetworkIO& outputs, const GenericVector<int>& labels,
const GenericVector<int> label_coords, const TBOX& line_box, bool debug,
bool use_alternates, bool one_word, float score_ratio, float scale_factor,
const UNICHARSET* target_unicharset, PointerVector<WERD_RES>* words) {
// Convert labels to unichar-ids.
int word_end = 0;
float prev_space_cert = 0.0f;
for (int i = 0; i < labels.size(); i = word_end) {
word_end = i + 1;
if (labels[i] == null_char_ || labels[i] == UNICHAR_SPACE) {
continue;
}
float space_cert = 0.0f;
if (one_word) {
word_end = labels.size();
} else {
// Find the end of the word at the first null_char_ that leads to the
// first UNICHAR_SPACE.
while (word_end < labels.size() && labels[word_end] != UNICHAR_SPACE)
++word_end;
if (word_end < labels.size()) {
float rating;
outputs.ScoresOverRange(label_coords[word_end],
label_coords[word_end] + 1, UNICHAR_SPACE,
null_char_, &rating, &space_cert);
}
while (word_end > i && labels[word_end - 1] == null_char_) --word_end;
}
ASSERT_HOST(word_end > i);
// Create a WERD_RES for the output word.
if (debug)
tprintf("Creating word from outputs over [%d,%d)\n", i, word_end);
WERD_RES* word =
WordFromOutput(line_box, outputs, i, word_end, score_ratio,
MIN(prev_space_cert, space_cert), debug,
use_alternates && !SimpleTextOutput(), target_unicharset,
labels, label_coords, scale_factor);
if (word == NULL && target_unicharset != NULL) {
// Unicharset translation failed - use decoder_ instead, and disable
// the segmentation search on output, as it won't understand the encoding.
word = WordFromOutput(line_box, outputs, i, word_end, score_ratio,
MIN(prev_space_cert, space_cert), debug, false,
NULL, labels, label_coords, scale_factor);
}
prev_space_cert = space_cert;
words->push_back(word);
}
}
// Helper computes min and mean best results in the output.
void LSTMRecognizer::OutputStats(const NetworkIO& outputs, float* min_output,
float* mean_output, float* sd) {
const int kOutputScale = MAX_INT8;
STATS stats(0, kOutputScale + 1);
for (int t = 0; t < outputs.Width(); ++t) {
int best_label = outputs.BestLabel(t, NULL);
if (best_label != null_char_ || t == 0) {
float best_output = outputs.f(t)[best_label];
stats.add(static_cast<int>(kOutputScale * best_output), 1);
}
}
*min_output = static_cast<float>(stats.min_bucket()) / kOutputScale;
*mean_output = stats.mean() / kOutputScale;
*sd = stats.sd() / kOutputScale;
}
// Recognizes the image_data, returning the labels,
// scores, and corresponding pairs of start, end x-coords in coords.
// If label_threshold is positive, uses it for making the labels, otherwise
// uses standard ctc.
bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
bool debug, bool re_invert,
float label_threshold, float* scale_factor,
NetworkIO* inputs, NetworkIO* outputs) {
// Maximum width of image to train on.
const int kMaxImageWidth = 2048;
// This ensures consistent recognition results.
SetRandomSeed();
int min_width = network_->XScaleFactor();
Pix* pix = Input::PrepareLSTMInputs(image_data, network_, min_width,
&randomizer_, scale_factor);
if (pix == NULL) {
tprintf("Line cannot be recognized!!\n");
return false;
}
if (network_->training() && pixGetWidth(pix) > kMaxImageWidth) {
tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix),
pixGetHeight(pix));
pixDestroy(&pix);
return false;
}
// Reduction factor from image to coords.
*scale_factor = min_width / *scale_factor;
inputs->set_int_mode(IsIntMode());
SetRandomSeed();
Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, inputs);
network_->Forward(debug, *inputs, NULL, &scratch_space_, outputs);
// Check for auto inversion.
float pos_min, pos_mean, pos_sd;
OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd);
if (invert && pos_min < 0.5) {
// Run again inverted and see if it is any better.
float inv_scale;
NetworkIO inv_inputs, inv_outputs;
inv_inputs.set_int_mode(IsIntMode());
SetRandomSeed();
pixInvert(pix, pix);
Input::PreparePixInput(network_->InputShape(), pix, &randomizer_,
&inv_inputs);
network_->Forward(debug, inv_inputs, NULL, &scratch_space_, &inv_outputs);
float inv_min, inv_mean, inv_sd;
OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd);
if (inv_min > pos_min && inv_mean > pos_mean && inv_sd < pos_sd) {
// Inverted did better. Use inverted data.
if (debug) {
tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n",
pos_min, pos_mean, pos_sd, inv_min, inv_mean, inv_sd);
}
*outputs = inv_outputs;
*inputs = inv_inputs;
} else if (re_invert) {
// Inverting was not an improvement, so undo and run again, so the
// outputs match the best forward result.
SetRandomSeed();
network_->Forward(debug, *inputs, NULL, &scratch_space_, outputs);
}
}
pixDestroy(&pix);
if (debug) {
GenericVector<int> labels, coords;
LabelsFromOutputs(*outputs, label_threshold, &labels, &coords);
DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_);
DebugActivationPath(*outputs, labels, coords);
}
return true;
}
// Returns a tesseract-compatible WERD_RES from the line recognizer outputs.
// line_box should be the bounding box of the line image in the main image,
// outputs the output of the network,
// [word_start, word_end) the interval over which to convert,
// score_ratio for choosing alternate classifier choices,
// use_alternates to control generation of alternative segmentations,
// labels, label_coords, scale_factor from RecognizeLine above.
// If target_unicharset is not NULL, attempts to translate the internal
// unichar_ids to the target_unicharset, but falls back to untranslated ids
// if the translation should fail.
WERD_RES* LSTMRecognizer::WordFromOutput(
const TBOX& line_box, const NetworkIO& outputs, int word_start,
int word_end, float score_ratio, float space_certainty, bool debug,
bool use_alternates, const UNICHARSET* target_unicharset,
const GenericVector<int>& labels, const GenericVector<int>& label_coords,
float scale_factor) {
WERD_RES* word_res = InitializeWord(
line_box, word_start, word_end, space_certainty, use_alternates,
target_unicharset, labels, label_coords, scale_factor);
int max_blob_run = word_res->ratings->bandwidth();
for (int width = 1; width <= max_blob_run; ++width) {
int col = 0;
for (int i = word_start; i + width <= word_end; ++i) {
if (labels[i] != null_char_) {
// Starting at i, use width labels, but stop at the next null_char_.
// This forms all combinations of blobs between regions of null_char_.
int j = i + 1;
while (j - i < width && labels[j] != null_char_) ++j;
if (j - i == width) {
// Make the blob choices.
int end_coord = label_coords[j];
if (j < word_end && labels[j] == null_char_)
end_coord = label_coords[j + 1];
BLOB_CHOICE_LIST* choices = GetBlobChoices(
col, col + width - 1, debug, outputs, target_unicharset,
label_coords[i], end_coord, score_ratio);
if (choices == NULL) {
delete word_res;
return NULL;
}
word_res->ratings->put(col, col + width - 1, choices);
}
++col;
}
}
}
if (use_alternates) {
// Merge adjacent single results over null_char boundaries.
int col = 0;
for (int i = word_start; i + 2 < word_end; ++i) {
if (labels[i] != null_char_ && labels[i + 1] == null_char_ &&
labels[i + 2] != null_char_ &&
(i == word_start || labels[i - 1] == null_char_) &&
(i + 3 == word_end || labels[i + 3] == null_char_)) {
int end_coord = label_coords[i + 3];
if (i + 3 < word_end && labels[i + 3] == null_char_)
end_coord = label_coords[i + 4];
BLOB_CHOICE_LIST* choices =
GetBlobChoices(col, col + 1, debug, outputs, target_unicharset,
label_coords[i], end_coord, score_ratio);
if (choices == NULL) {
delete word_res;
return NULL;
}
word_res->ratings->put(col, col + 1, choices);
}
if (labels[i] != null_char_) ++col;
}
} else {
word_res->FakeWordFromRatings(TOP_CHOICE_PERM);
}
return word_res;
}
// Sets up a word with the ratings matrix and fake blobs with boxes in the
// right places.
WERD_RES* LSTMRecognizer::InitializeWord(const TBOX& line_box, int word_start,
int word_end, float space_certainty,
bool use_alternates,
const UNICHARSET* target_unicharset,
const GenericVector<int>& labels,
const GenericVector<int>& label_coords,
float scale_factor) {
// Make a fake blob for each non-zero label.
C_BLOB_LIST blobs;
C_BLOB_IT b_it(&blobs);
// num_blobs is the length of the diagonal of the ratings matrix.
int num_blobs = 0;
// max_blob_run is the diagonal width of the ratings matrix
int max_blob_run = 0;
int blob_run = 0;
for (int i = word_start; i < word_end; ++i) {
if (IsRecoding() && !recoder_.IsValidFirstCode(labels[i])) continue;
if (labels[i] != null_char_) {
// Make a fake blob.
TBOX box(label_coords[i], 0, label_coords[i + 1], line_box.height());
box.scale(scale_factor);
box.move(ICOORD(line_box.left(), line_box.bottom()));
box.set_top(line_box.top());
b_it.add_after_then_move(C_BLOB::FakeBlob(box));
++num_blobs;
++blob_run;
}
if (labels[i] == null_char_ || i + 1 == word_end) {
if (blob_run > max_blob_run)
max_blob_run = blob_run;
}
}
if (!use_alternates) max_blob_run = 1;
ASSERT_HOST(label_coords.size() >= word_end);
// Make a fake word from the blobs.
WERD* word = new WERD(&blobs, word_start > 1 ? 1 : 0, NULL);
// Make a WERD_RES from the word.
WERD_RES* word_res = new WERD_RES(word);
word_res->uch_set =
target_unicharset != NULL ? target_unicharset : &GetUnicharset();
word_res->combination = true; // Give it ownership of the word.
word_res->space_certainty = space_certainty;
word_res->ratings = new MATRIX(num_blobs, max_blob_run);
return word_res;
}
// Converts an array of labels to utf-8, whether or not the labels are
// augmented with character boundaries.
STRING LSTMRecognizer::DecodeLabels(const GenericVector<int>& labels) {
STRING result;
int end = 1;
for (int start = 0; start < labels.size(); start = end) {
if (labels[start] == null_char_) {
end = start + 1;
} else {
result += DecodeLabel(labels, start, &end, NULL);
}
}
return result;
}
// Displays the forward results in a window with the characters and
// boundaries as determined by the labels and label_coords.
void LSTMRecognizer::DisplayForward(const NetworkIO& inputs,
const GenericVector<int>& labels,
const GenericVector<int>& label_coords,
const char* window_name,
ScrollView** window) {
#ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
int x_scale = network_->XScaleFactor();
Pix* input_pix = inputs.ToPix();
Network::ClearWindow(false, window_name, pixGetWidth(input_pix),
pixGetHeight(input_pix), window);
int line_height = Network::DisplayImage(input_pix, *window);
DisplayLSTMOutput(labels, label_coords, line_height, *window);
#endif // GRAPHICS_DISABLED
}
// Displays the labels and cuts at the corresponding xcoords.
// Size of labels should match xcoords.
void LSTMRecognizer::DisplayLSTMOutput(const GenericVector<int>& labels,
const GenericVector<int>& xcoords,
int height, ScrollView* window) {
#ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
int x_scale = network_->XScaleFactor();
window->TextAttributes("Arial", height / 4, false, false, false);
int end = 1;
for (int start = 0; start < labels.size(); start = end) {
int xpos = xcoords[start] * x_scale;
if (labels[start] == null_char_) {
end = start + 1;
window->Pen(ScrollView::RED);
} else {
window->Pen(ScrollView::GREEN);
const char* str = DecodeLabel(labels, start, &end, NULL);
if (*str == '\\') str = "\\\\";
xpos = xcoords[(start + end) / 2] * x_scale;
window->Text(xpos, height, str);
}
window->Line(xpos, 0, xpos, height * 3 / 2);
}
window->Update();
#endif // GRAPHICS_DISABLED
}
// Prints debug output detailing the activation path that is implied by the
// label_coords.
void LSTMRecognizer::DebugActivationPath(const NetworkIO& outputs,
const GenericVector<int>& labels,
const GenericVector<int>& xcoords) {
if (xcoords[0] > 0)
DebugActivationRange(outputs, "<null>", null_char_, 0, xcoords[0]);
int end = 1;
for (int start = 0; start < labels.size(); start = end) {
if (labels[start] == null_char_) {
end = start + 1;
DebugActivationRange(outputs, "<null>", null_char_, xcoords[start],
xcoords[end]);
continue;
} else {
int decoded;
const char* label = DecodeLabel(labels, start, &end, &decoded);
DebugActivationRange(outputs, label, labels[start], xcoords[start],
xcoords[start + 1]);
for (int i = start + 1; i < end; ++i) {
DebugActivationRange(outputs, DecodeSingleLabel(labels[i]), labels[i],
xcoords[i], xcoords[i + 1]);
}
}
}
}
// Prints debug output detailing activations and 2nd choice over a range
// of positions.
void LSTMRecognizer::DebugActivationRange(const NetworkIO& outputs,
const char* label, int best_choice,
int x_start, int x_end) {
tprintf("%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end);
double max_score = 0.0;
double mean_score = 0.0;
int width = x_end - x_start;
for (int x = x_start; x < x_end; ++x) {
const float* line = outputs.f(x);
double score = line[best_choice] * 100.0;
if (score > max_score) max_score = score;
mean_score += score / width;
int best_c = 0;
double best_score = 0.0;
for (int c = 0; c < outputs.NumFeatures(); ++c) {
if (c != best_choice && line[c] > best_score) {
best_c = c;
best_score = line[c];
}
}
tprintf(" %.3g(%s=%d=%.3g)", score, DecodeSingleLabel(best_c), best_c,
best_score * 100.0);
}
tprintf(", Mean=%g, max=%g\n", mean_score, max_score);
}
// Helper returns true if the null_char is the winner at t, and it beats the
// null_threshold, or the next choice is space, in which case we will use the
// null anyway.
static bool NullIsBest(const NetworkIO& output, float null_thr,
int null_char, int t) {
if (output.f(t)[null_char] >= null_thr) return true;
if (output.BestLabel(t, null_char, null_char, NULL) != UNICHAR_SPACE)
return false;
return output.f(t)[null_char] > output.f(t)[UNICHAR_SPACE];
}
// Converts the network output to a sequence of labels. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The conversion method is determined by internal state.
void LSTMRecognizer::LabelsFromOutputs(const NetworkIO& outputs, float null_thr,
GenericVector<int>* labels,
GenericVector<int>* xcoords) {
if (SimpleTextOutput()) {
LabelsViaSimpleText(outputs, labels, xcoords);
} else if (IsRecoding()) {
LabelsViaReEncode(outputs, labels, xcoords);
} else if (null_thr <= 0.0) {
LabelsViaCTC(outputs, labels, xcoords);
} else {
LabelsViaThreshold(outputs, null_thr, labels, xcoords);
}
}
// Converts the network output to a sequence of labels, using a threshold
// on the null_char_ to determine character boundaries. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The label output is the one with the highest score in the interval between
// null_chars_.
void LSTMRecognizer::LabelsViaThreshold(const NetworkIO& output,
float null_thr,
GenericVector<int>* labels,
GenericVector<int>* xcoords) {
labels->truncate(0);
xcoords->truncate(0);
int width = output.Width();
int t = 0;
// Skip any initial non-char.
int label = null_char_;
while (t < width && NullIsBest(output, null_thr, null_char_, t)) {
++t;
}
while (t < width) {
ASSERT_HOST(!isnan(output.f(t)[null_char_]));
int label = output.BestLabel(t, null_char_, null_char_, NULL);
int char_start = t++;
while (t < width && !NullIsBest(output, null_thr, null_char_, t) &&
label == output.BestLabel(t, null_char_, null_char_, NULL)) {
++t;
}
int char_end = t;
labels->push_back(label);
xcoords->push_back(char_start);
// Find the end of the non-char, and compute its score.
while (t < width && NullIsBest(output, null_thr, null_char_, t)) {
++t;
}
if (t > char_end) {
labels->push_back(null_char_);
xcoords->push_back(char_end);
}
}
xcoords->push_back(width);
}
// Converts the network output to a sequence of labels, with scores and
// start x-coords of the character labels. Retains the null_char_ as the
// end x-coord, where already present, otherwise the start of the next
// character is the end.
// The number of labels, scores, and xcoords is always matched, except that
// there is always an additional xcoord for the last end position.
void LSTMRecognizer::LabelsViaCTC(const NetworkIO& output,
GenericVector<int>* labels,
GenericVector<int>* xcoords) {
labels->truncate(0);
xcoords->truncate(0);
int width = output.Width();
int t = 0;
while (t < width) {
float score = 0.0f;
int label = output.BestLabel(t, &score);
labels->push_back(label);
xcoords->push_back(t);
while (++t < width && output.BestLabel(t, NULL) == label) {
}
}
xcoords->push_back(width);
}
// As LabelsViaCTC except that this function constructs the best path that
// contains only legal sequences of subcodes for CJK.
void LSTMRecognizer::LabelsViaReEncode(const NetworkIO& output,
GenericVector<int>* labels,
GenericVector<int>* xcoords) {
if (search_ == NULL) {
search_ =
new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
}
search_->Decode(output, 1.0, 0.0, RecodeBeamSearch::kMinCertainty, NULL);
search_->ExtractBestPathAsLabels(labels, xcoords);
}
// Converts the network output to a sequence of labels, with scores, using
// the simple character model (each position is a char, and the null_char_ is
// mainly intended for tail padding.)
void LSTMRecognizer::LabelsViaSimpleText(const NetworkIO& output,
GenericVector<int>* labels,
GenericVector<int>* xcoords) {
labels->truncate(0);
xcoords->truncate(0);
int width = output.Width();
for (int t = 0; t < width; ++t) {
float score = 0.0f;
int label = output.BestLabel(t, &score);
if (label != null_char_) {
labels->push_back(label);
xcoords->push_back(t);
}
}
xcoords->push_back(width);
}
// Helper returns a BLOB_CHOICE_LIST for the choices in a given x-range.
// Handles either LSTM labels or direct unichar-ids.
// Score ratio determines the worst ratio between top choice and remainder.
// If target_unicharset is not NULL, attempts to translate to the target
// unicharset, returning NULL on failure.
BLOB_CHOICE_LIST* LSTMRecognizer::GetBlobChoices(
int col, int row, bool debug, const NetworkIO& output,
const UNICHARSET* target_unicharset, int x_start, int x_end,
float score_ratio) {
int width = x_end - x_start;
float rating = 0.0f, certainty = 0.0f;
int label = output.BestChoiceOverRange(x_start, x_end, UNICHAR_SPACE,
null_char_, &rating, &certainty);
int unichar_id = label == null_char_ ? UNICHAR_SPACE : label;
if (debug) {
tprintf("Best choice over range %d,%d=unichar%d=%s r = %g, cert=%g\n",
x_start, x_end, unichar_id, DecodeSingleLabel(label), rating,
certainty);
}
BLOB_CHOICE_LIST* choices = new BLOB_CHOICE_LIST;
BLOB_CHOICE_IT bc_it(choices);
if (!AddBlobChoices(unichar_id, rating, certainty, col, row,
target_unicharset, &bc_it)) {
delete choices;
return NULL;
}
// Get the other choices.
double best_cert = certainty;
for (int c = 0; c < output.NumFeatures(); ++c) {
if (c == label || c == UNICHAR_SPACE || c == null_char_) continue;
// Compute the score over the range.
output.ScoresOverRange(x_start, x_end, c, null_char_, &rating, &certainty);
int unichar_id = c == null_char_ ? UNICHAR_SPACE : c;
if (certainty >= best_cert - score_ratio &&
!AddBlobChoices(unichar_id, rating, certainty, col, row,
target_unicharset, &bc_it)) {
delete choices;
return NULL;
}
}
choices->sort(&BLOB_CHOICE::SortByRating);
if (bc_it.length() > kMaxChoices) {
bc_it.move_to_first();
for (int i = 0; i < kMaxChoices; ++i)
bc_it.forward();
while (!bc_it.at_first()) {
delete bc_it.extract();
bc_it.forward();
}
}
return choices;
}
// Adds to the given iterator, the blob choices for the target_unicharset
// that correspond to the given LSTM unichar_id.
// Returns false if unicharset translation failed.
bool LSTMRecognizer::AddBlobChoices(int unichar_id, float rating,
float certainty, int col, int row,
const UNICHARSET* target_unicharset,
BLOB_CHOICE_IT* bc_it) {
int target_id = unichar_id;
if (target_unicharset != NULL) {
const char* utf8 = GetUnicharset().id_to_unichar(unichar_id);
if (target_unicharset->contains_unichar(utf8)) {
target_id = target_unicharset->unichar_to_id(utf8);
} else {
return false;
}
}
BLOB_CHOICE* choice = new BLOB_CHOICE(target_id, rating, certainty, -1, 1.0f,
static_cast<float>(MAX_INT16), 0.0f,
BCC_STATIC_CLASSIFIER);
choice->set_matrix_cell(col, row);
bc_it->add_after_then_move(choice);
return true;
}
// Returns a string corresponding to the label starting at start. Sets *end
// to the next start and if non-null, *decoded to the unichar id.
const char* LSTMRecognizer::DecodeLabel(const GenericVector<int>& labels,
int start, int* end, int* decoded) {
*end = start + 1;
if (IsRecoding()) {
// Decode labels via recoder_.
RecodedCharID code;
if (labels[start] == null_char_) {
if (decoded != NULL) {
code.Set(0, null_char_);
*decoded = recoder_.DecodeUnichar(code);
}
return "<null>";
}
int index = start;
while (index < labels.size() &&
code.length() < RecodedCharID::kMaxCodeLen) {
code.Set(code.length(), labels[index++]);
while (index < labels.size() && labels[index] == null_char_) ++index;
int uni_id = recoder_.DecodeUnichar(code);
// If the next label isn't a valid first code, then we need to continue
// extending even if we have a valid uni_id from this prefix.
if (uni_id != INVALID_UNICHAR_ID &&
(index == labels.size() ||
code.length() == RecodedCharID::kMaxCodeLen ||
recoder_.IsValidFirstCode(labels[index]))) {
*end = index;
if (decoded != NULL) *decoded = uni_id;
if (uni_id == UNICHAR_SPACE) return " ";
return GetUnicharset().get_normed_unichar(uni_id);
}
}
return "<Undecodable>";
} else {
if (decoded != NULL) *decoded = labels[start];
if (labels[start] == null_char_) return "<null>";
if (labels[start] == UNICHAR_SPACE) return " ";
return GetUnicharset().get_normed_unichar(labels[start]);
}
}
// Returns a string corresponding to a given single label id, falling back to
// a default of ".." for part of a multi-label unichar-id.
const char* LSTMRecognizer::DecodeSingleLabel(int label) {
if (label == null_char_) return "<null>";
if (IsRecoding()) {
// Decode label via recoder_.
RecodedCharID code;
code.Set(0, label);
label = recoder_.DecodeUnichar(code);
if (label == INVALID_UNICHAR_ID) return ".."; // Part of a bigger code.
}
if (label == UNICHAR_SPACE) return " ";
return GetUnicharset().get_normed_unichar(label);
}
} // namespace tesseract.

392
lstm/lstmrecognizer.h Normal file
View File

@ -0,0 +1,392 @@
///////////////////////////////////////////////////////////////////////
// File: lstmrecognizer.h
// Description: Top-level line recognizer class for LSTM-based networks.
// Author: Ray Smith
// Created: Thu May 02 08:57:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
#define TESSERACT_LSTM_LSTMRECOGNIZER_H_
#include "ccutil.h"
#include "helpers.h"
#include "imagedata.h"
#include "matrix.h"
#include "network.h"
#include "networkscratch.h"
#include "recodebeam.h"
#include "series.h"
#include "strngs.h"
#include "unicharcompress.h"
class BLOB_CHOICE_IT;
struct Pix;
class ROW_RES;
class ScrollView;
class TBOX;
class WERD_RES;
namespace tesseract {
class Dict;
class ImageData;
// Enum indicating training mode control flags.
enum TrainingFlags {
TF_INT_MODE = 1,
TF_AUTO_HARDEN = 2,
TF_ROUND_ROBIN_TRAINING = 16,
TF_COMPRESS_UNICHARSET = 64,
};
// Top-level line recognizer class for LSTM-based networks.
// Note that a sub-class, LSTMTrainer is used for training.
class LSTMRecognizer {
public:
LSTMRecognizer();
~LSTMRecognizer();
int NumOutputs() const {
return network_->NumOutputs();
}
int training_iteration() const {
return training_iteration_;
}
int sample_iteration() const {
return sample_iteration_;
}
double learning_rate() const {
return learning_rate_;
}
bool IsHardening() const {
return (training_flags_ & TF_AUTO_HARDEN) != 0;
}
LossType OutputLossType() const {
if (network_ == nullptr) return LT_NONE;
StaticShape shape;
shape = network_->OutputShape(shape);
return shape.loss_type();
}
bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; }
bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; }
// True if recoder_ is active to re-encode text to a smaller space.
bool IsRecoding() const {
return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0;
}
// Returns the cache strategy for the DocumentCache.
CachingStrategy CacheStrategy() const {
return training_flags_ & TF_ROUND_ROBIN_TRAINING ? CS_ROUND_ROBIN
: CS_SEQUENTIAL;
}
// Returns true if the network is a TensorFlow network.
bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
// Returns a vector of layer ids that can be passed to other layer functions
// to access a specific layer.
GenericVector<STRING> EnumerateLayers() const {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
Series* series = reinterpret_cast<Series*>(network_);
GenericVector<STRING> layers;
series->EnumerateLayers(NULL, &layers);
return layers;
}
// Returns a specific layer from its id (from EnumerateLayers).
Network* GetLayer(const STRING& id) const {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
ASSERT_HOST(id.length() > 1 && id[0] == ':');
Series* series = reinterpret_cast<Series*>(network_);
return series->GetLayer(&id[1]);
}
// Returns the learning rate of the layer from its id.
float GetLayerLearningRate(const STRING& id) const {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
ASSERT_HOST(id.length() > 1 && id[0] == ':');
Series* series = reinterpret_cast<Series*>(network_);
return series->LayerLearningRate(&id[1]);
} else {
return learning_rate_;
}
}
// Multiplies the all the learning rate(s) by the given factor.
void ScaleLearningRate(double factor) {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
learning_rate_ *= factor;
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
GenericVector<STRING> layers = EnumerateLayers();
for (int i = 0; i < layers.size(); ++i) {
ScaleLayerLearningRate(layers[i], factor);
}
}
}
// Multiplies the learning rate of the layer with id, by the given factor.
void ScaleLayerLearningRate(const STRING& id, double factor) {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
ASSERT_HOST(id.length() > 1 && id[0] == ':');
Series* series = reinterpret_cast<Series*>(network_);
series->ScaleLayerLearningRate(&id[1], factor);
}
// True if the network is using adagrad to train.
bool IsUsingAdaGrad() const { return network_->TestFlag(NF_ADA_GRAD); }
// Provides access to the UNICHARSET that this classifier works with.
const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
// Sets the sample iteration to the given value. The sample_iteration_
// determines the seed for the random number generator. The training
// iteration is incremented only by a successful training iteration.
void SetIteration(int iteration) {
sample_iteration_ = iteration;
}
// Accessors for textline image normalization.
int NumInputs() const {
return network_->NumInputs();
}
int null_char() const { return null_char_; }
// Writes to the given file. Returns false in case of error.
bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, TFile* fp);
// Loads the dictionary if possible from the traineddata file.
// Prints a warning message, and returns false but otherwise fails silently
// and continues to work without it if loading fails.
// Note that dictionary load is independent from DeSerialize, but dependent
// on the unicharset matching. This enables training to deserialize a model
// from checkpoint or restore without having to go back and reload the
// dictionary.
bool LoadDictionary(const char* data_file_name, const char* lang);
// Recognizes the line image, contained within image_data, returning the
// ratings matrix and matching box_word for each WERD_RES in the output.
// If invert, tries inverted as well if the normal interpretation doesn't
// produce a good enough result. If use_alternates, the ratings matrix is
// filled with segmentation and classifier alternatives that may be searched
// using the standard beam search, otherwise, just a diagonal and prebuilt
// best_choice. The line_box is used for computing the box_word in the
// output words. Score_ratio is used to determine the classifier alternates.
// If one_word, then a single WERD_RES is formed, regardless of the spaces
// found during recognition.
// If not NULL, we attempt to translate the output to target_unicharset, but
// do not guarantee success, due to mismatches. In that case the output words
// are marked with our UNICHARSET, not the caller's.
void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
double worst_dict_cert, bool use_alternates,
const UNICHARSET* target_unicharset, const TBOX& line_box,
float score_ratio, bool one_word,
PointerVector<WERD_RES>* words);
// Builds a set of tesseract-compatible WERD_RESs aligned to line_box,
// corresponding to the network output in outputs, labels, label_coords.
// one_word generates a single word output, that may include spaces inside.
// use_alternates generates alternative BLOB_CHOICEs and segmentation paths,
// with cut-offs determined by scale_factor.
// If not NULL, we attempt to translate the output to target_unicharset, but
// do not guarantee success, due to mismatches. In that case the output words
// are marked with our UNICHARSET, not the caller's.
void WordsFromOutputs(const NetworkIO& outputs,
const GenericVector<int>& labels,
const GenericVector<int> label_coords,
const TBOX& line_box, bool debug, bool use_alternates,
bool one_word, float score_ratio, float scale_factor,
const UNICHARSET* target_unicharset,
PointerVector<WERD_RES>* words);
// Helper computes min and mean best results in the output.
void OutputStats(const NetworkIO& outputs,
float* min_output, float* mean_output, float* sd);
// Recognizes the image_data, returning the labels,
// scores, and corresponding pairs of start, end x-coords in coords.
// If label_threshold is positive, uses it for making the labels, otherwise
// uses standard ctc. Returned in scale_factor is the reduction factor
// between the image and the output coords, for computing bounding boxes.
// If re_invert is true, the input is inverted back to its orginal
// photometric interpretation if inversion is attempted but fails to
// improve the results. This ensures that outputs contains the correct
// forward outputs for the best photometric interpretation.
// inputs is filled with the used inputs to the network, and if not null,
// target boxes is filled with scaled truth boxes if present in image_data.
bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
bool re_invert, float label_threshold, float* scale_factor,
NetworkIO* inputs, NetworkIO* outputs);
// Returns a tesseract-compatible WERD_RES from the line recognizer outputs.
// line_box should be the bounding box of the line image in the main image,
// outputs the output of the network,
// [word_start, word_end) the interval over which to convert,
// score_ratio for choosing alternate classifier choices,
// use_alternates to control generation of alternative segmentations,
// labels, label_coords, scale_factor from RecognizeLine above.
// If target_unicharset is not NULL, attempts to translate the internal
// unichar_ids to the target_unicharset, but falls back to untranslated ids
// if the translation should fail.
WERD_RES* WordFromOutput(const TBOX& line_box, const NetworkIO& outputs,
int word_start, int word_end, float score_ratio,
float space_certainty, bool debug,
bool use_alternates,
const UNICHARSET* target_unicharset,
const GenericVector<int>& labels,
const GenericVector<int>& label_coords,
float scale_factor);
// Sets up a word with the ratings matrix and fake blobs with boxes in the
// right places.
WERD_RES* InitializeWord(const TBOX& line_box, int word_start, int word_end,
float space_certainty, bool use_alternates,
const UNICHARSET* target_unicharset,
const GenericVector<int>& labels,
const GenericVector<int>& label_coords,
float scale_factor);
// Converts an array of labels to utf-8, whether or not the labels are
// augmented with character boundaries.
STRING DecodeLabels(const GenericVector<int>& labels);
// Displays the forward results in a window with the characters and
// boundaries as determined by the labels and label_coords.
void DisplayForward(const NetworkIO& inputs,
const GenericVector<int>& labels,
const GenericVector<int>& label_coords,
const char* window_name,
ScrollView** window);
protected:
// Sets the random seed from the sample_iteration_;
void SetRandomSeed() {
inT64 seed = static_cast<inT64>(sample_iteration_) * 0x10000001;
randomizer_.set_seed(seed);
randomizer_.IntRand();
}
// Displays the labels and cuts at the corresponding xcoords.
// Size of labels should match xcoords.
void DisplayLSTMOutput(const GenericVector<int>& labels,
const GenericVector<int>& xcoords,
int height, ScrollView* window);
// Prints debug output detailing the activation path that is implied by the
// xcoords.
void DebugActivationPath(const NetworkIO& outputs,
const GenericVector<int>& labels,
const GenericVector<int>& xcoords);
// Prints debug output detailing activations and 2nd choice over a range
// of positions.
void DebugActivationRange(const NetworkIO& outputs, const char* label,
int best_choice, int x_start, int x_end);
// Converts the network output to a sequence of labels. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The conversion method is determined by internal state.
void LabelsFromOutputs(const NetworkIO& outputs, float null_thr,
GenericVector<int>* labels,
GenericVector<int>* xcoords);
// Converts the network output to a sequence of labels, using a threshold
// on the null_char_ to determine character boundaries. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The label output is the one with the highest score in the interval between
// null_chars_.
void LabelsViaThreshold(const NetworkIO& output,
float null_threshold,
GenericVector<int>* labels,
GenericVector<int>* xcoords);
// Converts the network output to a sequence of labels, with scores and
// start x-coords of the character labels. Retains the null_char_ character as
// the end x-coord, where already present, otherwise the start of the next
// character is the end.
// The number of labels, scores, and xcoords is always matched, except that
// there is always an additional xcoord for the last end position.
void LabelsViaCTC(const NetworkIO& output,
GenericVector<int>* labels,
GenericVector<int>* xcoords);
// As LabelsViaCTC except that this function constructs the best path that
// contains only legal sequences of subcodes for recoder_.
void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
GenericVector<int>* xcoords);
// Converts the network output to a sequence of labels, with scores, using
// the simple character model (each position is a char, and the null_char_ is
// mainly intended for tail padding.)
void LabelsViaSimpleText(const NetworkIO& output,
GenericVector<int>* labels,
GenericVector<int>* xcoords);
// Helper returns a BLOB_CHOICE_LIST for the choices in a given x-range.
// Handles either LSTM labels or direct unichar-ids.
// Score ratio determines the worst ratio between top choice and remainder.
// If target_unicharset is not NULL, attempts to translate to the target
// unicharset, returning NULL on failure.
BLOB_CHOICE_LIST* GetBlobChoices(int col, int row, bool debug,
const NetworkIO& output,
const UNICHARSET* target_unicharset,
int x_start, int x_end, float score_ratio);
// Adds to the given iterator, the blob choices for the target_unicharset
// that correspond to the given LSTM unichar_id.
// Returns false if unicharset translation failed.
bool AddBlobChoices(int unichar_id, float rating, float certainty, int col,
int row, const UNICHARSET* target_unicharset,
BLOB_CHOICE_IT* bc_it);
// Returns a string corresponding to the label starting at start. Sets *end
// to the next start and if non-null, *decoded to the unichar id.
const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
int* decoded);
// Returns a string corresponding to a given single label id, falling back to
// a default of ".." for part of a multi-label unichar-id.
const char* DecodeSingleLabel(int label);
protected:
// The network hierarchy.
Network* network_;
// The unicharset. Only the unicharset element is serialized.
// Has to be a CCUtil, so Dict can point to it.
CCUtil ccutil_;
// For backward compatability, recoder_ is serialized iff
// training_flags_ & TF_COMPRESS_UNICHARSET.
// Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
UnicharCompress recoder_;
// ==Training parameters that are serialized to provide a record of them.==
STRING network_str_;
// Flags used to determine the training method of the network.
// See enum TrainingFlags above.
inT32 training_flags_;
// Number of actual backward training steps used.
inT32 training_iteration_;
// Index into training sample set. sample_iteration >= training_iteration_.
inT32 sample_iteration_;
// Index in softmax of null character. May take the value UNICHAR_BROKEN or
// ccutil_.unicharset.size().
inT32 null_char_;
// Range used for the initial random numbers in the weights.
float weight_range_;
// Learning rate and momentum multipliers of deltas in backprop.
float learning_rate_;
float momentum_;
// === NOT SERIALIZED.
TRand randomizer_;
NetworkScratch scratch_space_;
// Language model (optional) to use with the beam search.
Dict* dict_;
// Beam search held between uses to optimize memory allocation/use.
RecodeBeamSearch* search_;
// == Debugging parameters.==
// Recognition debug display window.
ScrollView* debug_win_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_

1331
lstm/lstmtrainer.cpp Normal file

File diff suppressed because it is too large Load Diff

477
lstm/lstmtrainer.h Normal file
View File

@ -0,0 +1,477 @@
///////////////////////////////////////////////////////////////////////
// File: lstmtrainer.h
// Description: Top-level line trainer class for LSTM-based networks.
// Author: Ray Smith
// Created: Fri May 03 09:07:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_LSTMTRAINER_H_
#define TESSERACT_LSTM_LSTMTRAINER_H_
#include "imagedata.h"
#include "lstmrecognizer.h"
#include "rect.h"
#include "tesscallback.h"
namespace tesseract {
class LSTM;
class LSTMTrainer;
class Parallel;
class Reversed;
class Softmax;
class Series;
// Enum for the types of errors that are counted.
enum ErrorTypes {
ET_RMS, // RMS activation error.
ET_DELTA, // Number of big errors in deltas.
ET_WORD_RECERR, // Output text string word recall error.
ET_CHAR_ERROR, // Output text string total char error.
ET_SKIP_RATIO, // Fraction of samples skipped.
ET_COUNT // For array sizing.
};
// Enum for the trainability_ flags.
enum Trainability {
TRAINABLE, // Non-zero delta error.
PERFECT, // Zero delta error.
UNENCODABLE, // Not trainable due to coding/alignment trouble.
HI_PRECISION_ERR, // Hi confidence disagreement.
NOT_BOXED, // Early in training and has no character boxes.
};
// Enum to define the amount of data to get serialized.
enum SerializeAmount {
LIGHT, // Minimal data for remote training.
NO_BEST_TRAINER, // Save an empty vector in place of best_trainer_.
FULL, // All data including best_trainer_.
};
// Enum to indicate how the sub_trainer_ training went.
enum SubTrainerResult {
STR_NONE, // Did nothing as not good enough.
STR_UPDATED, // Subtrainer was updated, but didn't replace *this.
STR_REPLACED // Subtrainer replaced *this.
};
class LSTMTrainer;
// Function to restore the trainer state from a given checkpoint.
// Returns false on failure.
typedef TessResultCallback2<bool, const GenericVector<char>&, LSTMTrainer*>*
CheckPointReader;
// Function to save a checkpoint of the current trainer state.
// Returns false on failure. SerializeAmount determines the amount of the
// trainer to serialize, typically used for saving the best state.
typedef TessResultCallback3<bool, SerializeAmount, const LSTMTrainer*,
GenericVector<char>*>* CheckPointWriter;
// Function to compute and record error rates on some external test set(s).
// Args are: iteration, mean errors, model, training stage.
// Returns a STRING containing logging information about the tests.
typedef TessResultCallback4<STRING, int, const double*,
const GenericVector<char>&, int>* TestCallback;
// Trainer class for LSTM networks. Most of the effort is in creating the
// ideal target outputs from the transcription. A box file is used if it is
// available, otherwise estimates of the char widths from the unicharset are
// used to guide a DP search for the best fit to the transcription.
class LSTMTrainer : public LSTMRecognizer {
public:
LSTMTrainer();
// Callbacks may be null, in which case defaults are used.
LSTMTrainer(FileReader file_reader, FileWriter file_writer,
CheckPointReader checkpoint_reader,
CheckPointWriter checkpoint_writer,
const char* model_base, const char* checkpoint_name,
int debug_interval, inT64 max_memory);
virtual ~LSTMTrainer();
// Tries to deserialize a trainer from the given file and silently returns
// false in case of failure.
bool TryLoadingCheckpoint(const char* filename);
// Initializes the character set encode/decode mechanism.
// train_flags control training behavior according to the TrainingFlags
// enum, including character set encoding.
// script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided,
// fully initializes the unicharset from the universal unicharsets.
// Note: Call before InitNetwork!
void InitCharSet(const UNICHARSET& unicharset, const STRING& script_dir,
int train_flags);
// Initializes the character set encode/decode mechanism directly from a
// previously setup UNICHARSET and UnicharCompress.
// ctc_mode controls how the truth text is mapped to the network targets.
// Note: Call before InitNetwork!
void InitCharSet(const UNICHARSET& unicharset, const UnicharCompress recoder);
// Initializes the trainer with a network_spec in the network description
// net_flags control network behavior according to the NetworkFlags enum.
// There isn't really much difference between them - only where the effects
// are implemented.
// For other args see NetworkBuilder::InitNetwork.
// Note: Be sure to call InitCharSet before InitNetwork!
bool InitNetwork(const STRING& network_spec, int append_index, int net_flags,
float weight_range, float learning_rate, float momentum);
// Initializes a trainer from a serialized TFNetworkModel proto.
// Returns the global step of TensorFlow graph or 0 if failed.
// Building a compatible TF graph: See tfnetwork.proto.
int InitTensorFlowNetwork(const std::string& tf_proto);
// Accessors.
double ActivationError() const {
return error_rates_[ET_DELTA];
}
double CharError() const { return error_rates_[ET_CHAR_ERROR]; }
const double* error_rates() const {
return error_rates_;
}
double best_error_rate() const {
return best_error_rate_;
}
int best_iteration() const {
return best_iteration_;
}
int learning_iteration() const { return learning_iteration_; }
int improvement_steps() const { return improvement_steps_; }
void set_perfect_delay(int delay) { perfect_delay_ = delay; }
const GenericVector<char>& best_trainer() const { return best_trainer_; }
// Returns the error that was just calculated by PrepareForBackward.
double NewSingleError(ErrorTypes type) const {
return error_buffers_[type][training_iteration() % kRollingBufferSize_];
}
// Returns the error that was just calculated by TrainOnLine. Since
// TrainOnLine rolls the error buffers, this is one further back than
// NewSingleError.
double LastSingleError(ErrorTypes type) const {
return error_buffers_[type]
[(training_iteration() + kRollingBufferSize_ - 1) %
kRollingBufferSize_];
}
const DocumentCache& training_data() const {
return training_data_;
}
DocumentCache* mutable_training_data() { return &training_data_; }
// If the training sample is usable, grid searches for the optimal
// dict_ratio/cert_offset, and returns the results in a string of space-
// separated triplets of ratio,offset=worderr.
Trainability GridSearchDictParams(
const ImageData* trainingdata, int iteration, double min_dict_ratio,
double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
double cert_offset_step, double max_cert_offset, STRING* results);
void SetSerializeMode(SerializeAmount serialize_amount) const {
serialize_amount_ = serialize_amount;
}
// Provides output on the distribution of weight values.
void DebugNetwork();
// Loads a set of lstmf files that were created using the lstm.train config to
// tesseract into memory ready for training. Returns false if nothing was
// loaded.
bool LoadAllTrainingData(const GenericVector<STRING>& filenames);
// Keeps track of best and locally worst error rate, using internally computed
// values. See MaintainCheckpointsSpecific for more detail.
bool MaintainCheckpoints(TestCallback tester, STRING* log_msg);
// Keeps track of best and locally worst error_rate (whatever it is) and
// launches tests using rec_model, when a new min or max is reached.
// Writes checkpoints using train_model at appropriate times and builds and
// returns a log message to indicate progress. Returns false if nothing
// interesting happened.
bool MaintainCheckpointsSpecific(int iteration,
const GenericVector<char>* train_model,
const GenericVector<char>* rec_model,
TestCallback tester, STRING* log_msg);
// Builds a string containing a progress message with current error rates.
void PrepareLogMsg(STRING* log_msg) const;
// Appends <intro_str> iteration learning_iteration()/training_iteration()/
// sample_iteration() to the log_msg.
void LogIterations(const char* intro_str, STRING* log_msg) const;
// TODO(rays) Add curriculum learning.
// Returns true and increments the training_stage_ if the error rate has just
// passed through the given threshold for the first time.
bool TransitionTrainingStage(float error_threshold);
// Returns the current training stage.
int CurrentTrainingStage() const { return training_stage_; }
// Writes to the given file. Returns false in case of error.
virtual bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
virtual bool DeSerialize(bool swap, TFile* fp);
// De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
// learning rates (by scaling reduction, or layer specific, according to
// NF_LAYER_SPECIFIC_LR).
void StartSubtrainer(STRING* log_msg);
// While the sub_trainer_ is behind the current training iteration and its
// training error is at least kSubTrainerMarginFraction better than the
// current training error, trains the sub_trainer_, and returns STR_UPDATED if
// it did anything. If it catches up, and has a better error rate than the
// current best, as well as a margin over the current error rate, then the
// trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
// returned. STR_NONE is returned if the subtrainer wasn't good enough to
// receive any training iterations.
SubTrainerResult UpdateSubtrainer(STRING* log_msg);
// Reduces network learning rates, either for everything, or for layers
// independently, according to NF_LAYER_SPECIFIC_LR.
void ReduceLearningRates(LSTMTrainer* samples_trainer, STRING* log_msg);
// Considers reducing the learning rate independently for each layer down by
// factor(<1), or leaving it the same, by double-training the given number of
// samples and minimizing the amount of changing of sign of weight updates.
// Even if it looks like all weights should remain the same, an adjustment
// will be made to guarantee a different result when reverting to an old best.
// Returns the number of layer learning rates that were reduced.
int ReduceLayerLearningRates(double factor, int num_samples,
LSTMTrainer* samples_trainer);
// Converts the string to integer class labels, with appropriate null_char_s
// in between if not in SimpleTextOutput mode. Returns false on failure.
bool EncodeString(const STRING& str, GenericVector<int>* labels) const {
return EncodeString(str, GetUnicharset(), IsRecoding() ? &recoder_ : NULL,
SimpleTextOutput(), null_char_, labels);
}
// Static version operates on supplied unicharset, encoder, simple_text.
static bool EncodeString(const STRING& str, const UNICHARSET& unicharset,
const UnicharCompress* recoder, bool simple_text,
int null_char, GenericVector<int>* labels);
// Converts the network to int if not already.
void ConvertToInt() {
if ((training_flags_ & TF_INT_MODE) == 0) {
network_->ConvertToInt();
training_flags_ |= TF_INT_MODE;
}
}
// Performs forward-backward on the given trainingdata.
// Returns the sample that was used or NULL if the next sample was deemed
// unusable. samples_trainer could be this or an alternative trainer that
// holds the training samples.
const ImageData* TrainOnLine(LSTMTrainer* samples_trainer, bool batch) {
int sample_index = sample_iteration();
const ImageData* image =
samples_trainer->training_data_.GetPageBySerial(sample_index);
if (image != NULL) {
Trainability trainable = TrainOnLine(image, batch);
if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
return NULL; // Sample was unusable.
}
} else {
++sample_iteration_;
}
return image;
}
Trainability TrainOnLine(const ImageData* trainingdata, bool batch);
// Prepares the ground truth, runs forward, and prepares the targets.
// Returns a Trainability enum to indicate the suitability of the sample.
Trainability PrepareForBackward(const ImageData* trainingdata,
NetworkIO* fwd_outputs, NetworkIO* targets);
// Writes the trainer to memory, so that the current training state can be
// restored.
bool SaveTrainingDump(SerializeAmount serialize_amount,
const LSTMTrainer* trainer,
GenericVector<char>* data) const;
// Reads previously saved trainer from memory.
bool ReadTrainingDump(const GenericVector<char>& data, LSTMTrainer* trainer);
bool ReadSizedTrainingDump(const char* data, int size);
// Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
void SetupCheckpointInfo();
// Writes the recognizer to memory, so that it can be used for testing later.
void SaveRecognitionDump(GenericVector<char>* data) const;
// Reads and returns a previously saved recognizer from memory.
static LSTMRecognizer* ReadRecognitionDump(const GenericVector<char>& data);
// Writes current best model to a file, unless it has already been written.
bool SaveBestModel(FileWriter writer) const;
// Returns a suitable filename for a training dump, based on the model_base_,
// the iteration and the error rates.
STRING DumpFilename() const;
// Fills the whole error buffer of the given type with the given value.
void FillErrorBuffer(double new_error, ErrorTypes type);
protected:
// Factored sub-constructor sets up reasonable default values.
void EmptyConstructor();
// Sets the unicharset properties using the given script_dir as a source of
// script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets
// up the recoder_ to simplify the unicharset.
void SetUnicharsetProperties(const STRING& script_dir);
// Outputs the string and periodically displays the given network inputs
// as an image in the given window, and the corresponding labels at the
// corresponding x_starts.
// Returns false if the truth string is empty.
bool DebugLSTMTraining(const NetworkIO& inputs,
const ImageData& trainingdata,
const NetworkIO& fwd_outputs,
const GenericVector<int>& truth_labels,
const NetworkIO& outputs);
// Displays the network targets as line a line graph.
void DisplayTargets(const NetworkIO& targets, const char* window_name,
ScrollView** window);
// Builds a no-compromises target where the first positions should be the
// truth labels and the rest is padded with the null_char_.
bool ComputeTextTargets(const NetworkIO& outputs,
const GenericVector<int>& truth_labels,
NetworkIO* targets);
// Builds a target using standard CTC. truth_labels should be pre-padded with
// nulls wherever desired. They don't have to be between all labels.
// outputs is input-output, as it gets clipped to minimum probability.
bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
NetworkIO* outputs, NetworkIO* targets);
// Computes network errors, and stores the results in the rolling buffers,
// along with the supplied text_error.
// Returns the delta error of the current sample (not running average.)
double ComputeErrorRates(const NetworkIO& deltas, double char_error,
double word_error);
// Computes the network activation RMS error rate.
double ComputeRMSError(const NetworkIO& deltas);
// Computes network activation winner error rate. (Number of values that are
// in error by >= 0.5 divided by number of time-steps.) More closely related
// to final character error than RMS, but still directly calculable from
// just the deltas. Because of the binary nature of the targets, zero winner
// error is a sufficient but not necessary condition for zero char error.
double ComputeWinnerError(const NetworkIO& deltas);
// Computes a very simple bag of chars char error rate.
double ComputeCharError(const GenericVector<int>& truth_str,
const GenericVector<int>& ocr_str);
// Computes a very simple bag of words word recall error rate.
// NOTE that this is destructive on both input strings.
double ComputeWordError(STRING* truth_str, STRING* ocr_str);
// Updates the error buffer and corresponding mean of the given type with
// the new_error.
void UpdateErrorBuffer(double new_error, ErrorTypes type);
// Rolls error buffers and reports the current means.
void RollErrorBuffers();
// Given that error_rate is either a new min or max, updates the best/worst
// error rates, and record of progress.
STRING UpdateErrorGraph(int iteration, double error_rate,
const GenericVector<char>& model_data,
TestCallback tester);
protected:
// Alignment display window.
ScrollView* align_win_;
// CTC target display window.
ScrollView* target_win_;
// CTC output display window.
ScrollView* ctc_win_;
// Reconstructed image window.
ScrollView* recon_win_;
// How often to display a debug image.
int debug_interval_;
// Iteration at which the last checkpoint was dumped.
int checkpoint_iteration_;
// Basename of files to save best models to.
STRING model_base_;
// Checkpoint filename.
STRING checkpoint_name_;
// Training data.
DocumentCache training_data_;
// A hack to serialize less data for batch training and record file version.
mutable SerializeAmount serialize_amount_;
// Name to use when saving best_trainer_.
STRING best_model_name_;
// Number of available training stages.
int num_training_stages_;
// Checkpointing callbacks.
FileReader file_reader_;
FileWriter file_writer_;
// TODO(rays) These are pointers, and must be deleted. Switch to unique_ptr
// when we can commit to c++11.
CheckPointReader checkpoint_reader_;
CheckPointWriter checkpoint_writer_;
// ===Serialized data to ensure that a restart produces the same results.===
// These members are only serialized when serialize_amount_ != LIGHT.
// Best error rate so far.
double best_error_rate_;
// Snapshot of all error rates at best_iteration_.
double best_error_rates_[ET_COUNT];
// Iteration of best_error_rate_.
int best_iteration_;
// Worst error rate since best_error_rate_.
double worst_error_rate_;
// Snapshot of all error rates at worst_iteration_.
double worst_error_rates_[ET_COUNT];
// Iteration of worst_error_rate_.
int worst_iteration_;
// Iteration at which the process will be thought stalled.
int stall_iteration_;
// Saved recognition models for computing test error for graph points.
GenericVector<char> best_model_data_;
GenericVector<char> worst_model_data_;
// Saved trainer for reverting back to last known best.
GenericVector<char> best_trainer_;
// A subsidiary trainer running with a different learning rate until either
// *this or sub_trainer_ hits a new best.
LSTMTrainer* sub_trainer_;
// Error rate at which last best model was dumped.
float error_rate_of_last_saved_best_;
// Current stage of training.
int training_stage_;
// History of best error rate against iteration. Used for computing the
// number of steps to each 2% improvement.
GenericVector<double> best_error_history_;
GenericVector<int> best_error_iterations_;
// Number of iterations since the best_error_rate_ was 2% more than it is now.
int improvement_steps_;
// Number of iterations that yielded a non-zero delta error and thus provided
// significant learning. learning_iteration_ <= training_iteration_.
// learning_iteration_ is used to measure rate of learning progress.
int learning_iteration_;
// Saved value of sample_iteration_ before looking for the the next sample.
int prev_sample_iteration_;
// How often to include a PERFECT training sample in backprop.
// A PERFECT training sample is used if the current
// training_iteration_ > last_perfect_training_iteration_ + perfect_delay_,
// so with perfect_delay_ == 0, all samples are used, and with
// perfect_delay_ == 4, at most 1 in 5 samples will be perfect.
int perfect_delay_;
// Value of training_iteration_ at which the last PERFECT training sample
// was used in back prop.
int last_perfect_training_iteration_;
// Rolling buffers storing recent training errors are indexed by
// training_iteration % kRollingBufferSize_.
static const int kRollingBufferSize_ = 1000;
GenericVector<double> error_buffers_[ET_COUNT];
// Rounded mean percent trailing training errors in the buffers.
double error_rates_[ET_COUNT]; // RMS training error.
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_LSTMTRAINER_H_

87
lstm/maxpool.cpp Normal file
View File

@ -0,0 +1,87 @@
///////////////////////////////////////////////////////////////////////
// File: maxpool.h
// Description: Standard Max-Pooling layer.
// Author: Ray Smith
// Created: Tue Mar 18 16:28:18 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "maxpool.h"
#include "tprintf.h"
namespace tesseract {
Maxpool::Maxpool(const STRING& name, int ni, int x_scale, int y_scale)
: Reconfig(name, ni, x_scale, y_scale) {
type_ = NT_MAXPOOL;
no_ = ni;
}
Maxpool::~Maxpool() {
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool Maxpool::DeSerialize(bool swap, TFile* fp) {
bool result = Reconfig::DeSerialize(swap, fp);
no_ = ni_;
return result;
}
// Runs forward propagation of activations on the input line.
// See NetworkCpp for a detailed discussion of the arguments.
void Maxpool::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
output->ResizeScaled(input, x_scale_, y_scale_, no_);
maxes_.ResizeNoInit(output->Width(), ni_);
back_map_ = input.stride_map();
StrideMap::Index dest_index(output->stride_map());
do {
int out_t = dest_index.t();
StrideMap::Index src_index(input.stride_map(), dest_index.index(FD_BATCH),
dest_index.index(FD_HEIGHT) * y_scale_,
dest_index.index(FD_WIDTH) * x_scale_);
// Find the max input out of x_scale_ groups of y_scale_ inputs.
// Do it independently for each input dimension.
int* max_line = maxes_[out_t];
int in_t = src_index.t();
output->CopyTimeStepFrom(out_t, input, in_t);
for (int i = 0; i < ni_; ++i) {
max_line[i] = in_t;
}
for (int x = 0; x < x_scale_; ++x) {
for (int y = 0; y < y_scale_; ++y) {
StrideMap::Index src_xy(src_index);
if (src_xy.AddOffset(x, FD_WIDTH) && src_xy.AddOffset(y, FD_HEIGHT)) {
output->MaxpoolTimeStep(out_t, input, src_xy.t(), max_line);
}
}
}
} while (dest_index.Increment());
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool Maxpool::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
back_deltas->ResizeToMap(fwd_deltas.int_mode(), back_map_, ni_);
back_deltas->MaxpoolBackward(fwd_deltas, maxes_);
return true;
}
} // namespace tesseract.

71
lstm/maxpool.h Normal file
View File

@ -0,0 +1,71 @@
///////////////////////////////////////////////////////////////////////
// File: maxpool.h
// Description: Standard Max-Pooling layer.
// Author: Ray Smith
// Created: Tue Mar 18 16:28:18 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_MAXPOOL_H_
#define TESSERACT_LSTM_MAXPOOL_H_
#include "reconfig.h"
namespace tesseract {
// Maxpooling reduction. Independently for each input, selects the location
// in the rectangle that contains the max value.
// Backprop propagates only to the position that was the max.
class Maxpool : public Reconfig {
public:
Maxpool(const STRING& name, int ni, int x_scale, int y_scale);
virtual ~Maxpool();
// Accessors.
virtual STRING spec() const {
STRING spec;
spec.add_str_int("Mp", y_scale_);
spec.add_str_int(",", x_scale_);
return spec;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
virtual bool DeSerialize(bool swap, TFile* fp);
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output);
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas);
private:
// Memory of which input was the max.
GENERIC_2D_ARRAY<int> maxes_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_MAXPOOL_H_

309
lstm/network.cpp Normal file
View File

@ -0,0 +1,309 @@
///////////////////////////////////////////////////////////////////////
// File: network.cpp
// Description: Base class for neural network implementations.
// Author: Ray Smith
// Created: Wed May 01 17:25:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "network.h"
#include <stdlib.h>
// This base class needs to know about all its sub-classes because of the
// factory deserializing method: CreateFromFile.
#include "allheaders.h"
#include "convolve.h"
#include "fullyconnected.h"
#include "input.h"
#include "lstm.h"
#include "maxpool.h"
#include "parallel.h"
#include "reconfig.h"
#include "reversed.h"
#include "scrollview.h"
#include "series.h"
#include "statistc.h"
#ifdef INCLUDE_TENSORFLOW
#include "tfnetwork.h"
#endif
#include "tprintf.h"
namespace tesseract {
// Min and max window sizes.
const int kMinWinSize = 500;
const int kMaxWinSize = 2000;
// Window frame sizes need adding on to make the content fit.
const int kXWinFrameSize = 30;
const int kYWinFrameSize = 80;
// String names corresponding to the NetworkType enum. Keep in sync.
// Names used in Serialization to allow re-ordering/addition/deletion of
// layer types in NetworkType without invalidating existing network files.
char const* const Network::kTypeNames[NT_COUNT] = {
"Invalid", "Input",
"Convolve", "Maxpool",
"Parallel", "Replicated",
"ParBidiLSTM", "DepParUDLSTM",
"Par2dLSTM", "Series",
"Reconfig", "RTLReversed",
"TTBReversed", "XYTranspose",
"LSTM", "SummLSTM",
"Logistic", "LinLogistic",
"LinTanh", "Tanh",
"Relu", "Linear",
"Softmax", "SoftmaxNoCTC",
"LSTMSoftmax", "LSTMBinarySoftmax",
"TensorFlow",
};
Network::Network()
: type_(NT_NONE), training_(true), needs_to_backprop_(true),
network_flags_(0), ni_(0), no_(0), num_weights_(0),
forward_win_(NULL), backward_win_(NULL), randomizer_(NULL) {
}
Network::Network(NetworkType type, const STRING& name, int ni, int no)
: type_(type), training_(true), needs_to_backprop_(true),
network_flags_(0), ni_(ni), no_(no), num_weights_(0),
name_(name), forward_win_(NULL), backward_win_(NULL), randomizer_(NULL) {
}
Network::~Network() {
}
// Ends training by setting the training_ flag to false. Serialize and
// DeSerialize will now only operate on the run-time data.
void Network::SetEnableTraining(bool state) {
training_ = state;
}
// Sets flags that control the action of the network. See NetworkFlags enum
// for bit values.
void Network::SetNetworkFlags(uinT32 flags) {
network_flags_ = flags;
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
int Network::InitWeights(float range, TRand* randomizer) {
randomizer_ = randomizer;
return 0;
}
// Provides a pointer to a TRand for any networks that care to use it.
// Note that randomizer is a borrowed pointer that should outlive the network
// and should not be deleted by any of the networks.
void Network::SetRandomizer(TRand* randomizer) {
randomizer_ = randomizer;
}
// Sets needs_to_backprop_ to needs_backprop and returns true if
// needs_backprop || any weights in this network so the next layer forward
// can be told to produce backprop for this layer if needed.
bool Network::SetupNeedsBackprop(bool needs_backprop) {
needs_to_backprop_ = needs_backprop;
return needs_backprop || num_weights_ > 0;
}
// Writes to the given file. Returns false in case of error.
bool Network::Serialize(TFile* fp) const {
inT8 data = NT_NONE;
if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
STRING type_name = kTypeNames[type_];
if (!type_name.Serialize(fp)) return false;
data = training_;
if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
data = needs_to_backprop_;
if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
if (fp->FWrite(&network_flags_, sizeof(network_flags_), 1) != 1) return false;
if (fp->FWrite(&ni_, sizeof(ni_), 1) != 1) return false;
if (fp->FWrite(&no_, sizeof(no_), 1) != 1) return false;
if (fp->FWrite(&num_weights_, sizeof(num_weights_), 1) != 1) return false;
if (!name_.Serialize(fp)) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
// Should be overridden by subclasses, but NOT called by their DeSerialize.
bool Network::DeSerialize(bool swap, TFile* fp) {
inT8 data = 0;
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
if (data == NT_NONE) {
STRING type_name;
if (!type_name.DeSerialize(swap, fp)) return false;
for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
}
if (data == NT_COUNT) {
tprintf("Invalid network layer type:%s\n", type_name.string());
return false;
}
}
type_ = static_cast<NetworkType>(data);
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
training_ = data != 0;
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
needs_to_backprop_ = data != 0;
if (fp->FRead(&network_flags_, sizeof(network_flags_), 1) != 1) return false;
if (fp->FRead(&ni_, sizeof(ni_), 1) != 1) return false;
if (fp->FRead(&no_, sizeof(no_), 1) != 1) return false;
if (fp->FRead(&num_weights_, sizeof(num_weights_), 1) != 1) return false;
if (!name_.DeSerialize(swap, fp)) return false;
if (swap) {
ReverseN(&network_flags_, sizeof(network_flags_));
ReverseN(&ni_, sizeof(ni_));
ReverseN(&no_, sizeof(no_));
ReverseN(&num_weights_, sizeof(num_weights_));
}
return true;
}
// Reads from the given file. Returns NULL in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
// Determines the type of the serialized class and calls its DeSerialize
// on a new object of the appropriate type, which is returned.
Network* Network::CreateFromFile(bool swap, TFile* fp) {
Network stub;
if (!stub.DeSerialize(swap, fp)) return NULL;
Network* network = NULL;
switch (stub.type_) {
case NT_CONVOLVE:
network = new Convolve(stub.name_, stub.ni_, 0, 0);
break;
case NT_INPUT:
network = new Input(stub.name_, stub.ni_, stub.no_);
break;
case NT_LSTM:
case NT_LSTM_SOFTMAX:
case NT_LSTM_SOFTMAX_ENCODED:
case NT_LSTM_SUMMARY:
network =
new LSTM(stub.name_, stub.ni_, stub.no_, stub.no_, false, stub.type_);
break;
case NT_MAXPOOL:
network = new Maxpool(stub.name_, stub.ni_, 0, 0);
break;
// All variants of Parallel.
case NT_PARALLEL:
case NT_REPLICATED:
case NT_PAR_RL_LSTM:
case NT_PAR_UD_LSTM:
case NT_PAR_2D_LSTM:
network = new Parallel(stub.name_, stub.type_);
break;
case NT_RECONFIG:
network = new Reconfig(stub.name_, stub.ni_, 0, 0);
break;
// All variants of reversed.
case NT_XREVERSED:
case NT_YREVERSED:
case NT_XYTRANSPOSE:
network = new Reversed(stub.name_, stub.type_);
break;
case NT_SERIES:
network = new Series(stub.name_);
break;
case NT_TENSORFLOW:
#ifdef INCLUDE_TENSORFLOW
network = new TFNetwork(stub.name_);
#else
tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
return NULL;
#endif
break;
// All variants of FullyConnected.
case NT_SOFTMAX:
case NT_SOFTMAX_NO_CTC:
case NT_RELU:
case NT_TANH:
case NT_LINEAR:
case NT_LOGISTIC:
case NT_POSCLIP:
case NT_SYMCLIP:
network = new FullyConnected(stub.name_, stub.ni_, stub.no_, stub.type_);
break;
default:
return NULL;
}
network->training_ = stub.training_;
network->needs_to_backprop_ = stub.needs_to_backprop_;
network->network_flags_ = stub.network_flags_;
network->num_weights_ = stub.num_weights_;
if (!network->DeSerialize(swap, fp)) {
delete network;
return NULL;
}
return network;
}
// Returns a random number in [-range, range].
double Network::Random(double range) {
ASSERT_HOST(randomizer_ != NULL);
return randomizer_->SignedRand(range);
}
#ifndef GRAPHICS_DISABLED
// === Debug image display methods. ===
// Displays the image of the matrix to the forward window.
void Network::DisplayForward(const NetworkIO& matrix) {
Pix* image = matrix.ToPix();
ClearWindow(false, name_.string(), pixGetWidth(image),
pixGetHeight(image), &forward_win_);
DisplayImage(image, forward_win_);
forward_win_->Update();
}
// Displays the image of the matrix to the backward window.
void Network::DisplayBackward(const NetworkIO& matrix) {
Pix* image = matrix.ToPix();
STRING window_name = name_ + "-back";
ClearWindow(false, window_name.string(), pixGetWidth(image),
pixGetHeight(image), &backward_win_);
DisplayImage(image, backward_win_);
backward_win_->Update();
}
// Creates the window if needed, otherwise clears it.
void Network::ClearWindow(bool tess_coords, const char* window_name,
int width, int height, ScrollView** window) {
if (*window == NULL) {
int min_size = MIN(width, height);
if (min_size < kMinWinSize) {
if (min_size < 1) min_size = 1;
width = width * kMinWinSize / min_size;
height = height * kMinWinSize / min_size;
}
width += kXWinFrameSize;
height += kYWinFrameSize;
if (width > kMaxWinSize) width = kMaxWinSize;
if (height > kMaxWinSize) height = kMaxWinSize;
*window = new ScrollView(window_name, 80, 100, width, height, width, height,
tess_coords);
tprintf("Created window %s of size %d, %d\n", window_name, width, height);
} else {
(*window)->Clear();
}
}
// Displays the pix in the given window. and returns the height of the pix.
// The pix is pixDestroyed.
int Network::DisplayImage(Pix* pix, ScrollView* window) {
int height = pixGetHeight(pix);
window->Image(pix, 0, 0);
pixDestroy(&pix);
return height;
}
#endif // GRAPHICS_DISABLED
} // namespace tesseract.

292
lstm/network.h Normal file
View File

@ -0,0 +1,292 @@
///////////////////////////////////////////////////////////////////////
// File: network.h
// Description: Base class for neural network implementations.
// Author: Ray Smith
// Created: Wed May 01 16:38:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_NETWORK_H_
#define TESSERACT_LSTM_NETWORK_H_
#include <stdio.h>
#include <cmath>
#include "genericvector.h"
#include "helpers.h"
#include "matrix.h"
#include "networkio.h"
#include "serialis.h"
#include "static_shape.h"
#include "tprintf.h"
struct Pix;
class ScrollView;
class TBOX;
namespace tesseract {
class ImageData;
class NetworkScratch;
// Enum to store the run-time type of a Network. Keep in sync with kTypeNames.
enum NetworkType {
NT_NONE, // The naked base class.
NT_INPUT, // Inputs from an image.
// Plumbing networks combine other networks or rearrange the inputs.
NT_CONVOLVE, // Duplicates inputs in a sliding window neighborhood.
NT_MAXPOOL, // Chooses the max result from a rectangle.
NT_PARALLEL, // Runs networks in parallel.
NT_REPLICATED, // Runs identical networks in parallel.
NT_PAR_RL_LSTM, // Runs LTR and RTL LSTMs in parallel.
NT_PAR_UD_LSTM, // Runs Up and Down LSTMs in parallel.
NT_PAR_2D_LSTM, // Runs 4 LSTMs in parallel.
NT_SERIES, // Executes a sequence of layers.
NT_RECONFIG, // Scales the time/y size but makes the output deeper.
NT_XREVERSED, // Reverses the x direction of the inputs/outputs.
NT_YREVERSED, // Reverses the y-direction of the inputs/outputs.
NT_XYTRANSPOSE, // Transposes x and y (for just a single op).
// Functional networks actually calculate stuff.
NT_LSTM, // Long-Short-Term-Memory block.
NT_LSTM_SUMMARY, // LSTM that only keeps its last output.
NT_LOGISTIC, // Fully connected logistic nonlinearity.
NT_POSCLIP, // Fully connected rect lin version of logistic.
NT_SYMCLIP, // Fully connected rect lin version of tanh.
NT_TANH, // Fully connected with tanh nonlinearity.
NT_RELU, // Fully connected with rectifier nonlinearity.
NT_LINEAR, // Fully connected with no nonlinearity.
NT_SOFTMAX, // Softmax uses exponential normalization, with CTC.
NT_SOFTMAX_NO_CTC, // Softmax uses exponential normalization, no CTC.
// The SOFTMAX LSTMs both have an extra softmax layer on top, but inside, with
// the outputs fed back to the input of the LSTM at the next timestep.
// The ENCODED version binary encodes the softmax outputs, providing log2 of
// the number of outputs as additional inputs, and the other version just
// provides all the softmax outputs as additional inputs.
NT_LSTM_SOFTMAX, // 1-d LSTM with built-in fully connected softmax.
NT_LSTM_SOFTMAX_ENCODED, // 1-d LSTM with built-in binary encoded softmax.
// A TensorFlow graph encapsulated as a Tesseract network.
NT_TENSORFLOW,
NT_COUNT // Array size.
};
// Enum of Network behavior flags. Can in theory be set for each individual
// network element.
enum NetworkFlags {
// Network forward/backprop behavior.
NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer.
NF_ADA_GRAD = 128, // Weight-specific learning rate.
};
// Base class for network types. Not quite an abstract base class, but almost.
// Most of the time no isolated Network exists, except prior to
// deserialization.
class Network {
public:
Network();
Network(NetworkType type, const STRING& name, int ni, int no);
virtual ~Network();
// Accessors.
NetworkType type() const {
return type_;
}
bool training() const {
return training_;
}
bool needs_to_backprop() const {
return needs_to_backprop_;
}
int num_weights() const { return num_weights_; }
int NumInputs() const {
return ni_;
}
int NumOutputs() const {
return no_;
}
// Returns the required shape input to the network.
virtual StaticShape InputShape() const {
StaticShape result;
return result;
}
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
virtual StaticShape OutputShape(const StaticShape& input_shape) const {
StaticShape result(input_shape);
result.set_depth(no_);
return result;
}
const STRING& name() const {
return name_;
}
virtual STRING spec() const {
return "?";
}
bool TestFlag(NetworkFlags flag) const {
return (network_flags_ & flag) != 0;
}
// Initialization and administrative functions that are mostly provided
// by Plumbing.
// Returns true if the given type is derived from Plumbing, and thus contains
// multiple sub-networks that can have their own learning rate.
virtual bool IsPlumbingType() const { return false; }
// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
virtual void SetEnableTraining(bool state);
// Sets flags that control the action of the network. See NetworkFlags enum
// for bit values.
virtual void SetNetworkFlags(uinT32 flags);
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
// Note that randomizer is a borrowed pointer that should outlive the network
// and should not be deleted by any of the networks.
// Returns the number of weights initialized.
virtual int InitWeights(float range, TRand* randomizer);
// Converts a float network to an int network.
virtual void ConvertToInt() {}
// Provides a pointer to a TRand for any networks that care to use it.
// Note that randomizer is a borrowed pointer that should outlive the network
// and should not be deleted by any of the networks.
virtual void SetRandomizer(TRand* randomizer);
// Sets needs_to_backprop_ to needs_backprop and returns true if
// needs_backprop || any weights in this network so the next layer forward
// can be told to produce backprop for this layer if needed.
virtual bool SetupNeedsBackprop(bool needs_backprop);
// Returns the most recent reduction factor that the network applied to the
// time sequence. Assumes that any 2-d is already eliminated. Used for
// scaling bounding boxes of truth data and calculating result bounding boxes.
// WARNING: if GlobalMinimax is used to vary the scale, this will return
// the last used scale factor. Call it before any forward, and it will return
// the minimum scale factor of the paths through the GlobalMinimax.
virtual int XScaleFactor() const {
return 1;
}
// Provides the (minimum) x scale factor to the network (of interest only to
// input units) so they can determine how to scale bounding boxes.
virtual void CacheXScaleFactor(int factor) {}
// Provides debug output on the weights.
virtual void DebugWeights() {
tprintf("Must override Network::DebugWeights for type %d\n", type_);
}
// Writes to the given file. Returns false in case of error.
// Should be overridden by subclasses, but called by their Serialize.
virtual bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
// Should be overridden by subclasses, but NOT called by their DeSerialize.
virtual bool DeSerialize(bool swap, TFile* fp);
// Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the adagrad computation iff
// use_ada_grad_ is true.
virtual void Update(float learning_rate, float momentum, int num_samples) {}
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// *changed.
virtual void CountAlternators(const Network& other, double* same,
double* changed) const {}
// Reads from the given file. Returns NULL in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
// Determines the type of the serialized class and calls its DeSerialize
// on a new object of the appropriate type, which is returned.
static Network* CreateFromFile(bool swap, TFile* fp);
// Runs forward propagation of activations on the input line.
// Note that input and output are both 2-d arrays.
// The 1st index is the time element. In a 1-d network, it might be the pixel
// position on the textline. In a 2-d network, the linearization is defined
// by the stride_map. (See networkio.h).
// The 2nd index of input is the network inputs/outputs, and the dimension
// of the input must match NumInputs() of this network.
// The output array will be resized as needed so that its 1st dimension is
// always equal to the number of output values, and its second dimension is
// always NumOutputs(). Note that all this detail is encapsulated away inside
// NetworkIO, as are the internals of the scratch memory space used by the
// network. See networkscratch.h for that.
// If input_transpose is not NULL, then it contains the transpose of input,
// and the caller guarantees that it will still be valid on the next call to
// backward. The callee is therefore at liberty to save the pointer and
// reference it on a call to backward. This is a bit ugly, but it makes it
// possible for a replicating parallel to calculate the input transpose once
// instead of all the replicated networks having to do it.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
tprintf("Must override Network::Forward for type %d\n", type_);
}
// Runs backward propagation of errors on fwdX_deltas.
// Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward.
// Returns false if back_deltas was not set, due to there being no point in
// propagating further backwards. Thus most complete networks will always
// return false from Backward!
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
tprintf("Must override Network::Backward for type %d\n", type_);
return false;
}
// === Debug image display methods. ===
// Displays the image of the matrix to the forward window.
void DisplayForward(const NetworkIO& matrix);
// Displays the image of the matrix to the backward window.
void DisplayBackward(const NetworkIO& matrix);
// Creates the window if needed, otherwise clears it.
static void ClearWindow(bool tess_coords, const char* window_name,
int width, int height, ScrollView** window);
// Displays the pix in the given window. and returns the height of the pix.
// The pix is pixDestroyed.
static int DisplayImage(Pix* pix, ScrollView* window);
protected:
// Returns a random number in [-range, range].
double Random(double range);
protected:
NetworkType type_; // Type of the derived network class.
bool training_; // Are we currently training?
bool needs_to_backprop_; // This network needs to output back_deltas.
inT32 network_flags_; // Behavior control flags in NetworkFlags.
inT32 ni_; // Number of input values.
inT32 no_; // Number of output values.
inT32 num_weights_; // Number of weights in this and sub-network.
STRING name_; // A unique name for this layer.
// NOT-serialized debug data.
ScrollView* forward_win_; // Recognition debug display window.
ScrollView* backward_win_; // Training debug display window.
TRand* randomizer_; // Random number generator.
// Static serialized name/type_ mapping. Keep in sync with NetworkType.
static char const* const kTypeNames[NT_COUNT];
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_NETWORK_H_

488
lstm/networkbuilder.cpp Normal file
View File

@ -0,0 +1,488 @@
///////////////////////////////////////////////////////////////////////
// File: networkbuilder.h
// Description: Class to parse the network description language and
// build a corresponding network.
// Author: Ray Smith
// Created: Wed Jul 16 18:35:38 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "networkbuilder.h"
#include "convolve.h"
#include "fullyconnected.h"
#include "input.h"
#include "lstm.h"
#include "maxpool.h"
#include "network.h"
#include "parallel.h"
#include "reconfig.h"
#include "reversed.h"
#include "series.h"
#include "unicharset.h"
namespace tesseract {
// Builds a network with a network_spec in the network description
// language, to recognize a character set of num_outputs size.
// If append_index is non-negative, then *network must be non-null and the
// given network_spec will be appended to *network AFTER append_index, with
// the top of the input *network discarded.
// Note that network_spec is call by value to allow a non-const char* pointer
// into the string for BuildFromString.
// net_flags control network behavior according to the NetworkFlags enum.
// The resulting network is returned via **network.
// Returns false if something failed.
bool NetworkBuilder::InitNetwork(int num_outputs, STRING network_spec,
int append_index, int net_flags,
float weight_range, TRand* randomizer,
Network** network) {
NetworkBuilder builder(num_outputs);
Series* bottom_series = NULL;
StaticShape input_shape;
if (append_index >= 0) {
// Split the current network after the given append_index.
ASSERT_HOST(*network != NULL && (*network)->type() == NT_SERIES);
Series* series = reinterpret_cast<Series*>(*network);
Series* top_series = NULL;
series->SplitAt(append_index, &bottom_series, &top_series);
if (bottom_series == NULL || top_series == NULL) {
tprintf("Yikes! Splitting current network failed!!\n");
return false;
}
input_shape = bottom_series->OutputShape(input_shape);
delete top_series;
}
char* str_ptr = &network_spec[0];
*network = builder.BuildFromString(input_shape, &str_ptr);
if (*network == NULL) return false;
(*network)->SetNetworkFlags(net_flags);
(*network)->InitWeights(weight_range, randomizer);
(*network)->SetupNeedsBackprop(false);
if (bottom_series != NULL) {
bottom_series->AppendSeries(*network);
*network = bottom_series;
}
(*network)->CacheXScaleFactor((*network)->XScaleFactor());
return true;
}
// Helper skips whitespace.
static void SkipWhitespace(char** str) {
while (**str == ' ' || **str == '\t' || **str == '\n') ++*str;
}
// Parses the given string and returns a network according to the network
// description language in networkbuilder.h
Network* NetworkBuilder::BuildFromString(const StaticShape& input_shape,
char** str) {
SkipWhitespace(str);
char code_ch = **str;
if (code_ch == '[') {
return ParseSeries(input_shape, nullptr, str);
}
if (input_shape.depth() == 0) {
// There must be an input at this point.
return ParseInput(str);
}
switch (code_ch) {
case '(':
return ParseParallel(input_shape, str);
case 'R':
return ParseR(input_shape, str);
case 'S':
return ParseS(input_shape, str);
case 'C':
return ParseC(input_shape, str);
case 'M':
return ParseM(input_shape, str);
case 'L':
return ParseLSTM(input_shape, str);
case 'F':
return ParseFullyConnected(input_shape, str);
case 'O':
return ParseOutput(input_shape, str);
default:
tprintf("Invalid network spec:%s\n", *str);
return nullptr;
}
return nullptr;
}
// Parses an input specification and returns the result, which may include a
// series.
Network* NetworkBuilder::ParseInput(char** str) {
// There must be an input at this point.
int length = 0;
int batch, height, width, depth;
int num_converted =
sscanf(*str, "%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length);
StaticShape shape;
shape.SetShape(batch, height, width, depth);
// num_converted may or may not include the length.
if (num_converted != 4 && num_converted != 5) {
tprintf("Must specify an input layer as the first layer, not %s!!\n", *str);
return nullptr;
}
*str += length;
Input* input = new Input("Input", shape);
// We want to allow [<input>rest of net... or <input>[rest of net... so we
// have to check explicitly for '[' here.
SkipWhitespace(str);
if (**str == '[') return ParseSeries(shape, input, str);
return input;
}
// Parses a sequential series of networks, defined by [<net><net>...].
Network* NetworkBuilder::ParseSeries(const StaticShape& input_shape,
Input* input_layer, char** str) {
StaticShape shape = input_shape;
Series* series = new Series("Series");
++*str;
if (input_layer != nullptr) {
series->AddToStack(input_layer);
shape = input_layer->OutputShape(shape);
}
Network* network = NULL;
while (**str != '\0' && **str != ']' &&
(network = BuildFromString(shape, str)) != NULL) {
shape = network->OutputShape(shape);
series->AddToStack(network);
}
if (**str != ']') {
tprintf("Missing ] at end of [Series]!\n");
delete series;
return NULL;
}
++*str;
return series;
}
// Parses a parallel set of networks, defined by (<net><net>...).
Network* NetworkBuilder::ParseParallel(const StaticShape& input_shape,
char** str) {
Parallel* parallel = new Parallel("Parallel", NT_PARALLEL);
++*str;
Network* network = NULL;
while (**str != '\0' && **str != ')' &&
(network = BuildFromString(input_shape, str)) != NULL) {
parallel->AddToStack(network);
}
if (**str != ')') {
tprintf("Missing ) at end of (Parallel)!\n");
delete parallel;
return nullptr;
}
++*str;
return parallel;
}
// Parses a network that begins with 'R'.
Network* NetworkBuilder::ParseR(const StaticShape& input_shape, char** str) {
char dir = (*str)[1];
if (dir == 'x' || dir == 'y') {
STRING name = "Reverse";
name += dir;
*str += 2;
Network* network = BuildFromString(input_shape, str);
if (network == nullptr) return nullptr;
Reversed* rev =
new Reversed(name, dir == 'y' ? NT_YREVERSED : NT_XREVERSED);
rev->SetNetwork(network);
return rev;
}
int replicas = strtol(*str + 1, str, 10);
if (replicas <= 0) {
tprintf("Invalid R spec!:%s\n", *str);
return nullptr;
}
Parallel* parallel = new Parallel("Replicated", NT_REPLICATED);
char* str_copy = *str;
for (int i = 0; i < replicas; ++i) {
str_copy = *str;
Network* network = BuildFromString(input_shape, &str_copy);
if (network == NULL) {
tprintf("Invalid replicated network!\n");
delete parallel;
return nullptr;
}
parallel->AddToStack(network);
}
*str = str_copy;
return parallel;
}
// Parses a network that begins with 'S'.
Network* NetworkBuilder::ParseS(const StaticShape& input_shape, char** str) {
int y = strtol(*str + 1, str, 10);
if (**str == ',') {
int x = strtol(*str + 1, str, 10);
if (y <= 0 || x <= 0) {
tprintf("Invalid S spec!:%s\n", *str);
return nullptr;
}
return new Reconfig("Reconfig", input_shape.depth(), x, y);
} else if (**str == '(') {
// TODO(rays) Add Generic reshape.
tprintf("Generic reshape not yet implemented!!\n");
return nullptr;
}
tprintf("Invalid S spec!:%s\n", *str);
return nullptr;
}
// Helper returns the fully-connected type for the character code.
static NetworkType NonLinearity(char func) {
switch (func) {
case 's':
return NT_LOGISTIC;
case 't':
return NT_TANH;
case 'r':
return NT_RELU;
case 'l':
return NT_LINEAR;
case 'm':
return NT_SOFTMAX;
case 'p':
return NT_POSCLIP;
case 'n':
return NT_SYMCLIP;
default:
return NT_NONE;
}
}
// Parses a network that begins with 'C'.
Network* NetworkBuilder::ParseC(const StaticShape& input_shape, char** str) {
NetworkType type = NonLinearity((*str)[1]);
if (type == NT_NONE) {
tprintf("Invalid nonlinearity on C-spec!: %s\n", *str);
return nullptr;
}
int y = 0, x = 0, d = 0;
if ((y = strtol(*str + 2, str, 10)) <= 0 || **str != ',' ||
(x = strtol(*str + 1, str, 10)) <= 0 || **str != ',' ||
(d = strtol(*str + 1, str, 10)) <= 0) {
tprintf("Invalid C spec!:%s\n", *str);
return nullptr;
}
if (x == 1 && y == 1) {
// No actual convolution. Just a FullyConnected on the current depth, to
// be slid over all batch,y,x.
return new FullyConnected("Conv1x1", input_shape.depth(), d, type);
}
Series* series = new Series("ConvSeries");
Convolve* convolve =
new Convolve("Convolve", input_shape.depth(), x / 2, y / 2);
series->AddToStack(convolve);
StaticShape fc_input = convolve->OutputShape(input_shape);
series->AddToStack(new FullyConnected("ConvNL", fc_input.depth(), d, type));
return series;
}
// Parses a network that begins with 'M'.
Network* NetworkBuilder::ParseM(const StaticShape& input_shape, char** str) {
int y = 0, x = 0;
if ((*str)[1] != 'p' || (y = strtol(*str + 2, str, 10)) <= 0 ||
**str != ',' || (x = strtol(*str + 1, str, 10)) <= 0) {
tprintf("Invalid Mp spec!:%s\n", *str);
return nullptr;
}
return new Maxpool("Maxpool", input_shape.depth(), x, y);
}
// Parses an LSTM network, either individual, bi- or quad-directional.
Network* NetworkBuilder::ParseLSTM(const StaticShape& input_shape, char** str) {
bool two_d = false;
NetworkType type = NT_LSTM;
char* spec_start = *str;
int chars_consumed = 1;
int num_outputs = 0;
char key = (*str)[chars_consumed], dir = 'f', dim = 'x';
if (key == 'S') {
type = NT_LSTM_SOFTMAX;
num_outputs = num_softmax_outputs_;
++chars_consumed;
} else if (key == 'E') {
type = NT_LSTM_SOFTMAX_ENCODED;
num_outputs = num_softmax_outputs_;
++chars_consumed;
} else if (key == '2' && (((*str)[2] == 'x' && (*str)[3] == 'y') ||
((*str)[2] == 'y' && (*str)[3] == 'x'))) {
chars_consumed = 4;
dim = (*str)[3];
two_d = true;
} else if (key == 'f' || key == 'r' || key == 'b') {
dir = key;
dim = (*str)[2];
if (dim != 'x' && dim != 'y') {
tprintf("Invalid dimension (x|y) in L Spec!:%s\n", *str);
return nullptr;
}
chars_consumed = 3;
if ((*str)[chars_consumed] == 's') {
++chars_consumed;
type = NT_LSTM_SUMMARY;
}
} else {
tprintf("Invalid direction (f|r|b) in L Spec!:%s\n", *str);
return nullptr;
}
int num_states = strtol(*str + chars_consumed, str, 10);
if (num_states <= 0) {
tprintf("Invalid number of states in L Spec!:%s\n", *str);
return nullptr;
}
Network* lstm = nullptr;
if (two_d) {
lstm = BuildLSTMXYQuad(input_shape.depth(), num_states);
} else {
if (num_outputs == 0) num_outputs = num_states;
STRING name(spec_start, *str - spec_start);
lstm = new LSTM(name, input_shape.depth(), num_states, num_outputs, false,
type);
if (dir != 'f') {
Reversed* rev = new Reversed("RevLSTM", NT_XREVERSED);
rev->SetNetwork(lstm);
lstm = rev;
}
if (dir == 'b') {
name += "LTR";
Parallel* parallel = new Parallel("BidiLSTM", NT_PAR_RL_LSTM);
parallel->AddToStack(new LSTM(name, input_shape.depth(), num_states,
num_outputs, false, type));
parallel->AddToStack(lstm);
lstm = parallel;
}
}
if (dim == 'y') {
Reversed* rev = new Reversed("XYTransLSTM", NT_XYTRANSPOSE);
rev->SetNetwork(lstm);
lstm = rev;
}
return lstm;
}
// Builds a set of 4 lstms with x and y reversal, running in true parallel.
Network* NetworkBuilder::BuildLSTMXYQuad(int num_inputs, int num_states) {
Parallel* parallel = new Parallel("2DLSTMQuad", NT_PAR_2D_LSTM);
parallel->AddToStack(new LSTM("L2DLTRDown", num_inputs, num_states,
num_states, true, NT_LSTM));
Reversed* rev = new Reversed("L2DLTRXRev", NT_XREVERSED);
rev->SetNetwork(new LSTM("L2DRTLDown", num_inputs, num_states, num_states,
true, NT_LSTM));
parallel->AddToStack(rev);
rev = new Reversed("L2DRTLYRev", NT_YREVERSED);
rev->SetNetwork(
new LSTM("L2DRTLUp", num_inputs, num_states, num_states, true, NT_LSTM));
Reversed* rev2 = new Reversed("L2DXRevU", NT_XREVERSED);
rev2->SetNetwork(rev);
parallel->AddToStack(rev2);
rev = new Reversed("L2DXRevY", NT_YREVERSED);
rev->SetNetwork(new LSTM("L2DLTRDown", num_inputs, num_states, num_states,
true, NT_LSTM));
parallel->AddToStack(rev);
return parallel;
}
// Helper builds a truly (0-d) fully connected layer of the given type.
static Network* BuildFullyConnected(const StaticShape& input_shape,
NetworkType type, const STRING& name,
int depth) {
if (input_shape.height() == 0 || input_shape.width() == 0) {
tprintf("Fully connected requires positive height and width, had %d,%d\n",
input_shape.height(), input_shape.width());
return nullptr;
}
int input_size = input_shape.height() * input_shape.width();
int input_depth = input_size * input_shape.depth();
Network* fc = new FullyConnected(name, input_depth, depth, type);
if (input_size > 1) {
Series* series = new Series("FCSeries");
series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(),
input_shape.width(), input_shape.height()));
series->AddToStack(fc);
fc = series;
}
return fc;
}
// Parses a Fully connected network.
Network* NetworkBuilder::ParseFullyConnected(const StaticShape& input_shape,
char** str) {
char* spec_start = *str;
NetworkType type = NonLinearity((*str)[1]);
if (type == NT_NONE) {
tprintf("Invalid nonlinearity on F-spec!: %s\n", *str);
return nullptr;
}
int depth = strtol(*str + 1, str, 10);
if (depth <= 0) {
tprintf("Invalid F spec!:%s\n", *str);
return nullptr;
}
STRING name(spec_start, *str - spec_start);
return BuildFullyConnected(input_shape, type, name, depth);
}
// Parses an Output spec.
Network* NetworkBuilder::ParseOutput(const StaticShape& input_shape,
char** str) {
char dims_ch = (*str)[1];
if (dims_ch != '0' && dims_ch != '1' && dims_ch != '2') {
tprintf("Invalid dims (2|1|0) in output spec!:%s\n", *str);
return nullptr;
}
char type_ch = (*str)[2];
if (type_ch != 'l' && type_ch != 's' && type_ch != 'c') {
tprintf("Invalid output type (l|s|c) in output spec!:%s\n", *str);
return nullptr;
}
int depth = strtol(*str + 3, str, 10);
if (depth != num_softmax_outputs_) {
tprintf("Warning: given outputs %d not equal to unicharset of %d.\n", depth,
num_softmax_outputs_);
depth = num_softmax_outputs_;
}
NetworkType type = NT_SOFTMAX;
if (type_ch == 'l')
type = NT_LOGISTIC;
else if (type_ch == 's')
type = NT_SOFTMAX_NO_CTC;
if (dims_ch == '0') {
// Same as standard fully connected.
return BuildFullyConnected(input_shape, type, "Output", depth);
} else if (dims_ch == '2') {
// We don't care if x and/or y are variable.
return new FullyConnected("Output2d", input_shape.depth(), depth, type);
}
// For 1-d y has to be fixed, and if not 1, moved to depth.
if (input_shape.height() == 0) {
tprintf("Fully connected requires fixed height!\n");
return nullptr;
}
int input_size = input_shape.height();
int input_depth = input_size * input_shape.depth();
Network* fc = new FullyConnected("Output", input_depth, depth, type);
if (input_size > 1) {
Series* series = new Series("FCSeries");
series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(), 1,
input_shape.height()));
series->AddToStack(fc);
fc = series;
}
return fc;
}
} // namespace tesseract.

160
lstm/networkbuilder.h Normal file
View File

@ -0,0 +1,160 @@
///////////////////////////////////////////////////////////////////////
// File: networkbuilder.h
// Description: Class to parse the network description language and
// build a corresponding network.
// Author: Ray Smith
// Created: Wed Jul 16 18:35:38 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_NETWORKBUILDER_H_
#define TESSERACT_LSTM_NETWORKBUILDER_H_
#include "static_shape.h"
#include "stridemap.h"
class STRING;
class UNICHARSET;
namespace tesseract {
class Input;
class Network;
class Parallel;
class TRand;
class NetworkBuilder {
public:
explicit NetworkBuilder(int num_softmax_outputs)
: num_softmax_outputs_(num_softmax_outputs) {}
// Builds a network with a network_spec in the network description
// language, to recognize a character set of num_outputs size.
// If append_index is non-negative, then *network must be non-null and the
// given network_spec will be appended to *network AFTER append_index, with
// the top of the input *network discarded.
// Note that network_spec is call by value to allow a non-const char* pointer
// into the string for BuildFromString.
// net_flags control network behavior according to the NetworkFlags enum.
// The resulting network is returned via **network.
// Returns false if something failed.
static bool InitNetwork(int num_outputs, STRING network_spec,
int append_index, int net_flags, float weight_range,
TRand* randomizer, Network** network);
// Parses the given string and returns a network according to the following
// language:
// ============ Syntax of description below: ============
// <d> represents a number.
// <net> represents any single network element, including (recursively) a
// [...] series or (...) parallel construct.
// (s|t|r|l|m) (regex notation) represents a single required letter.
// NOTE THAT THROUGHOUT, x and y are REVERSED from conventional mathematics,
// to use the same convention as Tensor Flow. The reason TF adopts this
// convention is to eliminate the need to transpose images on input, since
// adjacent memory locations in images increase x and then y, while adjacent
// memory locations in tensors in TF, and NetworkIO in tesseract increase the
// rightmost index first, then the next-left and so-on, like C arrays.
// ============ INPUTS ============
// <b>,<h>,<w>,<d> A batch of b images with height h, width w, and depth d.
// b, h and/or w may be zero, to indicate variable size. Some network layer
// (summarizing LSTM) must be used to make a variable h known.
// d may be 1 for greyscale, 3 for color.
// NOTE that throughout the constructed network, the inputs/outputs are all of
// the same [batch,height,width,depth] dimensions, even if a different size.
// ============ PLUMBING ============
// [...] Execute ... networks in series (layers).
// (...) Execute ... networks in parallel, with their output depths added.
// R<d><net> Execute d replicas of net in parallel, with their output depths
// added.
// Rx<net> Execute <net> with x-dimension reversal.
// Ry<net> Execute <net> with y-dimension reversal.
// S<y>,<x> Rescale 2-D input by shrink factor x,y, rearranging the data by
// increasing the depth of the input by factor xy.
// Mp<y>,<x> Maxpool the input, reducing the size by an (x,y) rectangle.
// ============ FUNCTIONAL UNITS ============
// C(s|t|r|l|m)<y>,<x>,<d> Convolves using a (x,y) window, with no shrinkage,
// random infill, producing d outputs, then applies a non-linearity:
// s: Sigmoid, t: Tanh, r: Relu, l: Linear, m: Softmax.
// F(s|t|r|l|m)<d> Truly fully-connected with s|t|r|l|m non-linearity and d
// outputs. Connects to every x,y,depth position of the input, reducing
// height, width to 1, producing a single <d> vector as the output.
// Input height and width must be constant.
// For a sliding-window linear or non-linear map that connects just to the
// input depth, and leaves the input image size as-is, use a 1x1 convolution
// eg. Cr1,1,64 instead of Fr64.
// L(f|r|b)(x|y)[s]<n> LSTM cell with n states/outputs.
// The LSTM must have one of:
// f runs the LSTM forward only.
// r runs the LSTM reversed only.
// b runs the LSTM bidirectionally.
// It will operate on either the x- or y-dimension, treating the other
// dimension independently (as if part of the batch).
// s (optional) summarizes the output in the requested dimension,
// outputting only the final step, collapsing the dimension to a
// single element.
// LS<n> Forward-only LSTM cell in the x-direction, with built-in Softmax.
// LE<n> Forward-only LSTM cell in the x-direction, with built-in softmax,
// with binary Encoding.
// L2xy<n> Full 2-d LSTM operating in quad-directions (bidi in x and y) and
// all the output depths added.
// ============ OUTPUTS ============
// The network description must finish with an output specification:
// O(2|1|0)(l|s|c)<n> output layer with n classes
// 2 (heatmap) Output is a 2-d vector map of the input (possibly at
// different scale).
// 1 (sequence) Output is a 1-d sequence of vector values.
// 0 (category) Output is a 0-d single vector value.
// l uses a logistic non-linearity on the output, allowing multiple
// hot elements in any output vector value.
// s uses a softmax non-linearity, with one-hot output in each value.
// c uses a softmax with CTC. Can only be used with s (sequence).
// NOTE1: Only O1s and O1c are currently supported.
// NOTE2: n is totally ignored, and for compatibility purposes only. The
// output number of classes is obtained automatically from the
// unicharset.
Network* BuildFromString(const StaticShape& input_shape, char** str);
private:
// Parses an input specification and returns the result, which may include a
// series.
Network* ParseInput(char** str);
// Parses a sequential series of networks, defined by [<net><net>...].
Network* ParseSeries(const StaticShape& input_shape, Input* input_layer,
char** str);
// Parses a parallel set of networks, defined by (<net><net>...).
Network* ParseParallel(const StaticShape& input_shape, char** str);
// Parses a network that begins with 'R'.
Network* ParseR(const StaticShape& input_shape, char** str);
// Parses a network that begins with 'S'.
Network* ParseS(const StaticShape& input_shape, char** str);
// Parses a network that begins with 'C'.
Network* ParseC(const StaticShape& input_shape, char** str);
// Parses a network that begins with 'M'.
Network* ParseM(const StaticShape& input_shape, char** str);
// Parses an LSTM network, either individual, bi- or quad-directional.
Network* ParseLSTM(const StaticShape& input_shape, char** str);
// Builds a set of 4 lstms with t and y reversal, running in true parallel.
static Network* BuildLSTMXYQuad(int num_inputs, int num_states);
// Parses a Fully connected network.
Network* ParseFullyConnected(const StaticShape& input_shape, char** str);
// Parses an Output spec.
Network* ParseOutput(const StaticShape& input_shape, char** str);
private:
int num_softmax_outputs_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_NETWORKBUILDER_H_

981
lstm/networkio.cpp Normal file
View File

@ -0,0 +1,981 @@
///////////////////////////////////////////////////////////////////////
// File: networkio.cpp
// Description: Network input/output data, allowing float/int implementations.
// Author: Ray Smith
// Created: Thu Jun 19 13:01:31 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "networkio.h"
#include "allheaders.h"
#include "functions.h"
#include "statistc.h"
#include "tprintf.h"
namespace tesseract {
// Minimum value to output for certainty.
const float kMinCertainty = -20.0f;
// Probability corresponding to kMinCertainty.
const float kMinProb = exp(kMinCertainty);
// Resizes to a specific size as a 2-d temp buffer. No batches, no y-dim.
void NetworkIO::Resize2d(bool int_mode, int width, int num_features) {
stride_map_ = StrideMap();
int_mode_ = int_mode;
if (int_mode_) {
i_.ResizeNoInit(width, num_features);
} else {
f_.ResizeNoInit(width, num_features);
}
}
// Resizes to a specific stride_map.
void NetworkIO::ResizeToMap(bool int_mode, const StrideMap& stride_map,
int num_features) {
// If this assert fails, it most likely got here through an uninitialized
// scratch element, ie call NetworkScratch::IO::Resizexxx() not
// NetworkIO::Resizexxx()!!
ASSERT_HOST(this != NULL);
stride_map_ = stride_map;
int_mode_ = int_mode;
if (int_mode_) {
i_.ResizeNoInit(stride_map.Width(), num_features);
} else {
f_.ResizeNoInit(stride_map.Width(), num_features);
}
ZeroInvalidElements();
}
// Shrinks image size by x_scale,y_scale, and use given number of features.
void NetworkIO::ResizeScaled(const NetworkIO& src,
int x_scale, int y_scale, int num_features) {
StrideMap stride_map = src.stride_map_;
stride_map.ScaleXY(x_scale, y_scale);
ResizeToMap(src.int_mode_, stride_map, num_features);
}
// Resizes to just 1 x-coord, whatever the input.
void NetworkIO::ResizeXTo1(const NetworkIO& src, int num_features) {
StrideMap stride_map = src.stride_map_;
stride_map.ReduceWidthTo1();
ResizeToMap(src.int_mode_, stride_map, num_features);
}
// Initialize all the array to zero.
void NetworkIO::Zero() {
int width = Width();
// Zero out the everything. Column-by-column in case it is aligned.
for (int t = 0; t < width; ++t) {
ZeroTimeStep(t);
}
}
// Initializes to zero all elements of the array that do not correspond to
// valid image positions. (If a batch of different-sized images are packed
// together, then there will be padding pixels.)
void NetworkIO::ZeroInvalidElements() {
int num_features = NumFeatures();
int full_width = stride_map_.Size(FD_WIDTH);
int full_height = stride_map_.Size(FD_HEIGHT);
StrideMap::Index b_index(stride_map_);
do {
int end_x = b_index.MaxIndexOfDim(FD_WIDTH) + 1;
if (end_x < full_width) {
// The width is small, so fill for every valid y.
StrideMap::Index y_index(b_index);
int fill_size = num_features * (full_width - end_x);
do {
StrideMap::Index z_index(y_index);
z_index.AddOffset(end_x, FD_WIDTH);
if (int_mode_) {
ZeroVector(fill_size, i_[z_index.t()]);
} else {
ZeroVector(fill_size, f_[z_index.t()]);
}
} while (y_index.AddOffset(1, FD_HEIGHT));
}
int end_y = b_index.MaxIndexOfDim(FD_HEIGHT) + 1;
if (end_y < full_height) {
// The height is small, so fill in the space in one go.
StrideMap::Index y_index(b_index);
y_index.AddOffset(end_y, FD_HEIGHT);
int fill_size = num_features * full_width * (full_height - end_y);
if (int_mode_) {
ZeroVector(fill_size, i_[y_index.t()]);
} else {
ZeroVector(fill_size, f_[y_index.t()]);
}
}
} while (b_index.AddOffset(1, FD_BATCH));
}
// Helper computes a black point and white point to contrast-enhance an image.
// The computation is based on the assumption that the image is of a single line
// of text, so a horizontal line through the middle of the image passes through
// at least some of it, so local minima and maxima are a good proxy for black
// and white pixel samples.
static void ComputeBlackWhite(Pix* pix, float* black, float* white) {
int width = pixGetWidth(pix);
int height = pixGetHeight(pix);
STATS mins(0, 256), maxes(0, 256);
if (width >= 3) {
int y = height / 2;
const l_uint32* line = pixGetData(pix) + pixGetWpl(pix) * y;
int prev = GET_DATA_BYTE(line, 0);
int curr = GET_DATA_BYTE(line, 1);
for (int x = 1; x + 1 < width; ++x) {
int next = GET_DATA_BYTE(line, x + 1);
if ((curr < prev && curr <= next) || (curr <= prev && curr < next)) {
// Local minimum.
mins.add(curr, 1);
}
if ((curr > prev && curr >= next) || (curr >= prev && curr > next)) {
// Local maximum.
maxes.add(curr, 1);
}
prev = curr;
curr = next;
}
}
if (mins.get_total() == 0) mins.add(0, 1);
if (maxes.get_total() == 0) maxes.add(255, 1);
*black = mins.ile(0.25);
*white = maxes.ile(0.75);
}
// Sets up the array from the given image, using the currently set int_mode_.
// If the image width doesn't match the shape, the image is truncated or padded
// with noise to match.
void NetworkIO::FromPix(const StaticShape& shape, const Pix* pix,
TRand* randomizer) {
std::vector<const Pix*> pixes(1, pix);
FromPixes(shape, pixes, randomizer);
}
// Sets up the array from the given set of images, using the currently set
// int_mode_. If the image width doesn't match the shape, the images are
// truncated or padded with noise to match.
void NetworkIO::FromPixes(const StaticShape& shape,
const std::vector<const Pix*>& pixes,
TRand* randomizer) {
int target_height = shape.height();
int target_width = shape.width();
std::vector<std::pair<int, int>> h_w_pairs;
for (auto pix : pixes) {
Pix* var_pix = const_cast<Pix*>(pix);
int width = pixGetWidth(var_pix);
if (target_width != 0) width = target_width;
int height = pixGetHeight(var_pix);
if (target_height != 0) height = target_height;
h_w_pairs.emplace_back(height, width);
}
stride_map_.SetStride(h_w_pairs);
ResizeToMap(int_mode(), stride_map_, shape.depth());
// Iterate over the images again to copy the data.
for (int b = 0; b < pixes.size(); ++b) {
Pix* pix = const_cast<Pix*>(pixes[b]);
float black = 0.0f, white = 255.0f;
if (shape.depth() != 3) ComputeBlackWhite(pix, &black, &white);
float contrast = (white - black) / 2.0f;
if (contrast <= 0.0f) contrast = 1.0f;
if (shape.height() == 1) {
Copy1DGreyImage(b, pix, black, contrast, randomizer);
} else {
Copy2DImage(b, pix, black, contrast, randomizer);
}
}
}
// Copies the given pix to *this at the given batch index, stretching and
// clipping the pixel values so that [black, black + 2*contrast] maps to the
// dynamic range of *this, ie [-1,1] for a float and (-127,127) for int.
// This is a 2-d operation in the sense that the output depth is the number
// of input channels, the height is the height of the image, and the width
// is the width of the image, or truncated/padded with noise if the width
// is a fixed size.
void NetworkIO::Copy2DImage(int batch, Pix* pix, float black, float contrast,
TRand* randomizer) {
int width = pixGetWidth(pix);
int height = pixGetHeight(pix);
int wpl = pixGetWpl(pix);
StrideMap::Index index(stride_map_);
index.AddOffset(batch, FD_BATCH);
int t = index.t();
int target_height = stride_map_.Size(FD_HEIGHT);
int target_width = stride_map_.Size(FD_WIDTH);
int num_features = NumFeatures();
bool color = num_features == 3;
if (width > target_width) width = target_width;
const uinT32* line = pixGetData(pix);
for (int y = 0; y < target_height; ++y, line += wpl) {
int x = 0;
if (y < height) {
for (x = 0; x < width; ++x, ++t) {
if (color) {
int f = 0;
for (int c = COLOR_RED; c <= COLOR_BLUE; ++c) {
int pixel = GET_DATA_BYTE(line + x, c);
SetPixel(t, f++, pixel, black, contrast);
}
} else {
int pixel = GET_DATA_BYTE(line, x);
SetPixel(t, 0, pixel, black, contrast);
}
}
}
for (; x < target_width; ++x) Randomize(t++, 0, num_features, randomizer);
}
}
// Copies the given pix to *this at the given batch index, as Copy2DImage
// above, except that the output depth is the height of the input image, the
// output height is 1, and the output width as for Copy2DImage.
// The image is thus treated as a 1-d set of vertical pixel strips.
void NetworkIO::Copy1DGreyImage(int batch, Pix* pix, float black,
float contrast, TRand* randomizer) {
int width = pixGetWidth(pix);
int height = pixGetHeight(pix);
ASSERT_HOST(height == NumFeatures());
int wpl = pixGetWpl(pix);
StrideMap::Index index(stride_map_);
index.AddOffset(batch, FD_BATCH);
int t = index.t();
int target_width = stride_map_.Size(FD_WIDTH);
if (width > target_width) width = target_width;
int x;
for (x = 0; x < width; ++x, ++t) {
for (int y = 0; y < height; ++y) {
const uinT32* line = pixGetData(pix) + wpl * y;
int pixel = GET_DATA_BYTE(line, x);
SetPixel(t, y, pixel, black, contrast);
}
}
for (; x < target_width; ++x) Randomize(t++, 0, height, randomizer);
}
// Helper stores the pixel value in i_ or f_ according to int_mode_.
// t: is the index from the StrideMap corresponding to the current
// [batch,y,x] position
// f: is the index into the depth/channel
// pixel: the value of the pixel from the image (in one channel)
// black: the pixel value to map to the lowest of the range of *this
// contrast: the range of pixel values to stretch to half the range of *this.
void NetworkIO::SetPixel(int t, int f, int pixel, float black, float contrast) {
float float_pixel = (pixel - black) / contrast - 1.0f;
if (int_mode_) {
i_[t][f] = ClipToRange(IntCastRounded((MAX_INT8 + 1) * float_pixel),
-MAX_INT8, MAX_INT8);
} else {
f_[t][f] = float_pixel;
}
}
// Converts the array to a Pix. Must be pixDestroyed after use.
Pix* NetworkIO::ToPix() const {
// Count the width of the image, and find the max multiplication factor.
int im_width = stride_map_.Size(FD_WIDTH);
int im_height = stride_map_.Size(FD_HEIGHT);
int num_features = NumFeatures();
int feature_factor = 1;
if (num_features == 3) {
// Special hack for color.
num_features = 1;
feature_factor = 3;
}
Pix* pix = pixCreate(im_width, im_height * num_features, 32);
StrideMap::Index index(stride_map_);
do {
int im_x = index.index(FD_WIDTH);
int top_im_y = index.index(FD_HEIGHT);
int im_y = top_im_y;
int t = index.t();
if (int_mode_) {
const inT8* features = i_[t];
for (int y = 0; y < num_features; ++y, im_y += im_height) {
int pixel = features[y * feature_factor];
// 1 or 2 features use greyscale.
int red = ClipToRange(pixel + 128, 0, 255);
int green = red, blue = red;
if (feature_factor == 3) {
// With 3 features assume RGB color.
green = ClipToRange(features[y * feature_factor + 1] + 128, 0, 255);
blue = ClipToRange(features[y * feature_factor + 2] + 128, 0, 255);
} else if (num_features > 3) {
// More than 3 features use false yellow/blue color, assuming a signed
// input in the range [-1,1].
red = abs(pixel) * 2;
if (pixel >= 0) {
green = red;
blue = 0;
} else {
blue = red;
green = red = 0;
}
}
pixSetPixel(pix, im_x, im_y, (red << L_RED_SHIFT) |
(green << L_GREEN_SHIFT) |
(blue << L_BLUE_SHIFT));
}
} else {
const float* features = f_[t];
for (int y = 0; y < num_features; ++y, im_y += im_height) {
float pixel = features[y * feature_factor];
// 1 or 2 features use greyscale.
int red = ClipToRange(IntCastRounded((pixel + 1.0f) * 127.5f), 0, 255);
int green = red, blue = red;
if (feature_factor == 3) {
// With 3 features assume RGB color.
pixel = features[y * feature_factor + 1];
green = ClipToRange(IntCastRounded((pixel + 1.0f) * 127.5f), 0, 255);
pixel = features[y * feature_factor + 2];
blue = ClipToRange(IntCastRounded((pixel + 1.0f) * 127.5f), 0, 255);
} else if (num_features > 3) {
// More than 3 features use false yellow/blue color, assuming a signed
// input in the range [-1,1].
red = ClipToRange(IntCastRounded(fabs(pixel) * 255), 0, 255);
if (pixel >= 0) {
green = red;
blue = 0;
} else {
blue = red;
green = red = 0;
}
}
pixSetPixel(pix, im_x, im_y, (red << L_RED_SHIFT) |
(green << L_GREEN_SHIFT) |
(blue << L_BLUE_SHIFT));
}
}
} while (index.Increment());
return pix;
}
// Prints the first and last num timesteps of the array for each feature.
void NetworkIO::Print(int num) const {
int num_features = NumFeatures();
for (int y = 0; y < num_features; ++y) {
for (int t = 0; t < Width(); ++t) {
if (num == 0 || t < num || t + num >= Width()) {
if (int_mode_) {
tprintf(" %g", static_cast<float>(i_[t][y]) / MAX_INT8);
} else {
tprintf(" %g", f_[t][y]);
}
}
}
tprintf("\n");
}
}
// Copies a single time step from src.
void NetworkIO::CopyTimeStepFrom(int dest_t, const NetworkIO& src, int src_t) {
ASSERT_HOST(int_mode_ == src.int_mode_);
if (int_mode_) {
memcpy(i_[dest_t], src.i_[src_t], i_.dim2() * sizeof(i_[0][0]));
} else {
memcpy(f_[dest_t], src.f_[src_t], f_.dim2() * sizeof(f_[0][0]));
}
}
// Copies a part of single time step from src.
void NetworkIO::CopyTimeStepGeneral(int dest_t, int dest_offset,
int num_features, const NetworkIO& src,
int src_t, int src_offset) {
ASSERT_HOST(int_mode_ == src.int_mode_);
if (int_mode_) {
memcpy(i_[dest_t] + dest_offset, src.i_[src_t] + src_offset,
num_features * sizeof(i_[0][0]));
} else {
memcpy(f_[dest_t] + dest_offset, src.f_[src_t] + src_offset,
num_features * sizeof(f_[0][0]));
}
}
// Zeroes a single time step.
void NetworkIO::ZeroTimeStepGeneral(int t, int offset, int num_features) {
if (int_mode_) {
ZeroVector(num_features, i_[t] + offset);
} else {
ZeroVector(num_features, f_[t] + offset);
}
}
// Sets the given range to random values.
void NetworkIO::Randomize(int t, int offset, int num_features,
TRand* randomizer) {
if (int_mode_) {
inT8* line = i_[t] + offset;
for (int i = 0; i < num_features; ++i)
line[i] = IntCastRounded(randomizer->SignedRand(MAX_INT8));
} else {
// float mode.
float* line = f_[t] + offset;
for (int i = 0; i < num_features; ++i)
line[i] = randomizer->SignedRand(1.0);
}
}
// Helper returns the label and score of the best choice over a range.
int NetworkIO::BestChoiceOverRange(int t_start, int t_end, int not_this,
int null_ch, float* rating,
float* certainty) const {
if (t_end <= t_start) return -1;
int max_char = -1;
float min_score = 0.0f;
for (int c = 0; c < NumFeatures(); ++c) {
if (c == not_this || c == null_ch) continue;
ScoresOverRange(t_start, t_end, c, null_ch, rating, certainty);
if (max_char < 0 || *rating < min_score) {
min_score = *rating;
max_char = c;
}
}
ScoresOverRange(t_start, t_end, max_char, null_ch, rating, certainty);
return max_char;
}
// Helper returns the rating and certainty of the choice over a range in output.
void NetworkIO::ScoresOverRange(int t_start, int t_end, int choice, int null_ch,
float* rating, float* certainty) const {
ASSERT_HOST(!int_mode_);
*rating = 0.0f;
*certainty = 0.0f;
if (t_end <= t_start || t_end <= 0) return;
float ratings[3] = {0.0f, 0.0f, 0.0f};
float certs[3] = {0.0f, 0.0f, 0.0f};
for (int t = t_start; t < t_end; ++t) {
const float* line = f_[t];
float score = ProbToCertainty(line[choice]);
float zero = ProbToCertainty(line[null_ch]);
if (t == t_start) {
ratings[2] = MAX_FLOAT32;
ratings[1] = -score;
certs[1] = score;
} else {
for (int i = 2; i >= 1; --i) {
if (ratings[i] > ratings[i - 1]) {
ratings[i] = ratings[i - 1];
certs[i] = certs[i - 1];
}
}
ratings[2] -= zero;
if (zero < certs[2]) certs[2] = zero;
ratings[1] -= score;
if (score < certs[1]) certs[1] = score;
}
ratings[0] -= zero;
if (zero < certs[0]) certs[0] = zero;
}
int best_i = ratings[2] < ratings[1] ? 2 : 1;
*rating = ratings[best_i] + t_end - t_start;
*certainty = certs[best_i];
}
// Returns the index (label) of the best value at the given timestep,
// excluding not_this and not_that, and if not null, sets the score to the
// log of the corresponding value.
int NetworkIO::BestLabel(int t, int not_this, int not_that,
float* score) const {
ASSERT_HOST(!int_mode_);
int best_index = -1;
float best_score = -MAX_FLOAT32;
const float* line = f_[t];
for (int i = 0; i < f_.dim2(); ++i) {
if (line[i] > best_score && i != not_this && i != not_that) {
best_score = line[i];
best_index = i;
}
}
if (score != NULL) *score = ProbToCertainty(best_score);
return best_index;
}
// Returns the best start position out of [start, end) (into which all labels
// must fit) to obtain the highest cumulative score for the given labels.
int NetworkIO::PositionOfBestMatch(const GenericVector<int>& labels, int start,
int end) const {
int length = labels.size();
int last_start = end - length;
int best_start = -1;
double best_score = 0.0;
for (int s = start; s <= last_start; ++s) {
double score = ScoreOfLabels(labels, s);
if (score > best_score || best_start < 0) {
best_score = score;
best_start = s;
}
}
return best_start;
}
// Returns the cumulative score of the given labels starting at start, and
// using one label per time-step.
double NetworkIO::ScoreOfLabels(const GenericVector<int>& labels,
int start) const {
int length = labels.size();
double score = 0.0;
for (int i = 0; i < length; ++i) {
score += f_(start + i, labels[i]);
}
return score;
}
// Helper function sets all the outputs for a single timestep, such that
// label has value ok_score, and the other labels share 1 - ok_score.
void NetworkIO::SetActivations(int t, int label, float ok_score) {
ASSERT_HOST(!int_mode_);
int num_classes = NumFeatures();
float bad_score = (1.0f - ok_score) / (num_classes - 1);
float* targets = f_[t];
for (int i = 0; i < num_classes; ++i)
targets[i] = bad_score;
targets[label] = ok_score;
}
// Modifies the values, only if needed, so that the given label is
// the winner at the given time step t.
void NetworkIO::EnsureBestLabel(int t, int label) {
ASSERT_HOST(!int_mode_);
if (BestLabel(t, NULL) != label) {
// Output value needs enhancing. Third all the other elements and add the
// remainder to best_label.
int num_classes = NumFeatures();
float* targets = f_[t];
float enhancement = (1.0f - targets[label]) / 3.0f;
for (int c = 0; c < num_classes; ++c) {
if (c == label) {
targets[c] += (1.0 - targets[c]) * (2 / 3.0);
} else {
targets[c] /= 3.0;
}
}
}
}
// Helper function converts prob to certainty taking the minimum into account.
/* static */
float NetworkIO::ProbToCertainty(float prob) {
return prob > kMinProb ? log(prob) : kMinCertainty;
}
// Returns true if there is any bad value that is suspiciously like a GT
// error. Assuming that *this is the difference(gradient) between target
// and forward output, returns true if there is a large negative value
// (correcting a very confident output) for which there is no corresponding
// positive value in an adjacent timestep for the same feature index. This
// allows the box-truthed samples to make fine adjustments to position while
// stopping other disagreements of confident output with ground truth.
bool NetworkIO::AnySuspiciousTruth(float confidence_thr) const {
int num_features = NumFeatures();
for (int t = 0; t < Width(); ++t) {
const float* features = f_[t];
for (int y = 0; y < num_features; ++y) {
float grad = features[y];
if (grad < -confidence_thr) {
// Correcting strong output. Check for movement.
if ((t == 0 || f_[t - 1][y] < confidence_thr / 2) &&
(t + 1 == Width() || f_[t + 1][y] < confidence_thr / 2)) {
return true; // No strong positive on either side.
}
}
}
}
return false;
}
// Reads a single timestep to floats in the range [-1, 1].
void NetworkIO::ReadTimeStep(int t, double* output) const {
if (int_mode_) {
const inT8* line = i_[t];
for (int i = 0; i < i_.dim2(); ++i) {
output[i] = static_cast<double>(line[i]) / MAX_INT8;
}
} else {
const float* line = f_[t];
for (int i = 0; i < f_.dim2(); ++i) {
output[i] = static_cast<double>(line[i]);
}
}
}
// Adds a single timestep to floats.
void NetworkIO::AddTimeStep(int t, double* inout) const {
int num_features = NumFeatures();
if (int_mode_) {
const inT8* line = i_[t];
for (int i = 0; i < num_features; ++i) {
inout[i] += static_cast<double>(line[i]) / MAX_INT8;
}
} else {
const float* line = f_[t];
for (int i = 0; i < num_features; ++i) {
inout[i] += line[i];
}
}
}
// Adds part of a single timestep to floats.
void NetworkIO::AddTimeStepPart(int t, int offset, int num_features,
float* inout) const {
if (int_mode_) {
const inT8* line = i_[t] + offset;
for (int i = 0; i < num_features; ++i) {
inout[i] += static_cast<float>(line[i]) / MAX_INT8;
}
} else {
const float* line = f_[t] + offset;
for (int i = 0; i < num_features; ++i) {
inout[i] += line[i];
}
}
}
// Writes a single timestep from floats in the range [-1, 1].
void NetworkIO::WriteTimeStep(int t, const double* input) {
WriteTimeStepPart(t, 0, NumFeatures(), input);
}
// Writes a single timestep from floats in the range [-1, 1] writing only
// num_features elements of input to (*this)[t], starting at offset.
void NetworkIO::WriteTimeStepPart(int t, int offset, int num_features,
const double* input) {
if (int_mode_) {
inT8* line = i_[t] + offset;
for (int i = 0; i < num_features; ++i) {
line[i] = ClipToRange(IntCastRounded(input[i] * MAX_INT8),
-MAX_INT8, MAX_INT8);
}
} else {
float* line = f_[t] + offset;
for (int i = 0; i < num_features; ++i) {
line[i] = static_cast<float>(input[i]);
}
}
}
// Maxpools a single time step from src.
void NetworkIO::MaxpoolTimeStep(int dest_t, const NetworkIO& src, int src_t,
int* max_line) {
ASSERT_HOST(int_mode_ == src.int_mode_);
if (int_mode_) {
int dim = i_.dim2();
inT8* dest_line = i_[dest_t];
const inT8* src_line = src.i_[src_t];
for (int i = 0; i < dim; ++i) {
if (dest_line[i] < src_line[i]) {
dest_line[i] = src_line[i];
max_line[i] = src_t;
}
}
} else {
int dim = f_.dim2();
float* dest_line = f_[dest_t];
const float* src_line = src.f_[src_t];
for (int i = 0; i < dim; ++i) {
if (dest_line[i] < src_line[i]) {
dest_line[i] = src_line[i];
max_line[i] = src_t;
}
}
}
}
// Runs maxpool backward, using maxes to index timesteps in *this.
void NetworkIO::MaxpoolBackward(const NetworkIO& fwd,
const GENERIC_2D_ARRAY<int>& maxes) {
ASSERT_HOST(!int_mode_);
int width = fwd.Width();
Zero();
StrideMap::Index index(fwd.stride_map_);
do {
int t = index.t();
const int* max_line = maxes[t];
const float* fwd_line = fwd.f_[t];
int num_features = fwd.f_.dim2();
for (int i = 0; i < num_features; ++i) {
f_[max_line[i]][i] = fwd_line[i];
}
} while (index.Increment());
}
// Returns the min over time of the maxes over features of the outputs.
float NetworkIO::MinOfMaxes() const {
float min_max = 0.0f;
int width = Width();
int num_features = NumFeatures();
for (int t = 0; t < width; ++t) {
float max_value = -MAX_FLOAT32;
if (int_mode_) {
const inT8* column = i_[t];
for (int i = 0; i < num_features; ++i) {
if (column[i] > max_value) max_value = column[i];
}
} else {
const float* column = f_[t];
for (int i = 0; i < num_features; ++i) {
if (column[i] > max_value) max_value = column[i];
}
}
if (t == 0 || max_value < min_max) min_max = max_value;
}
return min_max;
}
// Computes combined results for a combiner that chooses between an existing
// input and itself, with an additional output to indicate the choice.
void NetworkIO::CombineOutputs(const NetworkIO& base_output,
const NetworkIO& combiner_output) {
int no = base_output.NumFeatures();
ASSERT_HOST(combiner_output.NumFeatures() == no + 1);
Resize(base_output, no);
int width = Width();
if (int_mode_) {
// Number of outputs from base and final result.
for (int t = 0; t < width; ++t) {
inT8* out_line = i_[t];
const inT8* base_line = base_output.i_[t];
const inT8* comb_line = combiner_output.i_[t];
float base_weight = static_cast<float>(comb_line[no]) / MAX_INT8;
float boost_weight = 1.0f - base_weight;
for (int i = 0; i < no; ++i) {
out_line[i] = IntCastRounded(base_line[i] * base_weight +
comb_line[i] * boost_weight);
}
}
} else {
for (int t = 0; t < width; ++t) {
float* out_line = f_[t];
const float* base_line = base_output.f_[t];
const float* comb_line = combiner_output.f_[t];
float base_weight = comb_line[no];
float boost_weight = 1.0f - base_weight;
for (int i = 0; i < no; ++i) {
out_line[i] = base_line[i] * base_weight + comb_line[i] * boost_weight;
}
}
}
}
// Computes deltas for a combiner that chooses between 2 sets of inputs.
void NetworkIO::ComputeCombinerDeltas(const NetworkIO& fwd_deltas,
const NetworkIO& base_output) {
ASSERT_HOST(!int_mode_);
// Compute the deltas for the combiner.
int width = Width();
int no = NumFeatures() - 1;
ASSERT_HOST(fwd_deltas.NumFeatures() == no);
ASSERT_HOST(base_output.NumFeatures() == no);
// Number of outputs from base and final result.
for (int t = 0; t < width; ++t) {
const float* delta_line = fwd_deltas.f_[t];
const float* base_line = base_output.f_[t];
float* comb_line = f_[t];
float base_weight = comb_line[no];
float boost_weight = 1.0f - base_weight;
float max_base_delta = 0.0;
for (int i = 0; i < no; ++i) {
// What did the combiner actually produce?
float output = base_line[i] * base_weight + comb_line[i] * boost_weight;
// Reconstruct the target from the delta.
float comb_target = delta_line[i] + output;
comb_line[i] = comb_target - comb_line[i];
float base_delta = fabs(comb_target - base_line[i]);
if (base_delta > max_base_delta) max_base_delta = base_delta;
}
if (max_base_delta >= 0.5) {
// The base network got it wrong. The combiner should output the right
// answer and 0 for the base network.
comb_line[no] = 0.0 - base_weight;
} else {
// The base network was right. The combiner should flag that.
for (int i = 0; i < no; ++i) {
// All other targets are 0.
if (comb_line[i] > 0.0) comb_line[i] -= 1.0;
}
comb_line[no] = 1.0 - base_weight;
}
}
}
// Copies the array checking that the types match.
void NetworkIO::CopyAll(const NetworkIO& src) {
ASSERT_HOST(src.int_mode_ == int_mode_);
f_ = src.f_;
}
// Checks that both are floats and adds the src array to *this.
void NetworkIO::AddAllToFloat(const NetworkIO& src) {
ASSERT_HOST(!int_mode_);
ASSERT_HOST(!src.int_mode_);
f_ += src.f_;
}
// Subtracts the array from a float array. src must also be float.
void NetworkIO::SubtractAllFromFloat(const NetworkIO& src) {
ASSERT_HOST(!int_mode_);
ASSERT_HOST(!src.int_mode_);
f_ -= src.f_;
}
// Copies src to *this, with maxabs normalization to match scale.
void NetworkIO::CopyWithNormalization(const NetworkIO& src,
const NetworkIO& scale) {
ASSERT_HOST(!int_mode_);
ASSERT_HOST(!src.int_mode_);
ASSERT_HOST(!scale.int_mode_);
float src_max = src.f_.MaxAbs();
ASSERT_HOST(std::isfinite(src_max));
float scale_max = scale.f_.MaxAbs();
ASSERT_HOST(std::isfinite(scale_max));
if (src_max > 0.0f) {
float factor = scale_max / src_max;
for (int t = 0; t < src.Width(); ++t) {
const float* src_ptr = src.f_[t];
float* dest_ptr = f_[t];
for (int i = 0; i < src.f_.dim2(); ++i) dest_ptr[i] = src_ptr[i] * factor;
}
} else {
f_.Clear();
}
}
// Copies src to *this with independent reversal of the y dimension.
void NetworkIO::CopyWithYReversal(const NetworkIO& src) {
int num_features = src.NumFeatures();
Resize(src, num_features);
StrideMap::Index b_index(src.stride_map_);
do {
int width = b_index.MaxIndexOfDim(FD_WIDTH) + 1;
StrideMap::Index fwd_index(b_index);
StrideMap::Index rev_index(b_index);
rev_index.AddOffset(rev_index.MaxIndexOfDim(FD_HEIGHT), FD_HEIGHT);
do {
int fwd_t = fwd_index.t();
int rev_t = rev_index.t();
for (int x = 0; x < width; ++x) CopyTimeStepFrom(rev_t++, src, fwd_t++);
} while (fwd_index.AddOffset(1, FD_HEIGHT) &&
rev_index.AddOffset(-1, FD_HEIGHT));
} while (b_index.AddOffset(1, FD_BATCH));
}
// Copies src to *this with independent reversal of the x dimension.
void NetworkIO::CopyWithXReversal(const NetworkIO& src) {
int num_features = src.NumFeatures();
Resize(src, num_features);
StrideMap::Index b_index(src.stride_map_);
do {
StrideMap::Index y_index(b_index);
do {
StrideMap::Index fwd_index(y_index);
StrideMap::Index rev_index(y_index);
rev_index.AddOffset(rev_index.MaxIndexOfDim(FD_WIDTH), FD_WIDTH);
do {
CopyTimeStepFrom(rev_index.t(), src, fwd_index.t());
} while (fwd_index.AddOffset(1, FD_WIDTH) &&
rev_index.AddOffset(-1, FD_WIDTH));
} while (y_index.AddOffset(1, FD_HEIGHT));
} while (b_index.AddOffset(1, FD_BATCH));
}
// Copies src to *this with independent transpose of the x and y dimensions.
void NetworkIO::CopyWithXYTranspose(const NetworkIO& src) {
int num_features = src.NumFeatures();
stride_map_ = src.stride_map_;
stride_map_.TransposeXY();
ResizeToMap(src.int_mode(), stride_map_, num_features);
StrideMap::Index src_b_index(src.stride_map_);
StrideMap::Index dest_b_index(stride_map_);
do {
StrideMap::Index src_y_index(src_b_index);
StrideMap::Index dest_x_index(dest_b_index);
do {
StrideMap::Index src_x_index(src_y_index);
StrideMap::Index dest_y_index(dest_x_index);
do {
CopyTimeStepFrom(dest_y_index.t(), src, src_x_index.t());
} while (src_x_index.AddOffset(1, FD_WIDTH) &&
dest_y_index.AddOffset(1, FD_HEIGHT));
} while (src_y_index.AddOffset(1, FD_HEIGHT) &&
dest_x_index.AddOffset(1, FD_WIDTH));
} while (src_b_index.AddOffset(1, FD_BATCH) &&
dest_b_index.AddOffset(1, FD_BATCH));
}
// Copies src to *this, at the given feature_offset, returning the total
// feature offset after the copy. Multiple calls will stack outputs from
// multiple sources in feature space.
int NetworkIO::CopyPacking(const NetworkIO& src, int feature_offset) {
ASSERT_HOST(int_mode_ == src.int_mode_);
int width = src.Width();
ASSERT_HOST(width <= Width());
int num_features = src.NumFeatures();
ASSERT_HOST(num_features + feature_offset <= NumFeatures());
if (int_mode_) {
for (int t = 0; t < width; ++t) {
memcpy(i_[t] + feature_offset, src.i_[t],
num_features * sizeof(i_[t][0]));
}
for (int t = width; t < i_.dim1(); ++t) {
memset(i_[t], 0, num_features * sizeof(i_[t][0]));
}
} else {
for (int t = 0; t < width; ++t) {
memcpy(f_[t] + feature_offset, src.f_[t],
num_features * sizeof(f_[t][0]));
}
for (int t = width; t < f_.dim1(); ++t) {
memset(f_[t], 0, num_features * sizeof(f_[t][0]));
}
}
return num_features + feature_offset;
}
// Opposite of CopyPacking, fills *this with a part of src, starting at
// feature_offset, and picking num_features.
void NetworkIO::CopyUnpacking(const NetworkIO& src, int feature_offset,
int num_features) {
Resize(src, num_features);
int width = src.Width();
ASSERT_HOST(num_features + feature_offset <= src.NumFeatures());
if (int_mode_) {
for (int t = 0; t < width; ++t) {
memcpy(i_[t], src.i_[t] + feature_offset,
num_features * sizeof(i_[t][0]));
}
} else {
for (int t = 0; t < width; ++t) {
memcpy(f_[t], src.f_[t] + feature_offset,
num_features * sizeof(f_[t][0]));
}
}
}
// Transposes the float part of *this into dest.
void NetworkIO::Transpose(TransposedArray* dest) const {
int width = Width();
dest->ResizeNoInit(NumFeatures(), width);
for (int t = 0; t < width; ++t) dest->WriteStrided(t, f_[t]);
}
// Clips the content of a single time-step to +/-range.
void NetworkIO::ClipVector(int t, float range) {
ASSERT_HOST(!int_mode_);
float* v = f_[t];
int dim = f_.dim2();
for (int i = 0; i < dim; ++i)
v[i] = ClipToRange(v[i], -range, range);
}
} // namespace tesseract.

341
lstm/networkio.h Normal file
View File

@ -0,0 +1,341 @@
///////////////////////////////////////////////////////////////////////
// File: networkio.h
// Description: Network input/output data, allowing float/int implementations.
// Author: Ray Smith
// Created: Tue Jun 17 08:43:11 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_NETWORKIO_H_
#define TESSERACT_LSTM_NETWORKIO_H_
#include <math.h>
#include <stdio.h>
#include <vector>
#include "genericvector.h"
#include "helpers.h"
#include "static_shape.h"
#include "stridemap.h"
#include "weightmatrix.h"
struct Pix;
namespace tesseract {
// Class to contain all the input/output of a network, allowing for fixed or
// variable-strided 2d to 1d mapping, and float or inT8 values. Provides
// enough calculating functions to hide the detail of the implementation.
class NetworkIO {
public:
NetworkIO() : int_mode_(false) {}
// Resizes the array (and stride), avoiding realloc if possible, to the given
// size from various size specs:
// Same stride size, but given number of features.
void Resize(const NetworkIO& src, int num_features) {
ResizeToMap(src.int_mode(), src.stride_map(), num_features);
}
// Resizes to a specific size as a 2-d temp buffer. No batches, no y-dim.
void Resize2d(bool int_mode, int width, int num_features);
// Resizes forcing a float representation with the stridemap of src and the
// given number of features.
void ResizeFloat(const NetworkIO& src, int num_features) {
ResizeToMap(false, src.stride_map(), num_features);
}
// Resizes to a specific stride_map.
void ResizeToMap(bool int_mode, const StrideMap& stride_map,
int num_features);
// Shrinks image size by x_scale,y_scale, and use given number of features.
void ResizeScaled(const NetworkIO& src, int x_scale, int y_scale,
int num_features);
// Resizes to just 1 x-coord, whatever the input.
void ResizeXTo1(const NetworkIO& src, int num_features);
// Initialize all the array to zero.
void Zero();
// Initializes to zero all elements of the array that do not correspond to
// valid image positions. (If a batch of different-sized images are packed
// together, then there will be padding pixels.)
void ZeroInvalidElements();
// Sets up the array from the given image, using the currently set int_mode_.
// If the image width doesn't match the shape, the image is truncated or
// padded with noise to match.
void FromPix(const StaticShape& shape, const Pix* pix, TRand* randomizer);
// Sets up the array from the given set of images, using the currently set
// int_mode_. If the image width doesn't match the shape, the images are
// truncated or padded with noise to match.
void FromPixes(const StaticShape& shape, const std::vector<const Pix*>& pixes,
TRand* randomizer);
// Copies the given pix to *this at the given batch index, stretching and
// clipping the pixel values so that [black, black + 2*contrast] maps to the
// dynamic range of *this, ie [-1,1] for a float and (-127,127) for int.
// This is a 2-d operation in the sense that the output depth is the number
// of input channels, the height is the height of the image, and the width
// is the width of the image, or truncated/padded with noise if the width
// is a fixed size.
void Copy2DImage(int batch, Pix* pix, float black, float contrast,
TRand* randomizer);
// Copies the given pix to *this at the given batch index, as Copy2DImage
// above, except that the output depth is the height of the input image, the
// output height is 1, and the output width as for Copy2DImage.
// The image is thus treated as a 1-d set of vertical pixel strips.
void Copy1DGreyImage(int batch, Pix* pix, float black, float contrast,
TRand* randomizer);
// Helper stores the pixel value in i_ or f_ according to int_mode_.
// t: is the index from the StrideMap corresponding to the current
// [batch,y,x] position
// f: is the index into the depth/channel
// pixel: the value of the pixel from the image (in one channel)
// black: the pixel value to map to the lowest of the range of *this
// contrast: the range of pixel values to stretch to half the range of *this.
void SetPixel(int t, int f, int pixel, float black, float contrast);
// Converts the array to a Pix. Must be pixDestroyed after use.
Pix* ToPix() const;
// Prints the first and last num timesteps of the array for each feature.
void Print(int num) const;
// Returns the timestep width.
int Width() const {
return int_mode_ ? i_.dim1() : f_.dim1();
}
// Returns the number of features.
int NumFeatures() const {
return int_mode_ ? i_.dim2() : f_.dim2();
}
// Accessor to a timestep of the float matrix.
float* f(int t) {
ASSERT_HOST(!int_mode_);
return f_[t];
}
const float* f(int t) const {
ASSERT_HOST(!int_mode_);
return f_[t];
}
const inT8* i(int t) const {
ASSERT_HOST(int_mode_);
return i_[t];
}
bool int_mode() const {
return int_mode_;
}
void set_int_mode(bool is_quantized) {
int_mode_ = is_quantized;
}
const StrideMap& stride_map() const {
return stride_map_;
}
void set_stride_map(const StrideMap& map) {
stride_map_ = map;
}
const GENERIC_2D_ARRAY<float>& float_array() const { return f_; }
GENERIC_2D_ARRAY<float>* mutable_float_array() { return &f_; }
// Copies a single time step from src.
void CopyTimeStepFrom(int dest_t, const NetworkIO& src, int src_t);
// Copies a part of single time step from src.
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features,
const NetworkIO& src, int src_t, int src_offset);
// Zeroes a single time step.
void ZeroTimeStep(int t) { ZeroTimeStepGeneral(t, 0, NumFeatures()); }
void ZeroTimeStepGeneral(int t, int offset, int num_features);
// Sets the given range to random values.
void Randomize(int t, int offset, int num_features, TRand* randomizer);
// Helper returns the label and score of the best choice over a range.
int BestChoiceOverRange(int t_start, int t_end, int not_this, int null_ch,
float* rating, float* certainty) const;
// Helper returns the rating and certainty of the choice over a range in t.
void ScoresOverRange(int t_start, int t_end, int choice, int null_ch,
float* rating, float* certainty) const;
// Returns the index (label) of the best value at the given timestep,
// and if not null, sets the score to the log of the corresponding value.
int BestLabel(int t, float* score) const {
return BestLabel(t, -1, -1, score);
}
// Returns the index (label) of the best value at the given timestep,
// excluding not_this and not_that, and if not null, sets the score to the
// log of the corresponding value.
int BestLabel(int t, int not_this, int not_that, float* score) const;
// Returns the best start position out of range (into which both start and end
// must fit) to obtain the highest cumulative score for the given labels.
int PositionOfBestMatch(const GenericVector<int>& labels, int start,
int end) const;
// Returns the cumulative score of the given labels starting at start, and
// using one label per time-step.
double ScoreOfLabels(const GenericVector<int>& labels, int start) const;
// Helper function sets all the outputs for a single timestep, such that
// label has value ok_score, and the other labels share 1 - ok_score.
// Assumes float mode.
void SetActivations(int t, int label, float ok_score);
// Modifies the values, only if needed, so that the given label is
// the winner at the given time step t.
// Assumes float mode.
void EnsureBestLabel(int t, int label);
// Helper function converts prob to certainty taking the minimum into account.
static float ProbToCertainty(float prob);
// Returns true if there is any bad value that is suspiciously like a GT
// error. Assuming that *this is the difference(gradient) between target
// and forward output, returns true if there is a large negative value
// (correcting a very confident output) for which there is no corresponding
// positive value in an adjacent timestep for the same feature index. This
// allows the box-truthed samples to make fine adjustments to position while
// stopping other disagreements of confident output with ground truth.
bool AnySuspiciousTruth(float confidence_thr) const;
// Reads a single timestep to floats in the range [-1, 1].
void ReadTimeStep(int t, double* output) const;
// Adds a single timestep to floats.
void AddTimeStep(int t, double* inout) const;
// Adds part of a single timestep to floats.
void AddTimeStepPart(int t, int offset, int num_features, float* inout) const;
// Writes a single timestep from floats in the range [-1, 1].
void WriteTimeStep(int t, const double* input);
// Writes a single timestep from floats in the range [-1, 1] writing only
// num_features elements of input to (*this)[t], starting at offset.
void WriteTimeStepPart(int t, int offset, int num_features,
const double* input);
// Maxpools a single time step from src.
void MaxpoolTimeStep(int dest_t, const NetworkIO& src, int src_t,
int* max_line);
// Runs maxpool backward, using maxes to index timesteps in *this.
void MaxpoolBackward(const NetworkIO& fwd,
const GENERIC_2D_ARRAY<int>& maxes);
// Returns the min over time of the maxes over features of the outputs.
float MinOfMaxes() const;
// Returns the min over time.
float Max() const { return int_mode_ ? i_.Max() : f_.Max(); }
// Computes combined results for a combiner that chooses between an existing
// input and itself, with an additional output to indicate the choice.
void CombineOutputs(const NetworkIO& base_output,
const NetworkIO& combiner_output);
// Computes deltas for a combiner that chooses between 2 sets of inputs.
void ComputeCombinerDeltas(const NetworkIO& fwd_deltas,
const NetworkIO& base_output);
// Copies the array checking that the types match.
void CopyAll(const NetworkIO& src);
// Adds the array to a float array, with scaling to [-1, 1] if the src is int.
void AddAllToFloat(const NetworkIO& src);
// Subtracts the array from a float array. src must also be float.
void SubtractAllFromFloat(const NetworkIO& src);
// Copies src to *this, with maxabs normalization to match scale.
void CopyWithNormalization(const NetworkIO& src, const NetworkIO& scale);
// Multiplies the float data by the given factor.
void ScaleFloatBy(float factor) { f_ *= factor; }
// Copies src to *this with independent reversal of the y dimension.
void CopyWithYReversal(const NetworkIO& src);
// Copies src to *this with independent reversal of the x dimension.
void CopyWithXReversal(const NetworkIO& src);
// Copies src to *this with independent transpose of the x and y dimensions.
void CopyWithXYTranspose(const NetworkIO& src);
// Copies src to *this, at the given feature_offset, returning the total
// feature offset after the copy. Multiple calls will stack outputs from
// multiple sources in feature space.
int CopyPacking(const NetworkIO& src, int feature_offset);
// Opposite of CopyPacking, fills *this with a part of src, starting at
// feature_offset, and picking num_features. Resizes *this to match.
void CopyUnpacking(const NetworkIO& src, int feature_offset,
int num_features);
// Transposes the float part of *this into dest.
void Transpose(TransposedArray* dest) const;
// Clips the content of a single time-step to +/-range.
void ClipVector(int t, float range);
// Applies Func to timestep t of *this (u) and multiplies the result by v
// component-wise, putting the product in *product.
// *this and v may be int or float, but must match. The outputs are double.
template <class Func>
void FuncMultiply(const NetworkIO& v_io, int t, double* product) {
Func f;
ASSERT_HOST(!int_mode_);
ASSERT_HOST(!v_io.int_mode_);
int dim = f_.dim2();
if (int_mode_) {
const inT8* u = i_[t];
const inT8* v = v_io.i_[t];
for (int i = 0; i < dim; ++i) {
product[i] = f(u[i] / static_cast<double>(MAX_INT8)) * v[i] /
static_cast<double>(MAX_INT8);
}
} else {
const float* u = f_[t];
const float* v = v_io.f_[t];
for (int i = 0; i < dim; ++i) {
product[i] = f(u[i]) * v[i];
}
}
}
// Applies Func to *this (u) at u_t, and multiplies the result by v[v_t] * w,
// component-wise, putting the product in *product.
// All NetworkIOs are assumed to be float.
template <class Func>
void FuncMultiply3(int u_t, const NetworkIO& v_io, int v_t, const double* w,
double* product) const {
ASSERT_HOST(!int_mode_);
ASSERT_HOST(!v_io.int_mode_);
Func f;
const float* u = f_[u_t];
const float* v = v_io.f_[v_t];
int dim = f_.dim2();
for (int i = 0; i < dim; ++i) {
product[i] = f(u[i]) * v[i] * w[i];
}
}
// Applies Func to *this (u) at u_t, and multiplies the result by v[v_t] * w,
// component-wise, adding the product to *product.
// All NetworkIOs are assumed to be float.
template <class Func>
void FuncMultiply3Add(const NetworkIO& v_io, int t, const double* w,
double* product) const {
ASSERT_HOST(!int_mode_);
ASSERT_HOST(!v_io.int_mode_);
Func f;
const float* u = f_[t];
const float* v = v_io.f_[t];
int dim = f_.dim2();
for (int i = 0; i < dim; ++i) {
product[i] += f(u[i]) * v[i] * w[i];
}
}
// Applies Func1 to *this (u), Func2 to v, and multiplies the result by w,
// component-wise, putting the product in product, all at timestep t, except
// w, which is a simple array. All NetworkIOs are assumed to be float.
template <class Func1, class Func2>
void Func2Multiply3(const NetworkIO& v_io, int t, const double* w,
double* product) const {
ASSERT_HOST(!int_mode_);
ASSERT_HOST(!v_io.int_mode_);
Func1 f;
Func2 g;
const float* u = f_[t];
const float* v = v_io.f_[t];
int dim = f_.dim2();
for (int i = 0; i < dim; ++i) {
product[i] = f(u[i]) * g(v[i]) * w[i];
}
}
private:
// Choice of float vs 8 bit int for data.
GENERIC_2D_ARRAY<float> f_;
GENERIC_2D_ARRAY<inT8> i_;
// Which of f_ and i_ are we actually using.
bool int_mode_;
// Stride for 2d input data.
StrideMap stride_map_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_NETWORKIO_H_

257
lstm/networkscratch.h Normal file
View File

@ -0,0 +1,257 @@
///////////////////////////////////////////////////////////////////////
// File: networkscratch.h
// Description: Scratch space for Network layers that hides distinction
// between float/int implementations.
// Author: Ray Smith
// Created: Thu Jun 19 10:50:29 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_NETWORKSCRATCH_H_
#define TESSERACT_LSTM_NETWORKSCRATCH_H_
#include "genericvector.h"
#include "matrix.h"
#include "networkio.h"
#include "svutil.h"
#include "tprintf.h"
namespace tesseract {
// Generic scratch space for network layers. Provides NetworkIO that can store
// a complete set (over time) of intermediates, and GenericVector<float>
// scratch space that auto-frees after use. The aim here is to provide a set
// of temporary buffers to network layers that can be reused between layers
// and don't have to be reallocated on each call.
class NetworkScratch {
public:
NetworkScratch() : int_mode_(false) {}
~NetworkScratch() {}
// Sets the network representation. If the representation is integer, then
// default (integer) NetworkIOs are separated from the always-float variety.
// This saves memory by having separate int-specific and float-specific
// stacks. If the network representation is float, then all NetworkIOs go
// to the float stack.
void set_int_mode(bool int_mode) {
int_mode_ = int_mode;
}
// Class that acts like a NetworkIO (by having an implicit cast operator),
// yet actually holds a pointer to NetworkIOs in the source NetworkScratch,
// and knows how to unstack the borrowed pointers on destruction.
class IO {
public:
// The NetworkIO should be sized after construction.
IO(const NetworkIO& src, NetworkScratch* scratch)
: int_mode_(scratch->int_mode_ && src.int_mode()),
scratch_space_(scratch) {
network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
: scratch_space_->float_stack_.Borrow();
}
// Default constructor for arrays. Use one of the Resize functions
// below to initialize and size.
IO() : int_mode_(false), network_io_(NULL), scratch_space_(NULL) {}
~IO() {
if (scratch_space_ == NULL) {
ASSERT_HOST(network_io_ == NULL);
} else if (int_mode_) {
scratch_space_->int_stack_.Return(network_io_);
} else {
scratch_space_->float_stack_.Return(network_io_);
}
}
// Resizes the array (and stride), avoiding realloc if possible, to the
// size from various size specs:
// Same time size, given number of features.
void Resize(const NetworkIO& src, int num_features,
NetworkScratch* scratch) {
if (scratch_space_ == NULL) {
int_mode_ = scratch->int_mode_ && src.int_mode();
scratch_space_ = scratch;
network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
: scratch_space_->float_stack_.Borrow();
}
network_io_->Resize(src, num_features);
}
// Resizes to a specific size as a temp buffer. No batches, no y-dim.
void Resize2d(bool int_mode, int width, int num_features,
NetworkScratch* scratch) {
if (scratch_space_ == NULL) {
int_mode_ = scratch->int_mode_ && int_mode;
scratch_space_ = scratch;
network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
: scratch_space_->float_stack_.Borrow();
}
network_io_->Resize2d(int_mode, width, num_features);
}
// Resize forcing a float representation with the width of src and the given
// number of features.
void ResizeFloat(const NetworkIO& src, int num_features,
NetworkScratch* scratch) {
if (scratch_space_ == NULL) {
int_mode_ = false;
scratch_space_ = scratch;
network_io_ = scratch_space_->float_stack_.Borrow();
}
network_io_->ResizeFloat(src, num_features);
}
// Returns a ref to a NetworkIO that enables *this to be treated as if
// it were just a NetworkIO*.
NetworkIO& operator*() {
return *network_io_;
}
NetworkIO* operator->() {
return network_io_;
}
operator NetworkIO*() {
return network_io_;
}
private:
// True if this is from the always-float stack, otherwise the default stack.
bool int_mode_;
// The NetworkIO that we have borrowed from the scratch_space_.
NetworkIO* network_io_;
// The source scratch_space_. Borrowed pointer, used to free the
// NetworkIO. Don't delete!
NetworkScratch* scratch_space_;
}; // class IO.
// Class that acts like a fixed array of float, yet actually uses space
// from a GenericVector<float> in the source NetworkScratch, and knows how
// to unstack the borrowed vector on destruction.
class FloatVec {
public:
// The array will have size elements in it, uninitialized.
FloatVec(int size, NetworkScratch* scratch)
: vec_(NULL), scratch_space_(scratch) {
Init(size, scratch);
}
// Default constructor is for arrays. Use Init to setup.
FloatVec() : vec_(NULL), data_(NULL), scratch_space_(NULL) {}
~FloatVec() {
if (scratch_space_ != NULL) scratch_space_->vec_stack_.Return(vec_);
}
void Init(int size, NetworkScratch* scratch) {
if (scratch_space_ != NULL && vec_ != NULL)
scratch_space_->vec_stack_.Return(vec_);
scratch_space_ = scratch;
vec_ = scratch_space_->vec_stack_.Borrow();
vec_->resize_no_init(size);
data_ = &(*vec_)[0];
}
// Use the cast operator instead of operator[] so the FloatVec can be used
// as a double* argument to a function call.
operator double*() const { return data_; }
double* get() { return data_; }
private:
// Vector borrowed from the scratch space. Use Return to free it.
GenericVector<double>* vec_;
// Short-cut pointer to the underlying array.
double* data_;
// The source scratch_space_. Borrowed pointer, used to free the
// vector. Don't delete!
NetworkScratch* scratch_space_;
}; // class FloatVec
// Class that acts like a 2-D array of double, yet actually uses space
// from the source NetworkScratch, and knows how to unstack the borrowed
// array on destruction.
class GradientStore {
public:
// Default constructor is for arrays. Use Init to setup.
GradientStore() : array_(NULL), scratch_space_(NULL) {}
~GradientStore() {
if (scratch_space_ != NULL) scratch_space_->array_stack_.Return(array_);
}
void Init(int size1, int size2, NetworkScratch* scratch) {
if (scratch_space_ != NULL && array_ != NULL)
scratch_space_->array_stack_.Return(array_);
scratch_space_ = scratch;
array_ = scratch_space_->array_stack_.Borrow();
array_->Resize(size1, size2, 0.0);
}
// Accessors to get to the underlying TransposedArray.
TransposedArray* get() const { return array_; }
const TransposedArray& operator*() const { return *array_; }
private:
// Array borrowed from the scratch space. Use Return to free it.
TransposedArray* array_;
// The source scratch_space_. Borrowed pointer, used to free the
// vector. Don't delete!
NetworkScratch* scratch_space_;
}; // class GradientStore
// Class that does the work of holding a stack of objects, a stack pointer
// and a vector of in-use flags, so objects can be returned out of order.
// It is safe to attempt to Borrow/Return in multiple threads.
template<typename T> class Stack {
public:
Stack() : stack_top_(0) {
}
// Lends out the next free item, creating one if none available, sets
// the used flags and increments the stack top.
T* Borrow() {
SVAutoLock lock(&mutex_);
if (stack_top_ == stack_.size()) {
stack_.push_back(new T);
flags_.push_back(false);
}
flags_[stack_top_] = true;
return stack_[stack_top_++];
}
// Takes back the given item, and marks it free. Item does not have to be
// the most recently lent out, but free slots don't get re-used until the
// blocking item is returned. The assumption is that there will only be
// small, temporary variations from true stack use. (Determined by the order
// of destructors within a local scope.)
void Return(T* item) {
SVAutoLock lock(&mutex_);
// Linear search will do.
int index = stack_top_ - 1;
while (index >= 0 && stack_[index] != item) --index;
if (index >= 0) flags_[index] = false;
while (stack_top_ > 0 && !flags_[stack_top_ - 1]) --stack_top_;
}
private:
PointerVector<T> stack_;
GenericVector<bool> flags_;
int stack_top_;
SVMutex mutex_;
}; // class Stack.
private:
// If true, the network weights are inT8, if false, float.
bool int_mode_;
// Stacks of NetworkIO and GenericVector<float>. Once allocated, they are not
// deleted until the NetworkScratch is deleted.
Stack<NetworkIO> int_stack_;
Stack<NetworkIO> float_stack_;
Stack<GenericVector<double> > vec_stack_;
Stack<TransposedArray> array_stack_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_NETWORKSCRATCH_H_

180
lstm/parallel.cpp Normal file
View File

@ -0,0 +1,180 @@
/////////////////////////////////////////////////////////////////////////
// File: parallel.cpp
// Description: Runs networks in parallel on the same input.
// Author: Ray Smith
// Created: Thu May 02 08:06:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "parallel.h"
#include <omp.h>
#include "functions.h" // For conditional undef of _OPENMP.
#include "networkscratch.h"
namespace tesseract {
// ni_ and no_ will be set by AddToStack.
Parallel::Parallel(const STRING& name, NetworkType type) : Plumbing(name) {
type_ = type;
}
Parallel::~Parallel() {
}
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
StaticShape Parallel::OutputShape(const StaticShape& input_shape) const {
StaticShape result = stack_[0]->OutputShape(input_shape);
int stack_size = stack_.size();
for (int i = 1; i < stack_size; ++i) {
StaticShape shape = stack_[i]->OutputShape(input_shape);
result.set_depth(result.depth() + shape.depth());
}
return result;
}
// Runs forward propagation of activations on the input line.
// See NetworkCpp for a detailed discussion of the arguments.
void Parallel::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
bool parallel_debug = false;
// If this parallel is a replicator of convolvers, or holds a 1-d LSTM pair,
// or a 2-d LSTM quad, do debug locally, and don't pass the flag on.
if (debug && type_ != NT_PARALLEL) {
parallel_debug = true;
debug = false;
}
int stack_size = stack_.size();
if (type_ == NT_PAR_2D_LSTM) {
// Special case, run parallel in parallel.
GenericVector<NetworkScratch::IO> results;
results.init_to_size(stack_size, NetworkScratch::IO());
for (int i = 0; i < stack_size; ++i) {
results[i].Resize(input, stack_[i]->NumOutputs(), scratch);
}
#ifdef _OPENMP
#pragma omp parallel for num_threads(stack_size)
#endif
for (int i = 0; i < stack_size; ++i) {
stack_[i]->Forward(debug, input, NULL, scratch, results[i]);
}
// Now pack all the results (serially) into the output.
int out_offset = 0;
output->Resize(*results[0], NumOutputs());
for (int i = 0; i < stack_size; ++i) {
out_offset = output->CopyPacking(*results[i], out_offset);
}
} else {
// Revolving intermediate result.
NetworkScratch::IO result(input, scratch);
// Source for divided replicated.
NetworkScratch::IO source_part;
TransposedArray* src_transpose = NULL;
if (training() && type_ == NT_REPLICATED) {
// Make a transposed copy of the input.
input.Transpose(&transposed_input_);
src_transpose = &transposed_input_;
}
// Run each network, putting the outputs into result.
int input_offset = 0;
int out_offset = 0;
for (int i = 0; i < stack_size; ++i) {
stack_[i]->Forward(debug, input, src_transpose, scratch, result);
// All networks must have the same output width
if (i == 0) {
output->Resize(*result, NumOutputs());
} else {
ASSERT_HOST(result->Width() == output->Width());
}
out_offset = output->CopyPacking(*result, out_offset);
}
}
if (parallel_debug) {
DisplayForward(*output);
}
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool Parallel::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
// If this parallel is a replicator of convolvers, or holds a 1-d LSTM pair,
// or a 2-d LSTM quad, do debug locally, and don't pass the flag on.
if (debug && type_ != NT_PARALLEL) {
DisplayBackward(fwd_deltas);
debug = false;
}
int stack_size = stack_.size();
if (type_ == NT_PAR_2D_LSTM) {
// Special case, run parallel in parallel.
GenericVector<NetworkScratch::IO> in_deltas, out_deltas;
in_deltas.init_to_size(stack_size, NetworkScratch::IO());
out_deltas.init_to_size(stack_size, NetworkScratch::IO());
// Split the forward deltas for each stack element.
int feature_offset = 0;
int out_offset = 0;
for (int i = 0; i < stack_.size(); ++i) {
int num_features = stack_[i]->NumOutputs();
in_deltas[i].Resize(fwd_deltas, num_features, scratch);
out_deltas[i].Resize(fwd_deltas, stack_[i]->NumInputs(), scratch);
in_deltas[i]->CopyUnpacking(fwd_deltas, feature_offset, num_features);
feature_offset += num_features;
}
#ifdef _OPENMP
#pragma omp parallel for num_threads(stack_size)
#endif
for (int i = 0; i < stack_size; ++i) {
stack_[i]->Backward(debug, *in_deltas[i], scratch,
i == 0 ? back_deltas : out_deltas[i]);
}
if (needs_to_backprop_) {
for (int i = 1; i < stack_size; ++i) {
back_deltas->AddAllToFloat(*out_deltas[i]);
}
}
} else {
// Revolving partial deltas.
NetworkScratch::IO in_deltas(fwd_deltas, scratch);
// The sum of deltas from different sources, which will eventually go into
// back_deltas.
NetworkScratch::IO out_deltas;
int feature_offset = 0;
int out_offset = 0;
for (int i = 0; i < stack_.size(); ++i) {
int num_features = stack_[i]->NumOutputs();
in_deltas->CopyUnpacking(fwd_deltas, feature_offset, num_features);
feature_offset += num_features;
if (stack_[i]->Backward(debug, *in_deltas, scratch, back_deltas)) {
if (i == 0) {
out_deltas.ResizeFloat(*back_deltas, back_deltas->NumFeatures(),
scratch);
out_deltas->CopyAll(*back_deltas);
} else if (back_deltas->NumFeatures() == out_deltas->NumFeatures()) {
// Widths are allowed to be different going back, as we may have
// input nets, so only accumulate the deltas if the widths are the
// same.
out_deltas->AddAllToFloat(*back_deltas);
}
}
}
if (needs_to_backprop_) back_deltas->CopyAll(*out_deltas);
}
if (needs_to_backprop_) back_deltas->ScaleFloatBy(1.0f / stack_size);
return needs_to_backprop_;
}
} // namespace tesseract.

87
lstm/parallel.h Normal file
View File

@ -0,0 +1,87 @@
///////////////////////////////////////////////////////////////////////
// File: parallel.h
// Description: Runs networks in parallel on the same input.
// Author: Ray Smith
// Created: Thu May 02 08:02:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_PARALLEL_H_
#define TESSERACT_LSTM_PARALLEL_H_
#include "plumbing.h"
namespace tesseract {
// Runs multiple networks in parallel, interlacing their outputs.
class Parallel : public Plumbing {
public:
// ni_ and no_ will be set by AddToStack.
Parallel(const STRING& name, NetworkType type);
virtual ~Parallel();
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
virtual StaticShape OutputShape(const StaticShape& input_shape) const;
virtual STRING spec() const {
STRING spec;
if (type_ == NT_PAR_2D_LSTM) {
// We have 4 LSTMs operating in parallel here, so the size of each is
// the number of outputs/4.
spec.add_str_int("L2xy", no_ / 4);
} else if (type_ == NT_PAR_RL_LSTM) {
// We have 2 LSTMs operating in parallel here, so the size of each is
// the number of outputs/2.
if (stack_[0]->type() == NT_LSTM_SUMMARY)
spec.add_str_int("Lbxs", no_ / 2);
else
spec.add_str_int("Lbx", no_ / 2);
} else {
if (type_ == NT_REPLICATED) {
spec.add_str_int("R", stack_.size());
spec += "(";
spec += stack_[0]->spec();
} else {
spec = "(";
for (int i = 0; i < stack_.size(); ++i) spec += stack_[i]->spec();
}
spec += ")";
}
return spec;
}
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output);
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas);
private:
// If *this is a NT_REPLICATED, then it feeds a replicated network with
// identical inputs, and it would be extremely wasteful for them to each
// calculate and store the same transpose of the inputs, so Parallel does it
// and passes a pointer to the replicated network, allowing it to use the
// transpose on the next call to Backward.
TransposedArray transposed_input_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_PARALLEL_H_

233
lstm/plumbing.cpp Normal file
View File

@ -0,0 +1,233 @@
///////////////////////////////////////////////////////////////////////
// File: plumbing.cpp
// Description: Base class for networks that organize other networks
// eg series or parallel.
// Author: Ray Smith
// Created: Mon May 12 08:17:34 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "plumbing.h"
namespace tesseract {
// ni_ and no_ will be set by AddToStack.
Plumbing::Plumbing(const STRING& name)
: Network(NT_PARALLEL, name, 0, 0) {
}
Plumbing::~Plumbing() {
}
// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
void Plumbing::SetEnableTraining(bool state) {
Network::SetEnableTraining(state);
for (int i = 0; i < stack_.size(); ++i)
stack_[i]->SetEnableTraining(state);
}
// Sets flags that control the action of the network. See NetworkFlags enum
// for bit values.
void Plumbing::SetNetworkFlags(uinT32 flags) {
Network::SetNetworkFlags(flags);
for (int i = 0; i < stack_.size(); ++i)
stack_[i]->SetNetworkFlags(flags);
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
// Note that randomizer is a borrowed pointer that should outlive the network
// and should not be deleted by any of the networks.
// Returns the number of weights initialized.
int Plumbing::InitWeights(float range, TRand* randomizer) {
num_weights_ = 0;
for (int i = 0; i < stack_.size(); ++i)
num_weights_ += stack_[i]->InitWeights(range, randomizer);
return num_weights_;
}
// Converts a float network to an int network.
void Plumbing::ConvertToInt() {
for (int i = 0; i < stack_.size(); ++i)
stack_[i]->ConvertToInt();
}
// Provides a pointer to a TRand for any networks that care to use it.
// Note that randomizer is a borrowed pointer that should outlive the network
// and should not be deleted by any of the networks.
void Plumbing::SetRandomizer(TRand* randomizer) {
for (int i = 0; i < stack_.size(); ++i)
stack_[i]->SetRandomizer(randomizer);
}
// Adds the given network to the stack.
void Plumbing::AddToStack(Network* network) {
if (stack_.empty()) {
ni_ = network->NumInputs();
no_ = network->NumOutputs();
} else if (type_ == NT_SERIES) {
// ni is input of first, no output of last, others match output to input.
ASSERT_HOST(no_ == network->NumInputs());
no_ = network->NumOutputs();
} else {
// All parallel types. Output is sum of outputs, inputs all match.
ASSERT_HOST(ni_ == network->NumInputs());
no_ += network->NumOutputs();
}
stack_.push_back(network);
}
// Sets needs_to_backprop_ to needs_backprop and calls on sub-network
// according to needs_backprop || any weights in this network.
bool Plumbing::SetupNeedsBackprop(bool needs_backprop) {
needs_to_backprop_ = needs_backprop;
bool retval = needs_backprop;
for (int i = 0; i < stack_.size(); ++i) {
if (stack_[i]->SetupNeedsBackprop(needs_backprop))
retval = true;
}
return retval;
}
// Returns an integer reduction factor that the network applies to the
// time sequence. Assumes that any 2-d is already eliminated. Used for
// scaling bounding boxes of truth data.
// WARNING: if GlobalMinimax is used to vary the scale, this will return
// the last used scale factor. Call it before any forward, and it will return
// the minimum scale factor of the paths through the GlobalMinimax.
int Plumbing::XScaleFactor() const {
return stack_[0]->XScaleFactor();
}
// Provides the (minimum) x scale factor to the network (of interest only to
// input units) so they can determine how to scale bounding boxes.
void Plumbing::CacheXScaleFactor(int factor) {
for (int i = 0; i < stack_.size(); ++i) {
stack_[i]->CacheXScaleFactor(factor);
}
}
// Provides debug output on the weights.
void Plumbing::DebugWeights() {
for (int i = 0; i < stack_.size(); ++i)
stack_[i]->DebugWeights();
}
// Returns a set of strings representing the layer-ids of all layers below.
void Plumbing::EnumerateLayers(const STRING* prefix,
GenericVector<STRING>* layers) const {
for (int i = 0; i < stack_.size(); ++i) {
STRING layer_name;
if (prefix) layer_name = *prefix;
layer_name.add_str_int(":", i);
if (stack_[i]->IsPlumbingType()) {
Plumbing* plumbing = reinterpret_cast<Plumbing*>(stack_[i]);
plumbing->EnumerateLayers(&layer_name, layers);
} else {
layers->push_back(layer_name);
}
}
}
// Returns a pointer to the network layer corresponding to the given id.
Network* Plumbing::GetLayer(const char* id) const {
char* next_id;
int index = strtol(id, &next_id, 10);
if (index < 0 || index >= stack_.size()) return NULL;
if (stack_[index]->IsPlumbingType()) {
Plumbing* plumbing = reinterpret_cast<Plumbing*>(stack_[index]);
ASSERT_HOST(*next_id == ':');
return plumbing->GetLayer(next_id + 1);
}
return stack_[index];
}
// Returns a pointer to the learning rate for the given layer id.
float* Plumbing::LayerLearningRatePtr(const char* id) const {
char* next_id;
int index = strtol(id, &next_id, 10);
if (index < 0 || index >= stack_.size()) return NULL;
if (stack_[index]->IsPlumbingType()) {
Plumbing* plumbing = reinterpret_cast<Plumbing*>(stack_[index]);
ASSERT_HOST(*next_id == ':');
return plumbing->LayerLearningRatePtr(next_id + 1);
}
if (index < 0 || index >= learning_rates_.size()) return NULL;
return &learning_rates_[index];
}
// Writes to the given file. Returns false in case of error.
bool Plumbing::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
inT32 size = stack_.size();
// Can't use PointerVector::Serialize here as we need a special DeSerialize.
if (fp->FWrite(&size, sizeof(size), 1) != 1) return false;
for (int i = 0; i < size; ++i)
if (!stack_[i]->Serialize(fp)) return false;
if ((network_flags_ & NF_LAYER_SPECIFIC_LR) &&
!learning_rates_.Serialize(fp)) {
return false;
}
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool Plumbing::DeSerialize(bool swap, TFile* fp) {
stack_.truncate(0);
no_ = 0; // We will be modifying this as we AddToStack.
inT32 size;
if (fp->FRead(&size, sizeof(size), 1) != 1) return false;
for (int i = 0; i < size; ++i) {
Network* network = CreateFromFile(swap, fp);
if (network == NULL) return false;
AddToStack(network);
}
if ((network_flags_ & NF_LAYER_SPECIFIC_LR) &&
!learning_rates_.DeSerialize(swap, fp)) {
return false;
}
return true;
}
// Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the adagrad computation iff
// use_ada_grad_ is true.
void Plumbing::Update(float learning_rate, float momentum, int num_samples) {
for (int i = 0; i < stack_.size(); ++i) {
if (network_flags_ & NF_LAYER_SPECIFIC_LR) {
if (i < learning_rates_.size())
learning_rate = learning_rates_[i];
else
learning_rates_.push_back(learning_rate);
}
if (stack_[i]->training())
stack_[i]->Update(learning_rate, momentum, num_samples);
}
}
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// *changed.
void Plumbing::CountAlternators(const Network& other, double* same,
double* changed) const {
ASSERT_HOST(other.type() == type_);
const Plumbing* plumbing = reinterpret_cast<const Plumbing*>(&other);
ASSERT_HOST(plumbing->stack_.size() == stack_.size());
for (int i = 0; i < stack_.size(); ++i)
stack_[i]->CountAlternators(*plumbing->stack_[i], same, changed);
}
} // namespace tesseract.

143
lstm/plumbing.h Normal file
View File

@ -0,0 +1,143 @@
///////////////////////////////////////////////////////////////////////
// File: plumbing.h
// Description: Base class for networks that organize other networks
// eg series or parallel.
// Author: Ray Smith
// Created: Mon May 12 08:11:36 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_PLUMBING_H_
#define TESSERACT_LSTM_PLUMBING_H_
#include "genericvector.h"
#include "matrix.h"
#include "network.h"
namespace tesseract {
// Holds a collection of other networks and forwards calls to each of them.
class Plumbing : public Network {
public:
// ni_ and no_ will be set by AddToStack.
explicit Plumbing(const STRING& name);
virtual ~Plumbing();
// Returns the required shape input to the network.
virtual StaticShape InputShape() const { return stack_[0]->InputShape(); }
virtual STRING spec() const {
return "Sub-classes of Plumbing must implement spec()!";
}
// Returns true if the given type is derived from Plumbing, and thus contains
// multiple sub-networks that can have their own learning rate.
virtual bool IsPlumbingType() const { return true; }
// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
virtual void SetEnableTraining(bool state);
// Sets flags that control the action of the network. See NetworkFlags enum
// for bit values.
virtual void SetNetworkFlags(uinT32 flags);
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
// Note that randomizer is a borrowed pointer that should outlive the network
// and should not be deleted by any of the networks.
// Returns the number of weights initialized.
virtual int InitWeights(float range, TRand* randomizer);
// Converts a float network to an int network.
virtual void ConvertToInt();
// Provides a pointer to a TRand for any networks that care to use it.
// Note that randomizer is a borrowed pointer that should outlive the network
// and should not be deleted by any of the networks.
virtual void SetRandomizer(TRand* randomizer);
// Adds the given network to the stack.
virtual void AddToStack(Network* network);
// Sets needs_to_backprop_ to needs_backprop and returns true if
// needs_backprop || any weights in this network so the next layer forward
// can be told to produce backprop for this layer if needed.
virtual bool SetupNeedsBackprop(bool needs_backprop);
// Returns an integer reduction factor that the network applies to the
// time sequence. Assumes that any 2-d is already eliminated. Used for
// scaling bounding boxes of truth data.
// WARNING: if GlobalMinimax is used to vary the scale, this will return
// the last used scale factor. Call it before any forward, and it will return
// the minimum scale factor of the paths through the GlobalMinimax.
virtual int XScaleFactor() const;
// Provides the (minimum) x scale factor to the network (of interest only to
// input units) so they can determine how to scale bounding boxes.
virtual void CacheXScaleFactor(int factor);
// Provides debug output on the weights.
virtual void DebugWeights();
// Returns the current stack.
const PointerVector<Network>& stack() const {
return stack_;
}
// Returns a set of strings representing the layer-ids of all layers below.
void EnumerateLayers(const STRING* prefix,
GenericVector<STRING>* layers) const;
// Returns a pointer to the network layer corresponding to the given id.
Network* GetLayer(const char* id) const;
// Returns the learning rate for a specific layer of the stack.
float LayerLearningRate(const char* id) const {
const float* lr_ptr = LayerLearningRatePtr(id);
ASSERT_HOST(lr_ptr != NULL);
return *lr_ptr;
}
// Scales the learning rate for a specific layer of the stack.
void ScaleLayerLearningRate(const char* id, double factor) {
float* lr_ptr = LayerLearningRatePtr(id);
ASSERT_HOST(lr_ptr != NULL);
*lr_ptr *= factor;
}
// Returns a pointer to the learning rate for the given layer id.
float* LayerLearningRatePtr(const char* id) const;
// Writes to the given file. Returns false in case of error.
virtual bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
virtual bool DeSerialize(bool swap, TFile* fp);
// Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the adagrad computation iff
// use_ada_grad_ is true.
virtual void Update(float learning_rate, float momentum, int num_samples);
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// *changed.
virtual void CountAlternators(const Network& other, double* same,
double* changed) const;
protected:
// The networks.
PointerVector<Network> stack_;
// Layer-specific learning rate iff network_flags_ & NF_LAYER_SPECIFIC_LR.
// One element for each element of stack_.
GenericVector<float> learning_rates_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_PLUMBING_H_

759
lstm/recodebeam.cpp Normal file
View File

@ -0,0 +1,759 @@
///////////////////////////////////////////////////////////////////////
// File: recodebeam.cpp
// Description: Beam search to decode from the re-encoded CJK as a sequence of
// smaller numbers in place of a single large code.
// Author: Ray Smith
// Created: Fri Mar 13 09:39:01 PDT 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
///////////////////////////////////////////////////////////////////////
#include "recodebeam.h"
#include "networkio.h"
#include "pageres.h"
#include "unicharcompress.h"
namespace tesseract {
// Clipping value for certainty inside Tesseract. Reflects the minimum value
// of certainty that will be returned by ExtractBestPathAsUnicharIds.
// Supposedly on a uniform scale that can be compared across languages and
// engines.
const float RecodeBeamSearch::kMinCertainty = -20.0f;
// The beam width at each code position.
const int RecodeBeamSearch::kBeamWidths[RecodedCharID::kMaxCodeLen + 1] = {
5, 10, 16, 16, 16, 16, 16, 16, 16, 16,
};
// Borrows the pointer, which is expected to survive until *this is deleted.
RecodeBeamSearch::RecodeBeamSearch(const UnicharCompress& recoder,
int null_char, bool simple_text, Dict* dict)
: recoder_(recoder),
dict_(dict),
space_delimited_(true),
is_simple_text_(simple_text),
null_char_(null_char) {
if (dict_ != NULL && !dict_->IsSpaceDelimitedLang()) space_delimited_ = false;
}
// Decodes the set of network outputs, storing the lattice internally.
void RecodeBeamSearch::Decode(const NetworkIO& output, double dict_ratio,
double cert_offset, double worst_dict_cert,
const UNICHARSET* charset) {
beam_size_ = 0;
int width = output.Width();
for (int t = 0; t < width; ++t) {
ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]);
DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert,
charset);
}
}
void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY<float>& output,
double dict_ratio, double cert_offset,
double worst_dict_cert,
const UNICHARSET* charset) {
beam_size_ = 0;
int width = output.dim1();
for (int t = 0; t < width; ++t) {
ComputeTopN(output[t], output.dim2(), kBeamWidths[0]);
DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
}
}
// Returns the best path as labels/scores/xcoords similar to simple CTC.
void RecodeBeamSearch::ExtractBestPathAsLabels(
GenericVector<int>* labels, GenericVector<int>* xcoords) const {
labels->truncate(0);
xcoords->truncate(0);
GenericVector<const RecodeNode*> best_nodes;
ExtractBestPaths(&best_nodes, NULL);
// Now just run CTC on the best nodes.
int t = 0;
int width = best_nodes.size();
while (t < width) {
int label = best_nodes[t]->code;
if (label != null_char_) {
labels->push_back(label);
xcoords->push_back(t);
}
while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) {
}
}
xcoords->push_back(width);
}
// Returns the best path as unichar-ids/certs/ratings/xcoords skipping
// duplicates, nulls and intermediate parts.
void RecodeBeamSearch::ExtractBestPathAsUnicharIds(
bool debug, const UNICHARSET* unicharset, GenericVector<int>* unichar_ids,
GenericVector<float>* certs, GenericVector<float>* ratings,
GenericVector<int>* xcoords) const {
GenericVector<const RecodeNode*> best_nodes;
ExtractBestPaths(&best_nodes, NULL);
ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords);
if (debug) {
DebugPath(unicharset, best_nodes);
DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings,
*xcoords);
}
}
// Returns the best path as a set of WERD_RES.
void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box,
float scale_factor, bool debug,
const UNICHARSET* unicharset,
PointerVector<WERD_RES>* words) {
words->truncate(0);
GenericVector<int> unichar_ids;
GenericVector<float> certs;
GenericVector<float> ratings;
GenericVector<int> xcoords;
GenericVector<const RecodeNode*> best_nodes;
GenericVector<const RecodeNode*> second_nodes;
ExtractBestPaths(&best_nodes, &second_nodes);
if (debug) {
DebugPath(unicharset, best_nodes);
ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings,
&xcoords);
tprintf("\nSecond choice path:\n");
DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings,
xcoords);
}
ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords);
int num_ids = unichar_ids.size();
if (debug) {
DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings,
xcoords);
}
// Convert labels to unichar-ids.
int word_end = 0;
float prev_space_cert = 0.0f;
for (int word_start = 0; word_start < num_ids; word_start = word_end) {
for (word_end = word_start + 1; word_end < num_ids; ++word_end) {
// A word is terminated when a space character or start_of_word flag is
// hit. We also want to force a separate word for every non
// space-delimited character when not in a dictionary context.
if (unichar_ids[word_end] == UNICHAR_SPACE) break;
int index = xcoords[word_end];
if (best_nodes[index]->start_of_word) break;
if (best_nodes[index]->permuter == TOP_CHOICE_PERM &&
(!unicharset->IsSpaceDelimited(unichar_ids[word_end]) ||
!unicharset->IsSpaceDelimited(unichar_ids[word_end - 1])))
break;
}
float space_cert = 0.0f;
if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE)
space_cert = certs[word_end];
bool leading_space =
word_start > 0 && unichar_ids[word_start - 1] == UNICHAR_SPACE;
// Create a WERD_RES for the output word.
WERD_RES* word_res = InitializeWord(
leading_space, line_box, word_start, word_end,
MIN(space_cert, prev_space_cert), unicharset, xcoords, scale_factor);
for (int i = word_start; i < word_end; ++i) {
BLOB_CHOICE_LIST* choices = new BLOB_CHOICE_LIST;
BLOB_CHOICE_IT bc_it(choices);
BLOB_CHOICE* choice = new BLOB_CHOICE(
unichar_ids[i], ratings[i], certs[i], -1, 1.0f,
static_cast<float>(MAX_INT16), 0.0f, BCC_STATIC_CLASSIFIER);
int col = i - word_start;
choice->set_matrix_cell(col, col);
bc_it.add_after_then_move(choice);
word_res->ratings->put(col, col, choices);
}
int index = xcoords[word_end - 1];
word_res->FakeWordFromRatings(best_nodes[index]->permuter);
words->push_back(word_res);
prev_space_cert = space_cert;
if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE)
++word_end;
}
}
// Generates debug output of the content of the beams after a Decode.
void RecodeBeamSearch::DebugBeams(const UNICHARSET& unicharset) const {
for (int p = 0; p < beam_size_; ++p) {
// Print all the best scoring nodes for each unichar found.
tprintf("Position %d: Nondict beam\n", p);
DebugBeamPos(unicharset, beam_[p]->beams_[0]);
tprintf("Position %d: Dict beam\n", p);
DebugBeamPos(unicharset, beam_[p]->dawg_beams_[0]);
}
}
// Generates debug output of the content of a single beam position.
void RecodeBeamSearch::DebugBeamPos(const UNICHARSET& unicharset,
const RecodeHeap& heap) const {
GenericVector<const RecodeNode*> unichar_bests;
unichar_bests.init_to_size(unicharset.size(), NULL);
const RecodeNode* null_best = NULL;
int heap_size = heap.size();
for (int i = 0; i < heap_size; ++i) {
const RecodeNode* node = &heap.get(i).data;
if (node->unichar_id == INVALID_UNICHAR_ID) {
if (null_best == NULL || null_best->score < node->score) null_best = node;
} else {
if (unichar_bests[node->unichar_id] == NULL ||
unichar_bests[node->unichar_id]->score < node->score) {
unichar_bests[node->unichar_id] = node;
}
}
}
for (int u = 0; u < unichar_bests.size(); ++u) {
if (unichar_bests[u] != NULL) {
const RecodeNode& node = *unichar_bests[u];
tprintf("label=%d, uid=%d=%s score=%g, c=%g, s=%d, e=%d, perm=%d\n",
node.code, node.unichar_id,
unicharset.debug_str(node.unichar_id).string(), node.score,
node.certainty, node.start_of_word, node.end_of_word,
node.permuter);
}
}
if (null_best != NULL) {
tprintf("null_char score=%g, c=%g, s=%d, e=%d, perm=%d\n", null_best->score,
null_best->certainty, null_best->start_of_word,
null_best->end_of_word, null_best->permuter);
}
}
// Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping
// duplicates, nulls and intermediate parts.
/* static */
void RecodeBeamSearch::ExtractPathAsUnicharIds(
const GenericVector<const RecodeNode*>& best_nodes,
GenericVector<int>* unichar_ids, GenericVector<float>* certs,
GenericVector<float>* ratings, GenericVector<int>* xcoords) {
unichar_ids->truncate(0);
certs->truncate(0);
ratings->truncate(0);
xcoords->truncate(0);
// Backtrack extracting only valid, non-duplicate unichar-ids.
int t = 0;
int width = best_nodes.size();
while (t < width) {
double certainty = 0.0;
double rating = 0.0;
while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) {
double cert = best_nodes[t++]->certainty;
if (cert < certainty) certainty = cert;
rating -= cert;
}
if (t < width) {
int unichar_id = best_nodes[t]->unichar_id;
unichar_ids->push_back(unichar_id);
xcoords->push_back(t);
do {
double cert = best_nodes[t++]->certainty;
// Special-case NO-PERM space to forget the certainty of the previous
// nulls. See long comment in ContinueContext.
if (cert < certainty || (unichar_id == UNICHAR_SPACE &&
best_nodes[t - 1]->permuter == NO_PERM)) {
certainty = cert;
}
rating -= cert;
} while (t < width && best_nodes[t]->duplicate);
certs->push_back(certainty);
ratings->push_back(rating);
} else if (!certs->empty()) {
if (certainty < certs->back()) certs->back() = certainty;
ratings->back() += rating;
}
}
xcoords->push_back(width);
}
// Sets up a word with the ratings matrix and fake blobs with boxes in the
// right places.
WERD_RES* RecodeBeamSearch::InitializeWord(bool leading_space,
const TBOX& line_box, int word_start,
int word_end, float space_certainty,
const UNICHARSET* unicharset,
const GenericVector<int>& xcoords,
float scale_factor) {
// Make a fake blob for each non-zero label.
C_BLOB_LIST blobs;
C_BLOB_IT b_it(&blobs);
for (int i = word_start; i < word_end; ++i) {
int min_half_width = xcoords[i + 1] - xcoords[i];
if (i > 0 && xcoords[i] - xcoords[i - 1] < min_half_width)
min_half_width = xcoords[i] - xcoords[i - 1];
if (min_half_width < 1) min_half_width = 1;
// Make a fake blob.
TBOX box(xcoords[i] - min_half_width, 0, xcoords[i] + min_half_width,
line_box.height());
box.scale(scale_factor);
box.move(ICOORD(line_box.left(), line_box.bottom()));
box.set_top(line_box.top());
b_it.add_after_then_move(C_BLOB::FakeBlob(box));
}
// Make a fake word from the blobs.
WERD* word = new WERD(&blobs, leading_space, NULL);
// Make a WERD_RES from the word.
WERD_RES* word_res = new WERD_RES(word);
word_res->uch_set = unicharset;
word_res->combination = true; // Give it ownership of the word.
word_res->space_certainty = space_certainty;
word_res->ratings = new MATRIX(word_end - word_start, 1);
return word_res;
}
// Fills top_n_flags_ with bools that are true iff the corresponding output
// is one of the top_n.
void RecodeBeamSearch::ComputeTopN(const float* outputs, int num_outputs,
int top_n) {
top_n_flags_.init_to_size(num_outputs, false);
top_heap_.clear();
for (int i = 0; i < num_outputs; ++i) {
if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key) {
TopPair entry(outputs[i], i);
top_heap_.Push(&entry);
if (top_heap_.size() > top_n) top_heap_.Pop(&entry);
}
}
while (!top_heap_.empty()) {
TopPair entry;
top_heap_.Pop(&entry);
top_n_flags_[entry.data] = true;
}
}
// Adds the computation for the current time-step to the beam. Call at each
// time-step in sequence from left to right. outputs is the activation vector
// for the current timestep.
void RecodeBeamSearch::DecodeStep(const float* outputs, int t,
double dict_ratio, double cert_offset,
double worst_dict_cert,
const UNICHARSET* charset) {
if (t == beam_.size()) beam_.push_back(new RecodeBeam);
RecodeBeam* step = beam_[t];
beam_size_ = t + 1;
step->Clear();
if (t == 0) {
// The first step can only use singles and initials.
ContinueContext(NULL, 0, outputs, false, true, dict_ratio, cert_offset,
worst_dict_cert, step);
if (dict_ != NULL)
ContinueContext(NULL, 0, outputs, true, true, dict_ratio, cert_offset,
worst_dict_cert, step);
} else {
RecodeBeam* prev = beam_[t - 1];
if (charset != NULL) {
for (int i = prev->dawg_beams_[0].size() - 1; i >= 0; --i) {
GenericVector<const RecodeNode*> path;
ExtractPath(&prev->dawg_beams_[0].get(i).data, &path);
tprintf("Step %d: Dawg beam %d:\n", t, i);
DebugPath(charset, path);
}
}
int total_beam = 0;
// Try true and then false only if the beam is empty. This enables extending
// the context using only the top-n results first, which may have an empty
// intersection with the valid codes, so we fall back to the rest if the
// beam is empty.
for (int flag = 1; flag >= 0 && total_beam == 0; --flag) {
for (int length = 0; length <= RecodedCharID::kMaxCodeLen; ++length) {
// Working backwards through the heaps doesn't guarantee that we see the
// best first, but it comes before a lot of the worst, so it is slightly
// more efficient than going forwards.
for (int i = prev->dawg_beams_[length].size() - 1; i >= 0; --i) {
ContinueContext(&prev->dawg_beams_[length].get(i).data, length,
outputs, true, flag, dict_ratio, cert_offset,
worst_dict_cert, step);
}
for (int i = prev->beams_[length].size() - 1; i >= 0; --i) {
ContinueContext(&prev->beams_[length].get(i).data, length, outputs,
false, flag, dict_ratio, cert_offset, worst_dict_cert,
step);
}
}
for (int length = 0; length <= RecodedCharID::kMaxCodeLen; ++length) {
total_beam += step->beams_[length].size();
total_beam += step->dawg_beams_[length].size();
}
}
// Special case for the best initial dawg. Push it on the heap if good
// enough, but there is only one, so it doesn't blow up the beam.
RecodeHeap* dawg_heap = &step->dawg_beams_[0];
if (step->best_initial_dawg_.code >= 0 &&
(dawg_heap->size() < kBeamWidths[0] ||
step->best_initial_dawg_.score > dawg_heap->PeekTop().data.score)) {
RecodePair entry(step->best_initial_dawg_.score,
step->best_initial_dawg_);
dawg_heap->Push(&entry);
if (dawg_heap->size() > kBeamWidths[0]) dawg_heap->Pop(&entry);
}
}
}
// Adds to the appropriate beams the legal (according to recoder)
// continuations of context prev, which is of the given length, using the
// given network outputs to provide scores to the choices. Uses only those
// choices for which top_n_flags[index] == top_n_flag.
void RecodeBeamSearch::ContinueContext(const RecodeNode* prev, int length,
const float* outputs, bool use_dawgs,
bool top_n_flag, double dict_ratio,
double cert_offset,
double worst_dict_cert,
RecodeBeam* step) {
RecodedCharID prefix;
RecodedCharID full_code;
const RecodeNode* previous = prev;
for (int p = length - 1; p >= 0; --p, previous = previous->prev) {
while (previous != NULL &&
(previous->duplicate || previous->code == null_char_)) {
previous = previous->prev;
}
prefix.Set(p, previous->code);
full_code.Set(p, previous->code);
}
if (prev != NULL && !is_simple_text_) {
float cert = NetworkIO::ProbToCertainty(outputs[prev->code]) + cert_offset;
if ((cert >= kMinCertainty || prev->code == null_char_) &&
top_n_flags_[prev->code] == top_n_flag) {
if (use_dawgs) {
if (cert > worst_dict_cert) {
PushDupIfBetter(kBeamWidths[length], cert, prev,
&step->dawg_beams_[length]);
}
} else {
PushDupIfBetter(kBeamWidths[length], cert * dict_ratio, prev,
&step->beams_[length]);
}
}
if (prev->code != null_char_ && length > 0 &&
top_n_flags_[null_char_] == top_n_flag) {
// Allow nulls within multi code sequences, as the nulls within are not
// explicitly included in the code sequence.
cert = NetworkIO::ProbToCertainty(outputs[null_char_]) + cert_offset;
if (cert >= kMinCertainty && (!use_dawgs || cert > worst_dict_cert)) {
if (use_dawgs) {
PushNoDawgIfBetter(kBeamWidths[length], null_char_,
INVALID_UNICHAR_ID, NO_PERM, cert, prev,
&step->dawg_beams_[length]);
} else {
PushNoDawgIfBetter(kBeamWidths[length], null_char_,
INVALID_UNICHAR_ID, TOP_CHOICE_PERM,
cert * dict_ratio, prev, &step->beams_[length]);
}
}
}
}
const GenericVector<int>* final_codes = recoder_.GetFinalCodes(prefix);
if (final_codes != NULL) {
for (int i = 0; i < final_codes->size(); ++i) {
int code = (*final_codes)[i];
if (top_n_flags_[code] != top_n_flag) continue;
if (prev != NULL && prev->code == code && !is_simple_text_) continue;
float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
if (cert < kMinCertainty && code != null_char_) continue;
full_code.Set(length, code);
int unichar_id = recoder_.DecodeUnichar(full_code);
// Map the null char to INVALID.
if (length == 0 && code == null_char_) unichar_id = INVALID_UNICHAR_ID;
if (use_dawgs) {
if (cert > worst_dict_cert) {
ContinueDawg(kBeamWidths[0], code, unichar_id, cert, prev,
&step->dawg_beams_[0], step);
}
} else {
PushNoDawgIfBetter(kBeamWidths[0], code, unichar_id, TOP_CHOICE_PERM,
cert * dict_ratio, prev, &step->beams_[0]);
if (dict_ != NULL &&
((unichar_id == UNICHAR_SPACE && cert > worst_dict_cert) ||
!dict_->getUnicharset().IsSpaceDelimited(unichar_id))) {
// Any top choice position that can start a new word, ie a space or
// any non-space-delimited character, should also be considered
// by the dawg search, so push initial dawg to the dawg heap.
float dawg_cert = cert;
PermuterType permuter = TOP_CHOICE_PERM;
// Since we use the space either side of a dictionary word in the
// certainty of the word, (to properly handle weak spaces) and the
// space is coming from a non-dict word, we need special conditions
// to avoid degrading the certainty of the dict word that follows.
// With a space we don't multiply the certainty by dict_ratio, and we
// flag the space with NO_PERM to indicate that we should not use the
// predecessor nulls to generate the confidence for the space, as they
// have already been multiplied by dict_ratio, and we can't go back to
// insert more entries in any previous heaps.
if (unichar_id == UNICHAR_SPACE)
permuter = NO_PERM;
else
dawg_cert *= dict_ratio;
PushInitialDawgIfBetter(code, unichar_id, permuter, false, false,
dawg_cert, prev, &step->best_initial_dawg_);
}
}
}
}
const GenericVector<int>* next_codes = recoder_.GetNextCodes(prefix);
if (next_codes != NULL) {
for (int i = 0; i < next_codes->size(); ++i) {
int code = (*next_codes)[i];
if (top_n_flags_[code] != top_n_flag) continue;
if (prev != NULL && prev->code == code && !is_simple_text_) continue;
float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
if (cert < kMinCertainty && code != null_char_) continue;
if (use_dawgs) {
if (cert > worst_dict_cert) {
ContinueDawg(kBeamWidths[length + 1], code, INVALID_UNICHAR_ID, cert,
prev, &step->dawg_beams_[length + 1], step);
}
} else {
PushNoDawgIfBetter(kBeamWidths[length + 1], code, INVALID_UNICHAR_ID,
TOP_CHOICE_PERM, cert * dict_ratio, prev,
&step->beams_[length + 1]);
}
}
}
}
// Adds a RecodeNode composed of the tuple (code, unichar_id, cert, prev,
// appropriate-dawg-args, cert) to the given heap (dawg_beam_) if unichar_id
// is a valid continuation of whatever is in prev.
void RecodeBeamSearch::ContinueDawg(int max_size, int code, int unichar_id,
float cert, const RecodeNode* prev,
RecodeHeap* heap, RecodeBeam* step) {
if (unichar_id == INVALID_UNICHAR_ID) {
PushNoDawgIfBetter(max_size, code, unichar_id, NO_PERM, cert, prev, heap);
return;
}
// Avoid dictionary probe if score a total loss.
float score = cert;
if (prev != NULL) score += prev->score;
if (heap->size() >= max_size && score <= heap->PeekTop().data.score) return;
const RecodeNode* uni_prev = prev;
// Prev may be a partial code, null_char, or duplicate, so scan back to the
// last valid unichar_id.
while (uni_prev != NULL &&
(uni_prev->unichar_id == INVALID_UNICHAR_ID || uni_prev->duplicate))
uni_prev = uni_prev->prev;
if (unichar_id == UNICHAR_SPACE) {
if (uni_prev != NULL && uni_prev->end_of_word) {
// Space is good. Push initial state, to the dawg beam and a regular
// space to the top choice beam.
PushInitialDawgIfBetter(code, unichar_id, uni_prev->permuter, false,
false, cert, prev, &step->best_initial_dawg_);
PushNoDawgIfBetter(max_size, code, unichar_id, uni_prev->permuter, cert,
prev, &step->beams_[0]);
}
return;
} else if (uni_prev != NULL && uni_prev->start_of_dawg &&
uni_prev->unichar_id != UNICHAR_SPACE &&
dict_->getUnicharset().IsSpaceDelimited(uni_prev->unichar_id) &&
dict_->getUnicharset().IsSpaceDelimited(unichar_id)) {
return; // Can't break words between space delimited chars.
}
DawgPositionVector initial_dawgs;
DawgPositionVector* updated_dawgs = new DawgPositionVector;
DawgArgs dawg_args(&initial_dawgs, updated_dawgs, NO_PERM);
bool word_start = false;
if (uni_prev == NULL) {
// Starting from beginning of line.
dict_->default_dawgs(&initial_dawgs, false);
word_start = true;
} else if (uni_prev->dawgs != NULL) {
// Continuing a previous dict word.
dawg_args.active_dawgs = uni_prev->dawgs;
word_start = uni_prev->start_of_dawg;
} else {
return; // Can't continue if not a dict word.
}
PermuterType permuter = static_cast<PermuterType>(
dict_->def_letter_is_okay(&dawg_args, unichar_id, false));
if (permuter != NO_PERM) {
PushHeapIfBetter(max_size, code, unichar_id, permuter, false, word_start,
dawg_args.valid_end, false, cert, prev,
dawg_args.updated_dawgs, heap);
if (dawg_args.valid_end && !space_delimited_) {
// We can start another word right away, so push initial state as well,
// to the dawg beam, and the regular character to the top choice beam,
// since non-dict words can start here too.
PushInitialDawgIfBetter(code, unichar_id, permuter, word_start, true,
cert, prev, &step->best_initial_dawg_);
PushHeapIfBetter(max_size, code, unichar_id, permuter, false, word_start,
true, false, cert, prev, NULL, &step->beams_[0]);
}
} else {
delete updated_dawgs;
}
}
// Adds a RecodeNode composed of the tuple (code, unichar_id,
// initial-dawg-state, prev, cert) to the given heap if/ there is room or if
// better than the current worst element if already full.
void RecodeBeamSearch::PushInitialDawgIfBetter(int code, int unichar_id,
PermuterType permuter,
bool start, bool end, float cert,
const RecodeNode* prev,
RecodeNode* best_initial_dawg) {
float score = cert;
if (prev != NULL) score += prev->score;
if (best_initial_dawg->code < 0 || score > best_initial_dawg->score) {
DawgPositionVector* initial_dawgs = new DawgPositionVector;
dict_->default_dawgs(initial_dawgs, false);
RecodeNode node(code, unichar_id, permuter, true, start, end, false, cert,
score, prev, initial_dawgs);
*best_initial_dawg = node;
}
}
// Adds a copy of the given prev as a duplicate of and successor to prev, if
// there is room or if better than the current worst element if already full.
/* static */
void RecodeBeamSearch::PushDupIfBetter(int max_size, float cert,
const RecodeNode* prev,
RecodeHeap* heap) {
PushHeapIfBetter(max_size, prev->code, prev->unichar_id, prev->permuter,
false, false, false, true, cert, prev, NULL, heap);
}
// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
// false, false, false, false, cert, prev, NULL) to heap if there is room
// or if better than the current worst element if already full.
/* static */
void RecodeBeamSearch::PushNoDawgIfBetter(int max_size, int code,
int unichar_id, PermuterType permuter,
float cert, const RecodeNode* prev,
RecodeHeap* heap) {
float score = cert;
if (prev != NULL) score += prev->score;
if (heap->size() < max_size || score > heap->PeekTop().data.score) {
RecodeNode node(code, unichar_id, permuter, false, false, false, false,
cert, score, prev, NULL);
RecodePair entry(score, node);
heap->Push(&entry);
if (heap->size() > max_size) heap->Pop(&entry);
}
}
// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
// dawg_start, word_start, end, dup, cert, prev, d) to heap if there is room
// or if better than the current worst element if already full.
/* static */
void RecodeBeamSearch::PushHeapIfBetter(int max_size, int code, int unichar_id,
PermuterType permuter, bool dawg_start,
bool word_start, bool end, bool dup,
float cert, const RecodeNode* prev,
DawgPositionVector* d,
RecodeHeap* heap) {
float score = cert;
if (prev != NULL) score += prev->score;
if (heap->size() < max_size || score > heap->PeekTop().data.score) {
RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end,
dup, cert, score, prev, d);
RecodePair entry(score, node);
heap->Push(&entry);
ASSERT_HOST(entry.data.dawgs == NULL);
if (heap->size() > max_size) heap->Pop(&entry);
} else {
delete d;
}
}
// Backtracks to extract the best path through the lattice that was built
// during Decode. On return the best_nodes vector essentially contains the set
// of code, score pairs that make the optimal path with the constraint that
// the recoder can decode the code sequence back to a sequence of unichar-ids.
void RecodeBeamSearch::ExtractBestPaths(
GenericVector<const RecodeNode*>* best_nodes,
GenericVector<const RecodeNode*>* second_nodes) const {
// Scan both beams to extract the best and second best paths.
const RecodeNode* best_node = NULL;
const RecodeNode* second_best_node = NULL;
const RecodeBeam* last_beam = beam_[beam_size_ - 1];
int heap_size = last_beam->beams_[0].size();
for (int i = 0; i < heap_size; ++i) {
const RecodeNode* node = &last_beam->beams_[0].get(i).data;
if (best_node == NULL || node->score > best_node->score) {
second_best_node = best_node;
best_node = node;
} else if (second_best_node == NULL ||
node->score > second_best_node->score) {
second_best_node = node;
}
}
// Scan the entire dawg heap for the best *valid* nodes, if any.
int dawg_size = last_beam->dawg_beams_[0].size();
for (int i = 0; i < dawg_size; ++i) {
const RecodeNode* dawg_node = &last_beam->dawg_beams_[0].get(i).data;
// dawg_node may be a null_char, or duplicate, so scan back to the last
// valid unichar_id.
const RecodeNode* back_dawg_node = dawg_node;
while (back_dawg_node != NULL &&
(back_dawg_node->unichar_id == INVALID_UNICHAR_ID ||
back_dawg_node->duplicate))
back_dawg_node = back_dawg_node->prev;
if (back_dawg_node != NULL &&
(back_dawg_node->end_of_word ||
back_dawg_node->unichar_id == UNICHAR_SPACE)) {
// Dawg node is valid. Use it in preference to back_dawg_node, as the
// score comparison is fair that way.
if (best_node == NULL || dawg_node->score > best_node->score) {
second_best_node = best_node;
best_node = dawg_node;
} else if (second_best_node == NULL ||
dawg_node->score > second_best_node->score) {
second_best_node = dawg_node;
}
}
}
if (second_nodes != NULL) ExtractPath(second_best_node, second_nodes);
ExtractPath(best_node, best_nodes);
}
// Helper backtracks through the lattice from the given node, storing the
// path and reversing it.
void RecodeBeamSearch::ExtractPath(
const RecodeNode* node, GenericVector<const RecodeNode*>* path) const {
path->truncate(0);
while (node != NULL) {
path->push_back(node);
node = node->prev;
}
path->reverse();
}
// Helper prints debug information on the given lattice path.
void RecodeBeamSearch::DebugPath(
const UNICHARSET* unicharset,
const GenericVector<const RecodeNode*>& path) const {
for (int c = 0; c < path.size(); ++c) {
const RecodeNode& node = *path[c];
tprintf("%d %d=%s score=%g, c=%g, s=%d, e=%d, perm=%d\n", c,
node.unichar_id, unicharset->debug_str(node.unichar_id).string(),
node.score, node.certainty, node.start_of_word, node.end_of_word,
node.permuter);
}
}
// Helper prints debug information on the given unichar path.
void RecodeBeamSearch::DebugUnicharPath(
const UNICHARSET* unicharset, const GenericVector<const RecodeNode*>& path,
const GenericVector<int>& unichar_ids, const GenericVector<float>& certs,
const GenericVector<float>& ratings,
const GenericVector<int>& xcoords) const {
int num_ids = unichar_ids.size();
double total_rating = 0.0;
for (int c = 0; c < num_ids; ++c) {
int coord = xcoords[c];
tprintf("%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c],
unicharset->debug_str(unichar_ids[c]).string(), ratings[c],
certs[c], path[coord]->start_of_word, path[coord]->end_of_word,
path[coord]->permuter);
total_rating += ratings[c];
}
tprintf("Path total rating = %g\n", total_rating);
}
} // namespace tesseract.

304
lstm/recodebeam.h Normal file
View File

@ -0,0 +1,304 @@
///////////////////////////////////////////////////////////////////////
// File: recodebeam.h
// Description: Beam search to decode from the re-encoded CJK as a sequence of
// smaller numbers in place of a single large code.
// Author: Ray Smith
// Created: Fri Mar 13 09:12:01 PDT 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
///////////////////////////////////////////////////////////////////////
#ifndef THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
#define THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
#include "dawg.h"
#include "dict.h"
#include "genericheap.h"
#include "kdpair.h"
#include "networkio.h"
#include "ratngs.h"
#include "unicharcompress.h"
namespace tesseract {
// Lattice element for Re-encode beam search.
struct RecodeNode {
RecodeNode()
: code(-1),
unichar_id(INVALID_UNICHAR_ID),
permuter(TOP_CHOICE_PERM),
start_of_dawg(false),
start_of_word(false),
end_of_word(false),
duplicate(false),
certainty(0.0f),
score(0.0f),
prev(NULL),
dawgs(NULL) {}
RecodeNode(int c, int uni_id, PermuterType perm, bool dawg_start,
bool word_start, bool end, bool dup, float cert, float s,
const RecodeNode* p, DawgPositionVector* d)
: code(c),
unichar_id(uni_id),
permuter(perm),
start_of_dawg(dawg_start),
start_of_word(word_start),
end_of_word(end),
duplicate(dup),
certainty(cert),
score(s),
prev(p),
dawgs(d) {}
// NOTE: If we could use C++11, then this would be a move constructor.
// Instead we have copy constructor that does a move!! This is because we
// don't want to copy the whole DawgPositionVector each time, and true
// copying isn't necessary for this struct. It does get moved around a lot
// though inside the heap and during heap push, hence the move semantics.
RecodeNode(RecodeNode& src) : dawgs(NULL) {
*this = src;
ASSERT_HOST(src.dawgs == NULL);
}
RecodeNode& operator=(RecodeNode& src) {
delete dawgs;
memcpy(this, &src, sizeof(src));
src.dawgs = NULL;
return *this;
}
~RecodeNode() { delete dawgs; }
// The re-encoded code here = index to network output.
int code;
// The decoded unichar_id is only valid for the final code of a sequence.
int unichar_id;
// The type of permuter active at this point. Intervals between start_of_word
// and end_of_word make valid words of type given by permuter where
// end_of_word is true. These aren't necessarily delimited by spaces.
PermuterType permuter;
// True if this is the initial dawg state. May be attached to a space or,
// in a non-space-delimited lang, the end of the previous word.
bool start_of_dawg;
// True if this is the first node in a dictionary word.
bool start_of_word;
// True if this represents a valid candidate end of word position. Does not
// necessarily mark the end of a word, since a word can be extended beyond a
// candidiate end by a continuation, eg 'the' continues to 'these'.
bool end_of_word;
// True if this is a duplicate of prev in all respects. Some training modes
// allow the network to output duplicate characters and crush them with CTC,
// but that would mess up the decoding, so we just smash them together on the
// fly using the duplicate flag.
bool duplicate;
// Certainty (log prob) of (just) this position.
float certainty;
// Total certainty of the path to this position.
float score;
// The previous node in this chain. Borrowed pointer.
const RecodeNode* prev;
// The currently active dawgs at this position. Owned pointer.
DawgPositionVector* dawgs;
};
typedef KDPairInc<double, RecodeNode> RecodePair;
typedef GenericHeap<RecodePair> RecodeHeap;
// Class that holds the entire beam search for recognition of a text line.
class RecodeBeamSearch {
public:
// Borrows the pointer, which is expected to survive until *this is deleted.
RecodeBeamSearch(const UnicharCompress& recoder, int null_char,
bool simple_text, Dict* dict);
// Decodes the set of network outputs, storing the lattice internally.
// If charset is not null, it enables detailed debugging of the beam search.
void Decode(const NetworkIO& output, double dict_ratio, double cert_offset,
double worst_dict_cert, const UNICHARSET* charset);
void Decode(const GENERIC_2D_ARRAY<float>& output, double dict_ratio,
double cert_offset, double worst_dict_cert,
const UNICHARSET* charset);
// Returns the best path as labels/scores/xcoords similar to simple CTC.
void ExtractBestPathAsLabels(GenericVector<int>* labels,
GenericVector<int>* xcoords) const;
// Returns the best path as unichar-ids/certs/ratings/xcoords skipping
// duplicates, nulls and intermediate parts.
void ExtractBestPathAsUnicharIds(bool debug, const UNICHARSET* unicharset,
GenericVector<int>* unichar_ids,
GenericVector<float>* certs,
GenericVector<float>* ratings,
GenericVector<int>* xcoords) const;
// Returns the best path as a set of WERD_RES.
void ExtractBestPathAsWords(const TBOX& line_box, float scale_factor,
bool debug, const UNICHARSET* unicharset,
PointerVector<WERD_RES>* words);
// Generates debug output of the content of the beams after a Decode.
void DebugBeams(const UNICHARSET& unicharset) const;
// Clipping value for certainty inside Tesseract. Reflects the minimum value
// of certainty that will be returned by ExtractBestPathAsUnicharIds.
// Supposedly on a uniform scale that can be compared across languages and
// engines.
static const float kMinCertainty;
private:
// Struct for the Re-encode beam search. This struct holds the data for
// a single time-step position of the output. Use a PointerVector<RecodeBeam>
// to hold all the timesteps and prevent reallocation of the individual heaps.
struct RecodeBeam {
// Resets to the initial state without deleting all the memory.
void Clear() {
for (int i = 0; i <= RecodedCharID::kMaxCodeLen; ++i) {
beams_[i].clear();
dawg_beams_[i].clear();
}
RecodeNode empty;
best_initial_dawg_ = empty;
}
// A separate beam for each code position. Since there aren't that many
// code positions, this allows the beam to be quite narrow, and yet still
// have a low chance of losing the best path.
// Each heap is stored with the WORST result at the top, so we can quickly
// get the top-n values.
RecodeHeap beams_[RecodedCharID::kMaxCodeLen + 1];
// Although, we can only use complete codes in the dawg, we have to separate
// partial code paths that lead back to a mid-dawg word from paths that are
// not part of a dawg word, as they have a different score. Since a dawg
// word can dead-end at any point, we need to keep the non dawg path going
// so the dawg beams_ are totally separate set with a heap for each length
// just like the non-dawg beams.
RecodeHeap dawg_beams_[RecodedCharID::kMaxCodeLen + 1];
// While the language model is only a single word dictionary, we can use
// word starts as a choke point in the beam, and keep only a single dict
// start node at each step, so we find the best one here and push it on
// the heap, if it qualifies, after processing all of the step.
RecodeNode best_initial_dawg_;
};
typedef KDPairInc<float, int> TopPair;
// Generates debug output of the content of a single beam position.
void DebugBeamPos(const UNICHARSET& unicharset, const RecodeHeap& heap) const;
// Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping
// duplicates, nulls and intermediate parts.
static void ExtractPathAsUnicharIds(
const GenericVector<const RecodeNode*>& best_nodes,
GenericVector<int>* unichar_ids, GenericVector<float>* certs,
GenericVector<float>* ratings, GenericVector<int>* xcoords);
// Sets up a word with the ratings matrix and fake blobs with boxes in the
// right places.
WERD_RES* InitializeWord(bool leading_space, const TBOX& line_box,
int word_start, int word_end, float space_certainty,
const UNICHARSET* unicharset,
const GenericVector<int>& xcoords,
float scale_factor);
// Fills top_n_flags_ with bools that are true iff the corresponding output
// is one of the top_n.
void ComputeTopN(const float* outputs, int num_outputs, int top_n);
// Adds the computation for the current time-step to the beam. Call at each
// time-step in sequence from left to right. outputs is the activation vector
// for the current timestep.
void DecodeStep(const float* outputs, int t, double dict_ratio,
double cert_offset, double worst_dict_cert,
const UNICHARSET* charset);
// Adds to the appropriate beams the legal (according to recoder)
// continuations of context prev, which is of the given length, using the
// given network outputs to provide scores to the choices. Uses only those
// choices for which top_n_flags[index] == top_n_flag.
void ContinueContext(const RecodeNode* prev, int length, const float* outputs,
bool use_dawgs, bool top_n_flag, double dict_ratio,
double cert_offset, double worst_dict_cert,
RecodeBeam* step);
// Adds a RecodeNode composed of the tuple (code, unichar_id, cert, prev,
// appropriate-dawg-args, cert) to the given heap (dawg_beam_) if unichar_id
// is a valid continuation of whatever is in prev.
void ContinueDawg(int max_size, int code, int unichar_id, float cert,
const RecodeNode* prev, RecodeHeap* heap, RecodeBeam* step);
// Adds a RecodeNode composed of the tuple (code, unichar_id,
// initial-dawg-state, prev, cert) to the given heap if/ there is room or if
// better than the current worst element if already full.
void PushInitialDawgIfBetter(int code, int unichar_id, PermuterType permuter,
bool start, bool end, float cert,
const RecodeNode* prev,
RecodeNode* best_initial_dawg);
// Adds a copy of the given prev as a duplicate of and successor to prev, if
// there is room or if better than the current worst element if already full.
static void PushDupIfBetter(int max_size, float cert, const RecodeNode* prev,
RecodeHeap* heap);
// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
// false, false, false, false, cert, prev, NULL) to heap if there is room
// or if better than the current worst element if already full.
static void PushNoDawgIfBetter(int max_size, int code, int unichar_id,
PermuterType permuter, float cert,
const RecodeNode* prev, RecodeHeap* heap);
// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
// dawg_start, word_start, end, dup, cert, prev, d) to heap if there is room
// or if better than the current worst element if already full.
static void PushHeapIfBetter(int max_size, int code, int unichar_id,
PermuterType permuter, bool dawg_start,
bool word_start, bool end, bool dup, float cert,
const RecodeNode* prev, DawgPositionVector* d,
RecodeHeap* heap);
// Backtracks to extract the best path through the lattice that was built
// during Decode. On return the best_nodes vector essentially contains the set
// of code, score pairs that make the optimal path with the constraint that
// the recoder can decode the code sequence back to a sequence of unichar-ids.
void ExtractBestPaths(GenericVector<const RecodeNode*>* best_nodes,
GenericVector<const RecodeNode*>* second_nodes) const;
// Helper backtracks through the lattice from the given node, storing the
// path and reversing it.
void ExtractPath(const RecodeNode* node,
GenericVector<const RecodeNode*>* path) const;
// Helper prints debug information on the given lattice path.
void DebugPath(const UNICHARSET* unicharset,
const GenericVector<const RecodeNode*>& path) const;
// Helper prints debug information on the given unichar path.
void DebugUnicharPath(const UNICHARSET* unicharset,
const GenericVector<const RecodeNode*>& path,
const GenericVector<int>& unichar_ids,
const GenericVector<float>& certs,
const GenericVector<float>& ratings,
const GenericVector<int>& xcoords) const;
static const int kBeamWidths[RecodedCharID::kMaxCodeLen + 1];
// The encoder/decoder that we will be using.
const UnicharCompress& recoder_;
// The beam for each timestep in the output.
PointerVector<RecodeBeam> beam_;
// The number of timesteps valid in beam_;
int beam_size_;
// A flag to indicate which outputs are the top-n choices. Current timestep
// only.
GenericVector<bool> top_n_flags_;
// Heap used to compute the top_n_flags_.
GenericHeap<TopPair> top_heap_;
// Borrowed pointer to the dictionary to use in the search.
Dict* dict_;
// True if the language is space-delimited, which is true for most languages
// except chi*, jpn, tha.
bool space_delimited_;
// True if the input is simple text, ie adjacent equal chars are not to be
// eliminated.
bool is_simple_text_;
// The encoded (class label) of the null/reject character.
int null_char_;
};
} // namespace tesseract.
#endif // THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_

128
lstm/reconfig.cpp Normal file
View File

@ -0,0 +1,128 @@
///////////////////////////////////////////////////////////////////////
// File: reconfig.cpp
// Description: Network layer that reconfigures the scaling vs feature
// depth.
// Author: Ray Smith
// Created: Wed Feb 26 15:42:25 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "reconfig.h"
#include "tprintf.h"
namespace tesseract {
Reconfig::Reconfig(const STRING& name, int ni, int x_scale, int y_scale)
: Network(NT_RECONFIG, name, ni, ni * x_scale * y_scale),
x_scale_(x_scale), y_scale_(y_scale) {
}
Reconfig::~Reconfig() {
}
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
StaticShape Reconfig::OutputShape(const StaticShape& input_shape) const {
StaticShape result = input_shape;
result.set_height(result.height() / y_scale_);
result.set_width(result.width() / x_scale_);
if (type_ != NT_MAXPOOL)
result.set_depth(result.depth() * y_scale_ * x_scale_);
return result;
}
// Returns an integer reduction factor that the network applies to the
// time sequence. Assumes that any 2-d is already eliminated. Used for
// scaling bounding boxes of truth data.
// WARNING: if GlobalMinimax is used to vary the scale, this will return
// the last used scale factor. Call it before any forward, and it will return
// the minimum scale factor of the paths through the GlobalMinimax.
int Reconfig::XScaleFactor() const {
return x_scale_;
}
// Writes to the given file. Returns false in case of error.
bool Reconfig::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
if (fp->FWrite(&x_scale_, sizeof(x_scale_), 1) != 1) return false;
if (fp->FWrite(&y_scale_, sizeof(y_scale_), 1) != 1) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool Reconfig::DeSerialize(bool swap, TFile* fp) {
if (fp->FRead(&x_scale_, sizeof(x_scale_), 1) != 1) return false;
if (fp->FRead(&y_scale_, sizeof(y_scale_), 1) != 1) return false;
if (swap) {
ReverseN(&x_scale_, sizeof(x_scale_));
ReverseN(&y_scale_, sizeof(y_scale_));
}
no_ = ni_ * x_scale_ * y_scale_;
return true;
}
// Runs forward propagation of activations on the input line.
// See NetworkCpp for a detailed discussion of the arguments.
void Reconfig::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
output->ResizeScaled(input, x_scale_, y_scale_, no_);
back_map_ = input.stride_map();
StrideMap::Index dest_index(output->stride_map());
do {
int out_t = dest_index.t();
StrideMap::Index src_index(input.stride_map(), dest_index.index(FD_BATCH),
dest_index.index(FD_HEIGHT) * y_scale_,
dest_index.index(FD_WIDTH) * x_scale_);
// Stack x_scale_ groups of y_scale_ inputs together.
for (int x = 0; x < x_scale_; ++x) {
for (int y = 0; y < y_scale_; ++y) {
StrideMap::Index src_xy(src_index);
if (src_xy.AddOffset(x, FD_WIDTH) && src_xy.AddOffset(y, FD_HEIGHT)) {
output->CopyTimeStepGeneral(out_t, (x * y_scale_ + y) * ni_, ni_,
input, src_xy.t(), 0);
}
}
}
} while (dest_index.Increment());
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool Reconfig::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
back_deltas->ResizeToMap(fwd_deltas.int_mode(), back_map_, ni_);
StrideMap::Index src_index(fwd_deltas.stride_map());
do {
int in_t = src_index.t();
StrideMap::Index dest_index(back_deltas->stride_map(),
src_index.index(FD_BATCH),
src_index.index(FD_HEIGHT) * y_scale_,
src_index.index(FD_WIDTH) * x_scale_);
// Unstack x_scale_ groups of y_scale_ inputs that are together.
for (int x = 0; x < x_scale_; ++x) {
for (int y = 0; y < y_scale_; ++y) {
StrideMap::Index dest_xy(dest_index);
if (dest_xy.AddOffset(x, FD_WIDTH) && dest_xy.AddOffset(y, FD_HEIGHT)) {
back_deltas->CopyTimeStepGeneral(dest_xy.t(), 0, ni_, fwd_deltas,
in_t, (x * y_scale_ + y) * ni_);
}
}
}
} while (src_index.Increment());
return needs_to_backprop_;
}
} // namespace tesseract.

86
lstm/reconfig.h Normal file
View File

@ -0,0 +1,86 @@
///////////////////////////////////////////////////////////////////////
// File: reconfig.h
// Description: Network layer that reconfigures the scaling vs feature
// depth.
// Author: Ray Smith
// Created: Wed Feb 26 15:37:42 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_RECONFIG_H_
#define TESSERACT_LSTM_RECONFIG_H_
#include "genericvector.h"
#include "matrix.h"
#include "network.h"
namespace tesseract {
// Reconfigures (Shrinks) the inputs by concatenating an x_scale by y_scale tile
// of inputs together, producing a single, deeper output per tile.
// Note that fractional parts are truncated for efficiency, so make sure the
// input stride is a multiple of the y_scale factor!
class Reconfig : public Network {
public:
Reconfig(const STRING& name, int ni, int x_scale, int y_scale);
virtual ~Reconfig();
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
virtual StaticShape OutputShape(const StaticShape& input_shape) const;
virtual STRING spec() const {
STRING spec;
spec.add_str_int("S", y_scale_);
spec.add_str_int(",", x_scale_);
return spec;
}
// Returns an integer reduction factor that the network applies to the
// time sequence. Assumes that any 2-d is already eliminated. Used for
// scaling bounding boxes of truth data.
// WARNING: if GlobalMinimax is used to vary the scale, this will return
// the last used scale factor. Call it before any forward, and it will return
// the minimum scale factor of the paths through the GlobalMinimax.
virtual int XScaleFactor() const;
// Writes to the given file. Returns false in case of error.
virtual bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
virtual bool DeSerialize(bool swap, TFile* fp);
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output);
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas);
protected:
// Non-serialized data used to store parameters between forward and back.
StrideMap back_map_;
// Serialized data.
inT32 x_scale_;
inT32 y_scale_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_SUBSAMPLE_H_

91
lstm/reversed.cpp Normal file
View File

@ -0,0 +1,91 @@
///////////////////////////////////////////////////////////////////////
// File: reversed.cpp
// Description: Runs a single network on time-reversed input, reversing output.
// Author: Ray Smith
// Created: Thu May 02 08:42:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "reversed.h"
#include <stdio.h>
#include "networkscratch.h"
namespace tesseract {
Reversed::Reversed(const STRING& name, NetworkType type) : Plumbing(name) {
type_ = type;
}
Reversed::~Reversed() {
}
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
StaticShape Reversed::OutputShape(const StaticShape& input_shape) const {
if (type_ == NT_XYTRANSPOSE) {
StaticShape x_shape(input_shape);
x_shape.set_width(input_shape.height());
x_shape.set_height(input_shape.width());
x_shape = stack_[0]->OutputShape(x_shape);
x_shape.SetShape(x_shape.batch(), x_shape.width(), x_shape.height(),
x_shape.depth());
return x_shape;
}
return stack_[0]->OutputShape(input_shape);
}
// Takes ownership of the given network to make it the reversed one.
void Reversed::SetNetwork(Network* network) {
stack_.clear();
AddToStack(network);
}
// Runs forward propagation of activations on the input line.
// See NetworkCpp for a detailed discussion of the arguments.
void Reversed::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
NetworkScratch::IO rev_input(input, scratch);
ReverseData(input, rev_input);
NetworkScratch::IO rev_output(input, scratch);
stack_[0]->Forward(debug, *rev_input, NULL, scratch, rev_output);
ReverseData(*rev_output, output);
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool Reversed::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
NetworkScratch::IO rev_input(fwd_deltas, scratch);
ReverseData(fwd_deltas, rev_input);
NetworkScratch::IO rev_output(fwd_deltas, scratch);
if (stack_[0]->Backward(debug, *rev_input, scratch, rev_output)) {
ReverseData(*rev_output, back_deltas);
return true;
}
return false;
}
// Copies src to *dest with the reversal according to type_.
void Reversed::ReverseData(const NetworkIO& src, NetworkIO* dest) const {
if (type_ == NT_XREVERSED)
dest->CopyWithXReversal(src);
else if (type_ == NT_YREVERSED)
dest->CopyWithYReversal(src);
else
dest->CopyWithXYTranspose(src);
}
} // namespace tesseract.

89
lstm/reversed.h Normal file
View File

@ -0,0 +1,89 @@
///////////////////////////////////////////////////////////////////////
// File: reversed.h
// Description: Runs a single network on time-reversed input, reversing output.
// Author: Ray Smith
// Created: Thu May 02 08:38:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_REVERSED_H_
#define TESSERACT_LSTM_REVERSED_H_
#include "matrix.h"
#include "plumbing.h"
namespace tesseract {
// C++ Implementation of the Reversed class from lstm.py.
class Reversed : public Plumbing {
public:
explicit Reversed(const STRING& name, NetworkType type);
virtual ~Reversed();
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
virtual StaticShape OutputShape(const StaticShape& input_shape) const;
virtual STRING spec() const {
STRING spec(type_ == NT_XREVERSED ? "Rx"
: (type_ == NT_YREVERSED ? "Ry" : "Txy"));
// For most simple cases, we will output Rx<net> or Ry<net> where <net> is
// the network in stack_[0], but in the special case that <net> is an
// LSTM, we will just output the LSTM's spec modified to take the reversal
// into account. This is because when the user specified Lfy64, we actually
// generated TxyLfx64, and if the user specified Lrx64 we actually
// generated RxLfx64, and we want to display what the user asked for.
STRING net_spec = stack_[0]->spec();
if (net_spec[0] == 'L') {
// Setup a from and to character according to the type of the reversal
// such that the LSTM spec gets modified to the spec that the user
// asked for
char from = 'f';
char to = 'r';
if (type_ == NT_XYTRANSPOSE) {
from = 'x';
to = 'y';
}
// Change the from char to the to char.
for (int i = 0; i < net_spec.length(); ++i) {
if (net_spec[i] == from) net_spec[i] = to;
}
return net_spec;
}
spec += net_spec;
return spec;
}
// Takes ownership of the given network to make it the reversed one.
void SetNetwork(Network* network);
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output);
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas);
private:
// Copies src to *dest with the reversal according to type_.
void ReverseData(const NetworkIO& src, NetworkIO* dest) const;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_REVERSED_H_

188
lstm/series.cpp Normal file
View File

@ -0,0 +1,188 @@
///////////////////////////////////////////////////////////////////////
// File: series.cpp
// Description: Runs networks in series on the same input.
// Author: Ray Smith
// Created: Thu May 02 08:26:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "series.h"
#include "fullyconnected.h"
#include "networkscratch.h"
#include "scrollview.h"
#include "tprintf.h"
namespace tesseract {
// ni_ and no_ will be set by AddToStack.
Series::Series(const STRING& name) : Plumbing(name) {
type_ = NT_SERIES;
}
Series::~Series() {
}
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
StaticShape Series::OutputShape(const StaticShape& input_shape) const {
StaticShape result(input_shape);
int stack_size = stack_.size();
for (int i = 0; i < stack_size; ++i) {
result = stack_[i]->OutputShape(result);
}
return result;
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
// Note that series has its own implementation just for debug purposes.
int Series::InitWeights(float range, TRand* randomizer) {
num_weights_ = 0;
tprintf("Num outputs,weights in serial:\n");
for (int i = 0; i < stack_.size(); ++i) {
int weights = stack_[i]->InitWeights(range, randomizer);
tprintf(" %s:%d, %d\n",
stack_[i]->spec().string(), stack_[i]->NumOutputs(), weights);
num_weights_ += weights;
}
tprintf("Total weights = %d\n", num_weights_);
return num_weights_;
}
// Sets needs_to_backprop_ to needs_backprop and returns true if
// needs_backprop || any weights in this network so the next layer forward
// can be told to produce backprop for this layer if needed.
bool Series::SetupNeedsBackprop(bool needs_backprop) {
needs_to_backprop_ = needs_backprop;
for (int i = 0; i < stack_.size(); ++i)
needs_backprop = stack_[i]->SetupNeedsBackprop(needs_backprop);
return needs_backprop;
}
// Returns an integer reduction factor that the network applies to the
// time sequence. Assumes that any 2-d is already eliminated. Used for
// scaling bounding boxes of truth data.
// WARNING: if GlobalMinimax is used to vary the scale, this will return
// the last used scale factor. Call it before any forward, and it will return
// the minimum scale factor of the paths through the GlobalMinimax.
int Series::XScaleFactor() const {
int factor = 1;
for (int i = 0; i < stack_.size(); ++i)
factor *= stack_[i]->XScaleFactor();
return factor;
}
// Provides the (minimum) x scale factor to the network (of interest only to
// input units) so they can determine how to scale bounding boxes.
void Series::CacheXScaleFactor(int factor) {
stack_[0]->CacheXScaleFactor(factor);
}
// Runs forward propagation of activations on the input line.
// See NetworkCpp for a detailed discussion of the arguments.
void Series::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
int stack_size = stack_.size();
ASSERT_HOST(stack_size > 1);
// Revolving intermediate buffers.
NetworkScratch::IO buffer1(input, scratch);
NetworkScratch::IO buffer2(input, scratch);
// Run each network in turn, giving the output of n as the input to n + 1,
// with the final network providing the real output.
stack_[0]->Forward(debug, input, input_transpose, scratch, buffer1);
for (int i = 1; i < stack_size; i += 2) {
stack_[i]->Forward(debug, *buffer1, NULL, scratch,
i + 1 < stack_size ? buffer2 : output);
if (i + 1 == stack_size) return;
stack_[i + 1]->Forward(debug, *buffer2, NULL, scratch,
i + 2 < stack_size ? buffer1 : output);
}
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool Series::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
if (!training()) return false;
int stack_size = stack_.size();
ASSERT_HOST(stack_size > 1);
// Revolving intermediate buffers.
NetworkScratch::IO buffer1(fwd_deltas, scratch);
NetworkScratch::IO buffer2(fwd_deltas, scratch);
// Run each network in reverse order, giving the back_deltas output of n as
// the fwd_deltas input to n-1, with the 0 network providing the real output.
if (!stack_.back()->training() ||
!stack_.back()->Backward(debug, fwd_deltas, scratch, buffer1))
return false;
for (int i = stack_size - 2; i >= 0; i -= 2) {
if (!stack_[i]->training() ||
!stack_[i]->Backward(debug, *buffer1, scratch,
i > 0 ? buffer2 : back_deltas))
return false;
if (i == 0) return needs_to_backprop_;
if (!stack_[i - 1]->training() ||
!stack_[i - 1]->Backward(debug, *buffer2, scratch,
i > 1 ? buffer1 : back_deltas))
return false;
}
return needs_to_backprop_;
}
// Splits the series after the given index, returning the two parts and
// deletes itself. The first part, upto network with index last_start, goes
// into start, and the rest goes into end.
void Series::SplitAt(int last_start, Series** start, Series** end) {
*start = NULL;
*end = NULL;
if (last_start < 0 || last_start >= stack_.size()) {
tprintf("Invalid split index %d must be in range [0,%d]!\n",
last_start, stack_.size() - 1);
return;
}
Series* master_series = new Series("MasterSeries");
Series* boosted_series = new Series("BoostedSeries");
for (int s = 0; s <= last_start; ++s) {
if (s + 1 == stack_.size() && stack_[s]->type() == NT_SOFTMAX) {
// Change the softmax to a tanh.
FullyConnected* fc = reinterpret_cast<FullyConnected*>(stack_[s]);
fc->ChangeType(NT_TANH);
}
master_series->AddToStack(stack_[s]);
stack_[s] = NULL;
}
for (int s = last_start + 1; s < stack_.size(); ++s) {
boosted_series->AddToStack(stack_[s]);
stack_[s] = NULL;
}
*start = master_series;
*end = boosted_series;
delete this;
}
// Appends the elements of the src series to this, removing from src and
// deleting it.
void Series::AppendSeries(Network* src) {
ASSERT_HOST(src->type() == NT_SERIES);
Series* src_series = reinterpret_cast<Series*>(src);
for (int s = 0; s < src_series->stack_.size(); ++s) {
AddToStack(src_series->stack_[s]);
src_series->stack_[s] = NULL;
}
delete src;
}
} // namespace tesseract.

91
lstm/series.h Normal file
View File

@ -0,0 +1,91 @@
///////////////////////////////////////////////////////////////////////
// File: series.h
// Description: Runs networks in series on the same input.
// Author: Ray Smith
// Created: Thu May 02 08:20:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_SERIES_H_
#define TESSERACT_LSTM_SERIES_H_
#include "plumbing.h"
namespace tesseract {
// Runs two or more networks in series (layers) on the same input.
class Series : public Plumbing {
public:
// ni_ and no_ will be set by AddToStack.
explicit Series(const STRING& name);
virtual ~Series();
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
virtual StaticShape OutputShape(const StaticShape& input_shape) const;
virtual STRING spec() const {
STRING spec("[");
for (int i = 0; i < stack_.size(); ++i)
spec += stack_[i]->spec();
spec += "]";
return spec;
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
// Returns the number of weights initialized.
virtual int InitWeights(float range, TRand* randomizer);
// Sets needs_to_backprop_ to needs_backprop and returns true if
// needs_backprop || any weights in this network so the next layer forward
// can be told to produce backprop for this layer if needed.
virtual bool SetupNeedsBackprop(bool needs_backprop);
// Returns an integer reduction factor that the network applies to the
// time sequence. Assumes that any 2-d is already eliminated. Used for
// scaling bounding boxes of truth data.
// WARNING: if GlobalMinimax is used to vary the scale, this will return
// the last used scale factor. Call it before any forward, and it will return
// the minimum scale factor of the paths through the GlobalMinimax.
virtual int XScaleFactor() const;
// Provides the (minimum) x scale factor to the network (of interest only to
// input units) so they can determine how to scale bounding boxes.
virtual void CacheXScaleFactor(int factor);
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output);
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas);
// Splits the series after the given index, returning the two parts and
// deletes itself. The first part, upto network with index last_start, goes
// into start, and the rest goes into end.
void SplitAt(int last_start, Series** start, Series** end);
// Appends the elements of the src series to this, removing from src and
// deleting it.
void AppendSeries(Network* src);
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_SERIES_H_

80
lstm/static_shape.h Normal file
View File

@ -0,0 +1,80 @@
///////////////////////////////////////////////////////////////////////
// File: static_shape.h
// Description: Defines the size of the 4-d tensor input/output from a network.
// Author: Ray Smith
// Created: Fri Oct 14 09:07:31 PST 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_STATIC_SHAPE_H_
#define TESSERACT_LSTM_STATIC_SHAPE_H_
#include "tprintf.h"
namespace tesseract {
// Enum describing the loss function to apply during training and/or the
// decoding method to apply at runtime.
enum LossType {
LT_NONE, // Undefined.
LT_CTC, // Softmax with standard CTC for training/decoding.
LT_SOFTMAX, // Outputs sum to 1 in fixed positions.
LT_LOGISTIC, // Logistic outputs with independent values.
};
// Simple class to hold the tensor shape that is known at network build time
// and the LossType of the loss funtion.
class StaticShape {
public:
StaticShape()
: batch_(0), height_(0), width_(0), depth_(0), loss_type_(LT_NONE) {}
int batch() const { return batch_; }
void set_batch(int value) { batch_ = value; }
int height() const { return height_; }
void set_height(int value) { height_ = value; }
int width() const { return width_; }
void set_width(int value) { width_ = value; }
int depth() const { return depth_; }
void set_depth(int value) { depth_ = value; }
LossType loss_type() const { return loss_type_; }
void set_loss_type(LossType value) { loss_type_ = value; }
void SetShape(int batch, int height, int width, int depth) {
batch_ = batch;
height_ = height;
width_ = width;
depth_ = depth;
}
void Print() const {
tprintf("Batch=%d, Height=%d, Width=%d, Depth=%d, loss=%d\n", batch_,
height_, width_, depth_, loss_type_);
}
private:
// Size of the 4-D tensor input/output to a network. A value of zero is
// allowed for all except depth_ and means to be determined at runtime, and
// regarded as variable.
// Number of elements in a batch, or number of frames in a video stream.
int batch_;
// Height of the image.
int height_;
// Width of the image.
int width_;
// Depth of the image. (Number of "nodes").
int depth_;
// How to train/interpret the output.
LossType loss_type_;
};
} // namespace tesseract
#endif // TESSERACT_LSTM_STATIC_SHAPE_H_

173
lstm/stridemap.cpp Normal file
View File

@ -0,0 +1,173 @@
///////////////////////////////////////////////////////////////////////
// File: stridemap.cpp
// Description: Indexing into a 4-d tensor held in a 2-d Array.
// Author: Ray Smith
// Created: Fri Sep 20 15:30:31 PST 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "stridemap.h"
namespace tesseract {
// Returns true if *this is a valid index.
bool StrideMap::Index::IsValid() const {
// Cheap check first.
for (int d = 0; d < FD_DIMSIZE; ++d) {
if (indices_[d] < 0) return false;
}
for (int d = 0; d < FD_DIMSIZE; ++d) {
if (indices_[d] > MaxIndexOfDim(static_cast<FlexDimensions>(d)))
return false;
}
return true;
}
// Returns true if the index of the given dimension is the last.
bool StrideMap::Index::IsLast(FlexDimensions dimension) const {
return MaxIndexOfDim(dimension) == indices_[dimension];
}
// Given that the dimensions upto and including dim-1 are valid, returns the
// maximum index for dimension dim.
int StrideMap::Index::MaxIndexOfDim(FlexDimensions dim) const {
int max_index = stride_map_->shape_[dim] - 1;
if (dim == FD_BATCH) return max_index;
int batch = indices_[FD_BATCH];
if (dim == FD_HEIGHT) {
if (batch >= stride_map_->heights_.size() ||
stride_map_->heights_[batch] > max_index)
return max_index;
return stride_map_->heights_[batch] - 1;
}
if (batch >= stride_map_->widths_.size() ||
stride_map_->widths_[batch] > max_index)
return max_index;
return stride_map_->widths_[batch] - 1;
}
// Adds the given offset to the given dimension. Returns true if the result
// makes a valid index.
bool StrideMap::Index::AddOffset(int offset, FlexDimensions dimension) {
indices_[dimension] += offset;
SetTFromIndices();
return IsValid();
}
// Increments the index in some encapsulated way that guarantees to remain
// valid until it returns false, meaning that the iteration is complete.
bool StrideMap::Index::Increment() {
for (int d = FD_DIMSIZE - 1; d >= 0; --d) {
if (!IsLast(static_cast<FlexDimensions>(d))) {
t_ += stride_map_->t_increments_[d];
++indices_[d];
return true;
}
t_ -= stride_map_->t_increments_[d] * indices_[d];
indices_[d] = 0;
// Now carry to the next dimension.
}
return false;
}
// Decrements the index in some encapsulated way that guarantees to remain
// valid until it returns false, meaning that the iteration (that started
// with InitToLast()) is complete.
bool StrideMap::Index::Decrement() {
for (int d = FD_DIMSIZE - 1; d >= 0; --d) {
if (indices_[d] > 0) {
--indices_[d];
if (d == FD_BATCH) {
// The upper limits of the other dimensions may have changed as a result
// of a different batch index, so they have to be reset.
InitToLastOfBatch(indices_[FD_BATCH]);
} else {
t_ -= stride_map_->t_increments_[d];
}
return true;
}
indices_[d] = MaxIndexOfDim(static_cast<FlexDimensions>(d));
t_ += stride_map_->t_increments_[d] * indices_[d];
// Now borrow from the next dimension.
}
return false;
}
// Initializes the indices to the last valid location in the given batch
// index.
void StrideMap::Index::InitToLastOfBatch(int batch) {
indices_[FD_BATCH] = batch;
for (int d = FD_BATCH + 1; d < FD_DIMSIZE; ++d) {
indices_[d] = MaxIndexOfDim(static_cast<FlexDimensions>(d));
}
SetTFromIndices();
}
// Computes and sets t_ from the current indices_.
void StrideMap::Index::SetTFromIndices() {
t_ = 0;
for (int d = 0; d < FD_DIMSIZE; ++d) {
t_ += stride_map_->t_increments_[d] * indices_[d];
}
}
// Sets up the stride for the given array of height, width pairs.
void StrideMap::SetStride(const std::vector<std::pair<int, int>>& h_w_pairs) {
int max_height = 0;
int max_width = 0;
for (const std::pair<int, int>& hw : h_w_pairs) {
int height = hw.first;
int width = hw.second;
heights_.push_back(height);
widths_.push_back(width);
if (height > max_height) max_height = height;
if (width > max_width) max_width = width;
}
shape_[FD_BATCH] = heights_.size();
shape_[FD_HEIGHT] = max_height;
shape_[FD_WIDTH] = max_width;
ComputeTIncrements();
}
// Scales width and height dimensions by the given factors.
void StrideMap::ScaleXY(int x_factor, int y_factor) {
for (int& height : heights_) height /= y_factor;
for (int& width : widths_) width /= x_factor;
shape_[FD_HEIGHT] /= y_factor;
shape_[FD_WIDTH] /= x_factor;
ComputeTIncrements();
}
// Reduces width to 1, across the batch, whatever the input size.
void StrideMap::ReduceWidthTo1() {
widths_.assign(widths_.size(), 1);
shape_[FD_WIDTH] = 1;
ComputeTIncrements();
}
// Transposes the width and height dimensions.
void StrideMap::TransposeXY() {
std::swap(shape_[FD_HEIGHT], shape_[FD_WIDTH]);
std::swap(heights_, widths_);
ComputeTIncrements();
}
// Computes t_increments_ from shape_.
void StrideMap::ComputeTIncrements() {
t_increments_[FD_DIMSIZE - 1] = 1;
for (int d = FD_DIMSIZE - 2; d >= 0; --d) {
t_increments_[d] = t_increments_[d + 1] * shape_[d + 1];
}
}
} // namespace tesseract

137
lstm/stridemap.h Normal file
View File

@ -0,0 +1,137 @@
///////////////////////////////////////////////////////////////////////
// File: stridemap.h
// Description: Indexing into a 4-d tensor held in a 2-d Array.
// Author: Ray Smith
// Created: Fri Sep 20 16:00:31 PST 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_STRIDEMAP_H_
#define TESSERACT_LSTM_STRIDEMAP_H_
#include <string.h>
#include <vector>
#include "tprintf.h"
namespace tesseract {
// Enum describing the dimensions of the 'Tensor' in a NetworkIO.
// A NetworkIO is analogous to a TF Tensor, except that the number of dimensions
// is fixed (4), and they always have the same meaning. The underlying
// representation is a 2-D array, for which the product batch*height*width
// is always dim1 and depth is always dim2. FlexDimensions is used only for
// batch, height, width with the StrideMap, and therefore represents the runtime
// shape. The build-time shape is defined by StaticShape.
enum FlexDimensions {
FD_BATCH, // Index of multiple images.
FD_HEIGHT, // y-coordinate in image.
FD_WIDTH, // x-coordinate in image.
FD_DIMSIZE, // Number of flexible non-depth dimensions.
};
// Encapsulation of information relating to the mapping from [batch][y][x] to
// the first index into the 2-d array underlying a NetworkIO.
class StrideMap {
public:
// Class holding the non-depth indices.
class Index {
public:
explicit Index(const StrideMap& stride_map) : stride_map_(&stride_map) {
InitToFirst();
}
Index(const StrideMap& stride_map, int batch, int y, int x)
: stride_map_(&stride_map) {
indices_[FD_BATCH] = batch;
indices_[FD_HEIGHT] = y;
indices_[FD_WIDTH] = x;
SetTFromIndices();
}
// Accesses the index to the underlying array.
int t() const { return t_; }
int index(FlexDimensions dimension) const { return indices_[dimension]; }
// Initializes the indices to the first valid location.
void InitToFirst() {
memset(indices_, 0, sizeof(indices_));
t_ = 0;
}
// Initializes the indices to the last valid location.
void InitToLast() { InitToLastOfBatch(MaxIndexOfDim(FD_BATCH)); }
// Returns true if *this is a valid index.
bool IsValid() const;
// Returns true if the index of the given dimension is the last.
bool IsLast(FlexDimensions dimension) const;
// Given that the dimensions upto and including dim-1 are valid, returns the
// maximum index for dimension dim.
int MaxIndexOfDim(FlexDimensions dim) const;
// Adds the given offset to the given dimension. Returns true if the result
// makes a valid index.
bool AddOffset(int offset, FlexDimensions dimension);
// Increments the index in some encapsulated way that guarantees to remain
// valid until it returns false, meaning that the iteration is complete.
bool Increment();
// Decrements the index in some encapsulated way that guarantees to remain
// valid until it returns false, meaning that the iteration (that started
// with InitToLast()) is complete.
bool Decrement();
private:
// Initializes the indices to the last valid location in the given batch
// index.
void InitToLastOfBatch(int batch);
// Computes and sets t_ from the current indices_.
void SetTFromIndices();
// Map into which *this is an index.
const StrideMap* stride_map_;
// Index to the first dimension of the underlying array.
int t_;
// Indices into the individual dimensions.
int indices_[FD_DIMSIZE];
};
StrideMap() {
memset(shape_, 0, sizeof(shape_));
memset(t_increments_, 0, sizeof(t_increments_));
}
// Default copy constructor and operator= are OK to use here!
// Sets up the stride for the given array of height, width pairs.
void SetStride(const std::vector<std::pair<int, int>>& h_w_pairs);
// Scales width and height dimensions by the given factors.
void ScaleXY(int x_factor, int y_factor);
// Reduces width to 1, across the batch, whatever the input size.
void ReduceWidthTo1();
// Transposes the width and height dimensions.
void TransposeXY();
// Returns the size of the given dimension.
int Size(FlexDimensions dimension) const { return shape_[dimension]; }
// Returns the total width required.
int Width() const { return t_increments_[FD_BATCH] * shape_[FD_BATCH]; }
private:
// Computes t_increments_ from shape_.
void ComputeTIncrements();
// The size of each non-depth dimension.
int shape_[FD_DIMSIZE];
// Precomputed 't' increments for each dimension. This is the value of
// the given dimension in the packed 3-d array that the shape_ represents.
int t_increments_[FD_DIMSIZE];
// Vector of size shape_[FD_BATCH] holds the height of each image in a batch.
std::vector<int> heights_;
// Vector of size shape_[FD_BATCH] holds the width of each image in a batch.
std::vector<int> widths_;
};
} // namespace tesseract
#endif // TESSERACT_LSTM_STRIDEMAP_H_

146
lstm/tfnetwork.cpp Normal file
View File

@ -0,0 +1,146 @@
///////////////////////////////////////////////////////////////////////
// File: tfnetwork.h
// Description: Encapsulation of an entire tensorflow graph as a
// Tesseract Network.
// Author: Ray Smith
// Created: Fri Feb 26 09:35:29 PST 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifdef INCLUDE_TENSORFLOW
#include "tfnetwork.h"
#include "allheaders.h"
#include "input.h"
#include "networkscratch.h"
using tensorflow::Status;
using tensorflow::Tensor;
using tensorflow::TensorShape;
namespace tesseract {
TFNetwork::TFNetwork(const STRING& name) : Network(NT_TENSORFLOW, name, 0, 0) {}
TFNetwork::~TFNetwork() {}
int TFNetwork::InitFromProtoStr(const string& proto_str) {
if (!model_proto_.ParseFromString(proto_str)) return 0;
return InitFromProto();
}
// Writes to the given file. Returns false in case of error.
// Should be overridden by subclasses, but called by their Serialize.
bool TFNetwork::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
string proto_str;
model_proto_.SerializeToString(&proto_str);
GenericVector<char> data;
data.init_to_size(proto_str.size(), 0);
memcpy(&data[0], proto_str.data(), proto_str.size());
if (!data.Serialize(fp)) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
// Should be overridden by subclasses, but NOT called by their DeSerialize.
bool TFNetwork::DeSerialize(bool swap, TFile* fp) {
GenericVector<char> data;
if (!data.DeSerialize(swap, fp)) return false;
if (!model_proto_.ParseFromArray(&data[0], data.size())) {
return false;
}
return InitFromProto();
}
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
void TFNetwork::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
vector<std::pair<string, Tensor>> tf_inputs;
int depth = input_shape_.depth();
ASSERT_HOST(depth == input.NumFeatures());
// TODO(rays) Allow batching. For now batch_size = 1.
const StrideMap& stride_map = input.stride_map();
// TF requires a tensor of shape float[batch, height, width, depth].
TensorShape shape{1, stride_map.Size(FD_HEIGHT), stride_map.Size(FD_WIDTH),
depth};
Tensor input_tensor(tensorflow::DT_FLOAT, shape);
// The flat() member gives a 1d array, with a data() member to get the data.
auto eigen_tensor = input_tensor.flat<float>();
memcpy(eigen_tensor.data(), input.f(0),
input.Width() * depth * sizeof(input.f(0)[0]));
// Add the tensor to the vector of inputs.
tf_inputs.emplace_back(model_proto_.image_input(), input_tensor);
// Provide tensors giving the width and/or height of the image if they are
// required. Some tf ops require a separate tensor with knowledge of the
// size of the input as they cannot obtain it from the input tensor. This is
// usually true in the case of ops that process a batch of variable-sized
// objects.
if (!model_proto_.image_widths().empty()) {
TensorShape size_shape{1};
Tensor width_tensor(tensorflow::DT_INT32, size_shape);
auto eigen_wtensor = width_tensor.flat<int32>();
*eigen_wtensor.data() = stride_map.Size(FD_WIDTH);
tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor);
}
if (!model_proto_.image_heights().empty()) {
TensorShape size_shape{1};
Tensor height_tensor(tensorflow::DT_INT32, size_shape);
auto eigen_htensor = height_tensor.flat<int32>();
*eigen_htensor.data() = stride_map.Size(FD_HEIGHT);
tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor);
}
vector<string> target_layers = {model_proto_.output_layer()};
vector<Tensor> outputs;
Status s = session_->Run(tf_inputs, target_layers, {}, &outputs);
ASSERT_HOST(s.ok());
ASSERT_HOST(outputs.size() == 1);
const Tensor& output_tensor = outputs[0];
// Check the dimensions of the output.
ASSERT_HOST(output_tensor.shape().dims() == 2);
int output_dim0 = output_tensor.shape().dim_size(0);
int output_dim1 = output_tensor.shape().dim_size(1);
ASSERT_HOST(output_dim1 == output_shape_.depth());
output->Resize2d(false, output_dim0, output_dim1);
auto eigen_output = output_tensor.flat<float>();
memcpy(output->f(0), eigen_output.data(),
output_dim0 * output_dim1 * sizeof(output->f(0)[0]));
}
int TFNetwork::InitFromProto() {
spec_ = model_proto_.spec();
input_shape_.SetShape(
model_proto_.batch_size(), std::max(0, model_proto_.y_size()),
std::max(0, model_proto_.x_size()), model_proto_.depth());
output_shape_.SetShape(model_proto_.batch_size(), 1, 0,
model_proto_.num_classes());
output_shape_.set_loss_type(model_proto_.using_ctc() ? LT_CTC : LT_SOFTMAX);
ni_ = input_shape_.height();
no_ = output_shape_.depth();
// Initialize the session_ with the graph. Since we can't get the graph
// back from the session_, we have to keep the proto as well
tensorflow::SessionOptions options;
session_.reset(NewSession(options));
Status s = session_->Create(model_proto_.graph());
if (s.ok()) return model_proto_.global_step();
tprintf("Session_->Create returned '%s'\n", s.error_message().c_str());
return 0;
}
} // namespace tesseract
#endif // ifdef INCLUDE_TENSORFLOW

91
lstm/tfnetwork.h Normal file
View File

@ -0,0 +1,91 @@
///////////////////////////////////////////////////////////////////////
// File: tfnetwork.h
// Description: Encapsulation of an entire tensorflow graph as a
// Tesseract Network.
// Author: Ray Smith
// Created: Fri Feb 26 09:35:29 PST 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_TFNETWORK_H_
#define TESSERACT_LSTM_TFNETWORK_H_
#ifdef INCLUDE_TENSORFLOW
#include <memory>
#include <string>
#include "network.h"
#include "static_shape.h"
#include "tfnetwork.proto.h"
#include "third_party/tensorflow/core/framework/graph.pb.h"
#include "third_party/tensorflow/core/public/session.h"
namespace tesseract {
class TFNetwork : public Network {
public:
explicit TFNetwork(const STRING& name);
virtual ~TFNetwork();
// Returns the required shape input to the network.
virtual StaticShape InputShape() const { return input_shape_; }
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
virtual StaticShape OutputShape(const StaticShape& input_shape) const {
return output_shape_;
}
virtual STRING spec() const { return spec_.c_str(); }
// Deserializes *this from a serialized TFNetwork proto. Returns 0 if failed,
// otherwise the global step of the serialized graph.
int InitFromProtoStr(const string& proto_str);
// The number of classes in this network should be equal to those in the
// recoder_ in LSTMRecognizer.
int num_classes() const { return output_shape_.depth(); }
// Writes to the given file. Returns false in case of error.
// Should be overridden by subclasses, but called by their Serialize.
virtual bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
// Should be overridden by subclasses, but NOT called by their DeSerialize.
virtual bool DeSerialize(bool swap, TFile* fp);
// Runs forward propagation of activations on the input line.
// See Network for a detailed discussion of the arguments.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output);
private:
int InitFromProto();
// The original network definition for reference.
string spec_;
// Input tensor parameters.
StaticShape input_shape_;
// Output tensor parameters.
StaticShape output_shape_;
// The tensor flow graph is contained in here.
std::unique_ptr<tensorflow::Session> session_;
// The serialized graph is also contained in here.
TFNetworkModel model_proto_;
};
} // namespace tesseract.
#endif // ifdef INCLUDE_TENSORFLOW
#endif // TESSERACT_TENSORFLOW_TFNETWORK_H_

61
lstm/tfnetwork.proto Normal file
View File

@ -0,0 +1,61 @@
syntax = "proto3";
package tesseract;
// TODO(rays) How to make this usable both in Google and open source?
import "third_party/tensorflow/core/framework/graph.proto";
// This proto is the interface between a python TF graph builder/trainer and
// the C++ world. The writer of this proto must provide fields as documented
// by the comments below.
// The graph must have a placeholder for NetworkIO, Widths and Heights. The
// following python code creates the appropriate placeholders:
//
// input_layer = tf.placeholder(tf.float32,
// shape=[batch_size, xsize, ysize, depth_dim],
// name='NetworkIO')
// widths = tf.placeholder(tf.int32, shape=[batch_size], name='Widths')
// heights = tf.placeholder(tf.int32, shape=[batch_size], name='Heights')
// # Flip x and y to the TF convention.
// input_layer = tf.transpose(input_layer, [0, 2, 1, 3])
//
// The widths and heights will be set to indicate the post-scaling size of the
// input image(s).
// For now batch_size is ignored and set to 1.
// The graph should return a 2-dimensional float32 tensor called 'softmax' of
// shape [sequence_length, num_classes], where sequence_length is allowed to
// be variable, given by the tensor itself.
// TODO(rays) determine whether it is worth providing for batch_size >1 and if
// so, how.
message TFNetworkModel {
// The TF graph definition. Required.
tensorflow.GraphDef graph = 1;
// The training index. Required to be > 0.
int64 global_step = 2;
// The original network definition for reference. Optional
string spec = 3;
// Input tensor parameters.
// Values per pixel. Required to be 1 or 3. Inputs assumed to be float32.
int32 depth = 4;
// Image size. Required. Zero implies flexible sizes, fixed if non-zero.
// If x_size > 0, images will be cropped/padded to the given size, after
// any scaling required by the y_size.
// If y_size > 0, images will be scaled isotropically to the given height.
int32 x_size = 5;
int32 y_size = 6;
// Number of images in a batch. Optional.
int32 batch_size = 8;
// Output tensor parameters.
// Number of output classes. Required to match the depth of the softmax.
int32 num_classes = 9;
// True if this network needs CTC-like decoding, dropping duplicated labels.
// The decoder always drops the null character.
bool using_ctc = 10;
// Name of input image tensor.
string image_input = 11;
// Name of image height and width tensors.
string image_widths = 12;
string image_heights = 13;
// Name of output (softmax) tensor.
string output_layer = 14;
}

443
lstm/weightmatrix.cpp Normal file
View File

@ -0,0 +1,443 @@
///////////////////////////////////////////////////////////////////////
// File: weightmatrix.h
// Description: Hides distinction between float/int implementations.
// Author: Ray Smith
// Created: Tue Jun 17 11:46:20 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "weightmatrix.h"
#undef NONX86_BUILD
#if defined(ANDROID_BUILD) or defined(__PPC__) or defined(_ARCH_PPC64)
#define NONX86_BUILD 1
#endif
#ifndef NONX86_BUILD
#include <cpuid.h>
#endif
#include "dotproductavx.h"
#include "dotproductsse.h"
#include "statistc.h"
#include "svutil.h"
#include "tprintf.h"
namespace tesseract {
// Architecture detector. Add code here to detect any other architectures for
// SIMD-based faster dot product functions. Intended to be a single static
// object, but it does no real harm to have more than one.
class SIMDDetect {
public:
SIMDDetect()
: arch_tested_(false), avx_available_(false), sse_available_(false) {}
// Returns true if AVX is available on this system.
bool IsAVXAvailable() {
if (!arch_tested_) TestArchitecture();
return avx_available_;
}
// Returns true if SSE4.1 is available on this system.
bool IsSSEAvailable() {
if (!arch_tested_) TestArchitecture();
return sse_available_;
}
private:
// Tests the architecture in a system-dependent way to detect AVX, SSE and
// any other available SIMD equipment.
void TestArchitecture() {
SVAutoLock lock(&arch_mutex_);
if (arch_tested_) return;
#if defined(__linux__) && !defined(NONX86_BUILD)
if (__get_cpuid_max(0, NULL) >= 1) {
unsigned int eax, ebx, ecx, edx;
__get_cpuid(1, &eax, &ebx, &ecx, &edx);
sse_available_ = (ecx & 0x00080000) != 0;
avx_available_ = (ecx & 0x10000000) != 0;
}
#endif
if (avx_available_) tprintf("Found AVX\n");
if (sse_available_) tprintf("Found SSE\n");
arch_tested_ = true;
}
private:
// Detect architecture in only a single thread.
SVMutex arch_mutex_;
// Flag set to true after TestArchitecture has been called.
bool arch_tested_;
// If true, then AVX has been detected.
bool avx_available_;
// If true, then SSe4.1 has been detected.
bool sse_available_;
};
static SIMDDetect detector;
// Copies the whole input transposed, converted to double, into *this.
void TransposedArray::Transpose(const GENERIC_2D_ARRAY<double>& input) {
int width = input.dim1();
int num_features = input.dim2();
ResizeNoInit(num_features, width);
for (int t = 0; t < width; ++t) WriteStrided(t, input[t]);
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad,
float weight_range, TRand* randomizer) {
int_mode_ = false;
use_ada_grad_ = ada_grad;
if (use_ada_grad_) dw_sq_sum_.Resize(no, ni, 0.0);
wf_.Resize(no, ni, 0.0);
if (randomizer != NULL) {
for (int i = 0; i < no; ++i) {
for (int j = 0; j < ni; ++j) {
wf_[i][j] = randomizer->SignedRand(weight_range);
}
}
}
InitBackward();
return ni * no;
}
// Converts a float network to an int network. Each set of input weights that
// corresponds to a single output weight is converted independently:
// Compute the max absolute value of the weight set.
// Scale so the max absolute value becomes MAX_INT8.
// Round to integer.
// Store a multiplicative scale factor (as a double) that will reproduce
// the original value, subject to rounding errors.
void WeightMatrix::ConvertToInt() {
wi_.ResizeNoInit(wf_.dim1(), wf_.dim2());
scales_.init_to_size(wi_.dim1(), 0.0);
int dim2 = wi_.dim2();
for (int t = 0; t < wi_.dim1(); ++t) {
double* f_line = wf_[t];
inT8* i_line = wi_[t];
double max_abs = 0.0;
for (int f = 0; f < dim2; ++f) {
double abs_val = fabs(f_line[f]);
if (abs_val > max_abs) max_abs = abs_val;
}
double scale = max_abs / MAX_INT8;
scales_[t] = scale;
if (scale == 0.0) scale = 1.0;
for (int f = 0; f < dim2; ++f) {
i_line[f] = IntCastRounded(f_line[f] / scale);
}
}
wf_.Resize(1, 1, 0.0);
int_mode_ = true;
}
// Allocates any needed memory for running Backward, and zeroes the deltas,
// thus eliminating any existing momentum.
void WeightMatrix::InitBackward() {
int no = int_mode_ ? wi_.dim1() : wf_.dim1();
int ni = int_mode_ ? wi_.dim2() : wf_.dim2();
dw_.Resize(no, ni, 0.0);
updates_.Resize(no, ni, 0.0);
wf_t_.Transpose(wf_);
}
// Flag on mode to indicate that this weightmatrix uses inT8.
const int kInt8Flag = 1;
// Flag on mode to indicate that this weightmatrix uses ada grad.
const int kAdaGradFlag = 4;
// Flag on mode to indicate that this weightmatrix uses double. Set
// independently of kInt8Flag as even in int mode the scales can
// be float or double.
const int kDoubleFlag = 128;
// Writes to the given file. Returns false in case of error.
bool WeightMatrix::Serialize(bool training, TFile* fp) const {
// For backward compatability, add kDoubleFlag to mode to indicate the doubles
// format, without errs, so we can detect and read old format weight matrices.
uinT8 mode = (int_mode_ ? kInt8Flag : 0) |
(use_ada_grad_ ? kAdaGradFlag : 0) | kDoubleFlag;
if (fp->FWrite(&mode, sizeof(mode), 1) != 1) return false;
if (int_mode_) {
if (!wi_.Serialize(fp)) return false;
if (!scales_.Serialize(fp)) return false;
} else {
if (!wf_.Serialize(fp)) return false;
if (training && !updates_.Serialize(fp)) return false;
if (training && use_ada_grad_ && !dw_sq_sum_.Serialize(fp)) return false;
}
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool WeightMatrix::DeSerialize(bool training, bool swap, TFile* fp) {
uinT8 mode = 0;
if (fp->FRead(&mode, sizeof(mode), 1) != 1) return false;
int_mode_ = (mode & kInt8Flag) != 0;
use_ada_grad_ = (mode & kAdaGradFlag) != 0;
if ((mode & kDoubleFlag) == 0) return DeSerializeOld(training, swap, fp);
if (int_mode_) {
if (!wi_.DeSerialize(swap, fp)) return false;
if (!scales_.DeSerialize(swap, fp)) return false;
} else {
if (!wf_.DeSerialize(swap, fp)) return false;
if (training) {
InitBackward();
if (!updates_.DeSerialize(swap, fp)) return false;
if (use_ada_grad_ && !dw_sq_sum_.DeSerialize(swap, fp)) return false;
}
}
return true;
}
// As DeSerialize, but reads an old (float) format WeightMatrix for
// backward compatability.
bool WeightMatrix::DeSerializeOld(bool training, bool swap, TFile* fp) {
GENERIC_2D_ARRAY<float> float_array;
if (int_mode_) {
if (!wi_.DeSerialize(swap, fp)) return false;
GenericVector<float> old_scales;
if (!old_scales.DeSerialize(swap, fp)) return false;
scales_.init_to_size(old_scales.size(), 0.0);
for (int i = 0; i < old_scales.size(); ++i) scales_[i] = old_scales[i];
} else {
if (!float_array.DeSerialize(swap, fp)) return false;
FloatToDouble(float_array, &wf_);
}
if (training) {
InitBackward();
if (!float_array.DeSerialize(swap, fp)) return false;
FloatToDouble(float_array, &updates_);
// Errs was only used in int training, which is now dead.
if (!float_array.DeSerialize(swap, fp)) return false;
}
return true;
}
// Computes matrix.vector v = Wu.
// u is of size W.dim2() - 1 and the output v is of size W.dim1().
// u is imagined to have an extra element at the end with value 1, to
// implement the bias, but it doesn't actually have it.
// Asserts that the call matches what we have.
void WeightMatrix::MatrixDotVector(const double* u, double* v) const {
ASSERT_HOST(!int_mode_);
MatrixDotVectorInternal(wf_, true, false, u, v);
}
void WeightMatrix::MatrixDotVector(const inT8* u, double* v) const {
ASSERT_HOST(int_mode_);
int num_out = wi_.dim1();
int num_in = wi_.dim2() - 1;
for (int i = 0; i < num_out; ++i) {
const inT8* Wi = wi_[i];
int total = 0;
if (detector.IsSSEAvailable()) {
total = IntDotProductSSE(u, Wi, num_in);
} else {
for (int j = 0; j < num_in; ++j) total += Wi[j] * u[j];
}
// Add in the bias and correct for integer values.
v[i] = (static_cast<double>(total) / MAX_INT8 + Wi[num_in]) * scales_[i];
}
}
// MatrixDotVector for peep weights, MultiplyAccumulate adds the
// component-wise products of *this[0] and v to inout.
void WeightMatrix::MultiplyAccumulate(const double* v, double* inout) {
ASSERT_HOST(!int_mode_);
ASSERT_HOST(wf_.dim1() == 1);
int n = wf_.dim2();
const double* u = wf_[0];
for (int i = 0; i < n; ++i) {
inout[i] += u[i] * v[i];
}
}
// Computes vector.matrix v = uW.
// u is of size W.dim1() and the output v is of size W.dim2() - 1.
// The last result is discarded, as v is assumed to have an imaginary
// last value of 1, as with MatrixDotVector.
void WeightMatrix::VectorDotMatrix(const double* u, double* v) const {
ASSERT_HOST(!int_mode_);
MatrixDotVectorInternal(wf_t_, false, true, u, v);
}
// Fills dw_[i][j] with the dot product u[i][] . v[j][], using elements from
// u and v. In terms of the neural network, u is the gradients and v is the
// inputs.
// Note that (matching MatrixDotVector) v[last][] is missing, presumed 1.0.
// Runs parallel if requested. Note that u and v must be transposed.
void WeightMatrix::SumOuterTransposed(const TransposedArray& u,
const TransposedArray& v,
bool in_parallel) {
ASSERT_HOST(!int_mode_);
int num_outputs = dw_.dim1();
ASSERT_HOST(u.dim1() == num_outputs);
ASSERT_HOST(u.dim2() == v.dim2());
int num_inputs = dw_.dim2() - 1;
int num_samples = u.dim2();
// v is missing the last element in dim1.
ASSERT_HOST(v.dim1() == num_inputs);
#ifdef _OPENMP
#pragma omp parallel for num_threads(4) if (in_parallel)
#endif
for (int i = 0; i < num_outputs; ++i) {
double* dwi = dw_[i];
const double* ui = u[i];
for (int j = 0; j < num_inputs; ++j) {
dwi[j] = DotProduct(ui, v[j], num_samples);
}
// The last element of v is missing, presumed 1.0f.
double total = 0.0;
for (int k = 0; k < num_samples; ++k) total += ui[k];
dwi[num_inputs] = total;
}
}
// Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the adagrad computation iff
// use_ada_grad_ is true.
void WeightMatrix::Update(double learning_rate, double momentum,
int num_samples) {
ASSERT_HOST(!int_mode_);
if (use_ada_grad_ && num_samples > 0) {
dw_sq_sum_.SumSquares(dw_);
dw_.AdaGradScaling(dw_sq_sum_, num_samples);
}
dw_ *= learning_rate;
updates_ += dw_;
if (momentum > 0.0) wf_ += updates_;
if (momentum >= 0.0) updates_ *= momentum;
wf_t_.Transpose(wf_);
}
// Adds the dw_ in other to the dw_ is *this.
void WeightMatrix::AddDeltas(const WeightMatrix& other) {
ASSERT_HOST(dw_.dim1() == other.dw_.dim1());
ASSERT_HOST(dw_.dim2() == other.dw_.dim2());
dw_ += other.dw_;
}
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// *changed.
void WeightMatrix::CountAlternators(const WeightMatrix& other, double* same,
double* changed) const {
int num_outputs = updates_.dim1();
int num_inputs = updates_.dim2();
ASSERT_HOST(num_outputs == other.updates_.dim1());
ASSERT_HOST(num_inputs == other.updates_.dim2());
for (int i = 0; i < num_outputs; ++i) {
const double* this_i = updates_[i];
const double* other_i = other.updates_[i];
for (int j = 0; j < num_inputs; ++j) {
double product = this_i[j] * other_i[j];
if (product < 0.0)
*changed -= product;
else
*same += product;
}
}
}
// Helper computes an integer histogram bucket for a weight and adds it
// to the histogram.
const int kHistogramBuckets = 16;
static void HistogramWeight(double weight, STATS* histogram) {
int bucket = kHistogramBuckets - 1;
if (weight != 0.0) {
double logval = -log2(fabs(weight));
bucket = ClipToRange(IntCastRounded(logval), 0, kHistogramBuckets - 1);
}
histogram->add(bucket, 1);
}
void WeightMatrix::Debug2D(const char* msg) {
STATS histogram(0, kHistogramBuckets);
if (int_mode_) {
for (int i = 0; i < wi_.dim1(); ++i) {
for (int j = 0; j < wi_.dim2(); ++j) {
HistogramWeight(wi_[i][j] * scales_[i], &histogram);
}
}
} else {
for (int i = 0; i < wf_.dim1(); ++i) {
for (int j = 0; j < wf_.dim2(); ++j) {
HistogramWeight(wf_[i][j], &histogram);
}
}
}
tprintf("%s\n", msg);
histogram.print();
}
// Computes and returns the dot product of the two n-vectors u and v.
/* static */
double WeightMatrix::DotProduct(const double* u, const double* v, int n) {
// Note: because the order of addition is different among the 3 DotProduct
// functions, the results can (and do) vary slightly (although they agree
// to within about 4e-15). This produces different results when running
// training, despite all random inputs being precisely equal.
// To get consistent results, use just one of these DotProduct functions.
// On a test multi-layer network, serial is 57% slower than sse, and avx
// is about 8% faster than sse. This suggests that the time is memory
// bandwidth constrained and could benefit from holding the reused vector
// in AVX registers.
if (detector.IsAVXAvailable()) return DotProductAVX(u, v, n);
if (detector.IsSSEAvailable()) return DotProductSSE(u, v, n);
double total = 0.0;
for (int k = 0; k < n; ++k) total += u[k] * v[k];
return total;
}
// Utility function converts an array of float to the corresponding array
// of double.
/* static */
void WeightMatrix::FloatToDouble(const GENERIC_2D_ARRAY<float>& wf,
GENERIC_2D_ARRAY<double>* wd) {
int dim1 = wf.dim1();
int dim2 = wf.dim2();
wd->ResizeNoInit(dim1, dim2);
for (int i = 0; i < dim1; ++i) {
const float* wfi = wf[i];
double* wdi = (*wd)[i];
for (int j = 0; j < dim2; ++j) wdi[j] = static_cast<double>(wfi[j]);
}
}
// Computes matrix.vector v = Wu.
// u is of size W.dim2() - add_bias_fwd and the output v is of size
// W.dim1() - skip_bias_back.
// If add_bias_fwd, u is imagined to have an extra element at the end with value
// 1, to implement the bias, weight.
// If skip_bias_back, we are actullay performing the backwards product on a
// transposed matrix, so we need to drop the v output corresponding to the last
// element in dim1.
void WeightMatrix::MatrixDotVectorInternal(const GENERIC_2D_ARRAY<double>& w,
bool add_bias_fwd,
bool skip_bias_back, const double* u,
double* v) {
int num_results = w.dim1() - skip_bias_back;
int extent = w.dim2() - add_bias_fwd;
for (int i = 0; i < num_results; ++i) {
const double* wi = w[i];
double total = DotProduct(wi, u, extent);
if (add_bias_fwd) total += wi[extent]; // The bias value.
v[i] = total;
}
}
} // namespace tesseract.
#undef NONX86_BUILD

183
lstm/weightmatrix.h Normal file
View File

@ -0,0 +1,183 @@
///////////////////////////////////////////////////////////////////////
// File: weightmatrix.h
// Description: Hides distinction between float/int implementations.
// Author: Ray Smith
// Created: Tue Jun 17 09:05:39 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_WEIGHTMATRIX_H_
#define TESSERACT_LSTM_WEIGHTMATRIX_H_
#include "genericvector.h"
#include "matrix.h"
#include "tprintf.h"
namespace tesseract {
// Convenience instantiation of GENERIC_2D_ARRAY<double> with additional
// operations to write a strided vector, so the transposed form of the input
// is memory-contiguous.
class TransposedArray : public GENERIC_2D_ARRAY<double> {
public:
// Copies the whole input transposed, converted to double, into *this.
void Transpose(const GENERIC_2D_ARRAY<double>& input);
// Writes a vector of data representing a timestep (gradients or sources).
// The data is assumed to be of size1 in size (the strided dimension).
void WriteStrided(int t, const float* data) {
int size1 = dim1();
for (int i = 0; i < size1; ++i) put(i, t, data[i]);
}
void WriteStrided(int t, const double* data) {
int size1 = dim1();
for (int i = 0; i < size1; ++i) put(i, t, data[i]);
}
// Prints the first and last num elements of the un-transposed array.
void PrintUnTransposed(int num) {
int num_features = dim1();
int width = dim2();
for (int y = 0; y < num_features; ++y) {
for (int t = 0; t < width; ++t) {
if (num == 0 || t < num || t + num >= width) {
tprintf(" %g", (*this)(y, t));
}
}
tprintf("\n");
}
}
}; // class TransposedArray
// Generic weight matrix for network layers. Can store the matrix as either
// an array of floats or inT8. Provides functions to compute the forward and
// backward steps with the matrix and updates to the weights.
class WeightMatrix {
public:
WeightMatrix() : int_mode_(false), use_ada_grad_(false) {}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
// Note the order is outputs, inputs, as this is the order of indices to
// the matrix, so the adjacent elements are multiplied by the input during
// a forward operation.
int InitWeightsFloat(int no, int ni, bool ada_grad, float weight_range,
TRand* randomizer);
// Converts a float network to an int network. Each set of input weights that
// corresponds to a single output weight is converted independently:
// Compute the max absolute value of the weight set.
// Scale so the max absolute value becomes MAX_INT8.
// Round to integer.
// Store a multiplicative scale factor (as a float) that will reproduce
// the original value, subject to rounding errors.
void ConvertToInt();
// Accessors.
bool is_int_mode() const {
return int_mode_;
}
int NumOutputs() const { return int_mode_ ? wi_.dim1() : wf_.dim1(); }
// Provides one set of weights. Only used by peep weight maxpool.
const double* GetWeights(int index) const { return wf_[index]; }
// Provides access to the deltas (dw_).
double GetDW(int i, int j) const { return dw_(i, j); }
// Allocates any needed memory for running Backward, and zeroes the deltas,
// thus eliminating any existing momentum.
void InitBackward();
// Writes to the given file. Returns false in case of error.
bool Serialize(bool training, TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool training, bool swap, TFile* fp);
// As DeSerialize, but reads an old (float) format WeightMatrix for
// backward compatability.
bool DeSerializeOld(bool training, bool swap, TFile* fp);
// Computes matrix.vector v = Wu.
// u is of size W.dim2() - 1 and the output v is of size W.dim1().
// u is imagined to have an extra element at the end with value 1, to
// implement the bias, but it doesn't actually have it.
// Asserts that the call matches what we have.
void MatrixDotVector(const double* u, double* v) const;
void MatrixDotVector(const inT8* u, double* v) const;
// MatrixDotVector for peep weights, MultiplyAccumulate adds the
// component-wise products of *this[0] and v to inout.
void MultiplyAccumulate(const double* v, double* inout);
// Computes vector.matrix v = uW.
// u is of size W.dim1() and the output v is of size W.dim2() - 1.
// The last result is discarded, as v is assumed to have an imaginary
// last value of 1, as with MatrixDotVector.
void VectorDotMatrix(const double* u, double* v) const;
// Fills dw_[i][j] with the dot product u[i][] . v[j][], using elements
// from u and v, starting with u[i][offset] and v[j][offset].
// Note that (matching MatrixDotVector) v[last][] is missing, presumed 1.0.
// Runs parallel if requested. Note that inputs must be transposed.
void SumOuterTransposed(const TransposedArray& u, const TransposedArray& v,
bool parallel);
// Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the adagrad computation iff
// use_ada_grad_ is true.
void Update(double learning_rate, double momentum, int num_samples);
// Adds the dw_ in other to the dw_ is *this.
void AddDeltas(const WeightMatrix& other);
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// *changed.
void CountAlternators(const WeightMatrix& other, double* same,
double* changed) const;
void Debug2D(const char* msg);
// Computes and returns the dot product of the two n-vectors u and v.
static double DotProduct(const double* u, const double* v, int n);
// Utility function converts an array of float to the corresponding array
// of double.
static void FloatToDouble(const GENERIC_2D_ARRAY<float>& wf,
GENERIC_2D_ARRAY<double>* wd);
private:
// Computes matrix.vector v = Wu.
// u is of size starts.back()+extents.back() and the output v is of size
// starts.size().
// The weight matrix w, is of size starts.size()xMAX(extents)+add_bias_fwd.
// If add_bias_fwd, an extra element at the end of w[i] is the bias weight
// and is added to v[i].
static void MatrixDotVectorInternal(const GENERIC_2D_ARRAY<double>& w,
bool add_bias_fwd, bool skip_bias_back,
const double* u, double* v);
private:
// Choice between float and 8 bit int implementations.
GENERIC_2D_ARRAY<double> wf_;
GENERIC_2D_ARRAY<inT8> wi_;
// Transposed copy of wf_, used only for Backward, and set with each Update.
TransposedArray wf_t_;
// Which of wf_ and wi_ are we actually using.
bool int_mode_;
// True if we are running adagrad in this weight matrix.
bool use_ada_grad_;
// If we are using wi_, then scales_ is a factor to restore the row product
// with a vector to the correct range.
GenericVector<double> scales_;
// Weight deltas. dw_ is the new delta, and updates_ the momentum-decaying
// amount to be added to wf_/wi_.
GENERIC_2D_ARRAY<double> dw_;
GENERIC_2D_ARRAY<double> updates_;
// Iff use_ada_grad_, the sum of squares of dw_. The number of samples is
// given to Update(). Serialized iff use_ada_grad_.
GENERIC_2D_ARRAY<double> dw_sq_sum_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_WEIGHTMATRIX_H_

View File

@ -850,6 +850,7 @@ void BaselineDetect::ComputeBaselineSplinesAndXheights(const ICOORD& page_tr,
Pix* pix_spline = pix_debug_ ? pixConvertTo32(pix_debug_) : NULL;
for (int i = 0; i < blocks_.size(); ++i) {
BaselineBlock* bl_block = blocks_[i];
if (enable_splines)
bl_block->PrepareForSplineFitting(page_tr, remove_noise);
bl_block->FitBaselineSplines(enable_splines, show_final_rows, textord);
if (pix_spline) {

View File

@ -1632,6 +1632,10 @@ TO_BLOCK* ColPartition::MakeBlock(const ICOORD& bleft, const ICOORD& tright,
ColPartition_LIST* used_parts) {
if (block_parts->empty())
return NULL; // Nothing to do.
// If the block_parts are not in reading order, then it will make an invalid
// block polygon and bounding_box, so sort by bounding box now just to make
// sure.
block_parts->sort(&ColPartition::SortByBBox);
ColPartition_IT it(block_parts);
ColPartition* part = it.data();
PolyBlockType type = part->type();

View File

@ -704,6 +704,25 @@ class ColPartition : public ELIST2_LINK {
// doing a SideSearch when you want things in the same page column.
bool IsInSameColumnAs(const ColPartition& part) const;
// Sort function to sort by bounding box.
static int SortByBBox(const void* p1, const void* p2) {
const ColPartition* part1 =
*reinterpret_cast<const ColPartition* const*>(p1);
const ColPartition* part2 =
*reinterpret_cast<const ColPartition* const*>(p2);
int mid_y1 = part1->bounding_box_.y_middle();
int mid_y2 = part2->bounding_box_.y_middle();
if ((part2->bounding_box_.bottom() <= mid_y1 &&
mid_y1 <= part2->bounding_box_.top()) ||
(part1->bounding_box_.bottom() <= mid_y2 &&
mid_y2 <= part1->bounding_box_.top())) {
// Sort by increasing x.
return part1->bounding_box_.x_middle() - part2->bounding_box_.x_middle();
}
// Sort by decreasing y.
return mid_y2 - mid_y1;
}
// Sets the column bounds. Primarily used in testing.
void set_first_column(int column) {
first_column_ = column;

View File

@ -251,6 +251,7 @@ void Textord::filter_blobs(ICOORD page_tr, // top right
&block->noise_blobs,
&block->small_blobs,
&block->large_blobs);
if (block->line_size == 0) block->line_size = 1;
block->line_spacing = block->line_size *
(tesseract::CCStruct::kDescenderFraction +
tesseract::CCStruct::kXHeightFraction +
@ -769,6 +770,7 @@ void Textord::TransferDiacriticsToBlockGroups(BLOBNBOX_LIST* diacritic_blobs,
PointerVector<WordWithBox> word_ptrs;
for (int g = 0; g < groups.size(); ++g) {
const BlockGroup* group = groups[g];
if (group->bounding_box.null_box()) continue;
WordGrid word_grid(group->min_xheight, group->bounding_box.botleft(),
group->bounding_box.topright());
for (int b = 0; b < group->blocks.size(); ++b) {

View File

@ -1323,9 +1323,10 @@ BOOL8 Textord::make_a_word_break(
we may need to set PARTICULAR spaces to fuzzy or not. The values will ONLY
be used if the function returns TRUE - ie the word is to be broken.
*/
blanks = (uinT8) (current_gap / row->space_size);
if (blanks < 1)
blanks = 1;
int num_blanks = current_gap;
if (row->space_size > 1.0f)
num_blanks = IntCastRounded(current_gap / row->space_size);
blanks = static_cast<uinT8>(ClipToRange(num_blanks, 1, MAX_UINT8));
fuzzy_sp = FALSE;
fuzzy_non = FALSE;
/*

View File

@ -3,6 +3,7 @@ AM_CPPFLAGS += \
-DUSE_STD_NAMESPACE -DPANGO_ENABLE_ENGINE\
-I$(top_srcdir)/ccmain -I$(top_srcdir)/api \
-I$(top_srcdir)/ccutil -I$(top_srcdir)/ccstruct \
-I$(top_srcdir)/lstm -I$(top_srcdir)/arch \
-I$(top_srcdir)/viewer \
-I$(top_srcdir)/textord -I$(top_srcdir)/dict \
-I$(top_srcdir)/classify -I$(top_srcdir)/display \
@ -45,7 +46,7 @@ libtesseract_tessopt_la_SOURCES = \
tessopt.cpp
bin_PROGRAMS = ambiguous_words classifier_tester cntraining combine_tessdata \
dawg2wordlist mftraining set_unicharset_properties shapeclustering \
dawg2wordlist lstmtraining mftraining set_unicharset_properties shapeclustering \
text2image unicharset_extractor wordlist2dawg
ambiguous_words_SOURCES = ambiguous_words.cpp
@ -58,6 +59,9 @@ ambiguous_words_LDADD += \
../textord/libtesseract_textord.la \
../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \
@ -82,6 +86,9 @@ classifier_tester_LDADD += \
../textord/libtesseract_textord.la \
../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \
@ -115,6 +122,9 @@ cntraining_LDADD += \
../textord/libtesseract_textord.la \
../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \
@ -136,6 +146,9 @@ if USING_MULTIPLELIBS
dawg2wordlist_LDADD += \
../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \
@ -150,6 +163,33 @@ dawg2wordlist_LDADD += \
../api/libtesseract.la
endif
lstmtraining_SOURCES = lstmtraining.cpp
#lstmtraining_LDFLAGS = -static
lstmtraining_LDADD = \
libtesseract_training.la \
libtesseract_tessopt.la \
$(libicu)
if USING_MULTIPLELIBS
lstmtraining_LDADD += \
../textord/libtesseract_textord.la \
../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \
../ccmain/libtesseract_main.la \
../cube/libtesseract_cube.la \
../neural_networks/runtime/libtesseract_neural.la \
../wordrec/libtesseract_wordrec.la \
../ccutil/libtesseract_ccutil.la
else
lstmtraining_LDADD += \
../api/libtesseract.la
endif
mftraining_SOURCES = mftraining.cpp mergenf.cpp
#mftraining_LDFLAGS = -static
mftraining_LDADD = \
@ -160,6 +200,9 @@ mftraining_LDADD += \
../textord/libtesseract_textord.la \
../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \
@ -185,6 +228,9 @@ set_unicharset_properties_LDADD += \
../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \
../ccstruct/libtesseract_ccstruct.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \
../ccmain/libtesseract_main.la \
@ -207,6 +253,9 @@ shapeclustering_LDADD += \
../textord/libtesseract_textord.la \
../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \
@ -230,6 +279,9 @@ text2image_LDADD += \
../textord/libtesseract_textord.la \
../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \
@ -266,6 +318,9 @@ if USING_MULTIPLELIBS
wordlist2dawg_LDADD += \
../classify/libtesseract_classify.la \
../dict/libtesseract_dict.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \
../viewer/libtesseract_viewer.la \

View File

@ -22,10 +22,36 @@
#include <stdlib.h>
#include "allheaders.h" // from leptonica
#include "genericvector.h"
#include "helpers.h" // For TRand.
#include "rect.h"
namespace tesseract {
// A randomized perspective distortion can be applied to synthetic input.
// The perspective distortion comes from leptonica, which uses 2 sets of 4
// corners to determine the distortion. There are random values for each of
// the x numbers x0..x3 and y0..y3, except for x2 and x3 which are instead
// defined in terms of a single shear value. This reduces the degrees of
// freedom enough to make the distortion more realistic than it would otherwise
// be if all 8 coordinates could move independently.
// One additional factor is used for the color of the pixels that don't exist
// in the source image.
// Name for each of the randomizing factors.
enum FactorNames {
FN_INCOLOR,
FN_Y0,
FN_Y1,
FN_Y2,
FN_Y3,
FN_X0,
FN_X1,
FN_SHEAR,
// x2 = x1 - shear
// x3 = x0 + shear
FN_NUM_FACTORS
};
// Rotation is +/- kRotationRange radians.
const float kRotationRange = 0.02f;
// Number of grey levels to shift by for each exposure step.
@ -144,4 +170,141 @@ Pix* DegradeImage(Pix* input, int exposure, TRand* randomizer,
return input;
}
// Creates and returns a Pix distorted by various means according to the bool
// flags. If boxes is not NULL, the boxes are resized/positioned according to
// any spatial distortion and also by the integer reduction factor box_scale
// so they will match what the network will output.
// Returns NULL on error. The returned Pix must be pixDestroyed.
Pix* PrepareDistortedPix(const Pix* pix, bool perspective, bool invert,
bool white_noise, bool smooth_noise, bool blur,
int box_reduction, TRand* randomizer,
GenericVector<TBOX>* boxes) {
Pix* distorted = pixCopy(NULL, const_cast<Pix*>(pix));
// Things to do to synthetic training data.
if (invert && randomizer->SignedRand(1.0) < 0)
pixInvert(distorted, distorted);
if ((white_noise || smooth_noise) && randomizer->SignedRand(1.0) > 0.0) {
// TODO(rays) Cook noise in a more thread-safe manner than rand().
// Attempt to make the sequences reproducible.
srand(randomizer->IntRand());
Pix* pixn = pixAddGaussianNoise(distorted, 8.0);
pixDestroy(&distorted);
if (smooth_noise) {
distorted = pixBlockconv(pixn, 1, 1);
pixDestroy(&pixn);
} else {
distorted = pixn;
}
}
if (blur && randomizer->SignedRand(1.0) > 0.0) {
Pix* blurred = pixBlockconv(distorted, 1, 1);
pixDestroy(&distorted);
distorted = blurred;
}
if (perspective)
GeneratePerspectiveDistortion(0, 0, randomizer, &distorted, boxes);
if (boxes != NULL) {
for (int b = 0; b < boxes->size(); ++b) {
(*boxes)[b].scale(1.0f / box_reduction);
if ((*boxes)[b].width() <= 0)
(*boxes)[b].set_right((*boxes)[b].left() + 1);
}
}
return distorted;
}
// Distorts anything that has a non-null pointer with the same pseudo-random
// perspective distortion. Width and height only need to be set if there
// is no pix. If there is a pix, then they will be taken from there.
void GeneratePerspectiveDistortion(int width, int height, TRand* randomizer,
Pix** pix, GenericVector<TBOX>* boxes) {
if (pix != NULL && *pix != NULL) {
width = pixGetWidth(*pix);
height = pixGetHeight(*pix);
}
float* im_coeffs = NULL;
float* box_coeffs = NULL;
l_int32 incolor =
ProjectiveCoeffs(width, height, randomizer, &im_coeffs, &box_coeffs);
if (pix != NULL && *pix != NULL) {
// Transform the image.
Pix* transformed = pixProjective(*pix, im_coeffs, incolor);
if (transformed == NULL) {
tprintf("Projective transformation failed!!\n");
return;
}
pixDestroy(pix);
*pix = transformed;
}
if (boxes != NULL) {
// Transform the boxes.
for (int b = 0; b < boxes->size(); ++b) {
int x1, y1, x2, y2;
const TBOX& box = (*boxes)[b];
projectiveXformSampledPt(box_coeffs, box.left(), height - box.top(), &x1,
&y1);
projectiveXformSampledPt(box_coeffs, box.right(), height - box.bottom(),
&x2, &y2);
TBOX new_box1(x1, height - y2, x2, height - y1);
projectiveXformSampledPt(box_coeffs, box.left(), height - box.bottom(),
&x1, &y1);
projectiveXformSampledPt(box_coeffs, box.right(), height - box.top(), &x2,
&y2);
TBOX new_box2(x1, height - y1, x2, height - y2);
(*boxes)[b] = new_box1.bounding_union(new_box2);
}
}
free(im_coeffs);
free(box_coeffs);
}
// Computes the coefficients of a randomized projective transformation.
// The image transform requires backward transformation coefficient, and the
// box transform the forward coefficients.
// Returns the incolor arg to pixProjective.
int ProjectiveCoeffs(int width, int height, TRand* randomizer,
float** im_coeffs, float** box_coeffs) {
// Setup "from" points.
Pta* src_pts = ptaCreate(4);
ptaAddPt(src_pts, 0.0f, 0.0f);
ptaAddPt(src_pts, width, 0.0f);
ptaAddPt(src_pts, width, height);
ptaAddPt(src_pts, 0.0f, height);
// Extract factors from pseudo-random sequence.
float factors[FN_NUM_FACTORS];
float shear = 0.0f; // Shear is signed.
for (int i = 0; i < FN_NUM_FACTORS; ++i) {
// Everything is squared to make wild values rarer.
if (i == FN_SHEAR) {
// Shear is signed.
shear = randomizer->SignedRand(0.5 / 3.0);
shear = shear >= 0.0 ? shear * shear : -shear * shear;
// Keep the sheared points within the original rectangle.
if (shear < -factors[FN_X0]) shear = -factors[FN_X0];
if (shear > factors[FN_X1]) shear = factors[FN_X1];
factors[i] = shear;
} else if (i != FN_INCOLOR) {
factors[i] = fabs(randomizer->SignedRand(1.0));
if (i <= FN_Y3)
factors[i] *= 5.0 / 8.0;
else
factors[i] *= 0.5;
factors[i] *= factors[i];
}
}
// Setup "to" points.
Pta* dest_pts = ptaCreate(4);
ptaAddPt(dest_pts, factors[FN_X0] * width, factors[FN_Y0] * height);
ptaAddPt(dest_pts, (1.0f - factors[FN_X1]) * width, factors[FN_Y1] * height);
ptaAddPt(dest_pts, (1.0f - factors[FN_X1] + shear) * width,
(1 - factors[FN_Y2]) * height);
ptaAddPt(dest_pts, (factors[FN_X0] + shear) * width,
(1 - factors[FN_Y3]) * height);
getProjectiveXformCoeffs(dest_pts, src_pts, im_coeffs);
getProjectiveXformCoeffs(src_pts, dest_pts, box_coeffs);
ptaDestroy(&src_pts);
ptaDestroy(&dest_pts);
return factors[FN_INCOLOR] > 0.5f ? L_BRING_IN_WHITE : L_BRING_IN_BLACK;
}
} // namespace tesseract

View File

@ -20,12 +20,13 @@
#ifndef TESSERACT_TRAINING_DEGRADEIMAGE_H_
#define TESSERACT_TRAINING_DEGRADEIMAGE_H_
struct Pix;
#include "allheaders.h"
#include "genericvector.h"
#include "helpers.h" // For TRand.
#include "rect.h"
namespace tesseract {
class TRand;
// Degrade the pix as if by a print/copy/scan cycle with exposure > 0
// corresponding to darkening on the copier and <0 lighter and 0 not copied.
// If rotation is not NULL, the clockwise rotation in radians is saved there.
@ -34,6 +35,27 @@ class TRand;
struct Pix* DegradeImage(struct Pix* input, int exposure, TRand* randomizer,
float* rotation);
// Creates and returns a Pix distorted by various means according to the bool
// flags. If boxes is not NULL, the boxes are resized/positioned according to
// any spatial distortion and also by the integer reduction factor box_scale
// so they will match what the network will output.
// Returns NULL on error. The returned Pix must be pixDestroyed.
Pix* PrepareDistortedPix(const Pix* pix, bool perspective, bool invert,
bool white_noise, bool smooth_noise, bool blur,
int box_reduction, TRand* randomizer,
GenericVector<TBOX>* boxes);
// Distorts anything that has a non-null pointer with the same pseudo-random
// perspective distortion. Width and height only need to be set if there
// is no pix. If there is a pix, then they will be taken from there.
void GeneratePerspectiveDistortion(int width, int height, TRand* randomizer,
Pix** pix, GenericVector<TBOX>* boxes);
// Computes the coefficients of a randomized projective transformation.
// The image transform requires backward transformation coefficient, and the
// box transform the forward coefficients.
// Returns the incolor arg to pixProjective.
int ProjectiveCoeffs(int width, int height, TRand* randomizer,
float** im_coeffs, float** box_coeffs);
} // namespace tesseract
#endif // TESSERACT_TRAINING_DEGRADEIMAGE_H_

View File

@ -868,6 +868,9 @@ set_lang_specific_parameters() {
AMBIGS_FILTER_DENOMINATOR="100000"
LEADING="32"
MEAN_COUNT="40" # Default for latin script.
# Language to mix with the language for maximum accuracy. Defaults to eng.
# If no language is good, set to the base language.
MIX_LANG="eng"
case ${lang} in
# Latin languages.
@ -959,11 +962,13 @@ set_lang_specific_parameters() {
WORD_DAWG_SIZE=1000000
test -z "$FONTS" && FONTS=( "${EARLY_LATIN_FONTS[@]}" );;
# Cyrillic script-based languages.
# Cyrillic script-based languages. It is bad to mix Latin with Cyrillic.
rus ) test -z "$FONTS" && FONTS=( "${RUSSIAN_FONTS[@]}" )
MIX_LANG="rus"
NUMBER_DAWG_FACTOR=0.05
WORD_DAWG_SIZE=1000000 ;;
aze_cyrl | bel | bul | kaz | mkd | srp | tgk | ukr | uzb_cyrl )
MIX_LANG="${lang}"
test -z "$FONTS" && FONTS=( "${RUSSIAN_FONTS[@]}" ) ;;
# Special code for performing Cyrillic language-id that is trained on

185
training/lstmtraining.cpp Normal file
View File

@ -0,0 +1,185 @@
///////////////////////////////////////////////////////////////////////
// File: lstmtraining.cpp
// Description: Training program for LSTM-based networks.
// Author: Ray Smith
// Created: Fri May 03 11:05:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef USE_STD_NAMESPACE
#include "base/commandlineflags.h"
#endif
#include "commontraining.h"
#include "lstmtrainer.h"
#include "params.h"
#include "strngs.h"
#include "tprintf.h"
#include "unicharset_training_utils.h"
INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment.");
STRING_PARAM_FLAG(net_spec, "[I1,48Lt1,100O]", "Network specification");
INT_PARAM_FLAG(train_mode, 64, "Controls gross training behavior.");
INT_PARAM_FLAG(net_mode, 192, "Controls network behavior.");
INT_PARAM_FLAG(perfect_sample_delay, 4,
"How many imperfect samples between perfect ones.");
DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent.");
DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights.");
DOUBLE_PARAM_FLAG(learning_rate, 1.0e-4, "Weight factor for new deltas.");
DOUBLE_PARAM_FLAG(momentum, 0.9, "Decay factor for repeating deltas.");
INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images.");
STRING_PARAM_FLAG(continue_from, "", "Existing model to extend");
STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models");
STRING_PARAM_FLAG(script_dir, "",
"Required to set unicharset properties or"
" use unicharset compression.");
BOOL_PARAM_FLAG(stop_training, false,
"Just convert the training model to a runtime model.");
INT_PARAM_FLAG(append_index, -1, "Index in continue_from Network at which to"
" attach the new network defined by net_spec");
BOOL_PARAM_FLAG(debug_network, false,
"Get info on distribution of weight values");
INT_PARAM_FLAG(max_iterations, 0, "If set, exit after this many iterations");
DECLARE_STRING_PARAM_FLAG(U);
// Number of training images to train between calls to MaintainCheckpoints.
const int kNumPagesPerBatch = 100;
// Apart from command-line flags, input is a collection of lstmf files, that
// were previously created using tesseract with the lstm.train config file.
// The program iterates over the inputs, feeding the data to the network,
// until the error rate reaches a specified target or max_iterations is reached.
int main(int argc, char **argv) {
ParseArguments(&argc, &argv);
// Purify the model name in case it is based on the network string.
if (FLAGS_model_output.empty()) {
tprintf("Must provide a --model_output!\n");
return 1;
}
STRING model_output = FLAGS_model_output.c_str();
for (int i = 0; i < model_output.length(); ++i) {
if (model_output[i] == '[' || model_output[i] == ']')
model_output[i] = '-';
if (model_output[i] == '(' || model_output[i] == ')')
model_output[i] = '_';
}
// Setup the trainer.
STRING checkpoint_file = FLAGS_model_output.c_str();
checkpoint_file += "_checkpoint";
STRING checkpoint_bak = checkpoint_file + ".bak";
tesseract::LSTMTrainer trainer(
NULL, NULL, NULL, NULL, FLAGS_model_output.c_str(),
checkpoint_file.c_str(), FLAGS_debug_interval,
static_cast<inT64>(FLAGS_max_image_MB) * 1048576);
// Reading something from an existing model doesn't require many flags,
// so do it now and exit.
if (FLAGS_stop_training || FLAGS_debug_network) {
if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str())) {
tprintf("Failed to read continue from: %s\n",
FLAGS_continue_from.c_str());
return 1;
}
if (FLAGS_debug_network) {
trainer.DebugNetwork();
} else {
if (FLAGS_train_mode & tesseract::TF_INT_MODE)
trainer.ConvertToInt();
GenericVector<char> recognizer_data;
trainer.SaveRecognitionDump(&recognizer_data);
if (!tesseract::SaveDataToFile(recognizer_data,
FLAGS_model_output.c_str())) {
tprintf("Failed to write recognition model : %s\n",
FLAGS_model_output.c_str());
}
}
return 0;
}
// Get the list of files to process.
GenericVector<STRING> filenames;
for (int arg = 1; arg < argc; ++arg) {
filenames.push_back(STRING(argv[arg]));
}
UNICHARSET unicharset;
// Checkpoints always take priority if they are available.
if (trainer.TryLoadingCheckpoint(checkpoint_file.string()) ||
trainer.TryLoadingCheckpoint(checkpoint_bak.string())) {
tprintf("Successfully restored trainer from %s\n",
checkpoint_file.string());
} else {
if (!FLAGS_continue_from.empty()) {
// Load a past model file to improve upon.
if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str())) {
tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str());
return 1;
}
tprintf("Continuing from %s\n", FLAGS_continue_from.c_str());
}
if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
// We need a unicharset to start from scratch or append.
string unicharset_str;
// Character coding to be used by the classifier.
if (!unicharset.load_from_file(FLAGS_U.c_str())) {
tprintf("Error: must provide a -U unicharset!\n");
return 1;
}
tesseract::SetupBasicProperties(true, &unicharset);
if (FLAGS_append_index >= 0) {
tprintf("Appending a new network to an old one!!");
if (FLAGS_continue_from.empty()) {
tprintf("Must set --continue_from for appending!\n");
return 1;
}
}
// We are initializing from scratch.
trainer.InitCharSet(unicharset, FLAGS_script_dir.c_str(),
FLAGS_train_mode);
if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index,
FLAGS_net_mode, FLAGS_weight_range,
FLAGS_learning_rate, FLAGS_momentum)) {
tprintf("Failed to create network from spec: %s\n",
FLAGS_net_spec.c_str());
return 1;
}
trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
}
}
if (!trainer.LoadAllTrainingData(filenames)) {
tprintf("Load of images failed!!\n");
return 1;
}
bool best_dumped = true;
char* best_model_dump = NULL;
size_t best_model_size = 0;
STRING best_model_name;
do {
// Train a few.
int iteration = trainer.training_iteration();
for (int target_iteration = iteration + kNumPagesPerBatch;
iteration < target_iteration;
iteration = trainer.training_iteration()) {
trainer.TrainOnLine(&trainer, false);
}
STRING log_str;
trainer.MaintainCheckpoints(NULL, &log_str);
tprintf("%s\n", log_str.string());
} while (trainer.best_error_rate() > FLAGS_target_error_rate &&
(trainer.training_iteration() < FLAGS_max_iterations ||
FLAGS_max_iterations == 0));
tprintf("Finished! Error rate = %g\n", trainer.best_error_rate());
return 0;
} /* main */

View File

@ -0,0 +1,52 @@
///////////////////////////////////////////////////////////////////////
// File: merge_unicharsets.cpp
// Description: Simple tool to merge two or more unicharsets.
// Author: Ray Smith
// Created: Wed Sep 30 16:09:01 PDT 2015
//
// (C) Copyright 2015, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
///////////////////////////////////////////////////////////////////////
#include <stdio.h>
#include "unicharset.h"
int main(int argc, char** argv) {
// Print usage
if (argc < 4) {
printf("Usage: %s unicharset-in-1 ... unicharset-in-n unicharset-out\n",
argv[0]);
exit(1);
}
UNICHARSET input_unicharset, result_unicharset;
for (int arg = 1; arg < argc - 1; ++arg) {
// Load the input unicharset
if (input_unicharset.load_from_file(argv[arg])) {
printf("Loaded unicharset of size %d from file %s\n",
input_unicharset.size(), argv[arg]);
result_unicharset.AppendOtherUnicharset(input_unicharset);
} else {
printf("Failed to load unicharset from file %s!!\n", argv[arg]);
exit(1);
}
}
// Save the combined unicharset.
if (result_unicharset.save_to_file(argv[argc - 1])) {
printf("Wrote unicharset file %s.\n", argv[argc - 1]);
} else {
printf("Cannot save unicharset file %s.\n", argv[argc - 1]);
exit(1);
}
return 0;
}

View File

@ -302,6 +302,9 @@ int main (int argc, char **argv) {
*shape_table, float_classes,
inttemp_file.string(),
pffmtable_file.string());
for (int c = 0; c < unicharset->size(); ++c) {
FreeClassFields(&float_classes[c]);
}
delete [] float_classes;
FreeLabeledClassList(mf_classes);
delete trainer;

Some files were not shown because too many files have changed in this diff Show More