mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2024-11-27 12:49:35 +08:00
Added new LSTM-based neural network line recognizer
This commit is contained in:
parent
5d21ecfad3
commit
c1c1e426b3
@ -16,7 +16,7 @@ endif
|
||||
|
||||
.PHONY: install-langs ScrollView.jar install-jars training
|
||||
|
||||
SUBDIRS = ccutil viewer cutil opencl ccstruct dict classify wordrec textord
|
||||
SUBDIRS = arch ccutil viewer cutil opencl ccstruct dict classify wordrec textord lstm
|
||||
if !NO_CUBE_BUILD
|
||||
SUBDIRS += neural_networks/runtime cube
|
||||
endif
|
||||
|
@ -1,5 +1,6 @@
|
||||
AM_CPPFLAGS += -DLOCALEDIR=\"$(localedir)\"\
|
||||
-DUSE_STD_NAMESPACE \
|
||||
-I$(top_srcdir)/arch -I$(top_srcdir)/lstm \
|
||||
-I$(top_srcdir)/ccutil -I$(top_srcdir)/ccstruct -I$(top_srcdir)/cube \
|
||||
-I$(top_srcdir)/viewer \
|
||||
-I$(top_srcdir)/textord -I$(top_srcdir)/dict \
|
||||
@ -27,6 +28,9 @@ libtesseract_api_la_LIBADD = \
|
||||
../wordrec/libtesseract_wordrec.la \
|
||||
../classify/libtesseract_classify.la \
|
||||
../dict/libtesseract_dict.la \
|
||||
../arch/libtesseract_avx.la \
|
||||
../arch/libtesseract_sse.la \
|
||||
../lstm/libtesseract_lstm.la \
|
||||
../ccstruct/libtesseract_ccstruct.la \
|
||||
../cutil/libtesseract_cutil.la \
|
||||
../viewer/libtesseract_viewer.la \
|
||||
@ -57,6 +61,9 @@ libtesseract_la_LIBADD = \
|
||||
../wordrec/libtesseract_wordrec.la \
|
||||
../classify/libtesseract_classify.la \
|
||||
../dict/libtesseract_dict.la \
|
||||
../arch/libtesseract_avx.la \
|
||||
../arch/libtesseract_sse.la \
|
||||
../lstm/libtesseract_lstm.la \
|
||||
../ccstruct/libtesseract_ccstruct.la \
|
||||
../cutil/libtesseract_cutil.la \
|
||||
../viewer/libtesseract_viewer.la \
|
||||
|
@ -121,7 +121,6 @@ TessBaseAPI::TessBaseAPI()
|
||||
block_list_(NULL),
|
||||
page_res_(NULL),
|
||||
input_file_(NULL),
|
||||
input_image_(NULL),
|
||||
output_file_(NULL),
|
||||
datapath_(NULL),
|
||||
language_(NULL),
|
||||
@ -515,9 +514,7 @@ void TessBaseAPI::ClearAdaptiveClassifier() {
|
||||
|
||||
/**
|
||||
* Provide an image for Tesseract to recognize. Format is as
|
||||
* TesseractRect above. Does not copy the image buffer, or take
|
||||
* ownership. The source image may be destroyed after Recognize is called,
|
||||
* either explicitly or implicitly via one of the Get*Text functions.
|
||||
* TesseractRect above. Copies the image buffer and converts to Pix.
|
||||
* SetImage clears all recognition results, and sets the rectangle to the
|
||||
* full image, so it may be followed immediately by a GetUTF8Text, and it
|
||||
* will automatically perform recognition.
|
||||
@ -525,9 +522,11 @@ void TessBaseAPI::ClearAdaptiveClassifier() {
|
||||
void TessBaseAPI::SetImage(const unsigned char* imagedata,
|
||||
int width, int height,
|
||||
int bytes_per_pixel, int bytes_per_line) {
|
||||
if (InternalSetImage())
|
||||
if (InternalSetImage()) {
|
||||
thresholder_->SetImage(imagedata, width, height,
|
||||
bytes_per_pixel, bytes_per_line);
|
||||
SetInputImage(thresholder_->GetPixRect());
|
||||
}
|
||||
}
|
||||
|
||||
void TessBaseAPI::SetSourceResolution(int ppi) {
|
||||
@ -539,18 +538,17 @@ void TessBaseAPI::SetSourceResolution(int ppi) {
|
||||
|
||||
/**
|
||||
* Provide an image for Tesseract to recognize. As with SetImage above,
|
||||
* Tesseract doesn't take a copy or ownership or pixDestroy the image, so
|
||||
* it must persist until after Recognize.
|
||||
* Tesseract takes its own copy of the image, so it need not persist until
|
||||
* after Recognize.
|
||||
* Pix vs raw, which to use?
|
||||
* Use Pix where possible. A future version of Tesseract may choose to use Pix
|
||||
* as its internal representation and discard IMAGE altogether.
|
||||
* Because of that, an implementation that sources and targets Pix may end up
|
||||
* with less copies than an implementation that does not.
|
||||
* Use Pix where possible. Tesseract uses Pix as its internal representation
|
||||
* and it is therefore more efficient to provide a Pix directly.
|
||||
*/
|
||||
void TessBaseAPI::SetImage(Pix* pix) {
|
||||
if (InternalSetImage())
|
||||
if (InternalSetImage()) {
|
||||
thresholder_->SetImage(pix);
|
||||
SetInputImage(pix);
|
||||
SetInputImage(thresholder_->GetPixRect());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -693,8 +691,8 @@ Boxa* TessBaseAPI::GetComponentImages(PageIteratorLevel level,
|
||||
if (pixa != NULL) {
|
||||
Pix* pix = NULL;
|
||||
if (raw_image) {
|
||||
pix = page_it->GetImage(level, raw_padding, input_image_,
|
||||
&left, &top);
|
||||
pix = page_it->GetImage(level, raw_padding, GetInputImage(), &left,
|
||||
&top);
|
||||
} else {
|
||||
pix = page_it->GetBinaryImage(level);
|
||||
}
|
||||
@ -849,13 +847,17 @@ int TessBaseAPI::Recognize(ETEXT_DESC* monitor) {
|
||||
} else if (tesseract_->tessedit_resegment_from_boxes) {
|
||||
page_res_ = tesseract_->ApplyBoxes(*input_file_, false, block_list_);
|
||||
} else {
|
||||
// TODO(rays) LSTM here.
|
||||
page_res_ = new PAGE_RES(false,
|
||||
page_res_ = new PAGE_RES(tesseract_->AnyLSTMLang(),
|
||||
block_list_, &tesseract_->prev_word_best_choice_);
|
||||
}
|
||||
if (page_res_ == NULL) {
|
||||
return -1;
|
||||
}
|
||||
if (tesseract_->tessedit_train_line_recognizer) {
|
||||
tesseract_->TrainLineRecognizer(*input_file_, *output_file_, block_list_);
|
||||
tesseract_->CorrectClassifyWords(page_res_);
|
||||
return 0;
|
||||
}
|
||||
if (tesseract_->tessedit_make_boxes_from_boxes) {
|
||||
tesseract_->CorrectClassifyWords(page_res_);
|
||||
return 0;
|
||||
@ -938,17 +940,10 @@ int TessBaseAPI::RecognizeForChopTest(ETEXT_DESC* monitor) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
void TessBaseAPI::SetInputImage(Pix *pix) {
|
||||
if (input_image_)
|
||||
pixDestroy(&input_image_);
|
||||
input_image_ = NULL;
|
||||
if (pix)
|
||||
input_image_ = pixCopy(NULL, pix);
|
||||
}
|
||||
// Takes ownership of the input pix.
|
||||
void TessBaseAPI::SetInputImage(Pix* pix) { tesseract_->set_pix_original(pix); }
|
||||
|
||||
Pix* TessBaseAPI::GetInputImage() {
|
||||
return input_image_;
|
||||
}
|
||||
Pix* TessBaseAPI::GetInputImage() { return tesseract_->pix_original(); }
|
||||
|
||||
const char * TessBaseAPI::GetInputName() {
|
||||
if (input_file_)
|
||||
@ -992,8 +987,7 @@ bool TessBaseAPI::ProcessPagesFileList(FILE *flist,
|
||||
}
|
||||
|
||||
// Begin producing output
|
||||
const char* kUnknownTitle = "";
|
||||
if (renderer && !renderer->BeginDocument(kUnknownTitle)) {
|
||||
if (renderer && !renderer->BeginDocument(unknown_title_)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -1105,7 +1099,6 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename,
|
||||
const char* retry_config,
|
||||
int timeout_millisec,
|
||||
TessResultRenderer* renderer) {
|
||||
#ifndef ANDROID_BUILD
|
||||
PERF_COUNT_START("ProcessPages")
|
||||
bool stdInput = !strcmp(filename, "stdin") || !strcmp(filename, "-");
|
||||
if (stdInput) {
|
||||
@ -1162,8 +1155,7 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename,
|
||||
}
|
||||
|
||||
// Begin the output
|
||||
const char* kUnknownTitle = "";
|
||||
if (renderer && !renderer->BeginDocument(kUnknownTitle)) {
|
||||
if (renderer && !renderer->BeginDocument(unknown_title_)) {
|
||||
pixDestroy(&pix);
|
||||
return false;
|
||||
}
|
||||
@ -1185,9 +1177,6 @@ bool TessBaseAPI::ProcessPagesInternal(const char* filename,
|
||||
}
|
||||
PERF_COUNT_END
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool TessBaseAPI::ProcessPage(Pix* pix, int page_index, const char* filename,
|
||||
@ -2107,10 +2096,6 @@ void TessBaseAPI::End() {
|
||||
delete input_file_;
|
||||
input_file_ = NULL;
|
||||
}
|
||||
if (input_image_ != NULL) {
|
||||
pixDestroy(&input_image_);
|
||||
input_image_ = NULL;
|
||||
}
|
||||
if (output_file_ != NULL) {
|
||||
delete output_file_;
|
||||
output_file_ = NULL;
|
||||
|
@ -20,8 +20,8 @@
|
||||
#ifndef TESSERACT_API_BASEAPI_H__
|
||||
#define TESSERACT_API_BASEAPI_H__
|
||||
|
||||
#define TESSERACT_VERSION_STR "3.05.00dev"
|
||||
#define TESSERACT_VERSION 0x030500
|
||||
#define TESSERACT_VERSION_STR "4.00.00alpha"
|
||||
#define TESSERACT_VERSION 0x040000
|
||||
#define MAKE_VERSION(major, minor, patch) (((major) << 16) | ((minor) << 8) | \
|
||||
(patch))
|
||||
|
||||
@ -142,6 +142,7 @@ class TESS_API TessBaseAPI {
|
||||
* is stored in the PDF so we need that as well.
|
||||
*/
|
||||
const char* GetInputName();
|
||||
// Takes ownership of the input pix.
|
||||
void SetInputImage(Pix *pix);
|
||||
Pix* GetInputImage();
|
||||
int GetSourceYResolution();
|
||||
@ -333,9 +334,7 @@ class TESS_API TessBaseAPI {
|
||||
|
||||
/**
|
||||
* Provide an image for Tesseract to recognize. Format is as
|
||||
* TesseractRect above. Does not copy the image buffer, or take
|
||||
* ownership. The source image may be destroyed after Recognize is called,
|
||||
* either explicitly or implicitly via one of the Get*Text functions.
|
||||
* TesseractRect above. Copies the image buffer and converts to Pix.
|
||||
* SetImage clears all recognition results, and sets the rectangle to the
|
||||
* full image, so it may be followed immediately by a GetUTF8Text, and it
|
||||
* will automatically perform recognition.
|
||||
@ -345,13 +344,11 @@ class TESS_API TessBaseAPI {
|
||||
|
||||
/**
|
||||
* Provide an image for Tesseract to recognize. As with SetImage above,
|
||||
* Tesseract doesn't take a copy or ownership or pixDestroy the image, so
|
||||
* it must persist until after Recognize.
|
||||
* Tesseract takes its own copy of the image, so it need not persist until
|
||||
* after Recognize.
|
||||
* Pix vs raw, which to use?
|
||||
* Use Pix where possible. A future version of Tesseract may choose to use Pix
|
||||
* as its internal representation and discard IMAGE altogether.
|
||||
* Because of that, an implementation that sources and targets Pix may end up
|
||||
* with less copies than an implementation that does not.
|
||||
* Use Pix where possible. Tesseract uses Pix as its internal representation
|
||||
* and it is therefore more efficient to provide a Pix directly.
|
||||
*/
|
||||
void SetImage(Pix* pix);
|
||||
|
||||
@ -866,7 +863,6 @@ class TESS_API TessBaseAPI {
|
||||
BLOCK_LIST* block_list_; ///< The page layout.
|
||||
PAGE_RES* page_res_; ///< The page-level data.
|
||||
STRING* input_file_; ///< Name used by training code.
|
||||
Pix* input_image_; ///< Image used for searchable PDF
|
||||
STRING* output_file_; ///< Name used by debug code.
|
||||
STRING* datapath_; ///< Current location of tessdata.
|
||||
STRING* language_; ///< Last initialized language.
|
||||
@ -902,6 +898,12 @@ class TESS_API TessBaseAPI {
|
||||
int timeout_millisec,
|
||||
TessResultRenderer* renderer,
|
||||
int tessedit_page_number);
|
||||
// There's currently no way to pass a document title from the
|
||||
// Tesseract command line, and we have multiple places that choose
|
||||
// to set the title to an empty string. Using a single named
|
||||
// variable will hopefully reduce confusion if the situation changes
|
||||
// in the future.
|
||||
const char *unknown_title_ = "";
|
||||
}; // class TessBaseAPI.
|
||||
|
||||
/** Escape a char string - remove &<>"' with HTML codes. */
|
||||
|
@ -620,7 +620,6 @@ bool TessPDFRenderer::BeginDocumentHandler() {
|
||||
AppendPDFObject(buf);
|
||||
|
||||
// FONT DESCRIPTOR
|
||||
const int kCharHeight = 2; // Effect: highlights are half height
|
||||
n = snprintf(buf, sizeof(buf),
|
||||
"7 0 obj\n"
|
||||
"<<\n"
|
||||
@ -636,10 +635,10 @@ bool TessPDFRenderer::BeginDocumentHandler() {
|
||||
" /Type /FontDescriptor\n"
|
||||
">>\n"
|
||||
"endobj\n",
|
||||
1000 / kCharHeight,
|
||||
1000 / kCharHeight,
|
||||
1000,
|
||||
1000,
|
||||
1000 / kCharWidth,
|
||||
1000 / kCharHeight,
|
||||
1000,
|
||||
8L // Font data
|
||||
);
|
||||
if (n >= sizeof(buf)) return false;
|
||||
|
@ -77,7 +77,7 @@ class TESS_API TessResultRenderer {
|
||||
bool EndDocument();
|
||||
|
||||
const char* file_extension() const { return file_extension_; }
|
||||
const char* title() const { return title_; }
|
||||
const char* title() const { return title_.c_str(); }
|
||||
|
||||
/**
|
||||
* Returns the index of the last image given to AddImage
|
||||
@ -126,7 +126,7 @@ class TESS_API TessResultRenderer {
|
||||
|
||||
private:
|
||||
const char* file_extension_; // standard extension for generated output
|
||||
const char* title_; // title of document being renderered
|
||||
STRING title_; // title of document being renderered
|
||||
int imagenum_; // index of last image added
|
||||
|
||||
FILE* fout_; // output file pointer
|
||||
|
29
arch/Makefile.am
Normal file
29
arch/Makefile.am
Normal file
@ -0,0 +1,29 @@
|
||||
AM_CPPFLAGS += -I$(top_srcdir)/ccutil
|
||||
AUTOMAKE_OPTIONS = subdir-objects
|
||||
SUBDIRS =
|
||||
AM_CXXFLAGS =
|
||||
|
||||
if VISIBILITY
|
||||
AM_CXXFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden
|
||||
AM_CPPFLAGS += -DTESS_EXPORTS
|
||||
endif
|
||||
|
||||
include_HEADERS = \
|
||||
dotproductavx.h dotproductsse.h
|
||||
|
||||
noinst_HEADERS =
|
||||
|
||||
if !USING_MULTIPLELIBS
|
||||
noinst_LTLIBRARIES = libtesseract_avx.la libtesseract_sse.la
|
||||
else
|
||||
lib_LTLIBRARIES = libtesseract_avx.la libtesseract_sse.la
|
||||
libtesseract_avx_la_LDFLAGS = -version-info $(GENERIC_LIBRARY_VERSION)
|
||||
libtesseract_sse_la_LDFLAGS = -version-info $(GENERIC_LIBRARY_VERSION)
|
||||
endif
|
||||
libtesseract_avx_la_CXXFLAGS = -mavx
|
||||
libtesseract_sse_la_CXXFLAGS = -msse4.1
|
||||
|
||||
libtesseract_avx_la_SOURCES = dotproductavx.cpp
|
||||
|
||||
libtesseract_sse_la_SOURCES = dotproductsse.cpp
|
||||
|
103
arch/dotproductavx.cpp
Normal file
103
arch/dotproductavx.cpp
Normal file
@ -0,0 +1,103 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: dotproductavx.cpp
|
||||
// Description: Architecture-specific dot-product function.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Jul 22 10:48:05 PDT 2015
|
||||
//
|
||||
// (C) Copyright 2015, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if !defined(__AVX__)
|
||||
// Implementation for non-avx archs.
|
||||
|
||||
#include "dotproductavx.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
namespace tesseract {
|
||||
double DotProductAVX(const double* u, const double* v, int n) {
|
||||
fprintf(stderr, "DotProductAVX can't be used on Android\n");
|
||||
abort();
|
||||
}
|
||||
} // namespace tesseract
|
||||
|
||||
#else // !defined(__AVX__)
|
||||
// Implementation for avx capable archs.
|
||||
#include <immintrin.h>
|
||||
#include <stdint.h>
|
||||
#include "dotproductavx.h"
|
||||
#include "host.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Computes and returns the dot product of the n-vectors u and v.
|
||||
// Uses Intel AVX intrinsics to access the SIMD instruction set.
|
||||
double DotProductAVX(const double* u, const double* v, int n) {
|
||||
int max_offset = n - 4;
|
||||
int offset = 0;
|
||||
// Accumulate a set of 4 sums in sum, by loading pairs of 4 values from u and
|
||||
// v, and multiplying them together in parallel.
|
||||
__m256d sum = _mm256_setzero_pd();
|
||||
if (offset <= max_offset) {
|
||||
offset = 4;
|
||||
// Aligned load is reputedly faster but requires 32 byte aligned input.
|
||||
if ((reinterpret_cast<const uintptr_t>(u) & 31) == 0 &&
|
||||
(reinterpret_cast<const uintptr_t>(v) & 31) == 0) {
|
||||
// Use aligned load.
|
||||
__m256d floats1 = _mm256_load_pd(u);
|
||||
__m256d floats2 = _mm256_load_pd(v);
|
||||
// Multiply.
|
||||
sum = _mm256_mul_pd(floats1, floats2);
|
||||
while (offset <= max_offset) {
|
||||
floats1 = _mm256_load_pd(u + offset);
|
||||
floats2 = _mm256_load_pd(v + offset);
|
||||
offset += 4;
|
||||
__m256d product = _mm256_mul_pd(floats1, floats2);
|
||||
sum = _mm256_add_pd(sum, product);
|
||||
}
|
||||
} else {
|
||||
// Use unaligned load.
|
||||
__m256d floats1 = _mm256_loadu_pd(u);
|
||||
__m256d floats2 = _mm256_loadu_pd(v);
|
||||
// Multiply.
|
||||
sum = _mm256_mul_pd(floats1, floats2);
|
||||
while (offset <= max_offset) {
|
||||
floats1 = _mm256_loadu_pd(u + offset);
|
||||
floats2 = _mm256_loadu_pd(v + offset);
|
||||
offset += 4;
|
||||
__m256d product = _mm256_mul_pd(floats1, floats2);
|
||||
sum = _mm256_add_pd(sum, product);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add the 4 product sums together horizontally. Not so easy as with sse, as
|
||||
// there is no add across the upper/lower 128 bit boundary, so permute to
|
||||
// move the upper 128 bits to lower in another register.
|
||||
__m256d sum2 = _mm256_permute2f128_pd(sum, sum, 1);
|
||||
sum = _mm256_hadd_pd(sum, sum2);
|
||||
sum = _mm256_hadd_pd(sum, sum);
|
||||
double result;
|
||||
// _mm256_extract_f64 doesn't exist, but resist the temptation to use an sse
|
||||
// instruction, as that introduces a 70 cycle delay. All this casting is to
|
||||
// fool the instrinsics into thinking we are extracting the bottom int64.
|
||||
*(reinterpret_cast<inT64*>(&result)) =
|
||||
_mm256_extract_epi64(_mm256_castpd_si256(sum), 0);
|
||||
while (offset < n) {
|
||||
result += u[offset] * v[offset];
|
||||
++offset;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // ANDROID_BUILD
|
30
arch/dotproductavx.h
Normal file
30
arch/dotproductavx.h
Normal file
@ -0,0 +1,30 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: dotproductavx.h
|
||||
// Description: Architecture-specific dot-product function.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Jul 22 10:51:05 PDT 2015
|
||||
//
|
||||
// (C) Copyright 2015, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_ARCH_DOTPRODUCTAVX_H_
|
||||
#define TESSERACT_ARCH_DOTPRODUCTAVX_H_
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Computes and returns the dot product of the n-vectors u and v.
|
||||
// Uses Intel AVX intrinsics to access the SIMD instruction set.
|
||||
double DotProductAVX(const double* u, const double* v, int n);
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_ARCH_DOTPRODUCTAVX_H_
|
141
arch/dotproductsse.cpp
Normal file
141
arch/dotproductsse.cpp
Normal file
@ -0,0 +1,141 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: dotproductsse.cpp
|
||||
// Description: Architecture-specific dot-product function.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Jul 22 10:57:45 PDT 2015
|
||||
//
|
||||
// (C) Copyright 2015, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if !defined(__SSE4_1__)
|
||||
// This code can't compile with "-msse4.1", so use dummy stubs.
|
||||
|
||||
#include "dotproductsse.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
namespace tesseract {
|
||||
double DotProductSSE(const double* u, const double* v, int n) {
|
||||
fprintf(stderr, "DotProductSSE can't be used on Android\n");
|
||||
abort();
|
||||
}
|
||||
inT32 IntDotProductSSE(const inT8* u, const inT8* v, int n) {
|
||||
fprintf(stderr, "IntDotProductSSE can't be used on Android\n");
|
||||
abort();
|
||||
}
|
||||
} // namespace tesseract
|
||||
|
||||
#else // !defined(__SSE4_1__)
|
||||
// Non-Android code here
|
||||
|
||||
#include <emmintrin.h>
|
||||
#include <smmintrin.h>
|
||||
#include <stdint.h>
|
||||
#include "dotproductsse.h"
|
||||
#include "host.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Computes and returns the dot product of the n-vectors u and v.
|
||||
// Uses Intel SSE intrinsics to access the SIMD instruction set.
|
||||
double DotProductSSE(const double* u, const double* v, int n) {
|
||||
int max_offset = n - 2;
|
||||
int offset = 0;
|
||||
// Accumulate a set of 2 sums in sum, by loading pairs of 2 values from u and
|
||||
// v, and multiplying them together in parallel.
|
||||
__m128d sum = _mm_setzero_pd();
|
||||
if (offset <= max_offset) {
|
||||
offset = 2;
|
||||
// Aligned load is reputedly faster but requires 16 byte aligned input.
|
||||
if ((reinterpret_cast<const uintptr_t>(u) & 15) == 0 &&
|
||||
(reinterpret_cast<const uintptr_t>(v) & 15) == 0) {
|
||||
// Use aligned load.
|
||||
sum = _mm_load_pd(u);
|
||||
__m128d floats2 = _mm_load_pd(v);
|
||||
// Multiply.
|
||||
sum = _mm_mul_pd(sum, floats2);
|
||||
while (offset <= max_offset) {
|
||||
__m128d floats1 = _mm_load_pd(u + offset);
|
||||
floats2 = _mm_load_pd(v + offset);
|
||||
offset += 2;
|
||||
floats1 = _mm_mul_pd(floats1, floats2);
|
||||
sum = _mm_add_pd(sum, floats1);
|
||||
}
|
||||
} else {
|
||||
// Use unaligned load.
|
||||
sum = _mm_loadu_pd(u);
|
||||
__m128d floats2 = _mm_loadu_pd(v);
|
||||
// Multiply.
|
||||
sum = _mm_mul_pd(sum, floats2);
|
||||
while (offset <= max_offset) {
|
||||
__m128d floats1 = _mm_loadu_pd(u + offset);
|
||||
floats2 = _mm_loadu_pd(v + offset);
|
||||
offset += 2;
|
||||
floats1 = _mm_mul_pd(floats1, floats2);
|
||||
sum = _mm_add_pd(sum, floats1);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add the 2 sums in sum horizontally.
|
||||
sum = _mm_hadd_pd(sum, sum);
|
||||
// Extract the low result.
|
||||
double result = _mm_cvtsd_f64(sum);
|
||||
// Add on any left-over products.
|
||||
while (offset < n) {
|
||||
result += u[offset] * v[offset];
|
||||
++offset;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Computes and returns the dot product of the n-vectors u and v.
|
||||
// Uses Intel SSE intrinsics to access the SIMD instruction set.
|
||||
inT32 IntDotProductSSE(const inT8* u, const inT8* v, int n) {
|
||||
int max_offset = n - 8;
|
||||
int offset = 0;
|
||||
// Accumulate a set of 4 32-bit sums in sum, by loading 8 pairs of 8-bit
|
||||
// values, extending to 16 bit, multiplying to make 32 bit results.
|
||||
__m128i sum = _mm_setzero_si128();
|
||||
if (offset <= max_offset) {
|
||||
offset = 8;
|
||||
__m128i packed1 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(u));
|
||||
__m128i packed2 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(v));
|
||||
sum = _mm_cvtepi8_epi16(packed1);
|
||||
packed2 = _mm_cvtepi8_epi16(packed2);
|
||||
// The magic _mm_add_epi16 is perfect here. It multiplies 8 pairs of 16 bit
|
||||
// ints to make 32 bit results, which are then horizontally added in pairs
|
||||
// to make 4 32 bit results that still fit in a 128 bit register.
|
||||
sum = _mm_madd_epi16(sum, packed2);
|
||||
while (offset <= max_offset) {
|
||||
packed1 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(u + offset));
|
||||
packed2 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(v + offset));
|
||||
offset += 8;
|
||||
packed1 = _mm_cvtepi8_epi16(packed1);
|
||||
packed2 = _mm_cvtepi8_epi16(packed2);
|
||||
packed1 = _mm_madd_epi16(packed1, packed2);
|
||||
sum = _mm_add_epi32(sum, packed1);
|
||||
}
|
||||
}
|
||||
// Sum the 4 packed 32 bit sums and extract the low result.
|
||||
sum = _mm_hadd_epi32(sum, sum);
|
||||
sum = _mm_hadd_epi32(sum, sum);
|
||||
inT32 result = _mm_cvtsi128_si32(sum);
|
||||
while (offset < n) {
|
||||
result += u[offset] * v[offset];
|
||||
++offset;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // ANDROID_BUILD
|
35
arch/dotproductsse.h
Normal file
35
arch/dotproductsse.h
Normal file
@ -0,0 +1,35 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: dotproductsse.h
|
||||
// Description: Architecture-specific dot-product function.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Jul 22 10:57:05 PDT 2015
|
||||
//
|
||||
// (C) Copyright 2015, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_ARCH_DOTPRODUCTSSE_H_
|
||||
#define TESSERACT_ARCH_DOTPRODUCTSSE_H_
|
||||
|
||||
#include "host.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Computes and returns the dot product of the n-vectors u and v.
|
||||
// Uses Intel SSE intrinsics to access the SIMD instruction set.
|
||||
double DotProductSSE(const double* u, const double* v, int n);
|
||||
// Computes and returns the dot product of the n-vectors u and v.
|
||||
// Uses Intel SSE intrinsics to access the SIMD instruction set.
|
||||
inT32 IntDotProductSSE(const inT8* u, const inT8* v, int n);
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_ARCH_DOTPRODUCTSSE_H_
|
@ -1,6 +1,7 @@
|
||||
AM_CPPFLAGS += \
|
||||
-DUSE_STD_NAMESPACE \
|
||||
-I$(top_srcdir)/ccutil -I$(top_srcdir)/ccstruct \
|
||||
-I$(top_srcdir)/arch -I$(top_srcdir)/lstm \
|
||||
-I$(top_srcdir)/viewer \
|
||||
-I$(top_srcdir)/classify -I$(top_srcdir)/dict \
|
||||
-I$(top_srcdir)/wordrec -I$(top_srcdir)/cutil \
|
||||
@ -33,6 +34,9 @@ libtesseract_main_la_LIBADD = \
|
||||
../ccstruct/libtesseract_ccstruct.la \
|
||||
../viewer/libtesseract_viewer.la \
|
||||
../dict/libtesseract_dict.la \
|
||||
../arch/libtesseract_avx.la \
|
||||
../arch/libtesseract_sse.la \
|
||||
../lstm/libtesseract_lstm.la \
|
||||
../classify/libtesseract_classify.la \
|
||||
../cutil/libtesseract_cutil.la \
|
||||
../opencl/libtesseract_opencl.la
|
||||
@ -44,7 +48,7 @@ endif
|
||||
libtesseract_main_la_SOURCES = \
|
||||
adaptions.cpp applybox.cpp control.cpp \
|
||||
docqual.cpp equationdetect.cpp fixspace.cpp fixxht.cpp \
|
||||
ltrresultiterator.cpp \
|
||||
linerec.cpp ltrresultiterator.cpp \
|
||||
osdetect.cpp output.cpp pageiterator.cpp pagesegmain.cpp \
|
||||
pagewalk.cpp par_control.cpp paragraphs.cpp paramsd.cpp pgedit.cpp recogtraining.cpp \
|
||||
reject.cpp resultiterator.cpp superscript.cpp \
|
||||
|
@ -84,7 +84,12 @@ BOOL8 Tesseract::recog_interactive(PAGE_RES_IT* pr_it) {
|
||||
|
||||
WordData word_data(*pr_it);
|
||||
SetupWordPassN(2, &word_data);
|
||||
classify_word_and_language(2, pr_it, &word_data);
|
||||
// LSTM doesn't run on pass2, but we want to run pass2 for tesseract.
|
||||
if (lstm_recognizer_ == NULL) {
|
||||
classify_word_and_language(2, pr_it, &word_data);
|
||||
} else {
|
||||
classify_word_and_language(1, pr_it, &word_data);
|
||||
}
|
||||
if (tessedit_debug_quality_metrics) {
|
||||
WERD_RES* word_res = pr_it->word();
|
||||
word_char_quality(word_res, pr_it->row()->row, &char_qual, &good_char_qual);
|
||||
@ -218,16 +223,14 @@ bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC* monitor,
|
||||
if (pass_n == 1) {
|
||||
monitor->progress = 70 * w / words->size();
|
||||
if (monitor->progress_callback != NULL) {
|
||||
TBOX box = pr_it->word()->word->bounding_box();
|
||||
(*monitor->progress_callback)(monitor->progress,
|
||||
box.left(), box.right(),
|
||||
box.top(), box.bottom());
|
||||
TBOX box = pr_it->word()->word->bounding_box();
|
||||
(*monitor->progress_callback)(monitor->progress, box.left(),
|
||||
box.right(), box.top(), box.bottom());
|
||||
}
|
||||
} else {
|
||||
monitor->progress = 70 + 30 * w / words->size();
|
||||
if (monitor->progress_callback!=NULL) {
|
||||
(*monitor->progress_callback)(monitor->progress,
|
||||
0, 0, 0, 0);
|
||||
if (monitor->progress_callback != NULL) {
|
||||
(*monitor->progress_callback)(monitor->progress, 0, 0, 0, 0);
|
||||
}
|
||||
}
|
||||
if (monitor->deadline_exceeded() ||
|
||||
@ -252,7 +255,8 @@ bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC* monitor,
|
||||
pr_it->forward();
|
||||
ASSERT_HOST(pr_it->word() != NULL);
|
||||
bool make_next_word_fuzzy = false;
|
||||
if (ReassignDiacritics(pass_n, pr_it, &make_next_word_fuzzy)) {
|
||||
if (!AnyLSTMLang() &&
|
||||
ReassignDiacritics(pass_n, pr_it, &make_next_word_fuzzy)) {
|
||||
// Needs to be setup again to see the new outlines in the chopped_word.
|
||||
SetupWordPassN(pass_n, word);
|
||||
}
|
||||
@ -297,6 +301,16 @@ bool Tesseract::recog_all_words(PAGE_RES* page_res,
|
||||
const TBOX* target_word_box,
|
||||
const char* word_config,
|
||||
int dopasses) {
|
||||
// PSM_RAW_LINE is a special-case mode in which the layout analysis is
|
||||
// completely ignored and LSTM is run on the raw image. There is no hope
|
||||
// of running normal tesseract in this situation or of integrating output.
|
||||
#ifndef ANDROID_BUILD
|
||||
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY &&
|
||||
tessedit_pageseg_mode == PSM_RAW_LINE) {
|
||||
RecogRawLine(page_res);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
PAGE_RES_IT page_res_it(page_res);
|
||||
|
||||
if (tessedit_minimal_rej_pass1) {
|
||||
@ -385,7 +399,7 @@ bool Tesseract::recog_all_words(PAGE_RES* page_res,
|
||||
|
||||
// The next passes can only be run if tesseract has been used, as cube
|
||||
// doesn't set all the necessary outputs in WERD_RES.
|
||||
if (AnyTessLang()) {
|
||||
if (AnyTessLang() && !AnyLSTMLang()) {
|
||||
// ****************** Pass 3 *******************
|
||||
// Fix fuzzy spaces.
|
||||
set_global_loc_code(LOC_FUZZY_SPACE);
|
||||
@ -1362,6 +1376,19 @@ void Tesseract::classify_word_pass1(const WordData& word_data,
|
||||
cube_word_pass1(block, row, *in_word);
|
||||
return;
|
||||
}
|
||||
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
|
||||
if (!(*in_word)->odd_size) {
|
||||
LSTMRecognizeWord(*block, row, *in_word, out_words);
|
||||
if (!out_words->empty())
|
||||
return; // Successful lstm recognition.
|
||||
}
|
||||
// Fall back to tesseract for failed words or odd words.
|
||||
(*in_word)->SetupForRecognition(unicharset, this, BestPix(),
|
||||
OEM_TESSERACT_ONLY, NULL,
|
||||
classify_bln_numeric_mode,
|
||||
textord_use_cjk_fp_model,
|
||||
poly_allow_detailed_fx, row, block);
|
||||
}
|
||||
#endif
|
||||
WERD_RES* word = *in_word;
|
||||
match_word_pass_n(1, word, row, block);
|
||||
@ -1496,10 +1523,6 @@ void Tesseract::classify_word_pass2(const WordData& word_data,
|
||||
WERD_RES** in_word,
|
||||
PointerVector<WERD_RES>* out_words) {
|
||||
// Return if we do not want to run Tesseract.
|
||||
if (tessedit_ocr_engine_mode != OEM_TESSERACT_ONLY &&
|
||||
tessedit_ocr_engine_mode != OEM_TESSERACT_CUBE_COMBINED &&
|
||||
word_data.word->best_choice != NULL)
|
||||
return;
|
||||
if (tessedit_ocr_engine_mode == OEM_CUBE_ONLY) {
|
||||
return;
|
||||
}
|
||||
|
332
ccmain/linerec.cpp
Normal file
332
ccmain/linerec.cpp
Normal file
@ -0,0 +1,332 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: linerec.cpp
|
||||
// Description: Top-level line-based recognition module for Tesseract.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu May 02 09:47:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "tesseractclass.h"
|
||||
|
||||
#include "allheaders.h"
|
||||
#include "boxread.h"
|
||||
#include "imagedata.h"
|
||||
#ifndef ANDROID_BUILD
|
||||
#include "lstmrecognizer.h"
|
||||
#include "recodebeam.h"
|
||||
#endif
|
||||
#include "ndminx.h"
|
||||
#include "pageres.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Arbitarary penalty for non-dictionary words.
|
||||
// TODO(rays) How to learn this?
|
||||
const float kNonDictionaryPenalty = 5.0f;
|
||||
// Scale factor to make certainty more comparable to Tesseract.
|
||||
const float kCertaintyScale = 7.0f;
|
||||
// Worst acceptable certainty for a dictionary word.
|
||||
const float kWorstDictCertainty = -25.0f;
|
||||
|
||||
// Generates training data for training a line recognizer, eg LSTM.
|
||||
// Breaks the page into lines, according to the boxes, and writes them to a
|
||||
// serialized DocumentData based on output_basename.
|
||||
void Tesseract::TrainLineRecognizer(const STRING& input_imagename,
|
||||
const STRING& output_basename,
|
||||
BLOCK_LIST *block_list) {
|
||||
STRING lstmf_name = output_basename + ".lstmf";
|
||||
DocumentData images(lstmf_name);
|
||||
if (applybox_page > 0) {
|
||||
// Load existing document for the previous pages.
|
||||
if (!images.LoadDocument(lstmf_name.string(), "eng", 0, 0, NULL)) {
|
||||
tprintf("Failed to read training data from %s!\n", lstmf_name.string());
|
||||
return;
|
||||
}
|
||||
}
|
||||
GenericVector<TBOX> boxes;
|
||||
GenericVector<STRING> texts;
|
||||
// Get the boxes for this page, if there are any.
|
||||
if (!ReadAllBoxes(applybox_page, false, input_imagename, &boxes, &texts, NULL,
|
||||
NULL) ||
|
||||
boxes.empty()) {
|
||||
tprintf("Failed to read boxes from %s\n", input_imagename.string());
|
||||
return;
|
||||
}
|
||||
TrainFromBoxes(boxes, texts, block_list, &images);
|
||||
if (!images.SaveDocument(lstmf_name.string(), NULL)) {
|
||||
tprintf("Failed to write training data to %s!\n", lstmf_name.string());
|
||||
}
|
||||
}
|
||||
|
||||
// Generates training data for training a line recognizer, eg LSTM.
|
||||
// Breaks the boxes into lines, normalizes them, converts to ImageData and
|
||||
// appends them to the given training_data.
|
||||
void Tesseract::TrainFromBoxes(const GenericVector<TBOX>& boxes,
|
||||
const GenericVector<STRING>& texts,
|
||||
BLOCK_LIST *block_list,
|
||||
DocumentData* training_data) {
|
||||
int box_count = boxes.size();
|
||||
// Process all the text lines in this page, as defined by the boxes.
|
||||
int end_box = 0;
|
||||
for (int start_box = 0; start_box < box_count; start_box = end_box) {
|
||||
// Find the textline of boxes starting at start and their bounding box.
|
||||
TBOX line_box = boxes[start_box];
|
||||
STRING line_str = texts[start_box];
|
||||
for (end_box = start_box + 1; end_box < box_count && texts[end_box] != "\t";
|
||||
++end_box) {
|
||||
line_box += boxes[end_box];
|
||||
line_str += texts[end_box];
|
||||
}
|
||||
// Find the most overlapping block.
|
||||
BLOCK* best_block = NULL;
|
||||
int best_overlap = 0;
|
||||
BLOCK_IT b_it(block_list);
|
||||
for (b_it.mark_cycle_pt(); !b_it.cycled_list(); b_it.forward()) {
|
||||
BLOCK* block = b_it.data();
|
||||
if (block->poly_block() != NULL && !block->poly_block()->IsText())
|
||||
continue; // Not a text block.
|
||||
TBOX block_box = block->bounding_box();
|
||||
block_box.rotate(block->re_rotation());
|
||||
if (block_box.major_overlap(line_box)) {
|
||||
TBOX overlap_box = line_box.intersection(block_box);
|
||||
if (overlap_box.area() > best_overlap) {
|
||||
best_overlap = overlap_box.area();
|
||||
best_block = block;
|
||||
}
|
||||
}
|
||||
}
|
||||
ImageData* imagedata = NULL;
|
||||
if (best_block == NULL) {
|
||||
tprintf("No block overlapping textline: %s\n", line_str.string());
|
||||
} else {
|
||||
imagedata = GetLineData(line_box, boxes, texts, start_box, end_box,
|
||||
*best_block);
|
||||
}
|
||||
if (imagedata != NULL)
|
||||
training_data->AddPageToDocument(imagedata);
|
||||
if (end_box < texts.size() && texts[end_box] == "\t") ++end_box;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns an Imagedata containing the image of the given box,
|
||||
// and ground truth boxes/truth text if available in the input.
|
||||
// The image is not normalized in any way.
|
||||
ImageData* Tesseract::GetLineData(const TBOX& line_box,
|
||||
const GenericVector<TBOX>& boxes,
|
||||
const GenericVector<STRING>& texts,
|
||||
int start_box, int end_box,
|
||||
const BLOCK& block) {
|
||||
TBOX revised_box;
|
||||
ImageData* image_data = GetRectImage(line_box, block, kImagePadding,
|
||||
&revised_box);
|
||||
if (image_data == NULL) return NULL;
|
||||
image_data->set_page_number(applybox_page);
|
||||
// Copy the boxes and shift them so they are relative to the image.
|
||||
FCOORD block_rotation(block.re_rotation().x(), -block.re_rotation().y());
|
||||
ICOORD shift = -revised_box.botleft();
|
||||
GenericVector<TBOX> line_boxes;
|
||||
GenericVector<STRING> line_texts;
|
||||
for (int b = start_box; b < end_box; ++b) {
|
||||
TBOX box = boxes[b];
|
||||
box.rotate(block_rotation);
|
||||
box.move(shift);
|
||||
line_boxes.push_back(box);
|
||||
line_texts.push_back(texts[b]);
|
||||
}
|
||||
GenericVector<int> page_numbers;
|
||||
page_numbers.init_to_size(line_boxes.size(), applybox_page);
|
||||
image_data->AddBoxes(line_boxes, line_texts, page_numbers);
|
||||
return image_data;
|
||||
}
|
||||
|
||||
// Helper gets the image of a rectangle, using the block.re_rotation() if
|
||||
// needed to get to the image, and rotating the result back to horizontal
|
||||
// layout. (CJK characters will be on their left sides) The vertical text flag
|
||||
// is set in the returned ImageData if the text was originally vertical, which
|
||||
// can be used to invoke a different CJK recognition engine. The revised_box
|
||||
// is also returned to enable calculation of output bounding boxes.
|
||||
ImageData* Tesseract::GetRectImage(const TBOX& box, const BLOCK& block,
|
||||
int padding, TBOX* revised_box) const {
|
||||
TBOX wbox = box;
|
||||
wbox.pad(padding, padding);
|
||||
*revised_box = wbox;
|
||||
// Number of clockwise 90 degree rotations needed to get back to tesseract
|
||||
// coords from the clipped image.
|
||||
int num_rotations = 0;
|
||||
if (block.re_rotation().y() > 0.0f)
|
||||
num_rotations = 1;
|
||||
else if (block.re_rotation().x() < 0.0f)
|
||||
num_rotations = 2;
|
||||
else if (block.re_rotation().y() < 0.0f)
|
||||
num_rotations = 3;
|
||||
// Handle two cases automatically: 1 the box came from the block, 2 the box
|
||||
// came from a box file, and refers to the image, which the block may not.
|
||||
if (block.bounding_box().major_overlap(*revised_box))
|
||||
revised_box->rotate(block.re_rotation());
|
||||
// Now revised_box always refers to the image.
|
||||
// BestPix is never colormapped, but may be of any depth.
|
||||
Pix* pix = BestPix();
|
||||
int width = pixGetWidth(pix);
|
||||
int height = pixGetHeight(pix);
|
||||
TBOX image_box(0, 0, width, height);
|
||||
// Clip to image bounds;
|
||||
*revised_box &= image_box;
|
||||
if (revised_box->null_box()) return NULL;
|
||||
Box* clip_box = boxCreate(revised_box->left(), height - revised_box->top(),
|
||||
revised_box->width(), revised_box->height());
|
||||
Pix* box_pix = pixClipRectangle(pix, clip_box, NULL);
|
||||
if (box_pix == NULL) return NULL;
|
||||
boxDestroy(&clip_box);
|
||||
if (num_rotations > 0) {
|
||||
Pix* rot_pix = pixRotateOrth(box_pix, num_rotations);
|
||||
pixDestroy(&box_pix);
|
||||
box_pix = rot_pix;
|
||||
}
|
||||
// Convert sub-8-bit images to 8 bit.
|
||||
int depth = pixGetDepth(box_pix);
|
||||
if (depth < 8) {
|
||||
Pix* grey;
|
||||
grey = pixConvertTo8(box_pix, false);
|
||||
pixDestroy(&box_pix);
|
||||
box_pix = grey;
|
||||
}
|
||||
bool vertical_text = false;
|
||||
if (num_rotations > 0) {
|
||||
// Rotated the clipped revised box back to internal coordinates.
|
||||
FCOORD rotation(block.re_rotation().x(), -block.re_rotation().y());
|
||||
revised_box->rotate(rotation);
|
||||
if (num_rotations != 2)
|
||||
vertical_text = true;
|
||||
}
|
||||
return new ImageData(vertical_text, box_pix);
|
||||
}
|
||||
|
||||
#ifndef ANDROID_BUILD
|
||||
// Top-level function recognizes a single raw line.
|
||||
void Tesseract::RecogRawLine(PAGE_RES* page_res) {
|
||||
PAGE_RES_IT it(page_res);
|
||||
PointerVector<WERD_RES> words;
|
||||
LSTMRecognizeWord(*it.block()->block, it.row()->row, it.word(), &words);
|
||||
if (getDict().stopper_debug_level >= 1) {
|
||||
for (int w = 0; w < words.size(); ++w) {
|
||||
words[w]->DebugWordChoices(true, NULL);
|
||||
}
|
||||
}
|
||||
it.ReplaceCurrentWord(&words);
|
||||
}
|
||||
|
||||
// Recognizes a word or group of words, converting to WERD_RES in *words.
|
||||
// Analogous to classify_word_pass1, but can handle a group of words as well.
|
||||
void Tesseract::LSTMRecognizeWord(const BLOCK& block, ROW *row, WERD_RES *word,
|
||||
PointerVector<WERD_RES>* words) {
|
||||
TBOX word_box = word->word->bounding_box();
|
||||
// Get the word image - no frills.
|
||||
if (tessedit_pageseg_mode == PSM_SINGLE_WORD ||
|
||||
tessedit_pageseg_mode == PSM_RAW_LINE) {
|
||||
// In single word mode, use the whole image without any other row/word
|
||||
// interpretation.
|
||||
word_box = TBOX(0, 0, ImageWidth(), ImageHeight());
|
||||
} else {
|
||||
float baseline = row->base_line((word_box.left() + word_box.right()) / 2);
|
||||
if (baseline + row->descenders() < word_box.bottom())
|
||||
word_box.set_bottom(baseline + row->descenders());
|
||||
if (baseline + row->x_height() + row->ascenders() > word_box.top())
|
||||
word_box.set_top(baseline + row->x_height() + row->ascenders());
|
||||
}
|
||||
ImageData* im_data = GetRectImage(word_box, block, kImagePadding, &word_box);
|
||||
if (im_data == NULL) return;
|
||||
lstm_recognizer_->RecognizeLine(*im_data, true, classify_debug_level > 0,
|
||||
kWorstDictCertainty / kCertaintyScale,
|
||||
lstm_use_matrix, &unicharset, word_box, 2.0,
|
||||
false, words);
|
||||
delete im_data;
|
||||
SearchWords(words);
|
||||
}
|
||||
|
||||
// Apply segmentation search to the given set of words, within the constraints
|
||||
// of the existing ratings matrix. If there is already a best_choice on a word
|
||||
// leaves it untouched and just sets the done/accepted etc flags.
|
||||
void Tesseract::SearchWords(PointerVector<WERD_RES>* words) {
|
||||
// Run the segmentation search on the network outputs and make a BoxWord
|
||||
// for each of the output words.
|
||||
// If we drop a word as junk, then there is always a space in front of the
|
||||
// next.
|
||||
bool deleted_prev = false;
|
||||
for (int w = 0; w < words->size(); ++w) {
|
||||
WERD_RES* word = (*words)[w];
|
||||
if (word->best_choice == NULL) {
|
||||
// If we are using the beam search, the unicharset had better match!
|
||||
word->SetupWordScript(unicharset);
|
||||
WordSearch(word);
|
||||
} else if (word->best_choice->unicharset() == &unicharset &&
|
||||
!lstm_recognizer_->IsRecoding()) {
|
||||
// We set up the word without using the dictionary, so set the permuter
|
||||
// now, but we can only do it because the unicharsets match.
|
||||
word->best_choice->set_permuter(
|
||||
getDict().valid_word(*word->best_choice, true));
|
||||
}
|
||||
if (word->best_choice == NULL) {
|
||||
// It is a dud.
|
||||
words->remove(w);
|
||||
--w;
|
||||
deleted_prev = true;
|
||||
} else {
|
||||
// Set the best state.
|
||||
for (int i = 0; i < word->best_choice->length(); ++i) {
|
||||
int length = word->best_choice->state(i);
|
||||
word->best_state.push_back(length);
|
||||
}
|
||||
word->tess_failed = false;
|
||||
word->tess_accepted = true;
|
||||
word->tess_would_adapt = false;
|
||||
word->done = true;
|
||||
word->tesseract = this;
|
||||
float word_certainty = MIN(word->space_certainty,
|
||||
word->best_choice->certainty());
|
||||
word_certainty *= kCertaintyScale;
|
||||
// Arbitrary ding factor for non-dictionary words.
|
||||
if (!lstm_recognizer_->IsRecoding() &&
|
||||
!Dict::valid_word_permuter(word->best_choice->permuter(), true))
|
||||
word_certainty -= kNonDictionaryPenalty;
|
||||
if (getDict().stopper_debug_level >= 1) {
|
||||
tprintf("Best choice certainty=%g, space=%g, scaled=%g, final=%g\n",
|
||||
word->best_choice->certainty(), word->space_certainty,
|
||||
MIN(word->space_certainty, word->best_choice->certainty()) *
|
||||
kCertaintyScale,
|
||||
word_certainty);
|
||||
word->best_choice->print();
|
||||
}
|
||||
// Discard words that are impossibly bad, but allow a bit more for
|
||||
// dictionary words.
|
||||
if (word_certainty >= RecodeBeamSearch::kMinCertainty ||
|
||||
(word_certainty >= kWorstDictCertainty &&
|
||||
Dict::valid_word_permuter(word->best_choice->permuter(), true))) {
|
||||
word->best_choice->set_certainty(word_certainty);
|
||||
if (deleted_prev) word->word->set_blanks(1);
|
||||
} else {
|
||||
if (getDict().stopper_debug_level >= 1) {
|
||||
tprintf("Deleting word with certainty %g\n", word_certainty);
|
||||
word->best_choice->print();
|
||||
}
|
||||
// It is a dud.
|
||||
words->remove(w);
|
||||
--w;
|
||||
deleted_prev = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // ANDROID_BUILD
|
||||
|
||||
} // namespace tesseract.
|
@ -220,6 +220,12 @@ bool LTRResultIterator::WordIsFromDictionary() const {
|
||||
permuter == USER_DAWG_PERM;
|
||||
}
|
||||
|
||||
// Returns the number of blanks before the current word.
|
||||
int LTRResultIterator::BlanksBeforeWord() const {
|
||||
if (it_->word() == NULL) return 1;
|
||||
return it_->word()->word->space();
|
||||
}
|
||||
|
||||
// Returns true if the current word is numeric.
|
||||
bool LTRResultIterator::WordIsNumeric() const {
|
||||
if (it_->word() == NULL) return false; // Already at the end!
|
||||
|
@ -124,6 +124,9 @@ class TESS_API LTRResultIterator : public PageIterator {
|
||||
// Returns true if the current word was found in a dictionary.
|
||||
bool WordIsFromDictionary() const;
|
||||
|
||||
// Returns the number of blanks before the current word.
|
||||
int BlanksBeforeWord() const;
|
||||
|
||||
// Returns true if the current word is numeric.
|
||||
bool WordIsNumeric() const;
|
||||
|
||||
|
@ -40,6 +40,9 @@
|
||||
#include "efio.h"
|
||||
#include "danerror.h"
|
||||
#include "globals.h"
|
||||
#ifndef ANDROID_BUILD
|
||||
#include "lstmrecognizer.h"
|
||||
#endif
|
||||
#include "tesseractclass.h"
|
||||
#include "params.h"
|
||||
|
||||
@ -214,6 +217,18 @@ bool Tesseract::init_tesseract_lang_data(
|
||||
ASSERT_HOST(init_cube_objects(true, &tessdata_manager));
|
||||
if (tessdata_manager_debug_level)
|
||||
tprintf("Loaded Cube with combiner\n");
|
||||
} else if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
|
||||
if (tessdata_manager.SeekToStart(TESSDATA_LSTM)) {
|
||||
lstm_recognizer_ = new LSTMRecognizer;
|
||||
TFile fp;
|
||||
fp.Open(tessdata_manager.GetDataFilePtr(), -1);
|
||||
ASSERT_HOST(lstm_recognizer_->DeSerialize(tessdata_manager.swap(), &fp));
|
||||
if (lstm_use_matrix)
|
||||
lstm_recognizer_->LoadDictionary(tessdata_path.string(), language);
|
||||
} else {
|
||||
tprintf("Error: LSTM requested, but not present!! Loading tesseract.\n");
|
||||
tessedit_ocr_engine_mode.set_value(OEM_TESSERACT_ONLY);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
// Init ParamsModel.
|
||||
@ -409,8 +424,7 @@ int Tesseract::init_tesseract_internal(
|
||||
// If only Cube will be used, skip loading Tesseract classifier's
|
||||
// pre-trained templates.
|
||||
bool init_tesseract_classifier =
|
||||
(tessedit_ocr_engine_mode == OEM_TESSERACT_ONLY ||
|
||||
tessedit_ocr_engine_mode == OEM_TESSERACT_CUBE_COMBINED);
|
||||
tessedit_ocr_engine_mode != OEM_CUBE_ONLY;
|
||||
// If only Cube will be used and if it has its own Unicharset,
|
||||
// skip initializing permuter and loading Tesseract Dawgs.
|
||||
bool init_dict =
|
||||
@ -468,7 +482,9 @@ int Tesseract::init_tesseract_lm(const char *arg0,
|
||||
if (!init_tesseract_lang_data(arg0, textbase, language, OEM_TESSERACT_ONLY,
|
||||
NULL, 0, NULL, NULL, false))
|
||||
return -1;
|
||||
getDict().Load(Dict::GlobalDawgCache());
|
||||
getDict().SetupForLoad(Dict::GlobalDawgCache());
|
||||
getDict().Load(tessdata_manager.GetDataFileName().string(), lang);
|
||||
getDict().FinishLoad();
|
||||
tessdata_manager.End();
|
||||
return 0;
|
||||
}
|
||||
|
@ -49,6 +49,7 @@
|
||||
#include "equationdetect.h"
|
||||
#include "globals.h"
|
||||
#ifndef NO_CUBE_BUILD
|
||||
#include "lstmrecognizer.h"
|
||||
#include "tesseract_cube_combiner.h"
|
||||
#endif
|
||||
|
||||
@ -65,6 +66,9 @@ Tesseract::Tesseract()
|
||||
"Generate training data from boxed chars", this->params()),
|
||||
BOOL_MEMBER(tessedit_make_boxes_from_boxes, false,
|
||||
"Generate more boxes from boxed chars", this->params()),
|
||||
BOOL_MEMBER(tessedit_train_line_recognizer, false,
|
||||
"Break input into lines and remap boxes if present",
|
||||
this->params()),
|
||||
BOOL_MEMBER(tessedit_dump_pageseg_images, false,
|
||||
"Dump intermediate images made during page segmentation",
|
||||
this->params()),
|
||||
@ -222,6 +226,8 @@ Tesseract::Tesseract()
|
||||
"(more accurate)",
|
||||
this->params()),
|
||||
INT_MEMBER(cube_debug_level, 0, "Print cube debug info.", this->params()),
|
||||
BOOL_MEMBER(lstm_use_matrix, 1,
|
||||
"Use ratings matrix/beam search with lstm", this->params()),
|
||||
STRING_MEMBER(outlines_odd, "%| ", "Non standard number of outlines",
|
||||
this->params()),
|
||||
STRING_MEMBER(outlines_2, "ij!?%\":;", "Non standard number of outlines",
|
||||
@ -605,6 +611,7 @@ Tesseract::Tesseract()
|
||||
pix_binary_(NULL),
|
||||
cube_binary_(NULL),
|
||||
pix_grey_(NULL),
|
||||
pix_original_(NULL),
|
||||
pix_thresholds_(NULL),
|
||||
source_resolution_(0),
|
||||
textord_(this),
|
||||
@ -619,11 +626,16 @@ Tesseract::Tesseract()
|
||||
cube_cntxt_(NULL),
|
||||
tess_cube_combiner_(NULL),
|
||||
#endif
|
||||
equ_detect_(NULL) {
|
||||
equ_detect_(NULL),
|
||||
#ifndef ANDROID_BUILD
|
||||
lstm_recognizer_(NULL),
|
||||
#endif
|
||||
train_line_page_num_(0) {
|
||||
}
|
||||
|
||||
Tesseract::~Tesseract() {
|
||||
Clear();
|
||||
pixDestroy(&pix_original_);
|
||||
end_tesseract();
|
||||
sub_langs_.delete_data_pointers();
|
||||
#ifndef NO_CUBE_BUILD
|
||||
@ -636,6 +648,8 @@ Tesseract::~Tesseract() {
|
||||
delete tess_cube_combiner_;
|
||||
tess_cube_combiner_ = NULL;
|
||||
}
|
||||
delete lstm_recognizer_;
|
||||
lstm_recognizer_ = NULL;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -102,7 +102,10 @@ class CubeLineObject;
|
||||
class CubeObject;
|
||||
class CubeRecoContext;
|
||||
#endif
|
||||
class DocumentData;
|
||||
class EquationDetect;
|
||||
class ImageData;
|
||||
class LSTMRecognizer;
|
||||
class Tesseract;
|
||||
#ifndef NO_CUBE_BUILD
|
||||
class TesseractCubeCombiner;
|
||||
@ -189,7 +192,7 @@ class Tesseract : public Wordrec {
|
||||
}
|
||||
// Destroy any existing pix and return a pointer to the pointer.
|
||||
Pix** mutable_pix_binary() {
|
||||
Clear();
|
||||
pixDestroy(&pix_binary_);
|
||||
return &pix_binary_;
|
||||
}
|
||||
Pix* pix_binary() const {
|
||||
@ -202,16 +205,20 @@ class Tesseract : public Wordrec {
|
||||
pixDestroy(&pix_grey_);
|
||||
pix_grey_ = grey_pix;
|
||||
}
|
||||
// Returns a pointer to a Pix representing the best available image of the
|
||||
// page. The image will be 8-bit grey if the input was grey or color. Note
|
||||
// that in grey 0 is black and 255 is white. If the input was binary, then
|
||||
// the returned Pix will be binary. Note that here black is 1 and white is 0.
|
||||
// To tell the difference pixGetDepth() will return 8 or 1.
|
||||
// In either case, the return value is a borrowed Pix, and should not be
|
||||
// deleted or pixDestroyed.
|
||||
Pix* BestPix() const {
|
||||
return pix_grey_ != NULL ? pix_grey_ : pix_binary_;
|
||||
Pix* pix_original() const { return pix_original_; }
|
||||
// Takes ownership of the given original_pix.
|
||||
void set_pix_original(Pix* original_pix) {
|
||||
pixDestroy(&pix_original_);
|
||||
pix_original_ = original_pix;
|
||||
}
|
||||
// Returns a pointer to a Pix representing the best available (original) image
|
||||
// of the page. Can be of any bit depth, but never color-mapped, as that has
|
||||
// always been dealt with. Note that in grey and color, 0 is black and 255 is
|
||||
// white. If the input was binary, then black is 1 and white is 0.
|
||||
// To tell the difference pixGetDepth() will return 32, 8 or 1.
|
||||
// In any case, the return value is a borrowed Pix, and should not be
|
||||
// deleted or pixDestroyed.
|
||||
Pix* BestPix() const { return pix_original_; }
|
||||
void set_pix_thresholds(Pix* thresholds) {
|
||||
pixDestroy(&pix_thresholds_);
|
||||
pix_thresholds_ = thresholds;
|
||||
@ -263,6 +270,15 @@ class Tesseract : public Wordrec {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// Returns true if any language uses the LSTM.
|
||||
bool AnyLSTMLang() const {
|
||||
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) return true;
|
||||
for (int i = 0; i < sub_langs_.size(); ++i) {
|
||||
if (sub_langs_[i]->tessedit_ocr_engine_mode == OEM_LSTM_ONLY)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void SetBlackAndWhitelist();
|
||||
|
||||
@ -293,6 +309,48 @@ class Tesseract : public Wordrec {
|
||||
// par_control.cpp
|
||||
void PrerecAllWordsPar(const GenericVector<WordData>& words);
|
||||
|
||||
//// linerec.cpp
|
||||
// Generates training data for training a line recognizer, eg LSTM.
|
||||
// Breaks the page into lines, according to the boxes, and writes them to a
|
||||
// serialized DocumentData based on output_basename.
|
||||
void TrainLineRecognizer(const STRING& input_imagename,
|
||||
const STRING& output_basename,
|
||||
BLOCK_LIST *block_list);
|
||||
// Generates training data for training a line recognizer, eg LSTM.
|
||||
// Breaks the boxes into lines, normalizes them, converts to ImageData and
|
||||
// appends them to the given training_data.
|
||||
void TrainFromBoxes(const GenericVector<TBOX>& boxes,
|
||||
const GenericVector<STRING>& texts,
|
||||
BLOCK_LIST *block_list,
|
||||
DocumentData* training_data);
|
||||
|
||||
// Returns an Imagedata containing the image of the given textline,
|
||||
// and ground truth boxes/truth text if available in the input.
|
||||
// The image is not normalized in any way.
|
||||
ImageData* GetLineData(const TBOX& line_box,
|
||||
const GenericVector<TBOX>& boxes,
|
||||
const GenericVector<STRING>& texts,
|
||||
int start_box, int end_box,
|
||||
const BLOCK& block);
|
||||
// Helper gets the image of a rectangle, using the block.re_rotation() if
|
||||
// needed to get to the image, and rotating the result back to horizontal
|
||||
// layout. (CJK characters will be on their left sides) The vertical text flag
|
||||
// is set in the returned ImageData if the text was originally vertical, which
|
||||
// can be used to invoke a different CJK recognition engine. The revised_box
|
||||
// is also returned to enable calculation of output bounding boxes.
|
||||
ImageData* GetRectImage(const TBOX& box, const BLOCK& block, int padding,
|
||||
TBOX* revised_box) const;
|
||||
// Top-level function recognizes a single raw line.
|
||||
void RecogRawLine(PAGE_RES* page_res);
|
||||
// Recognizes a word or group of words, converting to WERD_RES in *words.
|
||||
// Analogous to classify_word_pass1, but can handle a group of words as well.
|
||||
void LSTMRecognizeWord(const BLOCK& block, ROW *row, WERD_RES *word,
|
||||
PointerVector<WERD_RES>* words);
|
||||
// Apply segmentation search to the given set of words, within the constraints
|
||||
// of the existing ratings matrix. If there is already a best_choice on a word
|
||||
// leaves it untouched and just sets the done/accepted etc flags.
|
||||
void SearchWords(PointerVector<WERD_RES>* words);
|
||||
|
||||
//// control.h /////////////////////////////////////////////////////////
|
||||
bool ProcessTargetWord(const TBOX& word_box, const TBOX& target_word_box,
|
||||
const char* word_config, int pass);
|
||||
@ -783,6 +841,8 @@ class Tesseract : public Wordrec {
|
||||
"Generate training data from boxed chars");
|
||||
BOOL_VAR_H(tessedit_make_boxes_from_boxes, false,
|
||||
"Generate more boxes from boxed chars");
|
||||
BOOL_VAR_H(tessedit_train_line_recognizer, false,
|
||||
"Break input into lines and remap boxes if present");
|
||||
BOOL_VAR_H(tessedit_dump_pageseg_images, false,
|
||||
"Dump intermediate images made during page segmentation");
|
||||
INT_VAR_H(tessedit_pageseg_mode, PSM_SINGLE_BLOCK,
|
||||
@ -891,6 +951,7 @@ class Tesseract : public Wordrec {
|
||||
"Run paragraph detection on the post-text-recognition "
|
||||
"(more accurate)");
|
||||
INT_VAR_H(cube_debug_level, 1, "Print cube debug info.");
|
||||
BOOL_VAR_H(lstm_use_matrix, 1, "Use ratings matrix/beam searct with lstm");
|
||||
STRING_VAR_H(outlines_odd, "%| ", "Non standard number of outlines");
|
||||
STRING_VAR_H(outlines_2, "ij!?%\":;", "Non standard number of outlines");
|
||||
BOOL_VAR_H(docqual_excuse_outline_errs, false,
|
||||
@ -1174,6 +1235,8 @@ class Tesseract : public Wordrec {
|
||||
Pix* cube_binary_;
|
||||
// Grey-level input image if the input was not binary, otherwise NULL.
|
||||
Pix* pix_grey_;
|
||||
// Original input image. Color if the input was color.
|
||||
Pix* pix_original_;
|
||||
// Thresholds that were used to generate the thresholded image from grey.
|
||||
Pix* pix_thresholds_;
|
||||
// Input image resolution after any scaling. The resolution is not well
|
||||
@ -1205,6 +1268,10 @@ class Tesseract : public Wordrec {
|
||||
#endif
|
||||
// Equation detector. Note: this pointer is NOT owned by the class.
|
||||
EquationDetect* equ_detect_;
|
||||
// LSTM recognizer, if available.
|
||||
LSTMRecognizer* lstm_recognizer_;
|
||||
// Output "page" number (actually line number) using TrainLineRecognizer.
|
||||
int train_line_page_num_;
|
||||
};
|
||||
|
||||
} // namespace tesseract
|
||||
|
@ -152,19 +152,27 @@ void ImageThresholder::SetImage(const Pix* pix) {
|
||||
int depth;
|
||||
pixGetDimensions(src, &image_width_, &image_height_, &depth);
|
||||
// Convert the image as necessary so it is one of binary, plain RGB, or
|
||||
// 8 bit with no colormap.
|
||||
if (depth > 1 && depth < 8) {
|
||||
// 8 bit with no colormap. Guarantee that we always end up with our own copy,
|
||||
// not just a clone of the input.
|
||||
if (pixGetColormap(src)) {
|
||||
Pix* tmp = pixRemoveColormap(src, REMOVE_CMAP_BASED_ON_SRC);
|
||||
depth = pixGetDepth(tmp);
|
||||
if (depth > 1 && depth < 8) {
|
||||
pix_ = pixConvertTo8(tmp, false);
|
||||
pixDestroy(&tmp);
|
||||
} else {
|
||||
pix_ = tmp;
|
||||
}
|
||||
} else if (depth > 1 && depth < 8) {
|
||||
pix_ = pixConvertTo8(src, false);
|
||||
} else if (pixGetColormap(src)) {
|
||||
pix_ = pixRemoveColormap(src, REMOVE_CMAP_BASED_ON_SRC);
|
||||
} else {
|
||||
pix_ = pixClone(src);
|
||||
pix_ = pixCopy(NULL, src);
|
||||
}
|
||||
depth = pixGetDepth(pix_);
|
||||
pix_channels_ = depth / 8;
|
||||
pix_wpl_ = pixGetWpl(pix_);
|
||||
scale_ = 1;
|
||||
estimated_res_ = yres_ = pixGetYRes(src);
|
||||
estimated_res_ = yres_ = pixGetYRes(pix_);
|
||||
Init();
|
||||
}
|
||||
|
||||
|
@ -24,12 +24,18 @@
|
||||
|
||||
#include "imagedata.h"
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
#include "allheaders.h"
|
||||
#include "boxread.h"
|
||||
#include "callcpp.h"
|
||||
#include "helpers.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
// Number of documents to read ahead while training. Doesn't need to be very
|
||||
// large.
|
||||
const int kMaxReadAhead = 8;
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
WordFeature::WordFeature() : x_(0), y_(0), dir_(0) {
|
||||
@ -182,6 +188,19 @@ bool ImageData::DeSerialize(bool swap, TFile* fp) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// As DeSerialize, but only seeks past the data - hence a static method.
|
||||
bool ImageData::SkipDeSerialize(bool swap, TFile* fp) {
|
||||
if (!STRING::SkipDeSerialize(swap, fp)) return false;
|
||||
inT32 page_number;
|
||||
if (fp->FRead(&page_number, sizeof(page_number), 1) != 1) return false;
|
||||
if (!GenericVector<char>::SkipDeSerialize(swap, fp)) return false;
|
||||
if (!STRING::SkipDeSerialize(swap, fp)) return false;
|
||||
if (!GenericVector<TBOX>::SkipDeSerialize(swap, fp)) return false;
|
||||
if (!GenericVector<STRING>::SkipDeSerializeClasses(swap, fp)) return false;
|
||||
inT8 vertical = 0;
|
||||
return fp->FRead(&vertical, sizeof(vertical), 1) == 1;
|
||||
}
|
||||
|
||||
// Saves the given Pix as a PNG-encoded string and destroys it.
|
||||
void ImageData::SetPix(Pix* pix) {
|
||||
SetPixInternal(pix, &image_data_);
|
||||
@ -195,11 +214,12 @@ Pix* ImageData::GetPix() const {
|
||||
// Gets anything and everything with a non-NULL pointer, prescaled to a
|
||||
// given target_height (if 0, then the original image height), and aligned.
|
||||
// Also returns (if not NULL) the width and height of the scaled image.
|
||||
// The return value is the scale factor that was applied to the image to
|
||||
// achieve the target_height.
|
||||
float ImageData::PreScale(int target_height, Pix** pix,
|
||||
int* scaled_width, int* scaled_height,
|
||||
GenericVector<TBOX>* boxes) const {
|
||||
// The return value is the scaled Pix, which must be pixDestroyed after use,
|
||||
// and scale_factor (if not NULL) is set to the scale factor that was applied
|
||||
// to the image to achieve the target_height.
|
||||
Pix* ImageData::PreScale(int target_height, float* scale_factor,
|
||||
int* scaled_width, int* scaled_height,
|
||||
GenericVector<TBOX>* boxes) const {
|
||||
int input_width = 0;
|
||||
int input_height = 0;
|
||||
Pix* src_pix = GetPix();
|
||||
@ -213,19 +233,14 @@ float ImageData::PreScale(int target_height, Pix** pix,
|
||||
*scaled_width = IntCastRounded(im_factor * input_width);
|
||||
if (scaled_height != NULL)
|
||||
*scaled_height = target_height;
|
||||
if (pix != NULL) {
|
||||
// Get the scaled image.
|
||||
pixDestroy(pix);
|
||||
*pix = pixScale(src_pix, im_factor, im_factor);
|
||||
if (*pix == NULL) {
|
||||
tprintf("Scaling pix of size %d, %d by factor %g made null pix!!\n",
|
||||
input_width, input_height, im_factor);
|
||||
}
|
||||
if (scaled_width != NULL)
|
||||
*scaled_width = pixGetWidth(*pix);
|
||||
if (scaled_height != NULL)
|
||||
*scaled_height = pixGetHeight(*pix);
|
||||
// Get the scaled image.
|
||||
Pix* pix = pixScale(src_pix, im_factor, im_factor);
|
||||
if (pix == NULL) {
|
||||
tprintf("Scaling pix of size %d, %d by factor %g made null pix!!\n",
|
||||
input_width, input_height, im_factor);
|
||||
}
|
||||
if (scaled_width != NULL) *scaled_width = pixGetWidth(pix);
|
||||
if (scaled_height != NULL) *scaled_height = pixGetHeight(pix);
|
||||
pixDestroy(&src_pix);
|
||||
if (boxes != NULL) {
|
||||
// Get the boxes.
|
||||
@ -241,7 +256,8 @@ float ImageData::PreScale(int target_height, Pix** pix,
|
||||
boxes->push_back(box);
|
||||
}
|
||||
}
|
||||
return im_factor;
|
||||
if (scale_factor != NULL) *scale_factor = im_factor;
|
||||
return pix;
|
||||
}
|
||||
|
||||
int ImageData::MemoryUsed() const {
|
||||
@ -266,19 +282,20 @@ void ImageData::Display() const {
|
||||
// Draw the boxes.
|
||||
win->Pen(ScrollView::RED);
|
||||
win->Brush(ScrollView::NONE);
|
||||
win->TextAttributes("Arial", kTextSize, false, false, false);
|
||||
for (int b = 0; b < boxes_.size(); ++b) {
|
||||
boxes_[b].plot(win);
|
||||
win->Text(boxes_[b].left(), height + kTextSize, box_texts_[b].string());
|
||||
TBOX scaled(boxes_[b]);
|
||||
scaled.scale(256.0 / height);
|
||||
scaled.plot(win);
|
||||
int text_size = kTextSize;
|
||||
if (!boxes_.empty() && boxes_[0].height() * 2 < text_size)
|
||||
text_size = boxes_[0].height() * 2;
|
||||
win->TextAttributes("Arial", text_size, false, false, false);
|
||||
if (!boxes_.empty()) {
|
||||
for (int b = 0; b < boxes_.size(); ++b) {
|
||||
boxes_[b].plot(win);
|
||||
win->Text(boxes_[b].left(), height + kTextSize, box_texts_[b].string());
|
||||
}
|
||||
} else {
|
||||
// The full transcription.
|
||||
win->Pen(ScrollView::CYAN);
|
||||
win->Text(0, height + kTextSize * 2, transcription_.string());
|
||||
}
|
||||
// The full transcription.
|
||||
win->Pen(ScrollView::CYAN);
|
||||
win->Text(0, height + kTextSize * 2, transcription_.string());
|
||||
// Add the features.
|
||||
win->Pen(ScrollView::GREEN);
|
||||
win->Update();
|
||||
window_wait(win);
|
||||
#endif
|
||||
@ -340,27 +357,51 @@ bool ImageData::AddBoxes(const char* box_text) {
|
||||
return false;
|
||||
}
|
||||
|
||||
DocumentData::DocumentData(const STRING& name)
|
||||
: document_name_(name), pages_offset_(0), total_pages_(0),
|
||||
memory_used_(0), max_memory_(0), reader_(NULL) {}
|
||||
// Thread function to call ReCachePages.
|
||||
void* ReCachePagesFunc(void* data) {
|
||||
DocumentData* document_data = reinterpret_cast<DocumentData*>(data);
|
||||
document_data->ReCachePages();
|
||||
return NULL;
|
||||
}
|
||||
|
||||
DocumentData::~DocumentData() {}
|
||||
DocumentData::DocumentData(const STRING& name)
|
||||
: document_name_(name),
|
||||
pages_offset_(-1),
|
||||
total_pages_(-1),
|
||||
memory_used_(0),
|
||||
max_memory_(0),
|
||||
reader_(NULL) {}
|
||||
|
||||
DocumentData::~DocumentData() {
|
||||
SVAutoLock lock_p(&pages_mutex_);
|
||||
SVAutoLock lock_g(&general_mutex_);
|
||||
}
|
||||
|
||||
// Reads all the pages in the given lstmf filename to the cache. The reader
|
||||
// is used to read the file.
|
||||
bool DocumentData::LoadDocument(const char* filename, const char* lang,
|
||||
int start_page, inT64 max_memory,
|
||||
FileReader reader) {
|
||||
SetDocument(filename, lang, max_memory, reader);
|
||||
pages_offset_ = start_page;
|
||||
return ReCachePages();
|
||||
}
|
||||
|
||||
// Sets up the document, without actually loading it.
|
||||
void DocumentData::SetDocument(const char* filename, const char* lang,
|
||||
inT64 max_memory, FileReader reader) {
|
||||
SVAutoLock lock_p(&pages_mutex_);
|
||||
SVAutoLock lock(&general_mutex_);
|
||||
document_name_ = filename;
|
||||
lang_ = lang;
|
||||
pages_offset_ = start_page;
|
||||
pages_offset_ = -1;
|
||||
max_memory_ = max_memory;
|
||||
reader_ = reader;
|
||||
return ReCachePages();
|
||||
}
|
||||
|
||||
// Writes all the pages to the given filename. Returns false on error.
|
||||
bool DocumentData::SaveDocument(const char* filename, FileWriter writer) {
|
||||
SVAutoLock lock(&pages_mutex_);
|
||||
TFile fp;
|
||||
fp.OpenWrite(NULL);
|
||||
if (!pages_.Serialize(&fp) || !fp.CloseWrite(filename, writer)) {
|
||||
@ -370,112 +411,166 @@ bool DocumentData::SaveDocument(const char* filename, FileWriter writer) {
|
||||
return true;
|
||||
}
|
||||
bool DocumentData::SaveToBuffer(GenericVector<char>* buffer) {
|
||||
SVAutoLock lock(&pages_mutex_);
|
||||
TFile fp;
|
||||
fp.OpenWrite(buffer);
|
||||
return pages_.Serialize(&fp);
|
||||
}
|
||||
|
||||
// Returns a pointer to the page with the given index, modulo the total
|
||||
// number of pages, recaching if needed.
|
||||
const ImageData* DocumentData::GetPage(int index) {
|
||||
index = Modulo(index, total_pages_);
|
||||
if (index < pages_offset_ || index >= pages_offset_ + pages_.size()) {
|
||||
pages_offset_ = index;
|
||||
if (!ReCachePages()) return NULL;
|
||||
}
|
||||
return pages_[index - pages_offset_];
|
||||
// Adds the given page data to this document, counting up memory.
|
||||
void DocumentData::AddPageToDocument(ImageData* page) {
|
||||
SVAutoLock lock(&pages_mutex_);
|
||||
pages_.push_back(page);
|
||||
set_memory_used(memory_used() + page->MemoryUsed());
|
||||
}
|
||||
|
||||
// Loads as many pages can fit in max_memory_ starting at index pages_offset_.
|
||||
// If the given index is not currently loaded, loads it using a separate
|
||||
// thread.
|
||||
void DocumentData::LoadPageInBackground(int index) {
|
||||
ImageData* page = NULL;
|
||||
if (IsPageAvailable(index, &page)) return;
|
||||
SVAutoLock lock(&pages_mutex_);
|
||||
if (pages_offset_ == index) return;
|
||||
pages_offset_ = index;
|
||||
pages_.clear();
|
||||
SVSync::StartThread(ReCachePagesFunc, this);
|
||||
}
|
||||
|
||||
// Returns a pointer to the page with the given index, modulo the total
|
||||
// number of pages. Blocks until the background load is completed.
|
||||
const ImageData* DocumentData::GetPage(int index) {
|
||||
ImageData* page = NULL;
|
||||
while (!IsPageAvailable(index, &page)) {
|
||||
// If there is no background load scheduled, schedule one now.
|
||||
pages_mutex_.Lock();
|
||||
bool needs_loading = pages_offset_ != index;
|
||||
pages_mutex_.Unlock();
|
||||
if (needs_loading) LoadPageInBackground(index);
|
||||
// We can't directly load the page, or the background load will delete it
|
||||
// while the caller is using it, so give it a chance to work.
|
||||
sleep(1);
|
||||
}
|
||||
return page;
|
||||
}
|
||||
|
||||
// Returns true if the requested page is available, and provides a pointer,
|
||||
// which may be NULL if the document is empty. May block, even though it
|
||||
// doesn't guarantee to return true.
|
||||
bool DocumentData::IsPageAvailable(int index, ImageData** page) {
|
||||
SVAutoLock lock(&pages_mutex_);
|
||||
int num_pages = NumPages();
|
||||
if (num_pages == 0 || index < 0) {
|
||||
*page = NULL; // Empty Document.
|
||||
return true;
|
||||
}
|
||||
if (num_pages > 0) {
|
||||
index = Modulo(index, num_pages);
|
||||
if (pages_offset_ <= index && index < pages_offset_ + pages_.size()) {
|
||||
*page = pages_[index - pages_offset_]; // Page is available already.
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Removes all pages from memory and frees the memory, but does not forget
|
||||
// the document metadata.
|
||||
inT64 DocumentData::UnCache() {
|
||||
SVAutoLock lock(&pages_mutex_);
|
||||
inT64 memory_saved = memory_used();
|
||||
pages_.clear();
|
||||
pages_offset_ = -1;
|
||||
set_total_pages(-1);
|
||||
set_memory_used(0);
|
||||
tprintf("Unloaded document %s, saving %d memory\n", document_name_.string(),
|
||||
memory_saved);
|
||||
return memory_saved;
|
||||
}
|
||||
|
||||
// Locks the pages_mutex_ and Loads as many pages can fit in max_memory_
|
||||
// starting at index pages_offset_.
|
||||
bool DocumentData::ReCachePages() {
|
||||
SVAutoLock lock(&pages_mutex_);
|
||||
// Read the file.
|
||||
set_total_pages(0);
|
||||
set_memory_used(0);
|
||||
int loaded_pages = 0;
|
||||
pages_.truncate(0);
|
||||
TFile fp;
|
||||
if (!fp.Open(document_name_, reader_)) return false;
|
||||
memory_used_ = 0;
|
||||
if (!pages_.DeSerialize(false, &fp)) {
|
||||
tprintf("Deserialize failed: %s\n", document_name_.string());
|
||||
pages_.truncate(0);
|
||||
if (!fp.Open(document_name_, reader_) ||
|
||||
!PointerVector<ImageData>::DeSerializeSize(false, &fp, &loaded_pages) ||
|
||||
loaded_pages <= 0) {
|
||||
tprintf("Deserialize header failed: %s\n", document_name_.string());
|
||||
return false;
|
||||
}
|
||||
total_pages_ = pages_.size();
|
||||
pages_offset_ %= total_pages_;
|
||||
// Delete pages before the first one we want, and relocate the rest.
|
||||
pages_offset_ %= loaded_pages;
|
||||
// Skip pages before the first one we want, and load the rest until max
|
||||
// memory and skip the rest after that.
|
||||
int page;
|
||||
for (page = 0; page < pages_.size(); ++page) {
|
||||
if (page < pages_offset_) {
|
||||
delete pages_[page];
|
||||
pages_[page] = NULL;
|
||||
for (page = 0; page < loaded_pages; ++page) {
|
||||
if (page < pages_offset_ ||
|
||||
(max_memory_ > 0 && memory_used() > max_memory_)) {
|
||||
if (!PointerVector<ImageData>::DeSerializeSkip(false, &fp)) break;
|
||||
} else {
|
||||
ImageData* image_data = pages_[page];
|
||||
if (max_memory_ > 0 && page > pages_offset_ &&
|
||||
memory_used_ + image_data->MemoryUsed() > max_memory_)
|
||||
break; // Don't go over memory quota unless the first image.
|
||||
if (!pages_.DeSerializeElement(false, &fp)) break;
|
||||
ImageData* image_data = pages_.back();
|
||||
if (image_data->imagefilename().length() == 0) {
|
||||
image_data->set_imagefilename(document_name_);
|
||||
image_data->set_page_number(page);
|
||||
}
|
||||
image_data->set_language(lang_);
|
||||
memory_used_ += image_data->MemoryUsed();
|
||||
if (pages_offset_ != 0) {
|
||||
pages_[page - pages_offset_] = image_data;
|
||||
pages_[page] = NULL;
|
||||
}
|
||||
set_memory_used(memory_used() + image_data->MemoryUsed());
|
||||
}
|
||||
}
|
||||
pages_.truncate(page - pages_offset_);
|
||||
tprintf("Loaded %d/%d pages (%d-%d) of document %s\n",
|
||||
pages_.size(), total_pages_, pages_offset_,
|
||||
pages_offset_ + pages_.size(), document_name_.string());
|
||||
if (page < loaded_pages) {
|
||||
tprintf("Deserialize failed: %s read %d/%d pages\n",
|
||||
document_name_.string(), page, loaded_pages);
|
||||
pages_.truncate(0);
|
||||
} else {
|
||||
tprintf("Loaded %d/%d pages (%d-%d) of document %s\n", pages_.size(),
|
||||
loaded_pages, pages_offset_, pages_offset_ + pages_.size(),
|
||||
document_name_.string());
|
||||
}
|
||||
set_total_pages(loaded_pages);
|
||||
return !pages_.empty();
|
||||
}
|
||||
|
||||
// Adds the given page data to this document, counting up memory.
|
||||
void DocumentData::AddPageToDocument(ImageData* page) {
|
||||
pages_.push_back(page);
|
||||
memory_used_ += page->MemoryUsed();
|
||||
}
|
||||
|
||||
// A collection of DocumentData that knows roughly how much memory it is using.
|
||||
DocumentCache::DocumentCache(inT64 max_memory)
|
||||
: total_pages_(0), memory_used_(0), max_memory_(max_memory) {}
|
||||
: num_pages_per_doc_(0), max_memory_(max_memory) {}
|
||||
DocumentCache::~DocumentCache() {}
|
||||
|
||||
// Adds all the documents in the list of filenames, counting memory.
|
||||
// The reader is used to read the files.
|
||||
bool DocumentCache::LoadDocuments(const GenericVector<STRING>& filenames,
|
||||
const char* lang, FileReader reader) {
|
||||
inT64 fair_share_memory = max_memory_ / filenames.size();
|
||||
const char* lang,
|
||||
CachingStrategy cache_strategy,
|
||||
FileReader reader) {
|
||||
cache_strategy_ = cache_strategy;
|
||||
inT64 fair_share_memory = 0;
|
||||
// In the round-robin case, each DocumentData handles restricting its content
|
||||
// to its fair share of memory. In the sequential case, DocumentCache
|
||||
// determines which DocumentDatas are held entirely in memory.
|
||||
if (cache_strategy_ == CS_ROUND_ROBIN)
|
||||
fair_share_memory = max_memory_ / filenames.size();
|
||||
for (int arg = 0; arg < filenames.size(); ++arg) {
|
||||
STRING filename = filenames[arg];
|
||||
DocumentData* document = new DocumentData(filename);
|
||||
if (document->LoadDocument(filename.string(), lang, 0,
|
||||
fair_share_memory, reader)) {
|
||||
AddToCache(document);
|
||||
} else {
|
||||
tprintf("Failed to load image %s!\n", filename.string());
|
||||
delete document;
|
||||
}
|
||||
document->SetDocument(filename.string(), lang, fair_share_memory, reader);
|
||||
AddToCache(document);
|
||||
}
|
||||
tprintf("Loaded %d pages, total %gMB\n",
|
||||
total_pages_, memory_used_ / 1048576.0);
|
||||
return total_pages_ > 0;
|
||||
if (!documents_.empty()) {
|
||||
// Try to get the first page now to verify the list of filenames.
|
||||
if (GetPageBySerial(0) != NULL) return true;
|
||||
tprintf("Load of page 0 failed!\n");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Adds document to the cache, throwing out other documents if needed.
|
||||
// Adds document to the cache.
|
||||
bool DocumentCache::AddToCache(DocumentData* data) {
|
||||
inT64 new_memory = data->memory_used();
|
||||
memory_used_ += new_memory;
|
||||
documents_.push_back(data);
|
||||
total_pages_ += data->NumPages();
|
||||
// Delete the first item in the array, and other pages of the same name
|
||||
// while memory is full.
|
||||
while (memory_used_ >= max_memory_ && max_memory_ > 0) {
|
||||
tprintf("Memory used=%lld vs max=%lld, discarding doc of size %lld\n",
|
||||
memory_used_ , max_memory_, documents_[0]->memory_used());
|
||||
memory_used_ -= documents_[0]->memory_used();
|
||||
total_pages_ -= documents_[0]->NumPages();
|
||||
documents_.remove(0);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -488,11 +583,104 @@ DocumentData* DocumentCache::FindDocument(const STRING& document_name) const {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Returns the total number of pages in an epoch. For CS_ROUND_ROBIN cache
|
||||
// strategy, could take a long time.
|
||||
int DocumentCache::TotalPages() {
|
||||
if (cache_strategy_ == CS_SEQUENTIAL) {
|
||||
// In sequential mode, we assume each doc has the same number of pages
|
||||
// whether it is true or not.
|
||||
if (num_pages_per_doc_ == 0) GetPageSequential(0);
|
||||
return num_pages_per_doc_ * documents_.size();
|
||||
}
|
||||
int total_pages = 0;
|
||||
int num_docs = documents_.size();
|
||||
for (int d = 0; d < num_docs; ++d) {
|
||||
// We have to load a page to make NumPages() valid.
|
||||
documents_[d]->GetPage(0);
|
||||
total_pages += documents_[d]->NumPages();
|
||||
}
|
||||
return total_pages;
|
||||
}
|
||||
|
||||
// Returns a page by serial number, selecting them in a round-robin fashion
|
||||
// from all the documents.
|
||||
const ImageData* DocumentCache::GetPageBySerial(int serial) {
|
||||
int document_index = serial % documents_.size();
|
||||
return documents_[document_index]->GetPage(serial / documents_.size());
|
||||
// from all the documents. Highly disk-intensive, but doesn't need samples
|
||||
// to be shuffled between files to begin with.
|
||||
const ImageData* DocumentCache::GetPageRoundRobin(int serial) {
|
||||
int num_docs = documents_.size();
|
||||
int doc_index = serial % num_docs;
|
||||
const ImageData* doc = documents_[doc_index]->GetPage(serial / num_docs);
|
||||
for (int offset = 1; offset <= kMaxReadAhead && offset < num_docs; ++offset) {
|
||||
doc_index = (serial + offset) % num_docs;
|
||||
int page = (serial + offset) / num_docs;
|
||||
documents_[doc_index]->LoadPageInBackground(page);
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
|
||||
// Returns a page by serial number, selecting them in sequence from each file.
|
||||
// Requires the samples to be shuffled between the files to give a random or
|
||||
// uniform distribution of data. Less disk-intensive than GetPageRoundRobin.
|
||||
const ImageData* DocumentCache::GetPageSequential(int serial) {
|
||||
int num_docs = documents_.size();
|
||||
ASSERT_HOST(num_docs > 0);
|
||||
if (num_pages_per_doc_ == 0) {
|
||||
// Use the pages in the first doc as the number of pages in each doc.
|
||||
documents_[0]->GetPage(0);
|
||||
num_pages_per_doc_ = documents_[0]->NumPages();
|
||||
if (num_pages_per_doc_ == 0) {
|
||||
tprintf("First document cannot be empty!!\n");
|
||||
ASSERT_HOST(num_pages_per_doc_ > 0);
|
||||
}
|
||||
// Get rid of zero now if we don't need it.
|
||||
if (serial / num_pages_per_doc_ % num_docs > 0) documents_[0]->UnCache();
|
||||
}
|
||||
int doc_index = serial / num_pages_per_doc_ % num_docs;
|
||||
const ImageData* doc =
|
||||
documents_[doc_index]->GetPage(serial % num_pages_per_doc_);
|
||||
// Count up total memory. Background loading makes it more complicated to
|
||||
// keep a running count.
|
||||
inT64 total_memory = 0;
|
||||
for (int d = 0; d < num_docs; ++d) {
|
||||
total_memory += documents_[d]->memory_used();
|
||||
}
|
||||
if (total_memory >= max_memory_) {
|
||||
// Find something to un-cache.
|
||||
// If there are more than 3 in front, then serial is from the back reader
|
||||
// of a pair of readers. If we un-cache from in-front-2 to 2-ahead, then
|
||||
// we create a hole between them and then un-caching the backmost occupied
|
||||
// will work for both.
|
||||
int num_in_front = CountNeighbourDocs(doc_index, 1);
|
||||
for (int offset = num_in_front - 2;
|
||||
offset > 1 && total_memory >= max_memory_; --offset) {
|
||||
int next_index = (doc_index + offset) % num_docs;
|
||||
total_memory -= documents_[next_index]->UnCache();
|
||||
}
|
||||
// If that didn't work, the best solution is to un-cache from the back. If
|
||||
// we take away the document that a 2nd reader is using, it will put it
|
||||
// back and make a hole between.
|
||||
int num_behind = CountNeighbourDocs(doc_index, -1);
|
||||
for (int offset = num_behind; offset < 0 && total_memory >= max_memory_;
|
||||
++offset) {
|
||||
int next_index = (doc_index + offset + num_docs) % num_docs;
|
||||
total_memory -= documents_[next_index]->UnCache();
|
||||
}
|
||||
}
|
||||
int next_index = (doc_index + 1) % num_docs;
|
||||
if (!documents_[next_index]->IsCached() && total_memory < max_memory_) {
|
||||
documents_[next_index]->LoadPageInBackground(0);
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
|
||||
// Helper counts the number of adjacent cached neighbours of index looking in
|
||||
// direction dir, ie index+dir, index+2*dir etc.
|
||||
int DocumentCache::CountNeighbourDocs(int index, int dir) {
|
||||
int num_docs = documents_.size();
|
||||
for (int offset = dir; abs(offset) < num_docs; offset += dir) {
|
||||
int offset_index = (index + offset + num_docs) % num_docs;
|
||||
if (!documents_[offset_index]->IsCached()) return offset - dir;
|
||||
}
|
||||
return num_docs;
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
||||
|
@ -25,6 +25,7 @@
|
||||
#include "normalis.h"
|
||||
#include "rect.h"
|
||||
#include "strngs.h"
|
||||
#include "svutil.h"
|
||||
|
||||
struct Pix;
|
||||
|
||||
@ -34,8 +35,22 @@ namespace tesseract {
|
||||
const int kFeaturePadding = 2;
|
||||
// Number of pixels to pad around text boxes.
|
||||
const int kImagePadding = 4;
|
||||
// Number of training images to combine into a mini-batch for training.
|
||||
const int kNumPagesPerMiniBatch = 100;
|
||||
|
||||
// Enum to determine the caching and data sequencing strategy.
|
||||
enum CachingStrategy {
|
||||
// Reads all of one file before moving on to the next. Requires samples to be
|
||||
// shuffled across files. Uses the count of samples in the first file as
|
||||
// the count in all the files to achieve high-speed random access. As a
|
||||
// consequence, if subsequent files are smaller, they get entries used more
|
||||
// than once, and if subsequent files are larger, some entries are not used.
|
||||
// Best for larger data sets that don't fit in memory.
|
||||
CS_SEQUENTIAL,
|
||||
// Reads one sample from each file in rotation. Does not require shuffled
|
||||
// samples, but is extremely disk-intensive. Samples in smaller files also
|
||||
// get used more often than samples in larger files.
|
||||
// Best for smaller data sets that mostly fit in memory.
|
||||
CS_ROUND_ROBIN,
|
||||
};
|
||||
|
||||
class WordFeature {
|
||||
public:
|
||||
@ -103,6 +118,8 @@ class ImageData {
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool DeSerialize(bool swap, TFile* fp);
|
||||
// As DeSerialize, but only seeks past the data - hence a static method.
|
||||
static bool SkipDeSerialize(bool swap, tesseract::TFile* fp);
|
||||
|
||||
// Other accessors.
|
||||
const STRING& imagefilename() const {
|
||||
@ -145,11 +162,11 @@ class ImageData {
|
||||
// Gets anything and everything with a non-NULL pointer, prescaled to a
|
||||
// given target_height (if 0, then the original image height), and aligned.
|
||||
// Also returns (if not NULL) the width and height of the scaled image.
|
||||
// The return value is the scale factor that was applied to the image to
|
||||
// achieve the target_height.
|
||||
float PreScale(int target_height, Pix** pix,
|
||||
int* scaled_width, int* scaled_height,
|
||||
GenericVector<TBOX>* boxes) const;
|
||||
// The return value is the scaled Pix, which must be pixDestroyed after use,
|
||||
// and scale_factor (if not NULL) is set to the scale factor that was applied
|
||||
// to the image to achieve the target_height.
|
||||
Pix* PreScale(int target_height, float* scale_factor, int* scaled_width,
|
||||
int* scaled_height, GenericVector<TBOX>* boxes) const;
|
||||
|
||||
int MemoryUsed() const;
|
||||
|
||||
@ -184,6 +201,8 @@ class ImageData {
|
||||
|
||||
// A collection of ImageData that knows roughly how much memory it is using.
|
||||
class DocumentData {
|
||||
friend void* ReCachePagesFunc(void* data);
|
||||
|
||||
public:
|
||||
explicit DocumentData(const STRING& name);
|
||||
~DocumentData();
|
||||
@ -192,6 +211,9 @@ class DocumentData {
|
||||
// is used to read the file.
|
||||
bool LoadDocument(const char* filename, const char* lang, int start_page,
|
||||
inT64 max_memory, FileReader reader);
|
||||
// Sets up the document, without actually loading it.
|
||||
void SetDocument(const char* filename, const char* lang, inT64 max_memory,
|
||||
FileReader reader);
|
||||
// Writes all the pages to the given filename. Returns false on error.
|
||||
bool SaveDocument(const char* filename, FileWriter writer);
|
||||
bool SaveToBuffer(GenericVector<char>* buffer);
|
||||
@ -200,26 +222,62 @@ class DocumentData {
|
||||
void AddPageToDocument(ImageData* page);
|
||||
|
||||
const STRING& document_name() const {
|
||||
SVAutoLock lock(&general_mutex_);
|
||||
return document_name_;
|
||||
}
|
||||
int NumPages() const {
|
||||
SVAutoLock lock(&general_mutex_);
|
||||
return total_pages_;
|
||||
}
|
||||
inT64 memory_used() const {
|
||||
SVAutoLock lock(&general_mutex_);
|
||||
return memory_used_;
|
||||
}
|
||||
// If the given index is not currently loaded, loads it using a separate
|
||||
// thread. Note: there are 4 cases:
|
||||
// Document uncached: IsCached() returns false, total_pages_ < 0.
|
||||
// Required page is available: IsPageAvailable returns true. In this case,
|
||||
// total_pages_ > 0 and
|
||||
// pages_offset_ <= index%total_pages_ <= pages_offset_+pages_.size()
|
||||
// Pages are loaded, but the required one is not.
|
||||
// The requested page is being loaded by LoadPageInBackground. In this case,
|
||||
// index == pages_offset_. Once the loading starts, the pages lock is held
|
||||
// until it completes, at which point IsPageAvailable will unblock and return
|
||||
// true.
|
||||
void LoadPageInBackground(int index);
|
||||
// Returns a pointer to the page with the given index, modulo the total
|
||||
// number of pages, recaching if needed.
|
||||
// number of pages. Blocks until the background load is completed.
|
||||
const ImageData* GetPage(int index);
|
||||
// Returns true if the requested page is available, and provides a pointer,
|
||||
// which may be NULL if the document is empty. May block, even though it
|
||||
// doesn't guarantee to return true.
|
||||
bool IsPageAvailable(int index, ImageData** page);
|
||||
// Takes ownership of the given page index. The page is made NULL in *this.
|
||||
ImageData* TakePage(int index) {
|
||||
SVAutoLock lock(&pages_mutex_);
|
||||
ImageData* page = pages_[index];
|
||||
pages_[index] = NULL;
|
||||
return page;
|
||||
}
|
||||
// Returns true if the document is currently loaded or in the process of
|
||||
// loading.
|
||||
bool IsCached() const { return NumPages() >= 0; }
|
||||
// Removes all pages from memory and frees the memory, but does not forget
|
||||
// the document metadata. Returns the memory saved.
|
||||
inT64 UnCache();
|
||||
|
||||
private:
|
||||
// Loads as many pages can fit in max_memory_ starting at index pages_offset_.
|
||||
// Sets the value of total_pages_ behind a mutex.
|
||||
void set_total_pages(int total) {
|
||||
SVAutoLock lock(&general_mutex_);
|
||||
total_pages_ = total;
|
||||
}
|
||||
void set_memory_used(inT64 memory_used) {
|
||||
SVAutoLock lock(&general_mutex_);
|
||||
memory_used_ = memory_used;
|
||||
}
|
||||
// Locks the pages_mutex_ and Loads as many pages can fit in max_memory_
|
||||
// starting at index pages_offset_.
|
||||
bool ReCachePages();
|
||||
|
||||
private:
|
||||
@ -239,43 +297,77 @@ class DocumentData {
|
||||
inT64 max_memory_;
|
||||
// Saved reader from LoadDocument to allow re-caching.
|
||||
FileReader reader_;
|
||||
// Mutex that protects pages_ and pages_offset_ against multiple parallel
|
||||
// loads, and provides a wait for page.
|
||||
SVMutex pages_mutex_;
|
||||
// Mutex that protects other data members that callers want to access without
|
||||
// waiting for a load operation.
|
||||
mutable SVMutex general_mutex_;
|
||||
};
|
||||
|
||||
// A collection of DocumentData that knows roughly how much memory it is using.
|
||||
// Note that while it supports background read-ahead, it assumes that a single
|
||||
// thread is accessing documents, ie it is not safe for multiple threads to
|
||||
// access different documents in parallel, as one may de-cache the other's
|
||||
// content.
|
||||
class DocumentCache {
|
||||
public:
|
||||
explicit DocumentCache(inT64 max_memory);
|
||||
~DocumentCache();
|
||||
|
||||
// Deletes all existing documents from the cache.
|
||||
void Clear() {
|
||||
documents_.clear();
|
||||
num_pages_per_doc_ = 0;
|
||||
}
|
||||
// Adds all the documents in the list of filenames, counting memory.
|
||||
// The reader is used to read the files.
|
||||
bool LoadDocuments(const GenericVector<STRING>& filenames, const char* lang,
|
||||
FileReader reader);
|
||||
CachingStrategy cache_strategy, FileReader reader);
|
||||
|
||||
// Adds document to the cache, throwing out other documents if needed.
|
||||
// Adds document to the cache.
|
||||
bool AddToCache(DocumentData* data);
|
||||
|
||||
// Finds and returns a document by name.
|
||||
DocumentData* FindDocument(const STRING& document_name) const;
|
||||
|
||||
// Returns a page by serial number, selecting them in a round-robin fashion
|
||||
// from all the documents.
|
||||
const ImageData* GetPageBySerial(int serial);
|
||||
// Returns a page by serial number using the current cache_strategy_ to
|
||||
// determine the mapping from serial number to page.
|
||||
const ImageData* GetPageBySerial(int serial) {
|
||||
if (cache_strategy_ == CS_SEQUENTIAL)
|
||||
return GetPageSequential(serial);
|
||||
else
|
||||
return GetPageRoundRobin(serial);
|
||||
}
|
||||
|
||||
const PointerVector<DocumentData>& documents() const {
|
||||
return documents_;
|
||||
}
|
||||
int total_pages() const {
|
||||
return total_pages_;
|
||||
}
|
||||
// Returns the total number of pages in an epoch. For CS_ROUND_ROBIN cache
|
||||
// strategy, could take a long time.
|
||||
int TotalPages();
|
||||
|
||||
private:
|
||||
// Returns a page by serial number, selecting them in a round-robin fashion
|
||||
// from all the documents. Highly disk-intensive, but doesn't need samples
|
||||
// to be shuffled between files to begin with.
|
||||
const ImageData* GetPageRoundRobin(int serial);
|
||||
// Returns a page by serial number, selecting them in sequence from each file.
|
||||
// Requires the samples to be shuffled between the files to give a random or
|
||||
// uniform distribution of data. Less disk-intensive than GetPageRoundRobin.
|
||||
const ImageData* GetPageSequential(int serial);
|
||||
|
||||
// Helper counts the number of adjacent cached neighbour documents_ of index
|
||||
// looking in direction dir, ie index+dir, index+2*dir etc.
|
||||
int CountNeighbourDocs(int index, int dir);
|
||||
|
||||
// A group of pages that corresponds in some loose way to a document.
|
||||
PointerVector<DocumentData> documents_;
|
||||
// Total of all pages.
|
||||
int total_pages_;
|
||||
// Total of all memory used by the cache.
|
||||
inT64 memory_used_;
|
||||
// Strategy to use for caching and serializing data samples.
|
||||
CachingStrategy cache_strategy_;
|
||||
// Number of pages in the first document, used as a divisor in
|
||||
// GetPageSequential to determine the document index.
|
||||
int num_pages_per_doc_;
|
||||
// Max memory allowed in this cache.
|
||||
inT64 max_memory_;
|
||||
};
|
||||
|
@ -1,8 +1,12 @@
|
||||
/* -*-C-*-
|
||||
******************************************************************************
|
||||
* File: matrix.h (Formerly matrix.h)
|
||||
* Description: Generic 2-d array/matrix and banded triangular matrix class.
|
||||
* Author: Ray Smith
|
||||
* TODO(rays) Separate from ratings matrix, which it also contains:
|
||||
*
|
||||
* File: matrix.h (Formerly matrix.h)
|
||||
* Description: Ratings matrix code. (Used by associator)
|
||||
* Descrition: Ratings matrix class (specialization of banded matrix).
|
||||
* Segmentation search matrix of lists of BLOB_CHOICE.
|
||||
* Author: Mark Seaman, OCR Technology
|
||||
* Created: Wed May 16 13:22:06 1990
|
||||
* Modified: Tue Mar 19 16:00:20 1991 (Mark Seaman) marks@hpgrlt
|
||||
@ -25,9 +29,13 @@
|
||||
#ifndef TESSERACT_CCSTRUCT_MATRIX_H__
|
||||
#define TESSERACT_CCSTRUCT_MATRIX_H__
|
||||
|
||||
#include <math.h>
|
||||
#include "kdpair.h"
|
||||
#include "points.h"
|
||||
#include "serialis.h"
|
||||
#include "unicharset.h"
|
||||
|
||||
class BLOB_CHOICE;
|
||||
class BLOB_CHOICE_LIST;
|
||||
|
||||
#define NOT_CLASSIFIED reinterpret_cast<BLOB_CHOICE_LIST*>(0)
|
||||
@ -44,34 +52,60 @@ class GENERIC_2D_ARRAY {
|
||||
// either pass the memory in, or allocate after by calling Resize().
|
||||
GENERIC_2D_ARRAY(int dim1, int dim2, const T& empty, T* array)
|
||||
: empty_(empty), dim1_(dim1), dim2_(dim2), array_(array) {
|
||||
size_allocated_ = dim1 * dim2;
|
||||
}
|
||||
// Original constructor for a full rectangular matrix DOES allocate memory
|
||||
// and initialize it to empty.
|
||||
GENERIC_2D_ARRAY(int dim1, int dim2, const T& empty)
|
||||
: empty_(empty), dim1_(dim1), dim2_(dim2) {
|
||||
array_ = new T[dim1_ * dim2_];
|
||||
for (int x = 0; x < dim1_; x++)
|
||||
for (int y = 0; y < dim2_; y++)
|
||||
this->put(x, y, empty_);
|
||||
int new_size = dim1 * dim2;
|
||||
array_ = new T[new_size];
|
||||
size_allocated_ = new_size;
|
||||
for (int i = 0; i < size_allocated_; ++i)
|
||||
array_[i] = empty_;
|
||||
}
|
||||
// Default constructor for array allocation. Use Resize to set the size.
|
||||
GENERIC_2D_ARRAY()
|
||||
: array_(NULL), empty_(static_cast<T>(0)), dim1_(0), dim2_(0),
|
||||
size_allocated_(0) {
|
||||
}
|
||||
GENERIC_2D_ARRAY(const GENERIC_2D_ARRAY<T>& src)
|
||||
: array_(NULL), empty_(static_cast<T>(0)), dim1_(0), dim2_(0),
|
||||
size_allocated_(0) {
|
||||
*this = src;
|
||||
}
|
||||
virtual ~GENERIC_2D_ARRAY() { delete[] array_; }
|
||||
|
||||
void operator=(const GENERIC_2D_ARRAY<T>& src) {
|
||||
ResizeNoInit(src.dim1(), src.dim2());
|
||||
memcpy(array_, src.array_, num_elements() * sizeof(array_[0]));
|
||||
}
|
||||
|
||||
// Reallocate the array to the given size. Does not keep old data, but does
|
||||
// not initialize the array either.
|
||||
void ResizeNoInit(int size1, int size2) {
|
||||
int new_size = size1 * size2;
|
||||
if (new_size > size_allocated_) {
|
||||
delete [] array_;
|
||||
array_ = new T[new_size];
|
||||
size_allocated_ = new_size;
|
||||
}
|
||||
dim1_ = size1;
|
||||
dim2_ = size2;
|
||||
}
|
||||
|
||||
// Reallocate the array to the given size. Does not keep old data.
|
||||
void Resize(int size1, int size2, const T& empty) {
|
||||
empty_ = empty;
|
||||
if (size1 != dim1_ || size2 != dim2_) {
|
||||
dim1_ = size1;
|
||||
dim2_ = size2;
|
||||
delete [] array_;
|
||||
array_ = new T[dim1_ * dim2_];
|
||||
}
|
||||
ResizeNoInit(size1, size2);
|
||||
Clear();
|
||||
}
|
||||
|
||||
// Reallocate the array to the given size, keeping old data.
|
||||
void ResizeWithCopy(int size1, int size2) {
|
||||
if (size1 != dim1_ || size2 != dim2_) {
|
||||
T* new_array = new T[size1 * size2];
|
||||
int new_size = size1 * size2;
|
||||
T* new_array = new T[new_size];
|
||||
for (int col = 0; col < size1; ++col) {
|
||||
for (int row = 0; row < size2; ++row) {
|
||||
int old_index = col * dim2() + row;
|
||||
@ -87,6 +121,7 @@ class GENERIC_2D_ARRAY {
|
||||
array_ = new_array;
|
||||
dim1_ = size1;
|
||||
dim2_ = size2;
|
||||
size_allocated_ = new_size;
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,9 +141,16 @@ class GENERIC_2D_ARRAY {
|
||||
if (fwrite(array_, sizeof(*array_), size, fp) != size) return false;
|
||||
return true;
|
||||
}
|
||||
bool Serialize(tesseract::TFile* fp) const {
|
||||
if (!SerializeSize(fp)) return false;
|
||||
if (fp->FWrite(&empty_, sizeof(empty_), 1) != 1) return false;
|
||||
int size = num_elements();
|
||||
if (fp->FWrite(array_, sizeof(*array_), size) != size) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// Only works with bitwise-serializeable typ
|
||||
// Only works with bitwise-serializeable types!
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool DeSerialize(bool swap, FILE* fp) {
|
||||
if (!DeSerializeSize(swap, fp)) return false;
|
||||
@ -122,6 +164,18 @@ class GENERIC_2D_ARRAY {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool DeSerialize(bool swap, tesseract::TFile* fp) {
|
||||
if (!DeSerializeSize(swap, fp)) return false;
|
||||
if (fp->FRead(&empty_, sizeof(empty_), 1) != 1) return false;
|
||||
if (swap) ReverseN(&empty_, sizeof(empty_));
|
||||
int size = num_elements();
|
||||
if (fp->FRead(array_, sizeof(*array_), size) != size) return false;
|
||||
if (swap) {
|
||||
for (int i = 0; i < size; ++i)
|
||||
ReverseN(&array_[i], sizeof(array_[i]));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
// Assumes a T::Serialize(FILE*) const function.
|
||||
@ -163,11 +217,17 @@ class GENERIC_2D_ARRAY {
|
||||
}
|
||||
|
||||
// Put a list element into the matrix at a specific location.
|
||||
void put(ICOORD pos, const T& thing) {
|
||||
array_[this->index(pos.x(), pos.y())] = thing;
|
||||
}
|
||||
void put(int column, int row, const T& thing) {
|
||||
array_[this->index(column, row)] = thing;
|
||||
}
|
||||
|
||||
// Get the item at a specified location from the matrix.
|
||||
T get(ICOORD pos) const {
|
||||
return array_[this->index(pos.x(), pos.y())];
|
||||
}
|
||||
T get(int column, int row) const {
|
||||
return array_[this->index(column, row)];
|
||||
}
|
||||
@ -187,6 +247,207 @@ class GENERIC_2D_ARRAY {
|
||||
return &array_[this->index(column, 0)];
|
||||
}
|
||||
|
||||
// Adds addend to *this, element-by-element.
|
||||
void operator+=(const GENERIC_2D_ARRAY<T>& addend) {
|
||||
if (dim2_ == addend.dim2_) {
|
||||
// Faster if equal size in the major dimension.
|
||||
int size = MIN(num_elements(), addend.num_elements());
|
||||
for (int i = 0; i < size; ++i) {
|
||||
array_[i] += addend.array_[i];
|
||||
}
|
||||
} else {
|
||||
for (int x = 0; x < dim1_; x++) {
|
||||
for (int y = 0; y < dim2_; y++) {
|
||||
(*this)(x, y) += addend(x, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Subtracts minuend from *this, element-by-element.
|
||||
void operator-=(const GENERIC_2D_ARRAY<T>& minuend) {
|
||||
if (dim2_ == minuend.dim2_) {
|
||||
// Faster if equal size in the major dimension.
|
||||
int size = MIN(num_elements(), minuend.num_elements());
|
||||
for (int i = 0; i < size; ++i) {
|
||||
array_[i] -= minuend.array_[i];
|
||||
}
|
||||
} else {
|
||||
for (int x = 0; x < dim1_; x++) {
|
||||
for (int y = 0; y < dim2_; y++) {
|
||||
(*this)(x, y) -= minuend(x, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Adds addend to all elements.
|
||||
void operator+=(const T& addend) {
|
||||
int size = num_elements();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
array_[i] += addend;
|
||||
}
|
||||
}
|
||||
// Multiplies *this by factor, element-by-element.
|
||||
void operator*=(const T& factor) {
|
||||
int size = num_elements();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
array_[i] *= factor;
|
||||
}
|
||||
}
|
||||
// Clips *this to the given range.
|
||||
void Clip(const T& rangemin, const T& rangemax) {
|
||||
int size = num_elements();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
array_[i] = ClipToRange(array_[i], rangemin, rangemax);
|
||||
}
|
||||
}
|
||||
// Returns true if all elements of *this are within the given range.
|
||||
// Only uses operator<
|
||||
bool WithinBounds(const T& rangemin, const T& rangemax) const {
|
||||
int size = num_elements();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
const T& value = array_[i];
|
||||
if (value < rangemin || rangemax < value)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
// Normalize the whole array.
|
||||
double Normalize() {
|
||||
int size = num_elements();
|
||||
if (size <= 0) return 0.0;
|
||||
// Compute the mean.
|
||||
double mean = 0.0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
mean += array_[i];
|
||||
}
|
||||
mean /= size;
|
||||
// Subtract the mean and compute the standard deviation.
|
||||
double sd = 0.0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
double normed = array_[i] - mean;
|
||||
array_[i] = normed;
|
||||
sd += normed * normed;
|
||||
}
|
||||
sd = sqrt(sd / size);
|
||||
if (sd > 0.0) {
|
||||
// Divide by the sd.
|
||||
for (int i = 0; i < size; ++i) {
|
||||
array_[i] /= sd;
|
||||
}
|
||||
}
|
||||
return sd;
|
||||
}
|
||||
|
||||
// Returns the maximum value of the array.
|
||||
T Max() const {
|
||||
int size = num_elements();
|
||||
if (size <= 0) return empty_;
|
||||
// Compute the max.
|
||||
T max_value = array_[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
const T& value = array_[i];
|
||||
if (value > max_value) max_value = value;
|
||||
}
|
||||
return max_value;
|
||||
}
|
||||
|
||||
// Returns the maximum absolute value of the array.
|
||||
T MaxAbs() const {
|
||||
int size = num_elements();
|
||||
if (size <= 0) return empty_;
|
||||
// Compute the max.
|
||||
T max_abs = static_cast<T>(0);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
T value = static_cast<T>(fabs(array_[i]));
|
||||
if (value > max_abs) max_abs = value;
|
||||
}
|
||||
return max_abs;
|
||||
}
|
||||
|
||||
// Accumulates the element-wise sums of squares of src into *this.
|
||||
void SumSquares(const GENERIC_2D_ARRAY<T>& src) {
|
||||
int size = num_elements();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
array_[i] += src.array_[i] * src.array_[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Scales each element using the ada-grad algorithm, ie array_[i] by
|
||||
// sqrt(num_samples/max(1,sqsum[i])).
|
||||
void AdaGradScaling(const GENERIC_2D_ARRAY<T>& sqsum, int num_samples) {
|
||||
int size = num_elements();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
array_[i] *= sqrt(num_samples / MAX(1.0, sqsum.array_[i]));
|
||||
}
|
||||
}
|
||||
|
||||
void AssertFinite() const {
|
||||
int size = num_elements();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ASSERT_HOST(isfinite(array_[i]));
|
||||
}
|
||||
}
|
||||
|
||||
// REGARDLESS OF THE CURRENT DIMENSIONS, treats the data as a
|
||||
// num_dims-dimensional array/tensor with dimensions given by dims, (ordered
|
||||
// from most significant to least significant, the same as standard C arrays)
|
||||
// and moves src_dim to dest_dim, with the initial dest_dim and any dimensions
|
||||
// in between shifted towards the hole left by src_dim. Example:
|
||||
// Current data content: array_=[0, 1, 2, ....119]
|
||||
// perhaps *this may be of dim[40, 3], with values [[0, 1, 2][3, 4, 5]...
|
||||
// but the current dimensions are irrelevant.
|
||||
// num_dims = 4, dims=[5, 4, 3, 2]
|
||||
// src_dim=3, dest_dim=1
|
||||
// tensor=[[[[0, 1][2, 3][4, 5]]
|
||||
// [[6, 7][8, 9][10, 11]]
|
||||
// [[12, 13][14, 15][16, 17]]
|
||||
// [[18, 19][20, 21][22, 23]]]
|
||||
// [[[24, 25]...
|
||||
// output dims =[5, 2, 4, 3]
|
||||
// output tensor=[[[[0, 2, 4][6, 8, 10][12, 14, 16][18, 20, 22]]
|
||||
// [[1, 3, 5][7, 9, 11][13, 15, 17][19, 21, 23]]]
|
||||
// [[[24, 26, 28]...
|
||||
// which is stored in the array_ as:
|
||||
// [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 1, 3, 5, 7, 9, 11, 13...]
|
||||
// NOTE: the 2 stored matrix dimensions are simply copied from *this. To
|
||||
// change the dimensions after the transpose, use ResizeNoInit.
|
||||
// Higher dimensions above 2 are strictly the responsibility of the caller.
|
||||
void RotatingTranspose(const int* dims, int num_dims, int src_dim,
|
||||
int dest_dim, GENERIC_2D_ARRAY<T>* result) const {
|
||||
int max_d = MAX(src_dim, dest_dim);
|
||||
int min_d = MIN(src_dim, dest_dim);
|
||||
// In a tensor of shape [d0, d1... min_d, ... max_d, ... dn-2, dn-1], the
|
||||
// ends outside of min_d and max_d are unaffected, with [max_d +1, dn-1]
|
||||
// being contiguous blocks of data that will move together, and
|
||||
// [d0, min_d -1] being replicas of the transpose operation.
|
||||
// num_replicas represents the large dimensions unchanged by the operation.
|
||||
// move_size represents the small dimensions unchanged by the operation.
|
||||
// src_step represents the stride in the src between each adjacent group
|
||||
// in the destination.
|
||||
int num_replicas = 1, move_size = 1, src_step = 1;
|
||||
for (int d = 0; d < min_d; ++d) num_replicas *= dims[d];
|
||||
for (int d = max_d + 1; d < num_dims; ++d) move_size *= dims[d];
|
||||
for (int d = src_dim + 1; d < num_dims; ++d) src_step *= dims[d];
|
||||
if (src_dim > dest_dim) src_step *= dims[src_dim];
|
||||
// wrap_size is the size of a single replica, being the amount that is
|
||||
// handled num_replicas times.
|
||||
int wrap_size = move_size;
|
||||
for (int d = min_d; d <= max_d; ++d) wrap_size *= dims[d];
|
||||
result->ResizeNoInit(dim1_, dim2_);
|
||||
result->empty_ = empty_;
|
||||
const T* src = array_;
|
||||
T* dest = result->array_;
|
||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||
for (int start = 0; start < src_step; start += move_size) {
|
||||
for (int pos = start; pos < wrap_size; pos += src_step) {
|
||||
memcpy(dest, src + pos, sizeof(*dest) * move_size);
|
||||
dest += move_size;
|
||||
}
|
||||
}
|
||||
src += wrap_size;
|
||||
}
|
||||
}
|
||||
|
||||
// Delete objects pointed to by array_[i].
|
||||
void delete_matrix_pointers() {
|
||||
int size = num_elements();
|
||||
@ -206,6 +467,13 @@ class GENERIC_2D_ARRAY {
|
||||
if (fwrite(&size, sizeof(size), 1, fp) != 1) return false;
|
||||
return true;
|
||||
}
|
||||
bool SerializeSize(tesseract::TFile* fp) const {
|
||||
inT32 size = dim1_;
|
||||
if (fp->FWrite(&size, sizeof(size), 1) != 1) return false;
|
||||
size = dim2_;
|
||||
if (fp->FWrite(&size, sizeof(size), 1) != 1) return false;
|
||||
return true;
|
||||
}
|
||||
// Factored helper to deserialize the size.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool DeSerializeSize(bool swap, FILE* fp) {
|
||||
@ -219,11 +487,26 @@ class GENERIC_2D_ARRAY {
|
||||
Resize(size1, size2, empty_);
|
||||
return true;
|
||||
}
|
||||
bool DeSerializeSize(bool swap, tesseract::TFile* fp) {
|
||||
inT32 size1, size2;
|
||||
if (fp->FRead(&size1, sizeof(size1), 1) != 1) return false;
|
||||
if (fp->FRead(&size2, sizeof(size2), 1) != 1) return false;
|
||||
if (swap) {
|
||||
ReverseN(&size1, sizeof(size1));
|
||||
ReverseN(&size2, sizeof(size2));
|
||||
}
|
||||
Resize(size1, size2, empty_);
|
||||
return true;
|
||||
}
|
||||
|
||||
T* array_;
|
||||
T empty_; // The unused cell.
|
||||
int dim1_; // Size of the 1st dimension in indexing functions.
|
||||
int dim2_; // Size of the 2nd dimension in indexing functions.
|
||||
// The total size to which the array can be expanded before a realloc is
|
||||
// needed. If Resize is used, memory is retained so it can be re-expanded
|
||||
// without a further alloc, and this stores the allocated size.
|
||||
int size_allocated_;
|
||||
};
|
||||
|
||||
// A generic class to store a banded triangular matrix with entries of type T.
|
||||
|
@ -304,6 +304,7 @@ bool WERD_RES::SetupForRecognition(const UNICHARSET& unicharset_in,
|
||||
tesseract = tess;
|
||||
POLY_BLOCK* pb = block != NULL ? block->poly_block() : NULL;
|
||||
if ((norm_mode_hint != tesseract::OEM_CUBE_ONLY &&
|
||||
norm_mode_hint != tesseract::OEM_LSTM_ONLY &&
|
||||
word->cblob_list()->empty()) || (pb != NULL && !pb->IsText())) {
|
||||
// Empty words occur when all the blobs have been moved to the rej_blobs
|
||||
// list, which seems to occur frequently in junk.
|
||||
@ -882,17 +883,17 @@ void WERD_RES::FakeClassifyWord(int blob_count, BLOB_CHOICE** choices) {
|
||||
choice_it.add_after_then_move(choices[c]);
|
||||
ratings->put(c, c, choice_list);
|
||||
}
|
||||
FakeWordFromRatings();
|
||||
FakeWordFromRatings(TOP_CHOICE_PERM);
|
||||
reject_map.initialise(blob_count);
|
||||
done = true;
|
||||
}
|
||||
|
||||
// Creates a WERD_CHOICE for the word using the top choices from the leading
|
||||
// diagonal of the ratings matrix.
|
||||
void WERD_RES::FakeWordFromRatings() {
|
||||
void WERD_RES::FakeWordFromRatings(PermuterType permuter) {
|
||||
int num_blobs = ratings->dimension();
|
||||
WERD_CHOICE* word_choice = new WERD_CHOICE(uch_set, num_blobs);
|
||||
word_choice->set_permuter(TOP_CHOICE_PERM);
|
||||
word_choice->set_permuter(permuter);
|
||||
for (int b = 0; b < num_blobs; ++b) {
|
||||
UNICHAR_ID unichar_id = UNICHAR_SPACE;
|
||||
float rating = MAX_INT32;
|
||||
@ -1105,6 +1106,7 @@ void WERD_RES::InitNonPointers() {
|
||||
x_height = 0.0;
|
||||
caps_height = 0.0;
|
||||
baseline_shift = 0.0f;
|
||||
space_certainty = 0.0f;
|
||||
guessed_x_ht = TRUE;
|
||||
guessed_caps_ht = TRUE;
|
||||
combination = FALSE;
|
||||
|
@ -295,6 +295,9 @@ class WERD_RES : public ELIST_LINK {
|
||||
float x_height; // post match estimate
|
||||
float caps_height; // post match estimate
|
||||
float baseline_shift; // post match estimate.
|
||||
// Certainty score for the spaces either side of this word (LSTM mode).
|
||||
// MIN this value with the actual word certainty.
|
||||
float space_certainty;
|
||||
|
||||
/*
|
||||
To deal with fuzzy spaces we need to be able to combine "words" to form
|
||||
@ -590,7 +593,7 @@ class WERD_RES : public ELIST_LINK {
|
||||
|
||||
// Creates a WERD_CHOICE for the word using the top choices from the leading
|
||||
// diagonal of the ratings matrix.
|
||||
void FakeWordFromRatings();
|
||||
void FakeWordFromRatings(PermuterType permuter);
|
||||
|
||||
// Copies the best_choice strings to the correct_text for adaption/training.
|
||||
void BestChoiceToCorrectText();
|
||||
|
@ -257,13 +257,21 @@ enum OcrEngineMode {
|
||||
OEM_TESSERACT_ONLY, // Run Tesseract only - fastest
|
||||
OEM_CUBE_ONLY, // Run Cube only - better accuracy, but slower
|
||||
OEM_TESSERACT_CUBE_COMBINED, // Run both and combine results - best accuracy
|
||||
OEM_DEFAULT // Specify this mode when calling init_*(),
|
||||
OEM_DEFAULT, // Specify this mode when calling init_*(),
|
||||
// to indicate that any of the above modes
|
||||
// should be automatically inferred from the
|
||||
// variables in the language-specific config,
|
||||
// command-line configs, or if not specified
|
||||
// in any of the above should be set to the
|
||||
// default OEM_TESSERACT_ONLY.
|
||||
// OEM_LSTM_ONLY will fall back (with a warning) to OEM_TESSERACT_ONLY where
|
||||
// there is no network model available. This allows use of a mix of languages,
|
||||
// some of which contain a network model, and some of which do not. Since the
|
||||
// tesseract model is required for the LSTM to fall back to for "difficult"
|
||||
// words anyway, this seems like a reasonable approach, but leaves the danger
|
||||
// of not noticing that it is using the wrong engine if the warning is
|
||||
// ignored.
|
||||
OEM_LSTM_ONLY, // Run just the LSTM line recognizer.
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
@ -14,7 +14,7 @@ endif
|
||||
include_HEADERS = \
|
||||
basedir.h errcode.h fileerr.h genericvector.h helpers.h host.h memry.h \
|
||||
ndminx.h params.h ocrclass.h platform.h serialis.h strngs.h \
|
||||
tesscallback.h unichar.h unicharmap.h unicharset.h
|
||||
tesscallback.h unichar.h unicharcompress.h unicharmap.h unicharset.h
|
||||
|
||||
noinst_HEADERS = \
|
||||
ambigs.h bits16.h bitvector.h ccutil.h clst.h doubleptr.h elst2.h \
|
||||
@ -38,7 +38,7 @@ libtesseract_ccutil_la_SOURCES = \
|
||||
mainblk.cpp memry.cpp \
|
||||
serialis.cpp strngs.cpp scanutils.cpp \
|
||||
tessdatamanager.cpp tprintf.cpp \
|
||||
unichar.cpp unicharmap.cpp unicharset.cpp unicodes.cpp \
|
||||
unichar.cpp unicharcompress.cpp unicharmap.cpp unicharset.cpp unicodes.cpp \
|
||||
params.cpp universalambigs.cpp
|
||||
|
||||
if T_WIN
|
||||
|
@ -108,6 +108,8 @@ class GenericHeap {
|
||||
const Pair& PeekTop() const {
|
||||
return heap_[0];
|
||||
}
|
||||
// Get the value of the worst (largest, defined by operator< ) element.
|
||||
const Pair& PeekWorst() const { return heap_[IndexOfWorst()]; }
|
||||
|
||||
// Removes the top element of the heap. If entry is not NULL, the element
|
||||
// is copied into *entry, otherwise it is discarded.
|
||||
@ -136,22 +138,12 @@ class GenericHeap {
|
||||
// not NULL, the element is copied into *entry, otherwise it is discarded.
|
||||
// Time = O(n). Returns false if the heap was already empty.
|
||||
bool PopWorst(Pair* entry) {
|
||||
int heap_size = heap_.size();
|
||||
if (heap_size == 0) return false; // It cannot be empty!
|
||||
|
||||
// Find the maximum element. Its index is guaranteed to be greater than
|
||||
// the index of the parent of the last element, since by the heap invariant
|
||||
// the parent must be less than or equal to the children.
|
||||
int worst_index = heap_size - 1;
|
||||
int end_parent = ParentNode(worst_index);
|
||||
for (int i = worst_index - 1; i > end_parent; --i) {
|
||||
if (heap_[worst_index] < heap_[i])
|
||||
worst_index = i;
|
||||
}
|
||||
int worst_index = IndexOfWorst();
|
||||
if (worst_index < 0) return false; // It cannot be empty!
|
||||
// Extract the worst element from the heap, leaving a hole at worst_index.
|
||||
if (entry != NULL)
|
||||
*entry = heap_[worst_index];
|
||||
--heap_size;
|
||||
int heap_size = heap_.size() - 1;
|
||||
if (heap_size > 0) {
|
||||
// Sift the hole upwards to match the last element of the heap_
|
||||
Pair hole_pair = heap_[heap_size];
|
||||
@ -162,6 +154,22 @@ class GenericHeap {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns the index of the worst element. Time = O(n/2).
|
||||
int IndexOfWorst() const {
|
||||
int heap_size = heap_.size();
|
||||
if (heap_size == 0) return -1; // It cannot be empty!
|
||||
|
||||
// Find the maximum element. Its index is guaranteed to be greater than
|
||||
// the index of the parent of the last element, since by the heap invariant
|
||||
// the parent must be less than or equal to the children.
|
||||
int worst_index = heap_size - 1;
|
||||
int end_parent = ParentNode(worst_index);
|
||||
for (int i = worst_index - 1; i > end_parent; --i) {
|
||||
if (heap_[worst_index] < heap_[i]) worst_index = i;
|
||||
}
|
||||
return worst_index;
|
||||
}
|
||||
|
||||
// The pointed-to Pair has changed its key value, so the location of pair
|
||||
// is reshuffled to maintain the heap invariant.
|
||||
// Must be a valid pointer to an element of the heap_!
|
||||
|
@ -174,6 +174,8 @@ class GenericVector {
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool DeSerialize(bool swap, FILE* fp);
|
||||
bool DeSerialize(bool swap, tesseract::TFile* fp);
|
||||
// Skips the deserialization of the vector.
|
||||
static bool SkipDeSerialize(bool swap, tesseract::TFile* fp);
|
||||
// Writes a vector of classes to the given file. Assumes the existence of
|
||||
// bool T::Serialize(FILE* fp) const that returns false in case of error.
|
||||
// Returns false in case of error.
|
||||
@ -186,6 +188,8 @@ class GenericVector {
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool DeSerializeClasses(bool swap, FILE* fp);
|
||||
bool DeSerializeClasses(bool swap, tesseract::TFile* fp);
|
||||
// Calls SkipDeSerialize on the elements of the vector.
|
||||
static bool SkipDeSerializeClasses(bool swap, tesseract::TFile* fp);
|
||||
|
||||
// Allocates a new array of double the current_size, copies over the
|
||||
// information from data to the new location, deletes data and returns
|
||||
@ -238,14 +242,13 @@ class GenericVector {
|
||||
int binary_search(const T& target) const {
|
||||
int bottom = 0;
|
||||
int top = size_used_;
|
||||
do {
|
||||
while (top - bottom > 1) {
|
||||
int middle = (bottom + top) / 2;
|
||||
if (data_[middle] > target)
|
||||
top = middle;
|
||||
else
|
||||
bottom = middle;
|
||||
}
|
||||
while (top - bottom > 1);
|
||||
return bottom;
|
||||
}
|
||||
|
||||
@ -361,7 +364,7 @@ inline bool LoadDataFromFile(const STRING& filename,
|
||||
size_t size = ftell(fp);
|
||||
fseek(fp, 0, SEEK_SET);
|
||||
// Pad with a 0, just in case we treat the result as a string.
|
||||
data->init_to_size((int)size + 1, 0);
|
||||
data->init_to_size(static_cast<int>(size) + 1, 0);
|
||||
bool result = fread(&(*data)[0], 1, size, fp) == size;
|
||||
fclose(fp);
|
||||
return result;
|
||||
@ -556,34 +559,54 @@ class PointerVector : public GenericVector<T*> {
|
||||
}
|
||||
bool DeSerialize(bool swap, TFile* fp) {
|
||||
inT32 reserved;
|
||||
if (fp->FRead(&reserved, sizeof(reserved), 1) != 1) return false;
|
||||
if (swap) Reverse32(&reserved);
|
||||
if (!DeSerializeSize(swap, fp, &reserved)) return false;
|
||||
GenericVector<T*>::reserve(reserved);
|
||||
truncate(0);
|
||||
for (int i = 0; i < reserved; ++i) {
|
||||
inT8 non_null;
|
||||
if (fp->FRead(&non_null, sizeof(non_null), 1) != 1) return false;
|
||||
T* item = NULL;
|
||||
if (non_null) {
|
||||
item = new T;
|
||||
if (!item->DeSerialize(swap, fp)) {
|
||||
delete item;
|
||||
return false;
|
||||
}
|
||||
this->push_back(item);
|
||||
} else {
|
||||
// Null elements should keep their place in the vector.
|
||||
this->push_back(NULL);
|
||||
if (!DeSerializeElement(swap, fp)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
// Enables deserialization of a selection of elements. Note that in order to
|
||||
// retain the integrity of the stream, the caller must call some combination
|
||||
// of DeSerializeElement and DeSerializeSkip of the exact number returned in
|
||||
// *size, assuming a true return.
|
||||
static bool DeSerializeSize(bool swap, TFile* fp, inT32* size) {
|
||||
if (fp->FRead(size, sizeof(*size), 1) != 1) return false;
|
||||
if (swap) Reverse32(size);
|
||||
return true;
|
||||
}
|
||||
// Reads and appends to the vector the next element of the serialization.
|
||||
bool DeSerializeElement(bool swap, TFile* fp) {
|
||||
inT8 non_null;
|
||||
if (fp->FRead(&non_null, sizeof(non_null), 1) != 1) return false;
|
||||
T* item = NULL;
|
||||
if (non_null) {
|
||||
item = new T;
|
||||
if (!item->DeSerialize(swap, fp)) {
|
||||
delete item;
|
||||
return false;
|
||||
}
|
||||
this->push_back(item);
|
||||
} else {
|
||||
// Null elements should keep their place in the vector.
|
||||
this->push_back(NULL);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
// Skips the next element of the serialization.
|
||||
static bool DeSerializeSkip(bool swap, TFile* fp) {
|
||||
inT8 non_null;
|
||||
if (fp->FRead(&non_null, sizeof(non_null), 1) != 1) return false;
|
||||
if (non_null) {
|
||||
if (!T::SkipDeSerialize(swap, fp)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Sorts the items pointed to by the members of this vector using
|
||||
// t::operator<().
|
||||
void sort() {
|
||||
sort(&sort_ptr_cmp<T>);
|
||||
}
|
||||
void sort() { this->GenericVector<T*>::sort(&sort_ptr_cmp<T>); }
|
||||
};
|
||||
|
||||
} // namespace tesseract
|
||||
@ -926,6 +949,13 @@ bool GenericVector<T>::DeSerialize(bool swap, tesseract::TFile* fp) {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
template <typename T>
|
||||
bool GenericVector<T>::SkipDeSerialize(bool swap, tesseract::TFile* fp) {
|
||||
inT32 reserved;
|
||||
if (fp->FRead(&reserved, sizeof(reserved), 1) != 1) return false;
|
||||
if (swap) Reverse32(&reserved);
|
||||
return fp->FRead(NULL, sizeof(T), reserved) == reserved;
|
||||
}
|
||||
|
||||
// Writes a vector of classes to the given file. Assumes the existence of
|
||||
// bool T::Serialize(FILE* fp) const that returns false in case of error.
|
||||
@ -976,6 +1006,16 @@ bool GenericVector<T>::DeSerializeClasses(bool swap, tesseract::TFile* fp) {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
template <typename T>
|
||||
bool GenericVector<T>::SkipDeSerializeClasses(bool swap, tesseract::TFile* fp) {
|
||||
uinT32 reserved;
|
||||
if (fp->FRead(&reserved, sizeof(reserved), 1) != 1) return false;
|
||||
if (swap) Reverse32(&reserved);
|
||||
for (int i = 0; i < reserved; ++i) {
|
||||
if (!T::SkipDeSerialize(swap, fp)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// This method clear the current object, then, does a shallow copy of
|
||||
// its argument, and finally invalidates its argument.
|
||||
|
@ -95,7 +95,7 @@ int TFile::FRead(void* buffer, int size, int count) {
|
||||
char* char_buffer = reinterpret_cast<char*>(buffer);
|
||||
if (data_->size() - offset_ < required_size)
|
||||
required_size = data_->size() - offset_;
|
||||
if (required_size > 0)
|
||||
if (required_size > 0 && char_buffer != NULL)
|
||||
memcpy(char_buffer, &(*data_)[offset_], required_size);
|
||||
offset_ += required_size;
|
||||
return required_size / size;
|
||||
|
@ -181,6 +181,14 @@ bool STRING::DeSerialize(bool swap, TFile* fp) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// As DeSerialize, but only seeks past the data - hence a static method.
|
||||
bool STRING::SkipDeSerialize(bool swap, tesseract::TFile* fp) {
|
||||
inT32 len;
|
||||
if (fp->FRead(&len, sizeof(len), 1) != 1) return false;
|
||||
if (swap) ReverseN(&len, sizeof(len));
|
||||
return fp->FRead(NULL, 1, len) == len;
|
||||
}
|
||||
|
||||
BOOL8 STRING::contains(const char c) const {
|
||||
return (c != '\0') && (strchr (GetCStr(), c) != NULL);
|
||||
}
|
||||
|
@ -60,6 +60,8 @@ class TESS_API STRING
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool DeSerialize(bool swap, tesseract::TFile* fp);
|
||||
// As DeSerialize, but only seeks past the data - hence a static method.
|
||||
static bool SkipDeSerialize(bool swap, tesseract::TFile* fp);
|
||||
|
||||
BOOL8 contains(const char c) const;
|
||||
inT32 length() const;
|
||||
|
@ -47,6 +47,10 @@ static const char kShapeTableFileSuffix[] = "shapetable";
|
||||
static const char kBigramDawgFileSuffix[] = "bigram-dawg";
|
||||
static const char kUnambigDawgFileSuffix[] = "unambig-dawg";
|
||||
static const char kParamsModelFileSuffix[] = "params-model";
|
||||
static const char kLSTMModelFileSuffix[] = "lstm";
|
||||
static const char kLSTMPuncDawgFileSuffix[] = "lstm-punc-dawg";
|
||||
static const char kLSTMSystemDawgFileSuffix[] = "lstm-word-dawg";
|
||||
static const char kLSTMNumberDawgFileSuffix[] = "lstm-number-dawg";
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
@ -68,6 +72,10 @@ enum TessdataType {
|
||||
TESSDATA_BIGRAM_DAWG, // 14
|
||||
TESSDATA_UNAMBIG_DAWG, // 15
|
||||
TESSDATA_PARAMS_MODEL, // 16
|
||||
TESSDATA_LSTM, // 17
|
||||
TESSDATA_LSTM_PUNC_DAWG, // 18
|
||||
TESSDATA_LSTM_SYSTEM_DAWG, // 19
|
||||
TESSDATA_LSTM_NUMBER_DAWG, // 20
|
||||
|
||||
TESSDATA_NUM_ENTRIES
|
||||
};
|
||||
@ -76,24 +84,28 @@ enum TessdataType {
|
||||
* kTessdataFileSuffixes[i] indicates the file suffix for
|
||||
* tessdata of type i (from TessdataType enum).
|
||||
*/
|
||||
static const char * const kTessdataFileSuffixes[] = {
|
||||
kLangConfigFileSuffix, // 0
|
||||
kUnicharsetFileSuffix, // 1
|
||||
kAmbigsFileSuffix, // 2
|
||||
kBuiltInTemplatesFileSuffix, // 3
|
||||
kBuiltInCutoffsFileSuffix, // 4
|
||||
kNormProtoFileSuffix, // 5
|
||||
kPuncDawgFileSuffix, // 6
|
||||
kSystemDawgFileSuffix, // 7
|
||||
kNumberDawgFileSuffix, // 8
|
||||
kFreqDawgFileSuffix, // 9
|
||||
kFixedLengthDawgsFileSuffix, // 10 // deprecated
|
||||
kCubeUnicharsetFileSuffix, // 11
|
||||
kCubeSystemDawgFileSuffix, // 12
|
||||
kShapeTableFileSuffix, // 13
|
||||
kBigramDawgFileSuffix, // 14
|
||||
kUnambigDawgFileSuffix, // 15
|
||||
kParamsModelFileSuffix, // 16
|
||||
static const char *const kTessdataFileSuffixes[] = {
|
||||
kLangConfigFileSuffix, // 0
|
||||
kUnicharsetFileSuffix, // 1
|
||||
kAmbigsFileSuffix, // 2
|
||||
kBuiltInTemplatesFileSuffix, // 3
|
||||
kBuiltInCutoffsFileSuffix, // 4
|
||||
kNormProtoFileSuffix, // 5
|
||||
kPuncDawgFileSuffix, // 6
|
||||
kSystemDawgFileSuffix, // 7
|
||||
kNumberDawgFileSuffix, // 8
|
||||
kFreqDawgFileSuffix, // 9
|
||||
kFixedLengthDawgsFileSuffix, // 10 // deprecated
|
||||
kCubeUnicharsetFileSuffix, // 11
|
||||
kCubeSystemDawgFileSuffix, // 12
|
||||
kShapeTableFileSuffix, // 13
|
||||
kBigramDawgFileSuffix, // 14
|
||||
kUnambigDawgFileSuffix, // 15
|
||||
kParamsModelFileSuffix, // 16
|
||||
kLSTMModelFileSuffix, // 17
|
||||
kLSTMPuncDawgFileSuffix, // 18
|
||||
kLSTMSystemDawgFileSuffix, // 19
|
||||
kLSTMNumberDawgFileSuffix, // 20
|
||||
};
|
||||
|
||||
/**
|
||||
@ -101,23 +113,27 @@ static const char * const kTessdataFileSuffixes[] = {
|
||||
* of type i (from TessdataType enum) is text, and is binary otherwise.
|
||||
*/
|
||||
static const bool kTessdataFileIsText[] = {
|
||||
true, // 0
|
||||
true, // 1
|
||||
true, // 2
|
||||
false, // 3
|
||||
true, // 4
|
||||
true, // 5
|
||||
false, // 6
|
||||
false, // 7
|
||||
false, // 8
|
||||
false, // 9
|
||||
false, // 10 // deprecated
|
||||
true, // 11
|
||||
false, // 12
|
||||
false, // 13
|
||||
false, // 14
|
||||
false, // 15
|
||||
true, // 16
|
||||
true, // 0
|
||||
true, // 1
|
||||
true, // 2
|
||||
false, // 3
|
||||
true, // 4
|
||||
true, // 5
|
||||
false, // 6
|
||||
false, // 7
|
||||
false, // 8
|
||||
false, // 9
|
||||
false, // 10 // deprecated
|
||||
true, // 11
|
||||
false, // 12
|
||||
false, // 13
|
||||
false, // 14
|
||||
false, // 15
|
||||
true, // 16
|
||||
false, // 17
|
||||
false, // 18
|
||||
false, // 19
|
||||
false, // 20
|
||||
};
|
||||
|
||||
/**
|
||||
|
439
ccutil/unicharcompress.cpp
Normal file
439
ccutil/unicharcompress.cpp
Normal file
@ -0,0 +1,439 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: unicharcompress.cpp
|
||||
// Description: Unicode re-encoding using a sequence of smaller numbers in
|
||||
// place of a single large code for CJK, similarly for Indic,
|
||||
// and dissection of ligatures for other scripts.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Mar 04 14:45:01 PST 2015
|
||||
//
|
||||
// (C) Copyright 2015, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "unicharcompress.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// String used to represent the null_id in direct_set.
|
||||
const char* kNullChar = "<nul>";
|
||||
|
||||
// Local struct used only for processing the radical-stroke table.
|
||||
struct RadicalStroke {
|
||||
RadicalStroke() : num_strokes(0) {}
|
||||
RadicalStroke(const STRING& r, int s) : radical(r), num_strokes(s) {}
|
||||
|
||||
bool operator==(const RadicalStroke& other) const {
|
||||
return radical == other.radical && num_strokes == other.num_strokes;
|
||||
}
|
||||
|
||||
// The radical is encoded as a string because its format is of an int with
|
||||
// an optional ' mark to indicate a simplified shape. To treat these as
|
||||
// distinct, we use a string and a UNICHARSET to do the integer mapping.
|
||||
STRING radical;
|
||||
// The number of strokes we treat as dense and just take the face value from
|
||||
// the table.
|
||||
int num_strokes;
|
||||
};
|
||||
|
||||
// Hash functor for RadicalStroke.
|
||||
struct RadicalStrokedHash {
|
||||
size_t operator()(const RadicalStroke& rs) const {
|
||||
size_t result = rs.num_strokes;
|
||||
for (int i = 0; i < rs.radical.length(); ++i) {
|
||||
result ^= rs.radical[i] << (6 * i + 8);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// A hash map to convert unicodes to radical,stroke pair.
|
||||
typedef TessHashMap<int, RadicalStroke> RSMap;
|
||||
// A hash map to count occurrences of each radical,stroke pair.
|
||||
typedef TessHashMap<RadicalStroke, int, RadicalStrokedHash> RSCounts;
|
||||
|
||||
// Helper function builds the RSMap from the radical-stroke file, which has
|
||||
// already been read into a STRING. Returns false on error.
|
||||
// The radical_stroke_table is non-const because it gets split and the caller
|
||||
// is unlikely to want to use it again.
|
||||
static bool DecodeRadicalStrokeTable(STRING* radical_stroke_table,
|
||||
RSMap* radical_map) {
|
||||
GenericVector<STRING> lines;
|
||||
radical_stroke_table->split('\n', &lines);
|
||||
for (int i = 0; i < lines.size(); ++i) {
|
||||
if (lines[i].length() == 0 || lines[i][0] == '#') continue;
|
||||
int unicode, radical, strokes;
|
||||
STRING str_radical;
|
||||
if (sscanf(lines[i].string(), "%x\t%d.%d", &unicode, &radical, &strokes) ==
|
||||
3) {
|
||||
str_radical.add_str_int("", radical);
|
||||
} else if (sscanf(lines[i].string(), "%x\t%d'.%d", &unicode, &radical,
|
||||
&strokes) == 3) {
|
||||
str_radical.add_str_int("'", radical);
|
||||
} else {
|
||||
tprintf("Invalid format in radical stroke table at line %d: %s\n", i,
|
||||
lines[i].string());
|
||||
return false;
|
||||
}
|
||||
(*radical_map)[unicode] = RadicalStroke(str_radical, strokes);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
UnicharCompress::UnicharCompress() : code_range_(0) {}
|
||||
UnicharCompress::UnicharCompress(const UnicharCompress& src) { *this = src; }
|
||||
UnicharCompress::~UnicharCompress() { Cleanup(); }
|
||||
UnicharCompress& UnicharCompress::operator=(const UnicharCompress& src) {
|
||||
Cleanup();
|
||||
encoder_ = src.encoder_;
|
||||
code_range_ = src.code_range_;
|
||||
SetupDecoder();
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Computes the encoding for the given unicharset. It is a requirement that
|
||||
// the file training/langdata/radical-stroke.txt have been read into the
|
||||
// input string radical_stroke_table.
|
||||
// Returns false if the encoding cannot be constructed.
|
||||
bool UnicharCompress::ComputeEncoding(const UNICHARSET& unicharset, int null_id,
|
||||
STRING* radical_stroke_table) {
|
||||
RSMap radical_map;
|
||||
if (!DecodeRadicalStrokeTable(radical_stroke_table, &radical_map))
|
||||
return false;
|
||||
encoder_.clear();
|
||||
UNICHARSET direct_set;
|
||||
UNICHARSET radicals;
|
||||
// To avoid unused codes, clear the special codes from the unicharsets.
|
||||
direct_set.clear();
|
||||
radicals.clear();
|
||||
// Always keep space as 0;
|
||||
direct_set.unichar_insert(" ");
|
||||
// Null char is next if we have one.
|
||||
if (null_id >= 0) {
|
||||
direct_set.unichar_insert(kNullChar);
|
||||
}
|
||||
RSCounts radical_counts;
|
||||
// In the initial map, codes [0, unicharset.size()) are
|
||||
// reserved for non-han/hangul sequences of 1 or more unicodes.
|
||||
int hangul_offset = unicharset.size();
|
||||
// Hangul takes the next range [hangul_offset, hangul_offset + kTotalJamos).
|
||||
const int kTotalJamos = kLCount + kVCount + kTCount;
|
||||
// Han takes the codes beyond hangul_offset + kTotalJamos. Since it is hard
|
||||
// to measure the number of radicals and strokes, initially we use the same
|
||||
// code range for all 3 Han code positions, and fix them after.
|
||||
int han_offset = hangul_offset + kTotalJamos;
|
||||
int max_num_strokes = -1;
|
||||
for (int u = 0; u <= unicharset.size(); ++u) {
|
||||
bool self_normalized = false;
|
||||
// We special-case allow null_id to be equal to unicharset.size() in case
|
||||
// there is no space in unicharset for it.
|
||||
if (u == unicharset.size()) {
|
||||
if (u == null_id) {
|
||||
self_normalized = true;
|
||||
} else {
|
||||
break; // Finished.
|
||||
}
|
||||
} else {
|
||||
self_normalized = strcmp(unicharset.id_to_unichar(u),
|
||||
unicharset.get_normed_unichar(u)) == 0;
|
||||
}
|
||||
RecodedCharID code;
|
||||
// Convert to unicodes.
|
||||
GenericVector<int> unicodes;
|
||||
if (u < unicharset.size() &&
|
||||
UNICHAR::UTF8ToUnicode(unicharset.get_normed_unichar(u), &unicodes) &&
|
||||
unicodes.size() == 1) {
|
||||
// Check single unicodes for Hangul/Han and encode if so.
|
||||
int unicode = unicodes[0];
|
||||
int leading, vowel, trailing;
|
||||
auto it = radical_map.find(unicode);
|
||||
if (it != radical_map.end()) {
|
||||
// This is Han. Convert to radical, stroke, index.
|
||||
if (!radicals.contains_unichar(it->second.radical.string())) {
|
||||
radicals.unichar_insert(it->second.radical.string());
|
||||
}
|
||||
int radical = radicals.unichar_to_id(it->second.radical.string());
|
||||
int num_strokes = it->second.num_strokes;
|
||||
int num_samples = radical_counts[it->second]++;
|
||||
if (num_strokes > max_num_strokes) max_num_strokes = num_strokes;
|
||||
code.Set3(radical + han_offset, num_strokes + han_offset,
|
||||
num_samples + han_offset);
|
||||
} else if (DecomposeHangul(unicode, &leading, &vowel, &trailing)) {
|
||||
// This is Hangul. Since we know the exact size of each part at compile
|
||||
// time, it gets the bottom set of codes.
|
||||
code.Set3(leading + hangul_offset, vowel + kLCount + hangul_offset,
|
||||
trailing + kLCount + kVCount + hangul_offset);
|
||||
}
|
||||
}
|
||||
// If the code is still empty, it wasn't Han or Hangul.
|
||||
if (code.length() == 0) {
|
||||
// Special cases.
|
||||
if (u == UNICHAR_SPACE) {
|
||||
code.Set(0, 0); // Space.
|
||||
} else if (u == null_id || (unicharset.has_special_codes() &&
|
||||
u < SPECIAL_UNICHAR_CODES_COUNT)) {
|
||||
code.Set(0, direct_set.unichar_to_id(kNullChar));
|
||||
} else {
|
||||
// Add the direct_set unichar-ids of the unicodes in sequence to the
|
||||
// code.
|
||||
for (int i = 0; i < unicodes.size(); ++i) {
|
||||
int position = code.length();
|
||||
if (position >= RecodedCharID::kMaxCodeLen) {
|
||||
tprintf("Unichar %d=%s->%s is too long to encode!!\n", u,
|
||||
unicharset.id_to_unichar(u),
|
||||
unicharset.get_normed_unichar(u));
|
||||
return false;
|
||||
}
|
||||
int uni = unicodes[i];
|
||||
UNICHAR unichar(uni);
|
||||
char* utf8 = unichar.utf8_str();
|
||||
if (!direct_set.contains_unichar(utf8))
|
||||
direct_set.unichar_insert(utf8);
|
||||
code.Set(position, direct_set.unichar_to_id(utf8));
|
||||
delete[] utf8;
|
||||
if (direct_set.size() > unicharset.size()) {
|
||||
// Code space got bigger!
|
||||
tprintf("Code space expanded from original unicharset!!\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
code.set_self_normalized(self_normalized);
|
||||
encoder_.push_back(code);
|
||||
}
|
||||
// Now renumber Han to make all codes unique. We already added han_offset to
|
||||
// all Han. Now separate out the radical, stroke, and count codes for Han.
|
||||
// In the uniqued Han encoding, the 1st code uses the next radical_map.size()
|
||||
// values, the 2nd code uses the next max_num_strokes+1 values, and the 3rd
|
||||
// code uses the rest for the max number of duplicated radical/stroke combos.
|
||||
int num_radicals = radicals.size();
|
||||
for (int u = 0; u < unicharset.size(); ++u) {
|
||||
RecodedCharID* code = &encoder_[u];
|
||||
if ((*code)(0) >= han_offset) {
|
||||
code->Set(1, (*code)(1) + num_radicals);
|
||||
code->Set(2, (*code)(2) + num_radicals + max_num_strokes + 1);
|
||||
}
|
||||
}
|
||||
DefragmentCodeValues(null_id >= 0 ? 1 : -1);
|
||||
SetupDecoder();
|
||||
return true;
|
||||
}
|
||||
|
||||
// Sets up an encoder that doesn't change the unichars at all, so it just
|
||||
// passes them through unchanged.
|
||||
void UnicharCompress::SetupPassThrough(const UNICHARSET& unicharset) {
|
||||
GenericVector<RecodedCharID> codes;
|
||||
for (int u = 0; u < unicharset.size(); ++u) {
|
||||
RecodedCharID code;
|
||||
code.Set(0, u);
|
||||
codes.push_back(code);
|
||||
}
|
||||
SetupDirect(codes);
|
||||
}
|
||||
|
||||
// Sets up an encoder directly using the given encoding vector, which maps
|
||||
// unichar_ids to the given codes.
|
||||
void UnicharCompress::SetupDirect(const GenericVector<RecodedCharID>& codes) {
|
||||
encoder_ = codes;
|
||||
ComputeCodeRange();
|
||||
SetupDecoder();
|
||||
}
|
||||
|
||||
// Renumbers codes to eliminate unused values.
|
||||
void UnicharCompress::DefragmentCodeValues(int encoded_null) {
|
||||
// There may not be any Hangul, but even if there is, it is possible that not
|
||||
// all codes are used. Likewise with the Han encoding, it is possible that not
|
||||
// all numbers of strokes are used.
|
||||
ComputeCodeRange();
|
||||
GenericVector<int> offsets;
|
||||
offsets.init_to_size(code_range_, 0);
|
||||
// Find which codes are used
|
||||
for (int c = 0; c < encoder_.size(); ++c) {
|
||||
const RecodedCharID& code = encoder_[c];
|
||||
for (int i = 0; i < code.length(); ++i) {
|
||||
offsets[code(i)] = 1;
|
||||
}
|
||||
}
|
||||
// Compute offsets based on code use.
|
||||
int offset = 0;
|
||||
for (int i = 0; i < offsets.size(); ++i) {
|
||||
// If not used, decrement everything above here.
|
||||
// We are moving encoded_null to the end, so it is not "used".
|
||||
if (offsets[i] == 0 || i == encoded_null) {
|
||||
--offset;
|
||||
} else {
|
||||
offsets[i] = offset;
|
||||
}
|
||||
}
|
||||
if (encoded_null >= 0) {
|
||||
// The encoded_null is moving to the end, for the benefit of TensorFlow,
|
||||
// which is offsets.size() + offsets.back().
|
||||
offsets[encoded_null] = offsets.size() + offsets.back() - encoded_null;
|
||||
}
|
||||
// Now apply the offsets.
|
||||
for (int c = 0; c < encoder_.size(); ++c) {
|
||||
RecodedCharID* code = &encoder_[c];
|
||||
for (int i = 0; i < code->length(); ++i) {
|
||||
int value = (*code)(i);
|
||||
code->Set(i, value + offsets[value]);
|
||||
}
|
||||
}
|
||||
ComputeCodeRange();
|
||||
}
|
||||
|
||||
// Encodes a single unichar_id. Returns the length of the code, or zero if
|
||||
// invalid input, and the encoding itself
|
||||
int UnicharCompress::EncodeUnichar(int unichar_id, RecodedCharID* code) const {
|
||||
if (unichar_id < 0 || unichar_id >= encoder_.size()) return 0;
|
||||
*code = encoder_[unichar_id];
|
||||
return code->length();
|
||||
}
|
||||
|
||||
// Decodes code, returning the original unichar-id, or
|
||||
// INVALID_UNICHAR_ID if the input is invalid.
|
||||
int UnicharCompress::DecodeUnichar(const RecodedCharID& code) const {
|
||||
int len = code.length();
|
||||
if (len <= 0 || len > RecodedCharID::kMaxCodeLen) return INVALID_UNICHAR_ID;
|
||||
auto it = decoder_.find(code);
|
||||
if (it == decoder_.end()) return INVALID_UNICHAR_ID;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool UnicharCompress::Serialize(TFile* fp) const {
|
||||
return encoder_.SerializeClasses(fp);
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool UnicharCompress::DeSerialize(bool swap, TFile* fp) {
|
||||
if (!encoder_.DeSerializeClasses(swap, fp)) return false;
|
||||
ComputeCodeRange();
|
||||
SetupDecoder();
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns a STRING containing a text file that describes the encoding thus:
|
||||
// <index>[,<index>]*<tab><UTF8-str><newline>
|
||||
// In words, a comma-separated list of one or more indices, followed by a tab
|
||||
// and the UTF-8 string that the code represents per line. Most simple scripts
|
||||
// will encode a single index to a UTF8-string, but Chinese, Japanese, Korean
|
||||
// and the Indic scripts will contain a many-to-many mapping.
|
||||
// See the class comment above for details.
|
||||
STRING UnicharCompress::GetEncodingAsString(
|
||||
const UNICHARSET& unicharset) const {
|
||||
STRING encoding;
|
||||
for (int c = 0; c < encoder_.size(); ++c) {
|
||||
const RecodedCharID& code = encoder_[c];
|
||||
if (0 < c && c < SPECIAL_UNICHAR_CODES_COUNT && code == encoder_[c - 1]) {
|
||||
// Don't show the duplicate entry.
|
||||
continue;
|
||||
}
|
||||
encoding.add_str_int("", code(0));
|
||||
for (int i = 1; i < code.length(); ++i) {
|
||||
encoding.add_str_int(",", code(i));
|
||||
}
|
||||
encoding += "\t";
|
||||
if (c >= unicharset.size() || (0 < c && c < SPECIAL_UNICHAR_CODES_COUNT &&
|
||||
unicharset.has_special_codes())) {
|
||||
encoding += kNullChar;
|
||||
} else {
|
||||
encoding += unicharset.id_to_unichar(c);
|
||||
}
|
||||
encoding += "\n";
|
||||
}
|
||||
return encoding;
|
||||
}
|
||||
|
||||
// Helper decomposes a Hangul unicode to 3 parts, leading, vowel, trailing.
|
||||
// Note that the returned values are 0-based indices, NOT unicode Jamo.
|
||||
// Returns false if the input is not in the Hangul unicode range.
|
||||
/* static */
|
||||
bool UnicharCompress::DecomposeHangul(int unicode, int* leading, int* vowel,
|
||||
int* trailing) {
|
||||
if (unicode < kFirstHangul) return false;
|
||||
int offset = unicode - kFirstHangul;
|
||||
if (offset >= kNumHangul) return false;
|
||||
const int kNCount = kVCount * kTCount;
|
||||
*leading = offset / kNCount;
|
||||
*vowel = (offset % kNCount) / kTCount;
|
||||
*trailing = offset % kTCount;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Computes the value of code_range_ from the encoder_.
|
||||
void UnicharCompress::ComputeCodeRange() {
|
||||
code_range_ = -1;
|
||||
for (int c = 0; c < encoder_.size(); ++c) {
|
||||
const RecodedCharID& code = encoder_[c];
|
||||
for (int i = 0; i < code.length(); ++i) {
|
||||
if (code(i) > code_range_) code_range_ = code(i);
|
||||
}
|
||||
}
|
||||
++code_range_;
|
||||
}
|
||||
|
||||
// Initializes the decoding hash_map from the encoding array.
|
||||
void UnicharCompress::SetupDecoder() {
|
||||
Cleanup();
|
||||
is_valid_start_.init_to_size(code_range_, false);
|
||||
for (int c = 0; c < encoder_.size(); ++c) {
|
||||
const RecodedCharID& code = encoder_[c];
|
||||
if (code.self_normalized() || decoder_.find(code) == decoder_.end())
|
||||
decoder_[code] = c;
|
||||
is_valid_start_[code(0)] = true;
|
||||
RecodedCharID prefix = code;
|
||||
int len = code.length() - 1;
|
||||
prefix.Truncate(len);
|
||||
auto final_it = final_codes_.find(prefix);
|
||||
if (final_it == final_codes_.end()) {
|
||||
GenericVectorEqEq<int>* code_list = new GenericVectorEqEq<int>;
|
||||
code_list->push_back(code(len));
|
||||
final_codes_[prefix] = code_list;
|
||||
while (--len >= 0) {
|
||||
prefix.Truncate(len);
|
||||
auto next_it = next_codes_.find(prefix);
|
||||
if (next_it == next_codes_.end()) {
|
||||
GenericVectorEqEq<int>* code_list = new GenericVectorEqEq<int>;
|
||||
code_list->push_back(code(len));
|
||||
next_codes_[prefix] = code_list;
|
||||
} else {
|
||||
// We still have to search the list as we may get here via multiple
|
||||
// lengths of code.
|
||||
if (!next_it->second->contains(code(len)))
|
||||
next_it->second->push_back(code(len));
|
||||
break; // This prefix has been processed.
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!final_it->second->contains(code(len)))
|
||||
final_it->second->push_back(code(len));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Frees allocated memory.
|
||||
void UnicharCompress::Cleanup() {
|
||||
decoder_.clear();
|
||||
is_valid_start_.clear();
|
||||
for (auto it = next_codes_.begin(); it != next_codes_.end(); ++it) {
|
||||
delete it->second;
|
||||
}
|
||||
for (auto it = final_codes_.begin(); it != final_codes_.end(); ++it) {
|
||||
delete it->second;
|
||||
}
|
||||
next_codes_.clear();
|
||||
final_codes_.clear();
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
258
ccutil/unicharcompress.h
Normal file
258
ccutil/unicharcompress.h
Normal file
@ -0,0 +1,258 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: unicharcompress.h
|
||||
// Description: Unicode re-encoding using a sequence of smaller numbers in
|
||||
// place of a single large code for CJK, similarly for Indic,
|
||||
// and dissection of ligatures for other scripts.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Mar 04 14:45:01 PST 2015
|
||||
//
|
||||
// (C) Copyright 2015, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_CCUTIL_UNICHARCOMPRESS_H_
|
||||
#define TESSERACT_CCUTIL_UNICHARCOMPRESS_H_
|
||||
|
||||
#include "hashfn.h"
|
||||
#include "serialis.h"
|
||||
#include "strngs.h"
|
||||
#include "unicharset.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Trivial class to hold the code for a recoded unichar-id.
|
||||
class RecodedCharID {
|
||||
public:
|
||||
// The maximum length of a code.
|
||||
static const int kMaxCodeLen = 9;
|
||||
|
||||
RecodedCharID() : self_normalized_(0), length_(0) {
|
||||
memset(code_, 0, sizeof(code_));
|
||||
}
|
||||
void Truncate(int length) { length_ = length; }
|
||||
// Sets the code value at the given index in the code.
|
||||
void Set(int index, int value) {
|
||||
code_[index] = value;
|
||||
if (length_ <= index) length_ = index + 1;
|
||||
}
|
||||
// Shorthand for setting codes of length 3, as all Hangul and Han codes are
|
||||
// length 3.
|
||||
void Set3(int code0, int code1, int code2) {
|
||||
length_ = 3;
|
||||
code_[0] = code0;
|
||||
code_[1] = code1;
|
||||
code_[2] = code2;
|
||||
}
|
||||
// Accessors
|
||||
bool self_normalized() const { return self_normalized_ != 0; }
|
||||
void set_self_normalized(bool value) { self_normalized_ = value; }
|
||||
int length() const { return length_; }
|
||||
int operator()(int index) const { return code_[index]; }
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool Serialize(TFile* fp) const {
|
||||
if (fp->FWrite(&self_normalized_, sizeof(self_normalized_), 1) != 1)
|
||||
return false;
|
||||
if (fp->FWrite(&length_, sizeof(length_), 1) != 1) return false;
|
||||
if (fp->FWrite(code_, sizeof(code_[0]), length_) != length_) return false;
|
||||
return true;
|
||||
}
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool DeSerialize(bool swap, TFile* fp) {
|
||||
if (fp->FRead(&self_normalized_, sizeof(self_normalized_), 1) != 1)
|
||||
return false;
|
||||
if (fp->FRead(&length_, sizeof(length_), 1) != 1) return false;
|
||||
if (swap) ReverseN(&length_, sizeof(length_));
|
||||
if (fp->FRead(code_, sizeof(code_[0]), length_) != length_) return false;
|
||||
if (swap) {
|
||||
for (int i = 0; i < length_; ++i) {
|
||||
ReverseN(&code_[i], sizeof(code_[i]));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool operator==(const RecodedCharID& other) const {
|
||||
if (length_ != other.length_) return false;
|
||||
for (int i = 0; i < length_; ++i) {
|
||||
if (code_[i] != other.code_[i]) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
// Hash functor for RecodedCharID.
|
||||
struct RecodedCharIDHash {
|
||||
size_t operator()(const RecodedCharID& code) const {
|
||||
size_t result = 0;
|
||||
for (int i = 0; i < code.length_; ++i) {
|
||||
result ^= code(i) << (7 * i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
// True if this code is self-normalizing, ie is the master entry for indices
|
||||
// that map to the same code. Has boolean value, but inT8 for serialization.
|
||||
inT8 self_normalized_;
|
||||
// The number of elements in use in code_;
|
||||
inT32 length_;
|
||||
// The re-encoded form of the unichar-id to which this RecodedCharID relates.
|
||||
inT32 code_[kMaxCodeLen];
|
||||
};
|
||||
|
||||
// Class holds a "compression" of a unicharset to simplify the learning problem
|
||||
// for a neural-network-based classifier.
|
||||
// Objectives:
|
||||
// 1 (CJK): Ids of a unicharset with a large number of classes are expressed as
|
||||
// a sequence of 3 codes with much fewer values.
|
||||
// This is achieved using the Jamo coding for Hangul and the Unicode
|
||||
// Radical-Stroke-index for Han.
|
||||
// 2 (Indic): Instead of thousands of codes with one for each grapheme, re-code
|
||||
// as the unicode sequence (but coded in a more compact space).
|
||||
// 3 (the rest): Eliminate multi-path problems with ligatures and fold confusing
|
||||
// and not significantly distinct shapes (quotes) togther, ie
|
||||
// represent the fi ligature as the f-i pair, and fold u+2019 and
|
||||
// friends all onto ascii single '
|
||||
// 4 The null character and mapping to target activations:
|
||||
// To save horizontal coding space, the compressed codes are generally mapped
|
||||
// to target network activations without intervening null characters, BUT
|
||||
// in the case of ligatures, such as ff, null characters have to be included
|
||||
// so existence of repeated codes is detected at codebook-building time, and
|
||||
// null characters are embedded directly into the codes, so the rest of the
|
||||
// system doesn't need to worry about the problem (much). There is still an
|
||||
// effect on the range of ways in which the target activations can be
|
||||
// generated.
|
||||
//
|
||||
// The computed code values are compact (no unused values), and, for CJK,
|
||||
// unique (each code position uses a disjoint set of values from each other code
|
||||
// position). For non-CJK, the same code value CAN be used in multiple
|
||||
// positions, eg the ff ligature is converted to <f> <nullchar> <f>, where <f>
|
||||
// is the same code as is used for the single f.
|
||||
// NOTE that an intended consequence of using the normalized text from the
|
||||
// unicharset is that the fancy quotes all map to a single code, so round-trip
|
||||
// conversion doesn't work for all unichar-ids.
|
||||
class UnicharCompress {
|
||||
public:
|
||||
UnicharCompress();
|
||||
UnicharCompress(const UnicharCompress& src);
|
||||
~UnicharCompress();
|
||||
UnicharCompress& operator=(const UnicharCompress& src);
|
||||
|
||||
// The 1st Hangul unicode.
|
||||
static const int kFirstHangul = 0xac00;
|
||||
// The number of Hangul unicodes.
|
||||
static const int kNumHangul = 11172;
|
||||
// The number of Jamos for each of the 3 parts of a Hangul character, being
|
||||
// the Leading consonant, Vowel and Trailing consonant.
|
||||
static const int kLCount = 19;
|
||||
static const int kVCount = 21;
|
||||
static const int kTCount = 28;
|
||||
|
||||
// Computes the encoding for the given unicharset. It is a requirement that
|
||||
// the file training/langdata/radical-stroke.txt have been read into the
|
||||
// input string radical_stroke_table.
|
||||
// Returns false if the encoding cannot be constructed.
|
||||
bool ComputeEncoding(const UNICHARSET& unicharset, int null_id,
|
||||
STRING* radical_stroke_table);
|
||||
// Sets up an encoder that doesn't change the unichars at all, so it just
|
||||
// passes them through unchanged.
|
||||
void SetupPassThrough(const UNICHARSET& unicharset);
|
||||
// Sets up an encoder directly using the given encoding vector, which maps
|
||||
// unichar_ids to the given codes.
|
||||
void SetupDirect(const GenericVector<RecodedCharID>& codes);
|
||||
|
||||
// Returns the number of different values that can be used in a code, ie
|
||||
// 1 + the maximum value that will ever be used by an RecodedCharID code in
|
||||
// any position in its array.
|
||||
int code_range() const { return code_range_; }
|
||||
|
||||
// Encodes a single unichar_id. Returns the length of the code, (or zero if
|
||||
// invalid input), and the encoding itself in code.
|
||||
int EncodeUnichar(int unichar_id, RecodedCharID* code) const;
|
||||
// Decodes code, returning the original unichar-id, or
|
||||
// INVALID_UNICHAR_ID if the input is invalid. Note that this is not a perfect
|
||||
// inverse of EncodeUnichar, since the unichar-id of U+2019 (curly single
|
||||
// quote), for example, will have the same encoding as the unichar-id of
|
||||
// U+0027 (ascii '). The foldings are obtained from the input unicharset,
|
||||
// which in turn obtains them from NormalizeUTF8String in normstrngs.cpp,
|
||||
// and include NFKC normalization plus others like quote and dash folding.
|
||||
int DecodeUnichar(const RecodedCharID& code) const;
|
||||
// Returns true if the given code is a valid start or single code.
|
||||
bool IsValidFirstCode(int code) const { return is_valid_start_[code]; }
|
||||
// Returns a list of valid non-final next codes for a given prefix code,
|
||||
// which may be empty.
|
||||
const GenericVector<int>* GetNextCodes(const RecodedCharID& code) const {
|
||||
auto it = next_codes_.find(code);
|
||||
return it == next_codes_.end() ? NULL : it->second;
|
||||
}
|
||||
// Returns a list of valid final codes for a given prefix code, which may
|
||||
// be empty.
|
||||
const GenericVector<int>* GetFinalCodes(const RecodedCharID& code) const {
|
||||
auto it = final_codes_.find(code);
|
||||
return it == final_codes_.end() ? NULL : it->second;
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool Serialize(TFile* fp) const;
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool DeSerialize(bool swap, TFile* fp);
|
||||
|
||||
// Returns a STRING containing a text file that describes the encoding thus:
|
||||
// <index>[,<index>]*<tab><UTF8-str><newline>
|
||||
// In words, a comma-separated list of one or more indices, followed by a tab
|
||||
// and the UTF-8 string that the code represents per line. Most simple scripts
|
||||
// will encode a single index to a UTF8-string, but Chinese, Japanese, Korean
|
||||
// and the Indic scripts will contain a many-to-many mapping.
|
||||
// See the class comment above for details.
|
||||
STRING GetEncodingAsString(const UNICHARSET& unicharset) const;
|
||||
|
||||
// Helper decomposes a Hangul unicode to 3 parts, leading, vowel, trailing.
|
||||
// Note that the returned values are 0-based indices, NOT unicode Jamo.
|
||||
// Returns false if the input is not in the Hangul unicode range.
|
||||
static bool DecomposeHangul(int unicode, int* leading, int* vowel,
|
||||
int* trailing);
|
||||
|
||||
private:
|
||||
// Renumbers codes to eliminate unused values.
|
||||
void DefragmentCodeValues(int encoded_null);
|
||||
// Computes the value of code_range_ from the encoder_.
|
||||
void ComputeCodeRange();
|
||||
// Initializes the decoding hash_map from the encoder_ array.
|
||||
void SetupDecoder();
|
||||
// Frees allocated memory.
|
||||
void Cleanup();
|
||||
|
||||
// The encoder that maps a unichar-id to a sequence of small codes.
|
||||
// encoder_ is the only part that is serialized. The rest is computed on load.
|
||||
GenericVector<RecodedCharID> encoder_;
|
||||
// Decoder converts the output of encoder back to a unichar-id.
|
||||
TessHashMap<RecodedCharID, int, RecodedCharID::RecodedCharIDHash> decoder_;
|
||||
// True if the index is a valid single or start code.
|
||||
GenericVector<bool> is_valid_start_;
|
||||
// Maps a prefix code to a list of valid next codes.
|
||||
// The map owns the vectors.
|
||||
TessHashMap<RecodedCharID, GenericVectorEqEq<int>*,
|
||||
RecodedCharID::RecodedCharIDHash>
|
||||
next_codes_;
|
||||
// Maps a prefix code to a list of valid final codes.
|
||||
// The map owns the vectors.
|
||||
TessHashMap<RecodedCharID, GenericVectorEqEq<int>*,
|
||||
RecodedCharID::RecodedCharIDHash>
|
||||
final_codes_;
|
||||
// Max of any value in encoder_ + 1.
|
||||
int code_range_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_CCUTIL_UNICHARCOMPRESS_H_
|
@ -906,6 +906,8 @@ void UNICHARSET::post_load_setup() {
|
||||
han_sid_ = get_script_id_from_name("Han");
|
||||
hiragana_sid_ = get_script_id_from_name("Hiragana");
|
||||
katakana_sid_ = get_script_id_from_name("Katakana");
|
||||
thai_sid_ = get_script_id_from_name("Thai");
|
||||
hangul_sid_ = get_script_id_from_name("Hangul");
|
||||
|
||||
// Compute default script. Use the highest-counting alpha script, that is
|
||||
// not the common script, as that still contains some "alphas".
|
||||
|
@ -290,6 +290,8 @@ class UNICHARSET {
|
||||
han_sid_ = 0;
|
||||
hiragana_sid_ = 0;
|
||||
katakana_sid_ = 0;
|
||||
thai_sid_ = 0;
|
||||
hangul_sid_ = 0;
|
||||
}
|
||||
|
||||
// Return the size of the set (the number of different UNICHAR it holds).
|
||||
@ -604,6 +606,16 @@ class UNICHARSET {
|
||||
return unichars[unichar_id].properties.AnyRangeEmpty();
|
||||
}
|
||||
|
||||
// Returns true if the script of the given id is space delimited.
|
||||
// Returns false for Han and Thai scripts.
|
||||
bool IsSpaceDelimited(UNICHAR_ID unichar_id) const {
|
||||
if (INVALID_UNICHAR_ID == unichar_id) return true;
|
||||
int script_id = get_script(unichar_id);
|
||||
return script_id != han_sid_ && script_id != thai_sid_ &&
|
||||
script_id != hangul_sid_ && script_id != hiragana_sid_ &&
|
||||
script_id != katakana_sid_;
|
||||
}
|
||||
|
||||
// Return the script name of the given unichar.
|
||||
// The returned pointer will always be the same for the same script, it's
|
||||
// managed by unicharset and thus MUST NOT be deleted
|
||||
@ -773,7 +785,7 @@ class UNICHARSET {
|
||||
|
||||
// Returns normalized version of unichar with the given unichar_id.
|
||||
const char *get_normed_unichar(UNICHAR_ID unichar_id) const {
|
||||
if (unichar_id == UNICHAR_SPACE && has_special_codes()) return " ";
|
||||
if (unichar_id == UNICHAR_SPACE) return " ";
|
||||
return unichars[unichar_id].properties.normed.string();
|
||||
}
|
||||
// Returns a vector of UNICHAR_IDs that represent the ids of the normalized
|
||||
@ -835,6 +847,8 @@ class UNICHARSET {
|
||||
int han_sid() const { return han_sid_; }
|
||||
int hiragana_sid() const { return hiragana_sid_; }
|
||||
int katakana_sid() const { return katakana_sid_; }
|
||||
int thai_sid() const { return thai_sid_; }
|
||||
int hangul_sid() const { return hangul_sid_; }
|
||||
int default_sid() const { return default_sid_; }
|
||||
|
||||
// Returns true if the unicharset has the concept of upper/lower case.
|
||||
@ -977,6 +991,8 @@ class UNICHARSET {
|
||||
int han_sid_;
|
||||
int hiragana_sid_;
|
||||
int katakana_sid_;
|
||||
int thai_sid_;
|
||||
int hangul_sid_;
|
||||
// The most frequently occurring script in the charset.
|
||||
int default_sid_;
|
||||
};
|
||||
|
12
configure.ac
12
configure.ac
@ -6,7 +6,7 @@
|
||||
# Initialization
|
||||
# ----------------------------------------
|
||||
AC_PREREQ([2.50])
|
||||
AC_INIT([tesseract], [3.05.00dev], [https://github.com/tesseract-ocr/tesseract/issues])
|
||||
AC_INIT([tesseract], [4.00.00dev], [https://github.com/tesseract-ocr/tesseract/issues])
|
||||
AC_PROG_CXX([g++ clang++])
|
||||
AC_LANG([C++])
|
||||
AC_LANG_COMPILER_REQUIRE
|
||||
@ -18,8 +18,8 @@ AC_PREFIX_DEFAULT([/usr/local])
|
||||
|
||||
# Define date of package, etc. Could be useful in auto-generated
|
||||
# documentation.
|
||||
PACKAGE_YEAR=2015
|
||||
PACKAGE_DATE="07/11"
|
||||
PACKAGE_YEAR=2016
|
||||
PACKAGE_DATE="11/11"
|
||||
|
||||
abs_top_srcdir=`AS_DIRNAME([$0])`
|
||||
gitrev="`git --git-dir=${abs_top_srcdir}/.git --work-tree=${abs_top_srcdir} describe --always --tags`"
|
||||
@ -42,8 +42,8 @@ AC_SUBST([PACKAGE_DATE])
|
||||
GENERIC_LIBRARY_NAME=tesseract
|
||||
|
||||
# Release versioning
|
||||
GENERIC_MAJOR_VERSION=3
|
||||
GENERIC_MINOR_VERSION=4
|
||||
GENERIC_MAJOR_VERSION=4
|
||||
GENERIC_MINOR_VERSION=0
|
||||
GENERIC_MICRO_VERSION=0
|
||||
|
||||
# API version (often = GENERIC_MAJOR_VERSION.GENERIC_MINOR_VERSION)
|
||||
@ -520,6 +520,7 @@ fi
|
||||
# Output files
|
||||
AC_CONFIG_FILES([Makefile tesseract.pc])
|
||||
AC_CONFIG_FILES([api/Makefile])
|
||||
AC_CONFIG_FILES([arch/Makefile])
|
||||
AC_CONFIG_FILES([ccmain/Makefile])
|
||||
AC_CONFIG_FILES([opencl/Makefile])
|
||||
AC_CONFIG_FILES([ccstruct/Makefile])
|
||||
@ -528,6 +529,7 @@ AC_CONFIG_FILES([classify/Makefile])
|
||||
AC_CONFIG_FILES([cube/Makefile])
|
||||
AC_CONFIG_FILES([cutil/Makefile])
|
||||
AC_CONFIG_FILES([dict/Makefile])
|
||||
AC_CONFIG_FILES([lstm/Makefile])
|
||||
AC_CONFIG_FILES([neural_networks/runtime/Makefile])
|
||||
AC_CONFIG_FILES([textord/Makefile])
|
||||
AC_CONFIG_FILES([viewer/Makefile])
|
||||
|
@ -401,7 +401,6 @@ LIST s_adjoin(LIST var_list, void *variable, int_compare compare) {
|
||||
return (push_last (var_list, variable));
|
||||
}
|
||||
|
||||
|
||||
/**********************************************************************
|
||||
* s e a r c h
|
||||
*
|
||||
|
@ -69,14 +69,17 @@ Dawg *DawgLoader::Load() {
|
||||
PermuterType perm_type;
|
||||
switch (tessdata_dawg_type_) {
|
||||
case TESSDATA_PUNC_DAWG:
|
||||
case TESSDATA_LSTM_PUNC_DAWG:
|
||||
dawg_type = DAWG_TYPE_PUNCTUATION;
|
||||
perm_type = PUNC_PERM;
|
||||
break;
|
||||
case TESSDATA_SYSTEM_DAWG:
|
||||
case TESSDATA_LSTM_SYSTEM_DAWG:
|
||||
dawg_type = DAWG_TYPE_WORD;
|
||||
perm_type = SYSTEM_DAWG_PERM;
|
||||
break;
|
||||
case TESSDATA_NUMBER_DAWG:
|
||||
case TESSDATA_LSTM_NUMBER_DAWG:
|
||||
dawg_type = DAWG_TYPE_NUMBER;
|
||||
perm_type = NUMBER_PERM;
|
||||
break;
|
||||
|
@ -202,10 +202,8 @@ DawgCache *Dict::GlobalDawgCache() {
|
||||
return &cache;
|
||||
}
|
||||
|
||||
void Dict::Load(DawgCache *dawg_cache) {
|
||||
STRING name;
|
||||
STRING &lang = getCCUtil()->lang;
|
||||
|
||||
// Sets up ready for a Load or LoadLSTM.
|
||||
void Dict::SetupForLoad(DawgCache *dawg_cache) {
|
||||
if (dawgs_.length() != 0) this->End();
|
||||
|
||||
apostrophe_unichar_id_ = getUnicharset().unichar_to_id(kApostropheSymbol);
|
||||
@ -220,10 +218,10 @@ void Dict::Load(DawgCache *dawg_cache) {
|
||||
dawg_cache_ = new DawgCache();
|
||||
dawg_cache_is_ours_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
TessdataManager &tessdata_manager = getCCUtil()->tessdata_manager;
|
||||
const char *data_file_name = tessdata_manager.GetDataFileName().string();
|
||||
|
||||
// Loads the dawgs needed by Tesseract. Call FinishLoad() after.
|
||||
void Dict::Load(const char *data_file_name, const STRING &lang) {
|
||||
// Load dawgs_.
|
||||
if (load_punc_dawg) {
|
||||
punc_dawg_ = dawg_cache_->GetSquishedDawg(
|
||||
@ -255,6 +253,7 @@ void Dict::Load(DawgCache *dawg_cache) {
|
||||
if (unambig_dawg_) dawgs_ += unambig_dawg_;
|
||||
}
|
||||
|
||||
STRING name;
|
||||
if (((STRING &)user_words_suffix).length() > 0 ||
|
||||
((STRING &)user_words_file).length() > 0) {
|
||||
Trie *trie_ptr = new Trie(DAWG_TYPE_WORD, lang, USER_DAWG_PERM,
|
||||
@ -300,8 +299,33 @@ void Dict::Load(DawgCache *dawg_cache) {
|
||||
// This dawg is temporary and should not be searched by letter_is_ok.
|
||||
pending_words_ = new Trie(DAWG_TYPE_WORD, lang, NO_PERM,
|
||||
getUnicharset().size(), dawg_debug_level);
|
||||
}
|
||||
|
||||
// Construct a list of corresponding successors for each dawg. Each entry i
|
||||
// Loads the dawgs needed by the LSTM model. Call FinishLoad() after.
|
||||
void Dict::LoadLSTM(const char *data_file_name, const STRING &lang) {
|
||||
// Load dawgs_.
|
||||
if (load_punc_dawg) {
|
||||
punc_dawg_ = dawg_cache_->GetSquishedDawg(
|
||||
lang, data_file_name, TESSDATA_LSTM_PUNC_DAWG, dawg_debug_level);
|
||||
if (punc_dawg_) dawgs_ += punc_dawg_;
|
||||
}
|
||||
if (load_system_dawg) {
|
||||
Dawg *system_dawg = dawg_cache_->GetSquishedDawg(
|
||||
lang, data_file_name, TESSDATA_LSTM_SYSTEM_DAWG, dawg_debug_level);
|
||||
if (system_dawg) dawgs_ += system_dawg;
|
||||
}
|
||||
if (load_number_dawg) {
|
||||
Dawg *number_dawg = dawg_cache_->GetSquishedDawg(
|
||||
lang, data_file_name, TESSDATA_LSTM_NUMBER_DAWG, dawg_debug_level);
|
||||
if (number_dawg) dawgs_ += number_dawg;
|
||||
}
|
||||
}
|
||||
|
||||
// Completes the loading process after Load() and/or LoadLSTM().
|
||||
// Returns false if no dictionaries were loaded.
|
||||
bool Dict::FinishLoad() {
|
||||
if (dawgs_.empty()) return false;
|
||||
// Construct a list of corresponding successors for each dawg. Each entry, i,
|
||||
// in the successors_ vector is a vector of integers that represent the
|
||||
// indices into the dawgs_ vector of the successors for dawg i.
|
||||
successors_.reserve(dawgs_.length());
|
||||
@ -316,6 +340,7 @@ void Dict::Load(DawgCache *dawg_cache) {
|
||||
}
|
||||
successors_ += lst;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Dict::End() {
|
||||
@ -368,6 +393,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
|
||||
// Initialization.
|
||||
PermuterType curr_perm = NO_PERM;
|
||||
dawg_args->updated_dawgs->clear();
|
||||
dawg_args->valid_end = false;
|
||||
|
||||
// Go over the active_dawgs vector and insert DawgPosition records
|
||||
// with the updated ref (an edge with the corresponding unichar id) into
|
||||
@ -405,6 +431,9 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
|
||||
dawg_debug_level > 0,
|
||||
"Append transition from punc dawg to current dawgs: ");
|
||||
if (sdawg->permuter() > curr_perm) curr_perm = sdawg->permuter();
|
||||
if (sdawg->end_of_word(dawg_edge) &&
|
||||
punc_dawg->end_of_word(punc_transition_edge))
|
||||
dawg_args->valid_end = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -419,6 +448,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
|
||||
dawg_debug_level > 0,
|
||||
"Extend punctuation dawg: ");
|
||||
if (PUNC_PERM > curr_perm) curr_perm = PUNC_PERM;
|
||||
if (punc_dawg->end_of_word(punc_edge)) dawg_args->valid_end = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@ -436,6 +466,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
|
||||
dawg_debug_level > 0,
|
||||
"Return to punctuation dawg: ");
|
||||
if (dawg->permuter() > curr_perm) curr_perm = dawg->permuter();
|
||||
if (punc_dawg->end_of_word(punc_edge)) dawg_args->valid_end = true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -445,8 +476,8 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
|
||||
// possible edges, not only for the exact unichar_id, but also
|
||||
// for all its character classes (alpha, digit, etc).
|
||||
if (dawg->type() == DAWG_TYPE_PATTERN) {
|
||||
ProcessPatternEdges(dawg, pos, unichar_id, word_end,
|
||||
dawg_args->updated_dawgs, &curr_perm);
|
||||
ProcessPatternEdges(dawg, pos, unichar_id, word_end, dawg_args,
|
||||
&curr_perm);
|
||||
// There can't be any successors to dawg that is of type
|
||||
// DAWG_TYPE_PATTERN, so we are done examining this DawgPosition.
|
||||
continue;
|
||||
@ -473,6 +504,9 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
|
||||
continue;
|
||||
}
|
||||
if (dawg->permuter() > curr_perm) curr_perm = dawg->permuter();
|
||||
if (dawg->end_of_word(edge) &&
|
||||
(punc_dawg == NULL || punc_dawg->end_of_word(pos.punc_ref)))
|
||||
dawg_args->valid_end = true;
|
||||
dawg_args->updated_dawgs->add_unique(
|
||||
DawgPosition(pos.dawg_index, edge, pos.punc_index, pos.punc_ref,
|
||||
false),
|
||||
@ -497,7 +531,7 @@ int Dict::def_letter_is_okay(void* void_dawg_args,
|
||||
|
||||
void Dict::ProcessPatternEdges(const Dawg *dawg, const DawgPosition &pos,
|
||||
UNICHAR_ID unichar_id, bool word_end,
|
||||
DawgPositionVector *updated_dawgs,
|
||||
DawgArgs *dawg_args,
|
||||
PermuterType *curr_perm) const {
|
||||
NODE_REF node = GetStartingNode(dawg, pos.dawg_ref);
|
||||
// Try to find the edge corresponding to the exact unichar_id and to all the
|
||||
@ -520,7 +554,8 @@ void Dict::ProcessPatternEdges(const Dawg *dawg, const DawgPosition &pos,
|
||||
tprintf("Letter found in pattern dawg %d\n", pos.dawg_index);
|
||||
}
|
||||
if (dawg->permuter() > *curr_perm) *curr_perm = dawg->permuter();
|
||||
updated_dawgs->add_unique(
|
||||
if (dawg->end_of_word(edge)) dawg_args->valid_end = true;
|
||||
dawg_args->updated_dawgs->add_unique(
|
||||
DawgPosition(pos.dawg_index, edge, pos.punc_index, pos.punc_ref,
|
||||
pos.back_to_punc),
|
||||
dawg_debug_level > 0,
|
||||
@ -816,5 +851,13 @@ bool Dict::valid_punctuation(const WERD_CHOICE &word) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if the language is space-delimited (not CJ, or T).
|
||||
bool Dict::IsSpaceDelimitedLang() const {
|
||||
const UNICHARSET &u_set = getUnicharset();
|
||||
if (u_set.han_sid() > 0) return false;
|
||||
if (u_set.katakana_sid() > 0) return false;
|
||||
if (u_set.thai_sid() > 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace tesseract
|
||||
|
23
dict/dict.h
23
dict/dict.h
@ -23,7 +23,6 @@
|
||||
#include "dawg.h"
|
||||
#include "dawg_cache.h"
|
||||
#include "host.h"
|
||||
#include "oldlist.h"
|
||||
#include "ratngs.h"
|
||||
#include "stopper.h"
|
||||
#include "trie.h"
|
||||
@ -76,11 +75,13 @@ enum XHeightConsistencyEnum {XH_GOOD, XH_SUBNORMAL, XH_INCONSISTENT};
|
||||
|
||||
struct DawgArgs {
|
||||
DawgArgs(DawgPositionVector *d, DawgPositionVector *up, PermuterType p)
|
||||
: active_dawgs(d), updated_dawgs(up), permuter(p) {}
|
||||
: active_dawgs(d), updated_dawgs(up), permuter(p), valid_end(false) {}
|
||||
|
||||
DawgPositionVector *active_dawgs;
|
||||
DawgPositionVector *updated_dawgs;
|
||||
PermuterType permuter;
|
||||
// True if the current position is a valid word end.
|
||||
bool valid_end;
|
||||
};
|
||||
|
||||
class Dict {
|
||||
@ -294,7 +295,15 @@ class Dict {
|
||||
/// Initialize Dict class - load dawgs from [lang].traineddata and
|
||||
/// user-specified wordlist and parttern list.
|
||||
static DawgCache *GlobalDawgCache();
|
||||
void Load(DawgCache *dawg_cache);
|
||||
// Sets up ready for a Load or LoadLSTM.
|
||||
void SetupForLoad(DawgCache *dawg_cache);
|
||||
// Loads the dawgs needed by Tesseract. Call FinishLoad() after.
|
||||
void Load(const char *data_file_name, const STRING &lang);
|
||||
// Loads the dawgs needed by the LSTM model. Call FinishLoad() after.
|
||||
void LoadLSTM(const char *data_file_name, const STRING &lang);
|
||||
// Completes the loading process after Load() and/or LoadLSTM().
|
||||
// Returns false if no dictionaries were loaded.
|
||||
bool FinishLoad();
|
||||
void End();
|
||||
|
||||
// Resets the document dictionary analogous to ResetAdaptiveClassifier.
|
||||
@ -397,9 +406,7 @@ class Dict {
|
||||
}
|
||||
|
||||
inline void SetWildcardID(UNICHAR_ID id) { wildcard_unichar_id_ = id; }
|
||||
inline UNICHAR_ID WildcardID() const {
|
||||
return wildcard_unichar_id_;
|
||||
}
|
||||
inline UNICHAR_ID WildcardID() const { return wildcard_unichar_id_; }
|
||||
/// Return the number of dawgs in the dawgs_ vector.
|
||||
inline int NumDawgs() const { return dawgs_.size(); }
|
||||
/// Return i-th dawg pointer recorded in the dawgs_ vector.
|
||||
@ -436,7 +443,7 @@ class Dict {
|
||||
/// edges were found.
|
||||
void ProcessPatternEdges(const Dawg *dawg, const DawgPosition &info,
|
||||
UNICHAR_ID unichar_id, bool word_end,
|
||||
DawgPositionVector *updated_dawgs,
|
||||
DawgArgs *dawg_args,
|
||||
PermuterType *current_permuter) const;
|
||||
|
||||
/// Read/Write/Access special purpose dawgs which contain words
|
||||
@ -483,6 +490,8 @@ class Dict {
|
||||
inline void SetWordsegRatingAdjustFactor(float f) {
|
||||
wordseg_rating_adjust_factor_ = f;
|
||||
}
|
||||
/// Returns true if the language is space-delimited (not CJ, or T).
|
||||
bool IsSpaceDelimitedLang() const;
|
||||
|
||||
private:
|
||||
/** Private member variables. */
|
||||
|
39
lstm/Makefile.am
Normal file
39
lstm/Makefile.am
Normal file
@ -0,0 +1,39 @@
|
||||
AM_CPPFLAGS += \
|
||||
-I$(top_srcdir)/ccutil -I$(top_srcdir)/cutil -I$(top_srcdir)/ccstruct \
|
||||
-I$(top_srcdir)/arch -I$(top_srcdir)/viewer -I$(top_srcdir)/classify \
|
||||
-I$(top_srcdir)/dict -I$(top_srcdir)/lstm
|
||||
AUTOMAKE_OPTIONS = subdir-objects
|
||||
SUBDIRS =
|
||||
AM_CXXFLAGS = -fopenmp
|
||||
|
||||
if !NO_TESSDATA_PREFIX
|
||||
AM_CXXFLAGS += -DTESSDATA_PREFIX=@datadir@/
|
||||
endif
|
||||
|
||||
if VISIBILITY
|
||||
AM_CXXFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden
|
||||
AM_CPPFLAGS += -DTESS_EXPORTS
|
||||
endif
|
||||
|
||||
include_HEADERS = \
|
||||
convolve.h ctc.h fullyconnected.h functions.h input.h \
|
||||
lstm.h lstmrecognizer.h lstmtrainer.h maxpool.h \
|
||||
networkbuilder.h network.h networkio.h networkscratch.h \
|
||||
parallel.h plumbing.h recodebeam.h reconfig.h reversed.h \
|
||||
series.h static_shape.h stridemap.h tfnetwork.h weightmatrix.h
|
||||
|
||||
noinst_HEADERS =
|
||||
|
||||
if !USING_MULTIPLELIBS
|
||||
noinst_LTLIBRARIES = libtesseract_lstm.la
|
||||
else
|
||||
lib_LTLIBRARIES = libtesseract_lstm.la
|
||||
libtesseract_lstm_la_LDFLAGS = -version-info $(GENERIC_LIBRARY_VERSION)
|
||||
endif
|
||||
|
||||
libtesseract_lstm_la_SOURCES = \
|
||||
convolve.cpp ctc.cpp fullyconnected.cpp functions.cpp input.cpp \
|
||||
lstm.cpp lstmrecognizer.cpp lstmtrainer.cpp maxpool.cpp \
|
||||
networkbuilder.cpp network.cpp networkio.cpp \
|
||||
parallel.cpp plumbing.cpp recodebeam.cpp reconfig.cpp reversed.cpp \
|
||||
series.cpp stridemap.cpp tfnetwork.cpp weightmatrix.cpp
|
124
lstm/convolve.cpp
Normal file
124
lstm/convolve.cpp
Normal file
@ -0,0 +1,124 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: convolve.cpp
|
||||
// Description: Convolutional layer that stacks the inputs over its rectangle
|
||||
// and pulls in random data to fill out-of-input inputs.
|
||||
// Output is therefore same size as its input, but deeper.
|
||||
// Author: Ray Smith
|
||||
// Created: Tue Mar 18 16:56:06 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "convolve.h"
|
||||
|
||||
#include "networkscratch.h"
|
||||
#include "serialis.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
Convolve::Convolve(const STRING& name, int ni, int half_x, int half_y)
|
||||
: Network(NT_CONVOLVE, name, ni, ni * (2*half_x + 1) * (2*half_y + 1)),
|
||||
half_x_(half_x), half_y_(half_y) {
|
||||
}
|
||||
|
||||
Convolve::~Convolve() {
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool Convolve::Serialize(TFile* fp) const {
|
||||
if (!Network::Serialize(fp)) return false;
|
||||
if (fp->FWrite(&half_x_, sizeof(half_x_), 1) != 1) return false;
|
||||
if (fp->FWrite(&half_y_, sizeof(half_y_), 1) != 1) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool Convolve::DeSerialize(bool swap, TFile* fp) {
|
||||
if (fp->FRead(&half_x_, sizeof(half_x_), 1) != 1) return false;
|
||||
if (fp->FRead(&half_y_, sizeof(half_y_), 1) != 1) return false;
|
||||
if (swap) {
|
||||
ReverseN(&half_x_, sizeof(half_x_));
|
||||
ReverseN(&half_y_, sizeof(half_y_));
|
||||
}
|
||||
no_ = ni_ * (2*half_x_ + 1) * (2*half_y_ + 1);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
void Convolve::Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output) {
|
||||
output->Resize(input, no_);
|
||||
int y_scale = 2 * half_y_ + 1;
|
||||
StrideMap::Index dest_index(output->stride_map());
|
||||
do {
|
||||
// Stack x_scale groups of y_scale * ni_ inputs together.
|
||||
int t = dest_index.t();
|
||||
int out_ix = 0;
|
||||
for (int x = -half_x_; x <= half_x_; ++x, out_ix += y_scale * ni_) {
|
||||
StrideMap::Index x_index(dest_index);
|
||||
if (!x_index.AddOffset(x, FD_WIDTH)) {
|
||||
// This x is outside the image.
|
||||
output->Randomize(t, out_ix, y_scale * ni_, randomizer_);
|
||||
} else {
|
||||
int out_iy = out_ix;
|
||||
for (int y = -half_y_; y <= half_y_; ++y, out_iy += ni_) {
|
||||
StrideMap::Index y_index(x_index);
|
||||
if (!y_index.AddOffset(y, FD_HEIGHT)) {
|
||||
// This y is outside the image.
|
||||
output->Randomize(t, out_iy, ni_, randomizer_);
|
||||
} else {
|
||||
output->CopyTimeStepGeneral(t, out_iy, ni_, input, y_index.t(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} while (dest_index.Increment());
|
||||
if (debug) DisplayForward(*output);
|
||||
}
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
bool Convolve::Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas) {
|
||||
back_deltas->Resize(fwd_deltas, ni_);
|
||||
NetworkScratch::IO delta_sum;
|
||||
delta_sum.ResizeFloat(fwd_deltas, ni_, scratch);
|
||||
delta_sum->Zero();
|
||||
int y_scale = 2 * half_y_ + 1;
|
||||
StrideMap::Index src_index(fwd_deltas.stride_map());
|
||||
do {
|
||||
// Stack x_scale groups of y_scale * ni_ inputs together.
|
||||
int t = src_index.t();
|
||||
int out_ix = 0;
|
||||
for (int x = -half_x_; x <= half_x_; ++x, out_ix += y_scale * ni_) {
|
||||
StrideMap::Index x_index(src_index);
|
||||
if (x_index.AddOffset(x, FD_WIDTH)) {
|
||||
int out_iy = out_ix;
|
||||
for (int y = -half_y_; y <= half_y_; ++y, out_iy += ni_) {
|
||||
StrideMap::Index y_index(x_index);
|
||||
if (y_index.AddOffset(y, FD_HEIGHT)) {
|
||||
fwd_deltas.AddTimeStepPart(t, out_iy, ni_,
|
||||
delta_sum->f(y_index.t()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} while (src_index.Increment());
|
||||
back_deltas->CopyWithNormalization(*delta_sum, fwd_deltas);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
74
lstm/convolve.h
Normal file
74
lstm/convolve.h
Normal file
@ -0,0 +1,74 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: convolve.h
|
||||
// Description: Convolutional layer that stacks the inputs over its rectangle
|
||||
// and pulls in random data to fill out-of-input inputs.
|
||||
// Output is therefore same size as its input, but deeper.
|
||||
// Author: Ray Smith
|
||||
// Created: Tue Mar 18 16:45:34 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_CONVOLVE_H_
|
||||
#define TESSERACT_LSTM_CONVOLVE_H_
|
||||
|
||||
#include "genericvector.h"
|
||||
#include "matrix.h"
|
||||
#include "network.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Makes each time-step deeper by stacking inputs over its rectangle. Does not
|
||||
// affect the size of its input. Achieves this by bringing in random values in
|
||||
// out-of-input areas.
|
||||
class Convolve : public Network {
|
||||
public:
|
||||
// The area of convolution is 2*half_x + 1 by 2*half_y + 1, forcing it to
|
||||
// always be odd, so the center is the current pixel.
|
||||
Convolve(const STRING& name, int ni, int half_x, int half_y);
|
||||
virtual ~Convolve();
|
||||
|
||||
virtual STRING spec() const {
|
||||
STRING spec;
|
||||
spec.add_str_int("C", half_x_ * 2 + 1);
|
||||
spec.add_str_int(",", half_y_ * 2 + 1);
|
||||
return spec;
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
virtual bool Serialize(TFile* fp) const;
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
virtual bool DeSerialize(bool swap, TFile* fp);
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual void Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output);
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas);
|
||||
|
||||
protected:
|
||||
// Serialized data.
|
||||
inT32 half_x_;
|
||||
inT32 half_y_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
|
||||
#endif // TESSERACT_LSTM_SUBSAMPLE_H_
|
412
lstm/ctc.cpp
Normal file
412
lstm/ctc.cpp
Normal file
@ -0,0 +1,412 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: ctc.cpp
|
||||
// Description: Slightly improved standard CTC to compute the targets.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Jul 13 15:50:06 PDT 2016
|
||||
//
|
||||
// (C) Copyright 2016, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
#include "ctc.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "genericvector.h"
|
||||
#include "host.h"
|
||||
#include "matrix.h"
|
||||
#include "networkio.h"
|
||||
|
||||
#include "network.h"
|
||||
#include "scrollview.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Magic constants that keep CTC stable.
|
||||
// Minimum probability limit for softmax input to ctc_loss.
|
||||
const float CTC::kMinProb_ = 1e-12;
|
||||
// Maximum absolute argument to exp().
|
||||
const double CTC::kMaxExpArg_ = 80.0;
|
||||
// Minimum probability for total prob in time normalization.
|
||||
const double CTC::kMinTotalTimeProb_ = 1e-8;
|
||||
// Minimum probability for total prob in final normalization.
|
||||
const double CTC::kMinTotalFinalProb_ = 1e-6;
|
||||
|
||||
// Builds a target using CTC. Slightly improved as follows:
|
||||
// Includes normalizations and clipping for stability.
|
||||
// labels should be pre-padded with nulls everywhere.
|
||||
// labels can be longer than the time sequence, but the total number of
|
||||
// essential labels (non-null plus nulls between equal labels) must not exceed
|
||||
// the number of timesteps in outputs.
|
||||
// outputs is the output of the network, and should have already been
|
||||
// normalized with NormalizeProbs.
|
||||
// On return targets is filled with the computed targets.
|
||||
// Returns false if there is insufficient time for the labels.
|
||||
/* static */
|
||||
bool CTC::ComputeCTCTargets(const GenericVector<int>& labels, int null_char,
|
||||
const GENERIC_2D_ARRAY<float>& outputs,
|
||||
NetworkIO* targets) {
|
||||
std::unique_ptr<CTC> ctc(new CTC(labels, null_char, outputs));
|
||||
if (!ctc->ComputeLabelLimits()) {
|
||||
return false; // Not enough time.
|
||||
}
|
||||
// Generate simple targets purely from the truth labels by spreading them
|
||||
// evenly over time.
|
||||
GENERIC_2D_ARRAY<float> simple_targets;
|
||||
ctc->ComputeSimpleTargets(&simple_targets);
|
||||
// Add the simple targets as a starter bias to the network outputs.
|
||||
float bias_fraction = ctc->CalculateBiasFraction();
|
||||
simple_targets *= bias_fraction;
|
||||
ctc->outputs_ += simple_targets;
|
||||
NormalizeProbs(&ctc->outputs_);
|
||||
// Run regular CTC on the biased outputs.
|
||||
// Run forward and backward
|
||||
GENERIC_2D_ARRAY<double> log_alphas, log_betas;
|
||||
ctc->Forward(&log_alphas);
|
||||
ctc->Backward(&log_betas);
|
||||
// Normalize and come out of log space with a clipped softmax over time.
|
||||
log_alphas += log_betas;
|
||||
ctc->NormalizeSequence(&log_alphas);
|
||||
ctc->LabelsToClasses(log_alphas, targets);
|
||||
NormalizeProbs(targets);
|
||||
return true;
|
||||
}
|
||||
|
||||
CTC::CTC(const GenericVector<int>& labels, int null_char,
|
||||
const GENERIC_2D_ARRAY<float>& outputs)
|
||||
: labels_(labels), outputs_(outputs), null_char_(null_char) {
|
||||
num_timesteps_ = outputs.dim1();
|
||||
num_classes_ = outputs.dim2();
|
||||
num_labels_ = labels_.size();
|
||||
}
|
||||
|
||||
// Computes vectors of min and max label index for each timestep, based on
|
||||
// whether skippability of nulls makes it possible to complete a valid path.
|
||||
bool CTC::ComputeLabelLimits() {
|
||||
min_labels_.init_to_size(num_timesteps_, 0);
|
||||
max_labels_.init_to_size(num_timesteps_, 0);
|
||||
int min_u = num_labels_ - 1;
|
||||
if (labels_[min_u] == null_char_) --min_u;
|
||||
for (int t = num_timesteps_ - 1; t >= 0; --t) {
|
||||
min_labels_[t] = min_u;
|
||||
if (min_u > 0) {
|
||||
--min_u;
|
||||
if (labels_[min_u] == null_char_ && min_u > 0 &&
|
||||
labels_[min_u + 1] != labels_[min_u - 1]) {
|
||||
--min_u;
|
||||
}
|
||||
}
|
||||
}
|
||||
int max_u = labels_[0] == null_char_;
|
||||
for (int t = 0; t < num_timesteps_; ++t) {
|
||||
max_labels_[t] = max_u;
|
||||
if (max_labels_[t] < min_labels_[t]) return false; // Not enough room.
|
||||
if (max_u + 1 < num_labels_) {
|
||||
++max_u;
|
||||
if (labels_[max_u] == null_char_ && max_u + 1 < num_labels_ &&
|
||||
labels_[max_u + 1] != labels_[max_u - 1]) {
|
||||
++max_u;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Computes targets based purely on the labels by spreading the labels evenly
|
||||
// over the available timesteps.
|
||||
void CTC::ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const {
|
||||
// Initialize all targets to zero.
|
||||
targets->Resize(num_timesteps_, num_classes_, 0.0f);
|
||||
GenericVector<float> half_widths;
|
||||
GenericVector<int> means;
|
||||
ComputeWidthsAndMeans(&half_widths, &means);
|
||||
for (int l = 0; l < num_labels_; ++l) {
|
||||
int label = labels_[l];
|
||||
float left_half_width = half_widths[l];
|
||||
float right_half_width = left_half_width;
|
||||
int mean = means[l];
|
||||
if (label == null_char_) {
|
||||
if (!NeededNull(l)) {
|
||||
if ((l > 0 && mean == means[l - 1]) ||
|
||||
(l + 1 < num_labels_ && mean == means[l + 1])) {
|
||||
continue; // Drop overlapping null.
|
||||
}
|
||||
}
|
||||
// Make sure that no space is left unoccupied and that non-nulls always
|
||||
// peak at 1 by stretching nulls to meet their neighbors.
|
||||
if (l > 0) left_half_width = mean - means[l - 1];
|
||||
if (l + 1 < num_labels_) right_half_width = means[l + 1] - mean;
|
||||
}
|
||||
if (mean >= 0 && mean < num_timesteps_) targets->put(mean, label, 1.0f);
|
||||
for (int offset = 1; offset < left_half_width && mean >= offset; ++offset) {
|
||||
float prob = 1.0f - offset / left_half_width;
|
||||
if (mean - offset < num_timesteps_ &&
|
||||
prob > targets->get(mean - offset, label)) {
|
||||
targets->put(mean - offset, label, prob);
|
||||
}
|
||||
}
|
||||
for (int offset = 1;
|
||||
offset < right_half_width && mean + offset < num_timesteps_;
|
||||
++offset) {
|
||||
float prob = 1.0f - offset / right_half_width;
|
||||
if (mean + offset >= 0 && prob > targets->get(mean + offset, label)) {
|
||||
targets->put(mean + offset, label, prob);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Computes mean positions and half widths of the simple targets by spreading
|
||||
// the labels evenly over the available timesteps.
|
||||
void CTC::ComputeWidthsAndMeans(GenericVector<float>* half_widths,
|
||||
GenericVector<int>* means) const {
|
||||
// Count the number of labels of each type, in regexp terms, counts plus
|
||||
// (non-null or necessary null, which must occur at least once) and star
|
||||
// (optional null).
|
||||
int num_plus = 0, num_star = 0;
|
||||
for (int i = 0; i < num_labels_; ++i) {
|
||||
if (labels_[i] != null_char_ || NeededNull(i))
|
||||
++num_plus;
|
||||
else
|
||||
++num_star;
|
||||
}
|
||||
// Compute the size for each type. If there is enough space for everything
|
||||
// to have size>=1, then all are equal, otherwise plus_size=1 and star gets
|
||||
// whatever is left-over.
|
||||
float plus_size = 1.0f, star_size = 0.0f;
|
||||
float total_floating = num_plus + num_star;
|
||||
if (total_floating <= num_timesteps_) {
|
||||
plus_size = star_size = num_timesteps_ / total_floating;
|
||||
} else if (num_star > 0) {
|
||||
star_size = static_cast<float>(num_timesteps_ - num_plus) / num_star;
|
||||
}
|
||||
// Set the width and compute the mean of each.
|
||||
float mean_pos = 0.0f;
|
||||
for (int i = 0; i < num_labels_; ++i) {
|
||||
float half_width;
|
||||
if (labels_[i] != null_char_ || NeededNull(i)) {
|
||||
half_width = plus_size / 2.0f;
|
||||
} else {
|
||||
half_width = star_size / 2.0f;
|
||||
}
|
||||
mean_pos += half_width;
|
||||
means->push_back(static_cast<int>(mean_pos));
|
||||
mean_pos += half_width;
|
||||
half_widths->push_back(half_width);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper returns the index of the highest probability label at timestep t.
|
||||
static int BestLabel(const GENERIC_2D_ARRAY<float>& outputs, int t) {
|
||||
int result = 0;
|
||||
int num_classes = outputs.dim2();
|
||||
const float* outputs_t = outputs[t];
|
||||
for (int c = 1; c < num_classes; ++c) {
|
||||
if (outputs_t[c] > outputs_t[result]) result = c;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Calculates and returns a suitable fraction of the simple targets to add
|
||||
// to the network outputs.
|
||||
float CTC::CalculateBiasFraction() {
|
||||
// Compute output labels via basic decoding.
|
||||
GenericVector<int> output_labels;
|
||||
for (int t = 0; t < num_timesteps_; ++t) {
|
||||
int label = BestLabel(outputs_, t);
|
||||
while (t + 1 < num_timesteps_ && BestLabel(outputs_, t + 1) == label) ++t;
|
||||
if (label != null_char_) output_labels.push_back(label);
|
||||
}
|
||||
// Simple bag of labels error calculation.
|
||||
GenericVector<int> truth_counts(num_classes_, 0);
|
||||
GenericVector<int> output_counts(num_classes_, 0);
|
||||
for (int l = 0; l < num_labels_; ++l) {
|
||||
++truth_counts[labels_[l]];
|
||||
}
|
||||
for (int l = 0; l < output_labels.size(); ++l) {
|
||||
++output_counts[output_labels[l]];
|
||||
}
|
||||
// Count the number of true and false positive non-nulls and truth labels.
|
||||
int true_pos = 0, false_pos = 0, total_labels = 0;
|
||||
for (int c = 0; c < num_classes_; ++c) {
|
||||
if (c == null_char_) continue;
|
||||
int truth_count = truth_counts[c];
|
||||
int ocr_count = output_counts[c];
|
||||
if (truth_count > 0) {
|
||||
total_labels += truth_count;
|
||||
if (ocr_count > truth_count) {
|
||||
true_pos += truth_count;
|
||||
false_pos += ocr_count - truth_count;
|
||||
} else {
|
||||
true_pos += ocr_count;
|
||||
}
|
||||
}
|
||||
// We don't need to count classes that don't exist in the truth as
|
||||
// false positives, because they don't affect CTC at all.
|
||||
}
|
||||
if (total_labels == 0) return 0.0f;
|
||||
return exp(MAX(true_pos - false_pos, 1) * log(kMinProb_) / total_labels);
|
||||
}
|
||||
|
||||
// Given ln(x) and ln(y), returns ln(x + y), using:
|
||||
// ln(x + y) = ln(y) + ln(1 + exp(ln(y) - ln(x)), ensuring that ln(x) is the
|
||||
// bigger number to maximize precision.
|
||||
static double LogSumExp(double ln_x, double ln_y) {
|
||||
if (ln_x >= ln_y) {
|
||||
return ln_x + log1p(exp(ln_y - ln_x));
|
||||
} else {
|
||||
return ln_y + log1p(exp(ln_x - ln_y));
|
||||
}
|
||||
}
|
||||
|
||||
// Runs the forward CTC pass, filling in log_probs.
|
||||
void CTC::Forward(GENERIC_2D_ARRAY<double>* log_probs) const {
|
||||
log_probs->Resize(num_timesteps_, num_labels_, -MAX_FLOAT32);
|
||||
log_probs->put(0, 0, log(outputs_(0, labels_[0])));
|
||||
if (labels_[0] == null_char_)
|
||||
log_probs->put(0, 1, log(outputs_(0, labels_[1])));
|
||||
for (int t = 1; t < num_timesteps_; ++t) {
|
||||
const float* outputs_t = outputs_[t];
|
||||
for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
|
||||
// Continuing the same label.
|
||||
double log_sum = log_probs->get(t - 1, u);
|
||||
// Change from previous label.
|
||||
if (u > 0) {
|
||||
log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 1));
|
||||
}
|
||||
// Skip the null if allowed.
|
||||
if (u >= 2 && labels_[u - 1] == null_char_ &&
|
||||
labels_[u] != labels_[u - 2]) {
|
||||
log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 2));
|
||||
}
|
||||
// Add in the log prob of the current label.
|
||||
double label_prob = outputs_t[labels_[u]];
|
||||
log_sum += log(label_prob);
|
||||
log_probs->put(t, u, log_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Runs the backward CTC pass, filling in log_probs.
|
||||
void CTC::Backward(GENERIC_2D_ARRAY<double>* log_probs) const {
|
||||
log_probs->Resize(num_timesteps_, num_labels_, -MAX_FLOAT32);
|
||||
log_probs->put(num_timesteps_ - 1, num_labels_ - 1, 0.0);
|
||||
if (labels_[num_labels_ - 1] == null_char_)
|
||||
log_probs->put(num_timesteps_ - 1, num_labels_ - 2, 0.0);
|
||||
for (int t = num_timesteps_ - 2; t >= 0; --t) {
|
||||
const float* outputs_tp1 = outputs_[t + 1];
|
||||
for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
|
||||
// Continuing the same label.
|
||||
double log_sum = log_probs->get(t + 1, u) + log(outputs_tp1[labels_[u]]);
|
||||
// Change from previous label.
|
||||
if (u + 1 < num_labels_) {
|
||||
double prev_prob = outputs_tp1[labels_[u + 1]];
|
||||
log_sum =
|
||||
LogSumExp(log_sum, log_probs->get(t + 1, u + 1) + log(prev_prob));
|
||||
}
|
||||
// Skip the null if allowed.
|
||||
if (u + 2 < num_labels_ && labels_[u + 1] == null_char_ &&
|
||||
labels_[u] != labels_[u + 2]) {
|
||||
double skip_prob = outputs_tp1[labels_[u + 2]];
|
||||
log_sum =
|
||||
LogSumExp(log_sum, log_probs->get(t + 1, u + 2) + log(skip_prob));
|
||||
}
|
||||
log_probs->put(t, u, log_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalizes and brings probs out of log space with a softmax over time.
|
||||
void CTC::NormalizeSequence(GENERIC_2D_ARRAY<double>* probs) const {
|
||||
double max_logprob = probs->Max();
|
||||
for (int u = 0; u < num_labels_; ++u) {
|
||||
double total = 0.0;
|
||||
for (int t = 0; t < num_timesteps_; ++t) {
|
||||
// Separate impossible path from unlikely probs.
|
||||
double prob = probs->get(t, u);
|
||||
if (prob > -MAX_FLOAT32)
|
||||
prob = ClippedExp(prob - max_logprob);
|
||||
else
|
||||
prob = 0.0;
|
||||
total += prob;
|
||||
probs->put(t, u, prob);
|
||||
}
|
||||
// Note that although this is a probability distribution over time and
|
||||
// therefore should sum to 1, it is important to allow some labels to be
|
||||
// all zero, (or at least tiny) as it is necessary to skip some blanks.
|
||||
if (total < kMinTotalTimeProb_) total = kMinTotalTimeProb_;
|
||||
for (int t = 0; t < num_timesteps_; ++t)
|
||||
probs->put(t, u, probs->get(t, u) / total);
|
||||
}
|
||||
}
|
||||
|
||||
// For each timestep computes the max prob for each class over all
|
||||
// instances of the class in the labels_, and sets the targets to
|
||||
// the max observed prob.
|
||||
void CTC::LabelsToClasses(const GENERIC_2D_ARRAY<double>& probs,
|
||||
NetworkIO* targets) const {
|
||||
// For each timestep compute the max prob for each class over all
|
||||
// instances of the class in the labels_.
|
||||
GenericVector<double> class_probs;
|
||||
for (int t = 0; t < num_timesteps_; ++t) {
|
||||
float* targets_t = targets->f(t);
|
||||
class_probs.init_to_size(num_classes_, 0.0);
|
||||
for (int u = 0; u < num_labels_; ++u) {
|
||||
double prob = probs(t, u);
|
||||
// Note that although Graves specifies sum over all labels of the same
|
||||
// class, we need to allow skipped blanks to go to zero, so they don't
|
||||
// interfere with the non-blanks, so max is better than sum.
|
||||
if (prob > class_probs[labels_[u]]) class_probs[labels_[u]] = prob;
|
||||
// class_probs[labels_[u]] += prob;
|
||||
}
|
||||
int best_class = 0;
|
||||
for (int c = 0; c < num_classes_; ++c) {
|
||||
targets_t[c] = class_probs[c];
|
||||
if (class_probs[c] > class_probs[best_class]) best_class = c;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalizes the probabilities such that no target has a prob below min_prob,
|
||||
// and, provided that the initial total is at least min_total_prob, then all
|
||||
// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
|
||||
// probability is thus 1 - (num_classes-1)*min_prob.
|
||||
/* static */
|
||||
void CTC::NormalizeProbs(GENERIC_2D_ARRAY<float>* probs) {
|
||||
int num_timesteps = probs->dim1();
|
||||
int num_classes = probs->dim2();
|
||||
for (int t = 0; t < num_timesteps; ++t) {
|
||||
float* probs_t = (*probs)[t];
|
||||
// Compute the total and clip that to prevent amplification of noise.
|
||||
double total = 0.0;
|
||||
for (int c = 0; c < num_classes; ++c) total += probs_t[c];
|
||||
if (total < kMinTotalFinalProb_) total = kMinTotalFinalProb_;
|
||||
// Compute the increased total as a result of clipping.
|
||||
double increment = 0.0;
|
||||
for (int c = 0; c < num_classes; ++c) {
|
||||
double prob = probs_t[c] / total;
|
||||
if (prob < kMinProb_) increment += kMinProb_ - prob;
|
||||
}
|
||||
// Now normalize with clipping. Any additional clipping is negligible.
|
||||
total += increment;
|
||||
for (int c = 0; c < num_classes; ++c) {
|
||||
float prob = probs_t[c] / total;
|
||||
probs_t[c] = MAX(prob, kMinProb_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if the label at index is a needed null.
|
||||
bool CTC::NeededNull(int index) const {
|
||||
return labels_[index] == null_char_ && index > 0 && index + 1 < num_labels_ &&
|
||||
labels_[index + 1] == labels_[index - 1];
|
||||
}
|
||||
|
||||
} // namespace tesseract
|
130
lstm/ctc.h
Normal file
130
lstm/ctc.h
Normal file
@ -0,0 +1,130 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: ctc.h
|
||||
// Description: Slightly improved standard CTC to compute the targets.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Jul 13 15:17:06 PDT 2016
|
||||
//
|
||||
// (C) Copyright 2016, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_CTC_H_
|
||||
#define TESSERACT_LSTM_CTC_H_
|
||||
|
||||
#include "genericvector.h"
|
||||
#include "network.h"
|
||||
#include "networkio.h"
|
||||
#include "scrollview.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Class to encapsulate CTC and simple target generation.
|
||||
class CTC {
|
||||
public:
|
||||
// Normalizes the probabilities such that no target has a prob below min_prob,
|
||||
// and, provided that the initial total is at least min_total_prob, then all
|
||||
// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
|
||||
// probability is thus 1 - (num_classes-1)*min_prob.
|
||||
static void NormalizeProbs(NetworkIO* probs) {
|
||||
NormalizeProbs(probs->mutable_float_array());
|
||||
}
|
||||
|
||||
// Builds a target using CTC. Slightly improved as follows:
|
||||
// Includes normalizations and clipping for stability.
|
||||
// labels should be pre-padded with nulls wherever desired, but they don't
|
||||
// have to be between all labels. Allows for multi-label codes with no
|
||||
// nulls between.
|
||||
// labels can be longer than the time sequence, but the total number of
|
||||
// essential labels (non-null plus nulls between equal labels) must not exceed
|
||||
// the number of timesteps in outputs.
|
||||
// outputs is the output of the network, and should have already been
|
||||
// normalized with NormalizeProbs.
|
||||
// On return targets is filled with the computed targets.
|
||||
// Returns false if there is insufficient time for the labels.
|
||||
static bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
|
||||
int null_char,
|
||||
const GENERIC_2D_ARRAY<float>& outputs,
|
||||
NetworkIO* targets);
|
||||
|
||||
private:
|
||||
// Constructor is private as the instance only holds information specific to
|
||||
// the current labels, outputs etc, and is built by the static function.
|
||||
CTC(const GenericVector<int>& labels, int null_char,
|
||||
const GENERIC_2D_ARRAY<float>& outputs);
|
||||
|
||||
// Computes vectors of min and max label index for each timestep, based on
|
||||
// whether skippability of nulls makes it possible to complete a valid path.
|
||||
bool ComputeLabelLimits();
|
||||
// Computes targets based purely on the labels by spreading the labels evenly
|
||||
// over the available timesteps.
|
||||
void ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const;
|
||||
// Computes mean positions and half widths of the simple targets by spreading
|
||||
// the labels even over the available timesteps.
|
||||
void ComputeWidthsAndMeans(GenericVector<float>* half_widths,
|
||||
GenericVector<int>* means) const;
|
||||
// Calculates and returns a suitable fraction of the simple targets to add
|
||||
// to the network outputs.
|
||||
float CalculateBiasFraction();
|
||||
// Runs the forward CTC pass, filling in log_probs.
|
||||
void Forward(GENERIC_2D_ARRAY<double>* log_probs) const;
|
||||
// Runs the backward CTC pass, filling in log_probs.
|
||||
void Backward(GENERIC_2D_ARRAY<double>* log_probs) const;
|
||||
// Normalizes and brings probs out of log space with a softmax over time.
|
||||
void NormalizeSequence(GENERIC_2D_ARRAY<double>* probs) const;
|
||||
// For each timestep computes the max prob for each class over all
|
||||
// instances of the class in the labels_, and sets the targets to
|
||||
// the max observed prob.
|
||||
void LabelsToClasses(const GENERIC_2D_ARRAY<double>& probs,
|
||||
NetworkIO* targets) const;
|
||||
// Normalizes the probabilities such that no target has a prob below min_prob,
|
||||
// and, provided that the initial total is at least min_total_prob, then all
|
||||
// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
|
||||
// probability is thus 1 - (num_classes-1)*min_prob.
|
||||
static void NormalizeProbs(GENERIC_2D_ARRAY<float>* probs);
|
||||
// Returns true if the label at index is a needed null.
|
||||
bool NeededNull(int index) const;
|
||||
// Returns exp(clipped(x)), clipping x to a reasonable range to prevent over/
|
||||
// underflow.
|
||||
static double ClippedExp(double x) {
|
||||
if (x < -kMaxExpArg_) return exp(-kMaxExpArg_);
|
||||
if (x > kMaxExpArg_) return exp(kMaxExpArg_);
|
||||
return exp(x);
|
||||
}
|
||||
|
||||
// Minimum probability limit for softmax input to ctc_loss.
|
||||
static const float kMinProb_;
|
||||
// Maximum absolute argument to exp().
|
||||
static const double kMaxExpArg_;
|
||||
// Minimum probability for total prob in time normalization.
|
||||
static const double kMinTotalTimeProb_;
|
||||
// Minimum probability for total prob in final normalization.
|
||||
static const double kMinTotalFinalProb_;
|
||||
|
||||
// The truth label indices that are to be matched to outputs_.
|
||||
const GenericVector<int>& labels_;
|
||||
// The network outputs.
|
||||
GENERIC_2D_ARRAY<float> outputs_;
|
||||
// The null or "blank" label.
|
||||
int null_char_;
|
||||
// Number of timesteps in outputs_.
|
||||
int num_timesteps_;
|
||||
// Number of classes in outputs_.
|
||||
int num_classes_;
|
||||
// Number of labels in labels_.
|
||||
int num_labels_;
|
||||
// Min and max valid label indices for each timestep.
|
||||
GenericVector<int> min_labels_;
|
||||
GenericVector<int> max_labels_;
|
||||
};
|
||||
|
||||
} // namespace tesseract
|
||||
|
||||
#endif // TESSERACT_LSTM_CTC_H_
|
285
lstm/fullyconnected.cpp
Normal file
285
lstm/fullyconnected.cpp
Normal file
@ -0,0 +1,285 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: fullyconnected.cpp
|
||||
// Description: Simple feed-forward layer with various non-linearities.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Feb 26 14:49:15 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "fullyconnected.h"
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "functions.h"
|
||||
#include "networkscratch.h"
|
||||
|
||||
// Number of threads to use for parallel calculation of Forward and Backward.
|
||||
const int kNumThreads = 4;
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
FullyConnected::FullyConnected(const STRING& name, int ni, int no,
|
||||
NetworkType type)
|
||||
: Network(type, name, ni, no), external_source_(NULL), int_mode_(false) {
|
||||
}
|
||||
|
||||
FullyConnected::~FullyConnected() {
|
||||
}
|
||||
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
StaticShape FullyConnected::OutputShape(const StaticShape& input_shape) const {
|
||||
LossType loss_type = LT_NONE;
|
||||
if (type_ == NT_SOFTMAX)
|
||||
loss_type = LT_CTC;
|
||||
else if (type_ == NT_SOFTMAX_NO_CTC)
|
||||
loss_type = LT_SOFTMAX;
|
||||
else if (type_ == NT_LOGISTIC)
|
||||
loss_type = LT_LOGISTIC;
|
||||
StaticShape result(input_shape);
|
||||
result.set_depth(no_);
|
||||
result.set_loss_type(loss_type);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Sets up the network for training. Initializes weights using weights of
|
||||
// scale `range` picked according to the random number generator `randomizer`.
|
||||
int FullyConnected::InitWeights(float range, TRand* randomizer) {
|
||||
Network::SetRandomizer(randomizer);
|
||||
num_weights_ = weights_.InitWeightsFloat(no_, ni_ + 1, TestFlag(NF_ADA_GRAD),
|
||||
range, randomizer);
|
||||
return num_weights_;
|
||||
}
|
||||
|
||||
// Converts a float network to an int network.
|
||||
void FullyConnected::ConvertToInt() {
|
||||
weights_.ConvertToInt();
|
||||
}
|
||||
|
||||
// Provides debug output on the weights.
|
||||
void FullyConnected::DebugWeights() {
|
||||
weights_.Debug2D(name_.string());
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool FullyConnected::Serialize(TFile* fp) const {
|
||||
if (!Network::Serialize(fp)) return false;
|
||||
if (!weights_.Serialize(training_, fp)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool FullyConnected::DeSerialize(bool swap, TFile* fp) {
|
||||
if (!weights_.DeSerialize(training_, swap, fp)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
void FullyConnected::Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output) {
|
||||
int width = input.Width();
|
||||
if (type_ == NT_SOFTMAX)
|
||||
output->ResizeFloat(input, no_);
|
||||
else
|
||||
output->Resize(input, no_);
|
||||
SetupForward(input, input_transpose);
|
||||
GenericVector<NetworkScratch::FloatVec> temp_lines;
|
||||
temp_lines.init_to_size(kNumThreads, NetworkScratch::FloatVec());
|
||||
GenericVector<NetworkScratch::FloatVec> curr_input;
|
||||
curr_input.init_to_size(kNumThreads, NetworkScratch::FloatVec());
|
||||
for (int i = 0; i < temp_lines.size(); ++i) {
|
||||
temp_lines[i].Init(no_, scratch);
|
||||
curr_input[i].Init(ni_, scratch);
|
||||
}
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for num_threads(kNumThreads)
|
||||
for (int t = 0; t < width; ++t) {
|
||||
// Thread-local pointer to temporary storage.
|
||||
int thread_id = omp_get_thread_num();
|
||||
#else
|
||||
for (int t = 0; t < width; ++t) {
|
||||
// Thread-local pointer to temporary storage.
|
||||
int thread_id = 0;
|
||||
#endif
|
||||
double* temp_line = temp_lines[thread_id];
|
||||
const double* d_input = NULL;
|
||||
const inT8* i_input = NULL;
|
||||
if (input.int_mode()) {
|
||||
i_input = input.i(t);
|
||||
} else {
|
||||
input.ReadTimeStep(t, curr_input[thread_id]);
|
||||
d_input = curr_input[thread_id];
|
||||
}
|
||||
ForwardTimeStep(d_input, i_input, t, temp_line);
|
||||
output->WriteTimeStep(t, temp_line);
|
||||
if (training() && type_ != NT_SOFTMAX) {
|
||||
acts_.CopyTimeStepFrom(t, *output, t);
|
||||
}
|
||||
}
|
||||
// Zero all the elements that are in the padding around images that allows
|
||||
// multiple different-sized images to exist in a single array.
|
||||
// acts_ is only used if this is not a softmax op.
|
||||
if (training() && type_ != NT_SOFTMAX) {
|
||||
acts_.ZeroInvalidElements();
|
||||
}
|
||||
output->ZeroInvalidElements();
|
||||
#if DEBUG_DETAIL > 0
|
||||
tprintf("F Output:%s\n", name_.string());
|
||||
output->Print(10);
|
||||
#endif
|
||||
if (debug) DisplayForward(*output);
|
||||
}
|
||||
|
||||
// Components of Forward so FullyConnected can be reused inside LSTM.
|
||||
void FullyConnected::SetupForward(const NetworkIO& input,
|
||||
const TransposedArray* input_transpose) {
|
||||
// Softmax output is always float, so save the input type.
|
||||
int_mode_ = input.int_mode();
|
||||
if (training()) {
|
||||
acts_.Resize(input, no_);
|
||||
// Source_ is a transposed copy of input. It isn't needed if provided.
|
||||
external_source_ = input_transpose;
|
||||
if (external_source_ == NULL) source_t_.ResizeNoInit(ni_, input.Width());
|
||||
}
|
||||
}
|
||||
|
||||
void FullyConnected::ForwardTimeStep(const double* d_input, const inT8* i_input,
|
||||
int t, double* output_line) {
|
||||
// input is copied to source_ line-by-line for cache coherency.
|
||||
if (training() && external_source_ == NULL && d_input != NULL)
|
||||
source_t_.WriteStrided(t, d_input);
|
||||
if (d_input != NULL)
|
||||
weights_.MatrixDotVector(d_input, output_line);
|
||||
else
|
||||
weights_.MatrixDotVector(i_input, output_line);
|
||||
if (type_ == NT_TANH) {
|
||||
FuncInplace<GFunc>(no_, output_line);
|
||||
} else if (type_ == NT_LOGISTIC) {
|
||||
FuncInplace<FFunc>(no_, output_line);
|
||||
} else if (type_ == NT_POSCLIP) {
|
||||
FuncInplace<ClipFFunc>(no_, output_line);
|
||||
} else if (type_ == NT_SYMCLIP) {
|
||||
FuncInplace<ClipGFunc>(no_, output_line);
|
||||
} else if (type_ == NT_RELU) {
|
||||
FuncInplace<Relu>(no_, output_line);
|
||||
} else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) {
|
||||
SoftmaxInPlace(no_, output_line);
|
||||
} else if (type_ != NT_LINEAR) {
|
||||
ASSERT_HOST("Invalid fully-connected type!" == NULL);
|
||||
}
|
||||
}
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas) {
|
||||
if (debug) DisplayBackward(fwd_deltas);
|
||||
back_deltas->Resize(fwd_deltas, ni_);
|
||||
GenericVector<NetworkScratch::FloatVec> errors;
|
||||
errors.init_to_size(kNumThreads, NetworkScratch::FloatVec());
|
||||
for (int i = 0; i < errors.size(); ++i) errors[i].Init(no_, scratch);
|
||||
GenericVector<NetworkScratch::FloatVec> temp_backprops;
|
||||
if (needs_to_backprop_) {
|
||||
temp_backprops.init_to_size(kNumThreads, NetworkScratch::FloatVec());
|
||||
for (int i = 0; i < kNumThreads; ++i) temp_backprops[i].Init(ni_, scratch);
|
||||
}
|
||||
int width = fwd_deltas.Width();
|
||||
NetworkScratch::GradientStore errors_t;
|
||||
errors_t.Init(no_, width, scratch);
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for num_threads(kNumThreads)
|
||||
for (int t = 0; t < width; ++t) {
|
||||
int thread_id = omp_get_thread_num();
|
||||
#else
|
||||
for (int t = 0; t < width; ++t) {
|
||||
int thread_id = 0;
|
||||
#endif
|
||||
double* backprop = NULL;
|
||||
if (needs_to_backprop_) backprop = temp_backprops[thread_id];
|
||||
double* curr_errors = errors[thread_id];
|
||||
BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
|
||||
if (backprop != NULL) {
|
||||
back_deltas->WriteTimeStep(t, backprop);
|
||||
}
|
||||
}
|
||||
FinishBackward(*errors_t.get());
|
||||
if (needs_to_backprop_) {
|
||||
back_deltas->ZeroInvalidElements();
|
||||
back_deltas->CopyWithNormalization(*back_deltas, fwd_deltas);
|
||||
#if DEBUG_DETAIL > 0
|
||||
tprintf("F Backprop:%s\n", name_.string());
|
||||
back_deltas->Print(10);
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
return false; // No point going further back.
|
||||
}
|
||||
|
||||
void FullyConnected::BackwardTimeStep(const NetworkIO& fwd_deltas, int t,
|
||||
double* curr_errors,
|
||||
TransposedArray* errors_t,
|
||||
double* backprop) {
|
||||
if (type_ == NT_TANH)
|
||||
acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
|
||||
else if (type_ == NT_LOGISTIC)
|
||||
acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors);
|
||||
else if (type_ == NT_POSCLIP)
|
||||
acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors);
|
||||
else if (type_ == NT_SYMCLIP)
|
||||
acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors);
|
||||
else if (type_ == NT_RELU)
|
||||
acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors);
|
||||
else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC ||
|
||||
type_ == NT_LINEAR)
|
||||
fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors.
|
||||
else
|
||||
ASSERT_HOST("Invalid fully-connected type!" == NULL);
|
||||
// Generate backprop only if needed by the lower layer.
|
||||
if (backprop != NULL) weights_.VectorDotMatrix(curr_errors, backprop);
|
||||
errors_t->WriteStrided(t, curr_errors);
|
||||
}
|
||||
|
||||
void FullyConnected::FinishBackward(const TransposedArray& errors_t) {
|
||||
if (external_source_ == NULL)
|
||||
weights_.SumOuterTransposed(errors_t, source_t_, true);
|
||||
else
|
||||
weights_.SumOuterTransposed(errors_t, *external_source_, true);
|
||||
}
|
||||
|
||||
// Updates the weights using the given learning rate and momentum.
|
||||
// num_samples is the quotient to be used in the adagrad computation iff
|
||||
// use_ada_grad_ is true.
|
||||
void FullyConnected::Update(float learning_rate, float momentum,
|
||||
int num_samples) {
|
||||
weights_.Update(learning_rate, momentum, num_samples);
|
||||
}
|
||||
|
||||
// Sums the products of weight updates in *this and other, splitting into
|
||||
// positive (same direction) in *same and negative (different direction) in
|
||||
// *changed.
|
||||
void FullyConnected::CountAlternators(const Network& other, double* same,
|
||||
double* changed) const {
|
||||
ASSERT_HOST(other.type() == type_);
|
||||
const FullyConnected* fc = reinterpret_cast<const FullyConnected*>(&other);
|
||||
weights_.CountAlternators(fc->weights_, same, changed);
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
130
lstm/fullyconnected.h
Normal file
130
lstm/fullyconnected.h
Normal file
@ -0,0 +1,130 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: fullyconnected.h
|
||||
// Description: Simple feed-forward layer with various non-linearities.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Feb 26 14:46:06 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_FULLYCONNECTED_H_
|
||||
#define TESSERACT_LSTM_FULLYCONNECTED_H_
|
||||
|
||||
#include "network.h"
|
||||
#include "networkscratch.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// C++ Implementation of the Softmax (output) class from lstm.py.
|
||||
class FullyConnected : public Network {
|
||||
public:
|
||||
FullyConnected(const STRING& name, int ni, int no, NetworkType type);
|
||||
virtual ~FullyConnected();
|
||||
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
virtual StaticShape OutputShape(const StaticShape& input_shape) const;
|
||||
|
||||
virtual STRING spec() const {
|
||||
STRING spec;
|
||||
if (type_ == NT_TANH)
|
||||
spec.add_str_int("Ft", no_);
|
||||
else if (type_ == NT_LOGISTIC)
|
||||
spec.add_str_int("Fs", no_);
|
||||
else if (type_ == NT_RELU)
|
||||
spec.add_str_int("Fr", no_);
|
||||
else if (type_ == NT_LINEAR)
|
||||
spec.add_str_int("Fl", no_);
|
||||
else if (type_ == NT_POSCLIP)
|
||||
spec.add_str_int("Fp", no_);
|
||||
else if (type_ == NT_SYMCLIP)
|
||||
spec.add_str_int("Fs", no_);
|
||||
else if (type_ == NT_SOFTMAX)
|
||||
spec.add_str_int("Fc", no_);
|
||||
else
|
||||
spec.add_str_int("Fm", no_);
|
||||
return spec;
|
||||
}
|
||||
|
||||
// Changes the type to the given type. Used to commute a softmax to a
|
||||
// non-output type for adding on other networks.
|
||||
void ChangeType(NetworkType type) {
|
||||
type_ = type;
|
||||
}
|
||||
|
||||
// Sets up the network for training. Initializes weights using weights of
|
||||
// scale `range` picked according to the random number generator `randomizer`.
|
||||
virtual int InitWeights(float range, TRand* randomizer);
|
||||
|
||||
// Converts a float network to an int network.
|
||||
virtual void ConvertToInt();
|
||||
|
||||
// Provides debug output on the weights.
|
||||
virtual void DebugWeights();
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
virtual bool Serialize(TFile* fp) const;
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
virtual bool DeSerialize(bool swap, TFile* fp);
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual void Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output);
|
||||
// Components of Forward so FullyConnected can be reused inside LSTM.
|
||||
void SetupForward(const NetworkIO& input,
|
||||
const TransposedArray* input_transpose);
|
||||
void ForwardTimeStep(const double* d_input, const inT8* i_input, int t,
|
||||
double* output_line);
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas);
|
||||
// Components of Backward so FullyConnected can be reused inside LSTM.
|
||||
void BackwardTimeStep(const NetworkIO& fwd_deltas, int t, double* curr_errors,
|
||||
TransposedArray* errors_t, double* backprop);
|
||||
void FinishBackward(const TransposedArray& errors_t);
|
||||
|
||||
// Updates the weights using the given learning rate and momentum.
|
||||
// num_samples is the quotient to be used in the adagrad computation iff
|
||||
// use_ada_grad_ is true.
|
||||
virtual void Update(float learning_rate, float momentum, int num_samples);
|
||||
// Sums the products of weight updates in *this and other, splitting into
|
||||
// positive (same direction) in *same and negative (different direction) in
|
||||
// *changed.
|
||||
virtual void CountAlternators(const Network& other, double* same,
|
||||
double* changed) const;
|
||||
|
||||
protected:
|
||||
// Weight arrays of size [no, ni + 1].
|
||||
WeightMatrix weights_;
|
||||
// Transposed copy of input used during training of size [ni, width].
|
||||
TransposedArray source_t_;
|
||||
// Pointer to transposed input stored elsewhere. If not null, this is used
|
||||
// in preference to calculating the transpose and storing it in source_t_.
|
||||
const TransposedArray* external_source_;
|
||||
// Activations from forward pass of size [width, no].
|
||||
NetworkIO acts_;
|
||||
// Memory of the integer mode input to forward as softmax always outputs
|
||||
// float, so the information is otherwise lost.
|
||||
bool int_mode_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
|
||||
|
||||
#endif // TESSERACT_LSTM_FULLYCONNECTED_H_
|
26
lstm/functions.cpp
Normal file
26
lstm/functions.cpp
Normal file
@ -0,0 +1,26 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: functions.cpp
|
||||
// Description: Static initialize-on-first-use non-linearity functions.
|
||||
// Author: Ray Smith
|
||||
// Created: Tue Jul 17 14:02:59 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "functions.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
double TanhTable[kTableSize];
|
||||
double LogisticTable[kTableSize];
|
||||
|
||||
} // namespace tesseract.
|
249
lstm/functions.h
Normal file
249
lstm/functions.h
Normal file
@ -0,0 +1,249 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: functions.h
|
||||
// Description: Collection of function-objects used by the network layers.
|
||||
// Author: Ray Smith
|
||||
// Created: Fri Jun 20 10:45:37 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_FUNCTIONS_H_
|
||||
#define TESSERACT_LSTM_FUNCTIONS_H_
|
||||
|
||||
#include <cmath>
|
||||
#include "helpers.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
// Setting this to 1 or more causes massive dumps of debug data: weights,
|
||||
// updates, internal calculations etc, and reduces the number of test iterations
|
||||
// to a small number, so outputs can be diffed.
|
||||
#define DEBUG_DETAIL 0
|
||||
#if DEBUG_DETAIL > 0
|
||||
#undef _OPENMP // Disable open mp to get the outputs in sync.
|
||||
#endif
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Size of static tables.
|
||||
const int kTableSize = 4096;
|
||||
// Scale factor for float arg to int index.
|
||||
const double kScaleFactor = 256.0;
|
||||
|
||||
extern double TanhTable[];
|
||||
extern double LogisticTable[];
|
||||
|
||||
// Non-linearity (sigmoid) functions with cache tables and clipping.
|
||||
inline double Tanh(double x) {
|
||||
if (x < 0.0) return -Tanh(-x);
|
||||
if (x >= (kTableSize - 1) / kScaleFactor) return 1.0;
|
||||
x *= kScaleFactor;
|
||||
int index = static_cast<int>(floor(x));
|
||||
if (TanhTable[index] == 0.0 && index > 0) {
|
||||
// Generate the entry.
|
||||
TanhTable[index] = tanh(index / kScaleFactor);
|
||||
}
|
||||
if (index == kTableSize - 1) return TanhTable[kTableSize - 1];
|
||||
if (TanhTable[index + 1] == 0.0) {
|
||||
// Generate the entry.
|
||||
TanhTable[index + 1] = tanh((index + 1) / kScaleFactor);
|
||||
}
|
||||
double offset = x - index;
|
||||
return TanhTable[index] * (1.0 - offset) + TanhTable[index + 1] * offset;
|
||||
}
|
||||
|
||||
inline double Logistic(double x) {
|
||||
if (x < 0.0) return 1.0 - Logistic(-x);
|
||||
if (x >= (kTableSize - 1) / kScaleFactor) return 1.0;
|
||||
x *= kScaleFactor;
|
||||
int index = static_cast<int>(floor(x));
|
||||
if (LogisticTable[index] == 0.0) {
|
||||
// Generate the entry.
|
||||
LogisticTable[index] = 1.0 / (1.0 + exp(-index / kScaleFactor));
|
||||
}
|
||||
if (index == kTableSize - 1) return LogisticTable[kTableSize - 1];
|
||||
if (LogisticTable[index + 1] == 0.0) {
|
||||
// Generate the entry.
|
||||
LogisticTable[index + 1] = 1.0 / (1.0 + exp(-(index + 1) / kScaleFactor));
|
||||
}
|
||||
double offset = x - index;
|
||||
return LogisticTable[index] * (1.0 - offset) +
|
||||
LogisticTable[index + 1] * offset;
|
||||
}
|
||||
|
||||
// Non-linearity (sigmoid) functions and their derivatives.
|
||||
struct FFunc {
|
||||
inline double operator()(double x) const { return Logistic(x); }
|
||||
};
|
||||
struct FPrime {
|
||||
inline double operator()(double y) const { return y * (1.0 - y); }
|
||||
};
|
||||
struct ClipFFunc {
|
||||
inline double operator()(double x) const {
|
||||
if (x <= 0.0) return 0.0;
|
||||
if (x >= 1.0) return 1.0;
|
||||
return x;
|
||||
}
|
||||
};
|
||||
struct ClipFPrime {
|
||||
inline double operator()(double y) const {
|
||||
return 0.0 < y && y < 1.0 ? 1.0 : 0.0;
|
||||
}
|
||||
};
|
||||
struct Relu {
|
||||
inline double operator()(double x) const {
|
||||
if (x <= 0.0) return 0.0;
|
||||
return x;
|
||||
}
|
||||
};
|
||||
struct ReluPrime {
|
||||
inline double operator()(double y) const { return 0.0 < y ? 1.0 : 0.0; }
|
||||
};
|
||||
struct GFunc {
|
||||
inline double operator()(double x) const { return Tanh(x); }
|
||||
};
|
||||
struct GPrime {
|
||||
inline double operator()(double y) const { return 1.0 - y * y; }
|
||||
};
|
||||
struct ClipGFunc {
|
||||
inline double operator()(double x) const {
|
||||
if (x <= -1.0) return -1.0;
|
||||
if (x >= 1.0) return 1.0;
|
||||
return x;
|
||||
}
|
||||
};
|
||||
struct ClipGPrime {
|
||||
inline double operator()(double y) const {
|
||||
return -1.0 < y && y < 1.0 ? 1.0 : 0.0;
|
||||
}
|
||||
};
|
||||
struct HFunc {
|
||||
inline double operator()(double x) const { return Tanh(x); }
|
||||
};
|
||||
struct HPrime {
|
||||
inline double operator()(double y) const {
|
||||
double u = Tanh(y);
|
||||
return 1.0 - u * u;
|
||||
}
|
||||
};
|
||||
struct UnityFunc {
|
||||
inline double operator()(double x) const { return 1.0; }
|
||||
};
|
||||
struct IdentityFunc {
|
||||
inline double operator()(double x) const { return x; }
|
||||
};
|
||||
|
||||
// Applies Func in-place to inout, of size n.
|
||||
template <class Func>
|
||||
inline void FuncInplace(int n, double* inout) {
|
||||
Func f;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
inout[i] = f(inout[i]);
|
||||
}
|
||||
}
|
||||
// Applies Func to u and multiplies the result by v component-wise,
|
||||
// putting the product in out, all of size n.
|
||||
template <class Func>
|
||||
inline void FuncMultiply(const double* u, const double* v, int n, double* out) {
|
||||
Func f;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
out[i] = f(u[i]) * v[i];
|
||||
}
|
||||
}
|
||||
// Applies the Softmax function in-place to inout, of size n.
|
||||
template <typename T>
|
||||
inline void SoftmaxInPlace(int n, T* inout) {
|
||||
if (n <= 0) return;
|
||||
// A limit on the negative range input to exp to guarantee non-zero output.
|
||||
const T kMaxSoftmaxActivation = 86.0f;
|
||||
|
||||
T max_output = inout[0];
|
||||
for (int i = 1; i < n; i++) {
|
||||
T output = inout[i];
|
||||
if (output > max_output) max_output = output;
|
||||
}
|
||||
T prob_total = 0.0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
T prob = inout[i] - max_output;
|
||||
prob = exp(ClipToRange(prob, -kMaxSoftmaxActivation, static_cast<T>(0)));
|
||||
prob_total += prob;
|
||||
inout[i] = prob;
|
||||
}
|
||||
if (prob_total > 0.0) {
|
||||
for (int i = 0; i < n; i++) inout[i] /= prob_total;
|
||||
}
|
||||
}
|
||||
|
||||
// Copies n values of the given src vector to dest.
|
||||
inline void CopyVector(int n, const double* src, double* dest) {
|
||||
memcpy(dest, src, n * sizeof(dest[0]));
|
||||
}
|
||||
|
||||
// Adds n values of the given src vector to dest.
|
||||
inline void AccumulateVector(int n, const double* src, double* dest) {
|
||||
for (int i = 0; i < n; ++i) dest[i] += src[i];
|
||||
}
|
||||
|
||||
// Multiplies n values of inout in-place element-wise by the given src vector.
|
||||
inline void MultiplyVectorsInPlace(int n, const double* src, double* inout) {
|
||||
for (int i = 0; i < n; ++i) inout[i] *= src[i];
|
||||
}
|
||||
|
||||
// Multiplies n values of u by v, element-wise, accumulating to out.
|
||||
inline void MultiplyAccumulate(int n, const double* u, const double* v,
|
||||
double* out) {
|
||||
for (int i = 0; i < n; i++) {
|
||||
out[i] += u[i] * v[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Sums the given 5 n-vectors putting the result into sum.
|
||||
inline void SumVectors(int n, const double* v1, const double* v2,
|
||||
const double* v3, const double* v4, const double* v5,
|
||||
double* sum) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
sum[i] = v1[i] + v2[i] + v3[i] + v4[i] + v5[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Sets the given n-vector vec to 0.
|
||||
template <typename T>
|
||||
inline void ZeroVector(int n, T* vec) {
|
||||
memset(vec, 0, n * sizeof(*vec));
|
||||
}
|
||||
|
||||
// Clips the given vector vec, of size n to [lower, upper].
|
||||
template <typename T>
|
||||
inline void ClipVector(int n, T lower, T upper, T* vec) {
|
||||
for (int i = 0; i < n; ++i) vec[i] = ClipToRange(vec[i], lower, upper);
|
||||
}
|
||||
|
||||
// Converts the given n-vector to a binary encoding of the maximum value,
|
||||
// encoded as vector of nf binary values.
|
||||
inline void CodeInBinary(int n, int nf, double* vec) {
|
||||
if (nf <= 0 || n < nf) return;
|
||||
int index = 0;
|
||||
double best_score = vec[0];
|
||||
for (int i = 1; i < n; ++i) {
|
||||
if (vec[i] > best_score) {
|
||||
best_score = vec[i];
|
||||
index = i;
|
||||
}
|
||||
}
|
||||
int mask = 1;
|
||||
for (int i = 0; i < nf; ++i, mask *= 2) {
|
||||
vec[i] = (index & mask) ? 1.0 : 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_FUNCTIONS_H_
|
154
lstm/input.cpp
Normal file
154
lstm/input.cpp
Normal file
@ -0,0 +1,154 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: input.cpp
|
||||
// Description: Input layer class for neural network implementations.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu Mar 13 09:10:34 PDT 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "input.h"
|
||||
|
||||
#include "allheaders.h"
|
||||
#include "imagedata.h"
|
||||
#include "pageres.h"
|
||||
#include "scrollview.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
Input::Input(const STRING& name, int ni, int no)
|
||||
: Network(NT_INPUT, name, ni, no), cached_x_scale_(1) {}
|
||||
Input::Input(const STRING& name, const StaticShape& shape)
|
||||
: Network(NT_INPUT, name, shape.height(), shape.depth()),
|
||||
shape_(shape),
|
||||
cached_x_scale_(1) {
|
||||
if (shape.height() == 1) ni_ = shape.depth();
|
||||
}
|
||||
|
||||
Input::~Input() {
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool Input::Serialize(TFile* fp) const {
|
||||
if (!Network::Serialize(fp)) return false;
|
||||
if (fp->FWrite(&shape_, sizeof(shape_), 1) != 1) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool Input::DeSerialize(bool swap, TFile* fp) {
|
||||
if (fp->FRead(&shape_, sizeof(shape_), 1) != 1) return false;
|
||||
// TODO(rays) swaps!
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns an integer reduction factor that the network applies to the
|
||||
// time sequence. Assumes that any 2-d is already eliminated. Used for
|
||||
// scaling bounding boxes of truth data.
|
||||
int Input::XScaleFactor() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Provides the (minimum) x scale factor to the network (of interest only to
|
||||
// input units) so they can determine how to scale bounding boxes.
|
||||
void Input::CacheXScaleFactor(int factor) {
|
||||
cached_x_scale_ = factor;
|
||||
}
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
void Input::Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output) {
|
||||
*output = input;
|
||||
}
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
bool Input::Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas) {
|
||||
tprintf("Input::Backward should not be called!!\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Creates and returns a Pix of appropriate size for the network from the
|
||||
// image_data. If non-null, *image_scale returns the image scale factor used.
|
||||
// Returns nullptr on error.
|
||||
/* static */
|
||||
Pix* Input::PrepareLSTMInputs(const ImageData& image_data,
|
||||
const Network* network, int min_width,
|
||||
TRand* randomizer, float* image_scale) {
|
||||
// Note that NumInputs() is defined as input image height.
|
||||
int target_height = network->NumInputs();
|
||||
int width, height;
|
||||
Pix* pix =
|
||||
image_data.PreScale(target_height, image_scale, &width, &height, nullptr);
|
||||
if (pix == nullptr) {
|
||||
tprintf("Bad pix from ImageData!\n");
|
||||
return nullptr;
|
||||
}
|
||||
if (width <= min_width) {
|
||||
tprintf("Image too small to scale!! (%dx%d vs min width of %d)\n", width,
|
||||
height, min_width);
|
||||
pixDestroy(&pix);
|
||||
return nullptr;
|
||||
}
|
||||
return pix;
|
||||
}
|
||||
|
||||
// Converts the given pix to a NetworkIO of height and depth appropriate to the
|
||||
// given StaticShape:
|
||||
// If depth == 3, convert to 24 bit color, otherwise normalized grey.
|
||||
// Scale to target height, if the shape's height is > 1, or its depth if the
|
||||
// height == 1. If height == 0 then no scaling.
|
||||
// NOTE: It isn't safe for multiple threads to call this on the same pix.
|
||||
/* static */
|
||||
void Input::PreparePixInput(const StaticShape& shape, const Pix* pix,
|
||||
TRand* randomizer, NetworkIO* input) {
|
||||
bool color = shape.depth() == 3;
|
||||
Pix* var_pix = const_cast<Pix*>(pix);
|
||||
int depth = pixGetDepth(var_pix);
|
||||
Pix* normed_pix = nullptr;
|
||||
// On input to BaseAPI, an image is forced to be 1, 8 or 24 bit, without
|
||||
// colormap, so we just have to deal with depth conversion here.
|
||||
if (color) {
|
||||
// Force RGB.
|
||||
if (depth == 32)
|
||||
normed_pix = pixClone(var_pix);
|
||||
else
|
||||
normed_pix = pixConvertTo32(var_pix);
|
||||
} else {
|
||||
// Convert non-8-bit images to 8 bit.
|
||||
if (depth == 8)
|
||||
normed_pix = pixClone(var_pix);
|
||||
else
|
||||
normed_pix = pixConvertTo8(var_pix, false);
|
||||
}
|
||||
int width = pixGetWidth(normed_pix);
|
||||
int height = pixGetHeight(normed_pix);
|
||||
int target_height = shape.height();
|
||||
if (target_height == 1) target_height = shape.depth();
|
||||
if (target_height == 0) target_height = height;
|
||||
float im_factor = static_cast<float>(target_height) / height;
|
||||
if (im_factor != 1.0f) {
|
||||
// Get the scaled image.
|
||||
Pix* scaled_pix = pixScale(normed_pix, im_factor, im_factor);
|
||||
pixDestroy(&normed_pix);
|
||||
normed_pix = scaled_pix;
|
||||
}
|
||||
input->FromPix(shape, normed_pix, randomizer);
|
||||
pixDestroy(&normed_pix);
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
107
lstm/input.h
Normal file
107
lstm/input.h
Normal file
@ -0,0 +1,107 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: input.h
|
||||
// Description: Input layer class for neural network implementations.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu Mar 13 08:56:26 PDT 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_INPUT_H_
|
||||
#define TESSERACT_LSTM_INPUT_H_
|
||||
|
||||
#include "network.h"
|
||||
|
||||
class ScrollView;
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
class Input : public Network {
|
||||
public:
|
||||
Input(const STRING& name, int ni, int no);
|
||||
Input(const STRING& name, const StaticShape& shape);
|
||||
virtual ~Input();
|
||||
|
||||
virtual STRING spec() const {
|
||||
STRING spec;
|
||||
spec.add_str_int("", shape_.batch());
|
||||
spec.add_str_int(",", shape_.height());
|
||||
spec.add_str_int(",", shape_.width());
|
||||
spec.add_str_int(",", shape_.depth());
|
||||
return spec;
|
||||
}
|
||||
|
||||
// Returns the required shape input to the network.
|
||||
virtual StaticShape InputShape() const { return shape_; }
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
virtual StaticShape OutputShape(const StaticShape& input_shape) const {
|
||||
return shape_;
|
||||
}
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
// Should be overridden by subclasses, but called by their Serialize.
|
||||
virtual bool Serialize(TFile* fp) const;
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
// Should be overridden by subclasses, but NOT called by their DeSerialize.
|
||||
virtual bool DeSerialize(bool swap, TFile* fp);
|
||||
|
||||
// Returns an integer reduction factor that the network applies to the
|
||||
// time sequence. Assumes that any 2-d is already eliminated. Used for
|
||||
// scaling bounding boxes of truth data.
|
||||
// WARNING: if GlobalMinimax is used to vary the scale, this will return
|
||||
// the last used scale factor. Call it before any forward, and it will return
|
||||
// the minimum scale factor of the paths through the GlobalMinimax.
|
||||
virtual int XScaleFactor() const;
|
||||
|
||||
// Provides the (minimum) x scale factor to the network (of interest only to
|
||||
// input units) so they can determine how to scale bounding boxes.
|
||||
virtual void CacheXScaleFactor(int factor);
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual void Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output);
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas);
|
||||
// Creates and returns a Pix of appropriate size for the network from the
|
||||
// image_data. If non-null, *image_scale returns the image scale factor used.
|
||||
// Returns nullptr on error.
|
||||
/* static */
|
||||
static Pix* PrepareLSTMInputs(const ImageData& image_data,
|
||||
const Network* network, int min_width,
|
||||
TRand* randomizer, float* image_scale);
|
||||
// Converts the given pix to a NetworkIO of height and depth appropriate to
|
||||
// the given StaticShape:
|
||||
// If depth == 3, convert to 24 bit color, otherwise normalized grey.
|
||||
// Scale to target height, if the shape's height is > 1, or its depth if the
|
||||
// height == 1. If height == 0 then no scaling.
|
||||
// NOTE: It isn't safe for multiple threads to call this on the same pix.
|
||||
static void PreparePixInput(const StaticShape& shape, const Pix* pix,
|
||||
TRand* randomizer, NetworkIO* input);
|
||||
|
||||
private:
|
||||
// Input shape determines how images are dealt with.
|
||||
StaticShape shape_;
|
||||
// Cached total network x scale factor for scaling bounding boxes.
|
||||
int cached_x_scale_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_INPUT_H_
|
||||
|
710
lstm/lstm.cpp
Normal file
710
lstm/lstm.cpp
Normal file
@ -0,0 +1,710 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: lstm.cpp
|
||||
// Description: Long-term-short-term-memory Recurrent neural network.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed May 01 17:43:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "lstm.h"
|
||||
|
||||
#ifndef ANDROID_BUILD
|
||||
#include <omp.h>
|
||||
#endif
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "fullyconnected.h"
|
||||
#include "functions.h"
|
||||
#include "networkscratch.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
// Macros for openmp code if it is available, otherwise empty macros.
|
||||
#ifdef _OPENMP
|
||||
#define PARALLEL_IF_OPENMP(__num_threads) \
|
||||
PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \
|
||||
PRAGMA(omp sections nowait) { \
|
||||
PRAGMA(omp section) {
|
||||
#define SECTION_IF_OPENMP \
|
||||
} \
|
||||
PRAGMA(omp section) \
|
||||
{
|
||||
|
||||
#define END_PARALLEL_IF_OPENMP \
|
||||
} \
|
||||
} /* end of sections */ \
|
||||
} /* end of parallel section */
|
||||
|
||||
// Define the portable PRAGMA macro.
|
||||
#ifdef _MSC_VER // Different _Pragma
|
||||
#define PRAGMA(x) __pragma(x)
|
||||
#else
|
||||
#define PRAGMA(x) _Pragma(#x)
|
||||
#endif // _MSC_VER
|
||||
|
||||
#else // _OPENMP
|
||||
#define PARALLEL_IF_OPENMP(__num_threads)
|
||||
#define SECTION_IF_OPENMP
|
||||
#define END_PARALLEL_IF_OPENMP
|
||||
#endif // _OPENMP
|
||||
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Max absolute value of state_. It is reasonably high to enable the state
|
||||
// to count things.
|
||||
const double kStateClip = 100.0;
|
||||
// Max absolute value of gate_errors (the gradients).
|
||||
const double kErrClip = 1.0f;
|
||||
|
||||
LSTM::LSTM(const STRING& name, int ni, int ns, int no, bool two_dimensional,
|
||||
NetworkType type)
|
||||
: Network(type, name, ni, no),
|
||||
na_(ni + ns),
|
||||
ns_(ns),
|
||||
nf_(0),
|
||||
is_2d_(two_dimensional),
|
||||
softmax_(NULL) {
|
||||
if (two_dimensional) na_ += ns_;
|
||||
if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
|
||||
nf_ = 0;
|
||||
// networkbuilder ensures this is always true.
|
||||
ASSERT_HOST(no == ns);
|
||||
} else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
|
||||
nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : IntCastRounded(ceil(log2(no_)));
|
||||
softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
|
||||
} else {
|
||||
tprintf("%d is invalid type of LSTM!\n", type);
|
||||
ASSERT_HOST(false);
|
||||
}
|
||||
na_ += nf_;
|
||||
}
|
||||
|
||||
LSTM::~LSTM() { delete softmax_; }
|
||||
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
|
||||
StaticShape result = input_shape;
|
||||
result.set_depth(no_);
|
||||
if (type_ == NT_LSTM_SUMMARY) result.set_width(1);
|
||||
if (softmax_ != NULL) return softmax_->OutputShape(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Sets up the network for training. Initializes weights using weights of
|
||||
// scale `range` picked according to the random number generator `randomizer`.
|
||||
int LSTM::InitWeights(float range, TRand* randomizer) {
|
||||
Network::SetRandomizer(randomizer);
|
||||
num_weights_ = 0;
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
if (w == GFS && !Is2D()) continue;
|
||||
num_weights_ += gate_weights_[w].InitWeightsFloat(
|
||||
ns_, na_ + 1, TestFlag(NF_ADA_GRAD), range, randomizer);
|
||||
}
|
||||
if (softmax_ != NULL) {
|
||||
num_weights_ += softmax_->InitWeights(range, randomizer);
|
||||
}
|
||||
return num_weights_;
|
||||
}
|
||||
|
||||
// Converts a float network to an int network.
|
||||
void LSTM::ConvertToInt() {
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
if (w == GFS && !Is2D()) continue;
|
||||
gate_weights_[w].ConvertToInt();
|
||||
}
|
||||
if (softmax_ != NULL) {
|
||||
softmax_->ConvertToInt();
|
||||
}
|
||||
}
|
||||
|
||||
// Sets up the network for training using the given weight_range.
|
||||
void LSTM::DebugWeights() {
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
if (w == GFS && !Is2D()) continue;
|
||||
STRING msg = name_;
|
||||
msg.add_str_int(" Gate weights ", w);
|
||||
gate_weights_[w].Debug2D(msg.string());
|
||||
}
|
||||
if (softmax_ != NULL) {
|
||||
softmax_->DebugWeights();
|
||||
}
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool LSTM::Serialize(TFile* fp) const {
|
||||
if (!Network::Serialize(fp)) return false;
|
||||
if (fp->FWrite(&na_, sizeof(na_), 1) != 1) return false;
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
if (w == GFS && !Is2D()) continue;
|
||||
if (!gate_weights_[w].Serialize(training_, fp)) return false;
|
||||
}
|
||||
if (softmax_ != NULL && !softmax_->Serialize(fp)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool LSTM::DeSerialize(bool swap, TFile* fp) {
|
||||
if (fp->FRead(&na_, sizeof(na_), 1) != 1) return false;
|
||||
if (swap) ReverseN(&na_, sizeof(na_));
|
||||
if (type_ == NT_LSTM_SOFTMAX) {
|
||||
nf_ = no_;
|
||||
} else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
|
||||
nf_ = IntCastRounded(ceil(log2(no_)));
|
||||
} else {
|
||||
nf_ = 0;
|
||||
}
|
||||
is_2d_ = false;
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
if (w == GFS && !Is2D()) continue;
|
||||
if (!gate_weights_[w].DeSerialize(training_, swap, fp)) return false;
|
||||
if (w == CI) {
|
||||
ns_ = gate_weights_[CI].NumOutputs();
|
||||
is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
|
||||
}
|
||||
}
|
||||
delete softmax_;
|
||||
if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
|
||||
softmax_ =
|
||||
reinterpret_cast<FullyConnected*>(Network::CreateFromFile(swap, fp));
|
||||
if (softmax_ == NULL) return false;
|
||||
} else {
|
||||
softmax_ = NULL;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
void LSTM::Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output) {
|
||||
input_map_ = input.stride_map();
|
||||
input_width_ = input.Width();
|
||||
if (softmax_ != NULL)
|
||||
output->ResizeFloat(input, no_);
|
||||
else if (type_ == NT_LSTM_SUMMARY)
|
||||
output->ResizeXTo1(input, no_);
|
||||
else
|
||||
output->Resize(input, no_);
|
||||
ResizeForward(input);
|
||||
// Temporary storage of forward computation for each gate.
|
||||
NetworkScratch::FloatVec temp_lines[WT_COUNT];
|
||||
for (int i = 0; i < WT_COUNT; ++i) temp_lines[i].Init(ns_, scratch);
|
||||
// Single timestep buffers for the current/recurrent output and state.
|
||||
NetworkScratch::FloatVec curr_state, curr_output;
|
||||
curr_state.Init(ns_, scratch);
|
||||
ZeroVector<double>(ns_, curr_state);
|
||||
curr_output.Init(ns_, scratch);
|
||||
ZeroVector<double>(ns_, curr_output);
|
||||
// Rotating buffers of width buf_width allow storage of the state and output
|
||||
// for the other dimension, used only when working in true 2D mode. The width
|
||||
// is enough to hold an entire strip of the major direction.
|
||||
int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
|
||||
GenericVector<NetworkScratch::FloatVec> states, outputs;
|
||||
if (Is2D()) {
|
||||
states.init_to_size(buf_width, NetworkScratch::FloatVec());
|
||||
outputs.init_to_size(buf_width, NetworkScratch::FloatVec());
|
||||
for (int i = 0; i < buf_width; ++i) {
|
||||
states[i].Init(ns_, scratch);
|
||||
ZeroVector<double>(ns_, states[i]);
|
||||
outputs[i].Init(ns_, scratch);
|
||||
ZeroVector<double>(ns_, outputs[i]);
|
||||
}
|
||||
}
|
||||
// Used only if a softmax LSTM.
|
||||
NetworkScratch::FloatVec softmax_output;
|
||||
NetworkScratch::IO int_output;
|
||||
if (softmax_ != NULL) {
|
||||
softmax_output.Init(no_, scratch);
|
||||
ZeroVector<double>(no_, softmax_output);
|
||||
if (input.int_mode()) int_output.Resize2d(true, 1, ns_, scratch);
|
||||
softmax_->SetupForward(input, NULL);
|
||||
}
|
||||
NetworkScratch::FloatVec curr_input;
|
||||
curr_input.Init(na_, scratch);
|
||||
StrideMap::Index src_index(input_map_);
|
||||
// Used only by NT_LSTM_SUMMARY.
|
||||
StrideMap::Index dest_index(output->stride_map());
|
||||
do {
|
||||
int t = src_index.t();
|
||||
// True if there is a valid old state for the 2nd dimension.
|
||||
bool valid_2d = Is2D();
|
||||
if (valid_2d) {
|
||||
StrideMap::Index dim_index(src_index);
|
||||
if (!dim_index.AddOffset(-1, FD_HEIGHT)) valid_2d = false;
|
||||
}
|
||||
// Index of the 2-D revolving buffers (outputs, states).
|
||||
int mod_t = Modulo(t, buf_width); // Current timestep.
|
||||
// Setup the padded input in source.
|
||||
source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
|
||||
if (softmax_ != NULL) {
|
||||
source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
|
||||
}
|
||||
source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
|
||||
if (Is2D())
|
||||
source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
|
||||
if (!source_.int_mode()) source_.ReadTimeStep(t, curr_input);
|
||||
// Matrix multiply the inputs with the source.
|
||||
PARALLEL_IF_OPENMP(GFS)
|
||||
// It looks inefficient to create the threads on each t iteration, but the
|
||||
// alternative of putting the parallel outside the t loop, a single around
|
||||
// the t-loop and then tasks in place of the sections is a *lot* slower.
|
||||
// Cell inputs.
|
||||
if (source_.int_mode())
|
||||
gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
|
||||
else
|
||||
gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
|
||||
FuncInplace<GFunc>(ns_, temp_lines[CI]);
|
||||
|
||||
SECTION_IF_OPENMP
|
||||
// Input Gates.
|
||||
if (source_.int_mode())
|
||||
gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
|
||||
else
|
||||
gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
|
||||
FuncInplace<FFunc>(ns_, temp_lines[GI]);
|
||||
|
||||
SECTION_IF_OPENMP
|
||||
// 1-D forget gates.
|
||||
if (source_.int_mode())
|
||||
gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
|
||||
else
|
||||
gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
|
||||
FuncInplace<FFunc>(ns_, temp_lines[GF1]);
|
||||
|
||||
// 2-D forget gates.
|
||||
if (Is2D()) {
|
||||
if (source_.int_mode())
|
||||
gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
|
||||
else
|
||||
gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
|
||||
FuncInplace<FFunc>(ns_, temp_lines[GFS]);
|
||||
}
|
||||
|
||||
SECTION_IF_OPENMP
|
||||
// Output gates.
|
||||
if (source_.int_mode())
|
||||
gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
|
||||
else
|
||||
gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
|
||||
FuncInplace<FFunc>(ns_, temp_lines[GO]);
|
||||
END_PARALLEL_IF_OPENMP
|
||||
|
||||
// Apply forget gate to state.
|
||||
MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
|
||||
if (Is2D()) {
|
||||
// Max-pool the forget gates (in 2-d) instead of blindly adding.
|
||||
inT8* which_fg_col = which_fg_[t];
|
||||
memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
|
||||
if (valid_2d) {
|
||||
const double* stepped_state = states[mod_t];
|
||||
for (int i = 0; i < ns_; ++i) {
|
||||
if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
|
||||
curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
|
||||
which_fg_col[i] = 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
|
||||
// Clip curr_state to a sane range.
|
||||
ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
|
||||
if (training_) {
|
||||
// Save the gate node values.
|
||||
node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
|
||||
node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
|
||||
node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
|
||||
node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
|
||||
if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
|
||||
}
|
||||
FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
|
||||
if (training_) state_.WriteTimeStep(t, curr_state);
|
||||
if (softmax_ != NULL) {
|
||||
if (input.int_mode()) {
|
||||
int_output->WriteTimeStep(0, curr_output);
|
||||
softmax_->ForwardTimeStep(NULL, int_output->i(0), t, softmax_output);
|
||||
} else {
|
||||
softmax_->ForwardTimeStep(curr_output, NULL, t, softmax_output);
|
||||
}
|
||||
output->WriteTimeStep(t, softmax_output);
|
||||
if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
|
||||
CodeInBinary(no_, nf_, softmax_output);
|
||||
}
|
||||
} else if (type_ == NT_LSTM_SUMMARY) {
|
||||
// Output only at the end of a row.
|
||||
if (src_index.IsLast(FD_WIDTH)) {
|
||||
output->WriteTimeStep(dest_index.t(), curr_output);
|
||||
dest_index.Increment();
|
||||
}
|
||||
} else {
|
||||
output->WriteTimeStep(t, curr_output);
|
||||
}
|
||||
// Save states for use by the 2nd dimension only if needed.
|
||||
if (Is2D()) {
|
||||
CopyVector(ns_, curr_state, states[mod_t]);
|
||||
CopyVector(ns_, curr_output, outputs[mod_t]);
|
||||
}
|
||||
// Always zero the states at the end of every row, but only for the major
|
||||
// direction. The 2-D state remains intact.
|
||||
if (src_index.IsLast(FD_WIDTH)) {
|
||||
ZeroVector<double>(ns_, curr_state);
|
||||
ZeroVector<double>(ns_, curr_output);
|
||||
}
|
||||
} while (src_index.Increment());
|
||||
#if DEBUG_DETAIL > 0
|
||||
tprintf("Source:%s\n", name_.string());
|
||||
source_.Print(10);
|
||||
tprintf("State:%s\n", name_.string());
|
||||
state_.Print(10);
|
||||
tprintf("Output:%s\n", name_.string());
|
||||
output->Print(10);
|
||||
#endif
|
||||
if (debug) DisplayForward(*output);
|
||||
}
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
bool LSTM::Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas) {
|
||||
if (debug) DisplayBackward(fwd_deltas);
|
||||
back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
|
||||
// ======Scratch space.======
|
||||
// Output errors from deltas with recurrence from sourceerr.
|
||||
NetworkScratch::FloatVec outputerr;
|
||||
outputerr.Init(ns_, scratch);
|
||||
// Recurrent error in the state/source.
|
||||
NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
|
||||
curr_stateerr.Init(ns_, scratch);
|
||||
curr_sourceerr.Init(na_, scratch);
|
||||
ZeroVector<double>(ns_, curr_stateerr);
|
||||
ZeroVector<double>(na_, curr_sourceerr);
|
||||
// Errors in the gates.
|
||||
NetworkScratch::FloatVec gate_errors[WT_COUNT];
|
||||
for (int g = 0; g < WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
|
||||
// Rotating buffers of width buf_width allow storage of the recurrent time-
|
||||
// steps used only for true 2-D. Stores one full strip of the major direction.
|
||||
int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
|
||||
GenericVector<NetworkScratch::FloatVec> stateerr, sourceerr;
|
||||
if (Is2D()) {
|
||||
stateerr.init_to_size(buf_width, NetworkScratch::FloatVec());
|
||||
sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec());
|
||||
for (int t = 0; t < buf_width; ++t) {
|
||||
stateerr[t].Init(ns_, scratch);
|
||||
sourceerr[t].Init(na_, scratch);
|
||||
ZeroVector<double>(ns_, stateerr[t]);
|
||||
ZeroVector<double>(na_, sourceerr[t]);
|
||||
}
|
||||
}
|
||||
// Parallel-generated sourceerr from each of the gates.
|
||||
NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
|
||||
for (int w = 0; w < WT_COUNT; ++w)
|
||||
sourceerr_temps[w].Init(na_, scratch);
|
||||
int width = input_width_;
|
||||
// Transposed gate errors stored over all timesteps for sum outer.
|
||||
NetworkScratch::GradientStore gate_errors_t[WT_COUNT];
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
gate_errors_t[w].Init(ns_, width, scratch);
|
||||
}
|
||||
// Used only if softmax_ != NULL.
|
||||
NetworkScratch::FloatVec softmax_errors;
|
||||
NetworkScratch::GradientStore softmax_errors_t;
|
||||
if (softmax_ != NULL) {
|
||||
softmax_errors.Init(no_, scratch);
|
||||
softmax_errors_t.Init(no_, width, scratch);
|
||||
}
|
||||
double state_clip = Is2D() ? 9.0 : 4.0;
|
||||
#if DEBUG_DETAIL > 1
|
||||
tprintf("fwd_deltas:%s\n", name_.string());
|
||||
fwd_deltas.Print(10);
|
||||
#endif
|
||||
StrideMap::Index dest_index(input_map_);
|
||||
dest_index.InitToLast();
|
||||
// Used only by NT_LSTM_SUMMARY.
|
||||
StrideMap::Index src_index(fwd_deltas.stride_map());
|
||||
src_index.InitToLast();
|
||||
do {
|
||||
int t = dest_index.t();
|
||||
bool at_last_x = dest_index.IsLast(FD_WIDTH);
|
||||
// up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
|
||||
// valid if >= 0, which is true if 2d and not on the top/bottom.
|
||||
int up_pos = -1;
|
||||
int down_pos = -1;
|
||||
if (Is2D()) {
|
||||
if (dest_index.index(FD_HEIGHT) > 0) {
|
||||
StrideMap::Index up_index(dest_index);
|
||||
if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t();
|
||||
}
|
||||
if (!dest_index.IsLast(FD_HEIGHT)) {
|
||||
StrideMap::Index down_index(dest_index);
|
||||
if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t();
|
||||
}
|
||||
}
|
||||
// Index of the 2-D revolving buffers (sourceerr, stateerr).
|
||||
int mod_t = Modulo(t, buf_width); // Current timestep.
|
||||
// Zero the state in the major direction only at the end of every row.
|
||||
if (at_last_x) {
|
||||
ZeroVector<double>(na_, curr_sourceerr);
|
||||
ZeroVector<double>(ns_, curr_stateerr);
|
||||
}
|
||||
// Setup the outputerr.
|
||||
if (type_ == NT_LSTM_SUMMARY) {
|
||||
if (dest_index.IsLast(FD_WIDTH)) {
|
||||
fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
|
||||
src_index.Decrement();
|
||||
} else {
|
||||
ZeroVector<double>(ns_, outputerr);
|
||||
}
|
||||
} else if (softmax_ == NULL) {
|
||||
fwd_deltas.ReadTimeStep(t, outputerr);
|
||||
} else {
|
||||
softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors,
|
||||
softmax_errors_t.get(), outputerr);
|
||||
}
|
||||
if (!at_last_x)
|
||||
AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
|
||||
if (down_pos >= 0)
|
||||
AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
|
||||
// Apply the 1-d forget gates.
|
||||
if (!at_last_x) {
|
||||
const float* next_node_gf1 = node_values_[GF1].f(t + 1);
|
||||
for (int i = 0; i < ns_; ++i) {
|
||||
curr_stateerr[i] *= next_node_gf1[i];
|
||||
}
|
||||
}
|
||||
if (Is2D() && t + 1 < width) {
|
||||
for (int i = 0; i < ns_; ++i) {
|
||||
if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
|
||||
}
|
||||
if (down_pos >= 0) {
|
||||
const float* right_node_gfs = node_values_[GFS].f(down_pos);
|
||||
const double* right_stateerr = stateerr[mod_t];
|
||||
for (int i = 0; i < ns_; ++i) {
|
||||
if (which_fg_[down_pos][i] == 2) {
|
||||
curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr,
|
||||
curr_stateerr);
|
||||
// Clip stateerr_ to a sane range.
|
||||
ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
|
||||
#if DEBUG_DETAIL > 1
|
||||
if (t + 10 > width) {
|
||||
tprintf("t=%d, stateerr=", t);
|
||||
for (int i = 0; i < ns_; ++i)
|
||||
tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i],
|
||||
curr_sourceerr[ni_ + nf_ + i]);
|
||||
tprintf("\n");
|
||||
}
|
||||
#endif
|
||||
// Matrix multiply to get the source errors.
|
||||
PARALLEL_IF_OPENMP(GFS)
|
||||
|
||||
// Cell inputs.
|
||||
node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t,
|
||||
curr_stateerr, gate_errors[CI]);
|
||||
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
|
||||
gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
|
||||
gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
|
||||
|
||||
SECTION_IF_OPENMP
|
||||
// Input Gates.
|
||||
node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t,
|
||||
curr_stateerr, gate_errors[GI]);
|
||||
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
|
||||
gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
|
||||
gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
|
||||
|
||||
SECTION_IF_OPENMP
|
||||
// 1-D forget Gates.
|
||||
if (t > 0) {
|
||||
node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
|
||||
gate_errors[GF1]);
|
||||
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
|
||||
gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1],
|
||||
sourceerr_temps[GF1]);
|
||||
} else {
|
||||
memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
|
||||
memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
|
||||
}
|
||||
gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
|
||||
|
||||
// 2-D forget Gates.
|
||||
if (up_pos >= 0) {
|
||||
node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
|
||||
gate_errors[GFS]);
|
||||
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
|
||||
gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS],
|
||||
sourceerr_temps[GFS]);
|
||||
} else {
|
||||
memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
|
||||
memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
|
||||
}
|
||||
if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
|
||||
|
||||
SECTION_IF_OPENMP
|
||||
// Output gates.
|
||||
state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr,
|
||||
gate_errors[GO]);
|
||||
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
|
||||
gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
|
||||
gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
|
||||
END_PARALLEL_IF_OPENMP
|
||||
|
||||
SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
|
||||
sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
|
||||
curr_sourceerr);
|
||||
back_deltas->WriteTimeStep(t, curr_sourceerr);
|
||||
// Save states for use by the 2nd dimension only if needed.
|
||||
if (Is2D()) {
|
||||
CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
|
||||
CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
|
||||
}
|
||||
} while (dest_index.Decrement());
|
||||
#if DEBUG_DETAIL > 2
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
tprintf("%s gate errors[%d]\n", name_.string(), w);
|
||||
gate_errors_t[w].get()->PrintUnTransposed(10);
|
||||
}
|
||||
#endif
|
||||
// Transposed source_ used to speed-up SumOuter.
|
||||
NetworkScratch::GradientStore source_t, state_t;
|
||||
source_t.Init(na_, width, scratch);
|
||||
source_.Transpose(source_t.get());
|
||||
state_t.Init(ns_, width, scratch);
|
||||
state_.Transpose(state_t.get());
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for num_threads(GFS) if (!Is2D())
|
||||
#endif
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
if (w == GFS && !Is2D()) continue;
|
||||
gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
|
||||
}
|
||||
if (softmax_ != NULL) {
|
||||
softmax_->FinishBackward(*softmax_errors_t);
|
||||
}
|
||||
if (needs_to_backprop_) {
|
||||
// Normalize the inputerr in back_deltas.
|
||||
back_deltas->CopyWithNormalization(*back_deltas, fwd_deltas);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Updates the weights using the given learning rate and momentum.
|
||||
// num_samples is the quotient to be used in the adagrad computation iff
|
||||
// use_ada_grad_ is true.
|
||||
void LSTM::Update(float learning_rate, float momentum, int num_samples) {
|
||||
#if DEBUG_DETAIL > 3
|
||||
PrintW();
|
||||
#endif
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
if (w == GFS && !Is2D()) continue;
|
||||
gate_weights_[w].Update(learning_rate, momentum, num_samples);
|
||||
}
|
||||
if (softmax_ != NULL) {
|
||||
softmax_->Update(learning_rate, momentum, num_samples);
|
||||
}
|
||||
#if DEBUG_DETAIL > 3
|
||||
PrintDW();
|
||||
#endif
|
||||
}
|
||||
|
||||
// Sums the products of weight updates in *this and other, splitting into
|
||||
// positive (same direction) in *same and negative (different direction) in
|
||||
// *changed.
|
||||
void LSTM::CountAlternators(const Network& other, double* same,
|
||||
double* changed) const {
|
||||
ASSERT_HOST(other.type() == type_);
|
||||
const LSTM* lstm = reinterpret_cast<const LSTM*>(&other);
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
if (w == GFS && !Is2D()) continue;
|
||||
gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
|
||||
}
|
||||
if (softmax_ != NULL) {
|
||||
softmax_->CountAlternators(*lstm->softmax_, same, changed);
|
||||
}
|
||||
}
|
||||
|
||||
// Prints the weights for debug purposes.
|
||||
void LSTM::PrintW() {
|
||||
tprintf("Weight state:%s\n", name_.string());
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
if (w == GFS && !Is2D()) continue;
|
||||
tprintf("Gate %d, inputs\n", w);
|
||||
for (int i = 0; i < ni_; ++i) {
|
||||
tprintf("Row %d:", i);
|
||||
for (int s = 0; s < ns_; ++s)
|
||||
tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
|
||||
tprintf("\n");
|
||||
}
|
||||
tprintf("Gate %d, outputs\n", w);
|
||||
for (int i = ni_; i < ni_ + ns_; ++i) {
|
||||
tprintf("Row %d:", i - ni_);
|
||||
for (int s = 0; s < ns_; ++s)
|
||||
tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
|
||||
tprintf("\n");
|
||||
}
|
||||
tprintf("Gate %d, bias\n", w);
|
||||
for (int s = 0; s < ns_; ++s)
|
||||
tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
|
||||
tprintf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// Prints the weight deltas for debug purposes.
|
||||
void LSTM::PrintDW() {
|
||||
tprintf("Delta state:%s\n", name_.string());
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
if (w == GFS && !Is2D()) continue;
|
||||
tprintf("Gate %d, inputs\n", w);
|
||||
for (int i = 0; i < ni_; ++i) {
|
||||
tprintf("Row %d:", i);
|
||||
for (int s = 0; s < ns_; ++s)
|
||||
tprintf(" %g", gate_weights_[w].GetDW(s, i));
|
||||
tprintf("\n");
|
||||
}
|
||||
tprintf("Gate %d, outputs\n", w);
|
||||
for (int i = ni_; i < ni_ + ns_; ++i) {
|
||||
tprintf("Row %d:", i - ni_);
|
||||
for (int s = 0; s < ns_; ++s)
|
||||
tprintf(" %g", gate_weights_[w].GetDW(s, i));
|
||||
tprintf("\n");
|
||||
}
|
||||
tprintf("Gate %d, bias\n", w);
|
||||
for (int s = 0; s < ns_; ++s)
|
||||
tprintf(" %g", gate_weights_[w].GetDW(s, na_));
|
||||
tprintf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// Resizes forward data to cope with an input image of the given width.
|
||||
void LSTM::ResizeForward(const NetworkIO& input) {
|
||||
source_.Resize(input, na_);
|
||||
which_fg_.ResizeNoInit(input.Width(), ns_);
|
||||
if (training_) {
|
||||
state_.ResizeFloat(input, ns_);
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
if (w == GFS && !Is2D()) continue;
|
||||
node_values_[w].ResizeFloat(input, ns_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // namespace tesseract.
|
157
lstm/lstm.h
Normal file
157
lstm/lstm.h
Normal file
@ -0,0 +1,157 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: lstm.h
|
||||
// Description: Long-term-short-term-memory Recurrent neural network.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed May 01 17:33:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_LSTM_H_
|
||||
#define TESSERACT_LSTM_LSTM_H_
|
||||
|
||||
#include "network.h"
|
||||
#include "fullyconnected.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// C++ Implementation of the LSTM class from lstm.py.
|
||||
class LSTM : public Network {
|
||||
public:
|
||||
// Enum for the different weights in LSTM, to reduce some of the I/O and
|
||||
// setup code to loops. The elements of the enum correspond to elements of an
|
||||
// array of WeightMatrix or a corresponding array of NetworkIO.
|
||||
enum WeightType {
|
||||
CI, // Cell Inputs.
|
||||
GI, // Gate at the input.
|
||||
GF1, // Forget gate at the memory (1-d or looking back 1 timestep).
|
||||
GO, // Gate at the output.
|
||||
GFS, // Forget gate at the memory, looking back in the other dimension.
|
||||
|
||||
WT_COUNT // Number of WeightTypes.
|
||||
};
|
||||
|
||||
// Constructor for NT_LSTM (regular 1 or 2-d LSTM), NT_LSTM_SOFTMAX (LSTM with
|
||||
// additional softmax layer included and fed back into the input at the next
|
||||
// timestep), or NT_LSTM_SOFTMAX_ENCODED (as LSTM_SOFTMAX, but the feedback
|
||||
// is binary encoded instead of categorical) only.
|
||||
// 2-d and bidi softmax LSTMs are not rejected, but are impossible to build
|
||||
// in the conventional way because the output feedback both forwards and
|
||||
// backwards in time does become impossible.
|
||||
LSTM(const STRING& name, int num_inputs, int num_states, int num_outputs,
|
||||
bool two_dimensional, NetworkType type);
|
||||
virtual ~LSTM();
|
||||
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
virtual StaticShape OutputShape(const StaticShape& input_shape) const;
|
||||
|
||||
virtual STRING spec() const {
|
||||
STRING spec;
|
||||
if (type_ == NT_LSTM)
|
||||
spec.add_str_int("Lfx", ns_);
|
||||
else if (type_ == NT_LSTM_SUMMARY)
|
||||
spec.add_str_int("Lfxs", ns_);
|
||||
else if (type_ == NT_LSTM_SOFTMAX)
|
||||
spec.add_str_int("LS", ns_);
|
||||
else if (type_ == NT_LSTM_SOFTMAX_ENCODED)
|
||||
spec.add_str_int("LE", ns_);
|
||||
if (softmax_ != NULL) spec += softmax_->spec();
|
||||
return spec;
|
||||
}
|
||||
|
||||
// Sets up the network for training. Initializes weights using weights of
|
||||
// scale `range` picked according to the random number generator `randomizer`.
|
||||
virtual int InitWeights(float range, TRand* randomizer);
|
||||
|
||||
// Converts a float network to an int network.
|
||||
virtual void ConvertToInt();
|
||||
|
||||
// Provides debug output on the weights.
|
||||
virtual void DebugWeights();
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
virtual bool Serialize(TFile* fp) const;
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
virtual bool DeSerialize(bool swap, TFile* fp);
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual void Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output);
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas);
|
||||
// Updates the weights using the given learning rate and momentum.
|
||||
// num_samples is the quotient to be used in the adagrad computation iff
|
||||
// use_ada_grad_ is true.
|
||||
virtual void Update(float learning_rate, float momentum, int num_samples);
|
||||
// Sums the products of weight updates in *this and other, splitting into
|
||||
// positive (same direction) in *same and negative (different direction) in
|
||||
// *changed.
|
||||
virtual void CountAlternators(const Network& other, double* same,
|
||||
double* changed) const;
|
||||
// Prints the weights for debug purposes.
|
||||
void PrintW();
|
||||
// Prints the weight deltas for debug purposes.
|
||||
void PrintDW();
|
||||
|
||||
// Returns true of this is a 2-d lstm.
|
||||
bool Is2D() const {
|
||||
return is_2d_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Resizes forward data to cope with an input image of the given width.
|
||||
void ResizeForward(const NetworkIO& input);
|
||||
|
||||
private:
|
||||
// Size of padded input to weight matrices = ni_ + no_ for 1-D operation
|
||||
// and ni_ + 2 * no_ for 2-D operation. Note that there is a phantom 1 input
|
||||
// for the bias that makes the weight matrices of size [na + 1][no].
|
||||
inT32 na_;
|
||||
// Number of internal states. Equal to no_ except for a softmax LSTM.
|
||||
// ns_ is NOT serialized, but is calculated from gate_weights_.
|
||||
inT32 ns_;
|
||||
// Number of additional feedback states. The softmax types feed back
|
||||
// additional output information on top of the ns_ internal states.
|
||||
// In the case of a binary-coded (EMBEDDED) softmax, nf_ < no_.
|
||||
inT32 nf_;
|
||||
// Flag indicating 2-D operation.
|
||||
bool is_2d_;
|
||||
|
||||
// Gate weight arrays of size [na + 1, no].
|
||||
WeightMatrix gate_weights_[WT_COUNT];
|
||||
// Used only if this is a softmax LSTM.
|
||||
FullyConnected* softmax_;
|
||||
// Input padded with previous output of size [width, na].
|
||||
NetworkIO source_;
|
||||
// Internal state used during forward operation, of size [width, ns].
|
||||
NetworkIO state_;
|
||||
// State of the 2-d maxpool, generated during forward, used during backward.
|
||||
GENERIC_2D_ARRAY<inT8> which_fg_;
|
||||
// Internal state saved from forward, but used only during backward.
|
||||
NetworkIO node_values_[WT_COUNT];
|
||||
// Preserved input stride_map used for Backward when NT_LSTM_SQUASHED.
|
||||
StrideMap input_map_;
|
||||
int input_width_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
|
||||
#endif // TESSERACT_LSTM_LSTM_H_
|
816
lstm/lstmrecognizer.cpp
Normal file
816
lstm/lstmrecognizer.cpp
Normal file
@ -0,0 +1,816 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: lstmrecognizer.cpp
|
||||
// Description: Top-level line recognizer class for LSTM-based networks.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu May 02 10:59:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "lstmrecognizer.h"
|
||||
|
||||
#include "allheaders.h"
|
||||
#include "callcpp.h"
|
||||
#include "dict.h"
|
||||
#include "genericheap.h"
|
||||
#include "helpers.h"
|
||||
#include "imagedata.h"
|
||||
#include "input.h"
|
||||
#include "lstm.h"
|
||||
#include "normalis.h"
|
||||
#include "pageres.h"
|
||||
#include "ratngs.h"
|
||||
#include "recodebeam.h"
|
||||
#include "scrollview.h"
|
||||
#include "shapetable.h"
|
||||
#include "statistc.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Max number of blob choices to return in any given position.
|
||||
const int kMaxChoices = 4;
|
||||
// Default ratio between dict and non-dict words.
|
||||
const double kDictRatio = 2.25;
|
||||
// Default certainty offset to give the dictionary a chance.
|
||||
const double kCertOffset = -0.085;
|
||||
|
||||
LSTMRecognizer::LSTMRecognizer()
|
||||
: network_(NULL),
|
||||
training_flags_(0),
|
||||
training_iteration_(0),
|
||||
sample_iteration_(0),
|
||||
null_char_(UNICHAR_BROKEN),
|
||||
weight_range_(0.0f),
|
||||
learning_rate_(0.0f),
|
||||
momentum_(0.0f),
|
||||
dict_(NULL),
|
||||
search_(NULL),
|
||||
debug_win_(NULL) {}
|
||||
|
||||
LSTMRecognizer::~LSTMRecognizer() {
|
||||
delete network_;
|
||||
delete dict_;
|
||||
delete search_;
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool LSTMRecognizer::Serialize(TFile* fp) const {
|
||||
if (!network_->Serialize(fp)) return false;
|
||||
if (!GetUnicharset().save_to_file(fp)) return false;
|
||||
if (!network_str_.Serialize(fp)) return false;
|
||||
if (fp->FWrite(&training_flags_, sizeof(training_flags_), 1) != 1)
|
||||
return false;
|
||||
if (fp->FWrite(&training_iteration_, sizeof(training_iteration_), 1) != 1)
|
||||
return false;
|
||||
if (fp->FWrite(&sample_iteration_, sizeof(sample_iteration_), 1) != 1)
|
||||
return false;
|
||||
if (fp->FWrite(&null_char_, sizeof(null_char_), 1) != 1) return false;
|
||||
if (fp->FWrite(&weight_range_, sizeof(weight_range_), 1) != 1) return false;
|
||||
if (fp->FWrite(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false;
|
||||
if (fp->FWrite(&momentum_, sizeof(momentum_), 1) != 1) return false;
|
||||
if (IsRecoding() && !recoder_.Serialize(fp)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool LSTMRecognizer::DeSerialize(bool swap, TFile* fp) {
|
||||
delete network_;
|
||||
network_ = Network::CreateFromFile(swap, fp);
|
||||
if (network_ == NULL) return false;
|
||||
if (!ccutil_.unicharset.load_from_file(fp, false)) return false;
|
||||
if (!network_str_.DeSerialize(swap, fp)) return false;
|
||||
if (fp->FRead(&training_flags_, sizeof(training_flags_), 1) != 1)
|
||||
return false;
|
||||
if (fp->FRead(&training_iteration_, sizeof(training_iteration_), 1) != 1)
|
||||
return false;
|
||||
if (fp->FRead(&sample_iteration_, sizeof(sample_iteration_), 1) != 1)
|
||||
return false;
|
||||
if (fp->FRead(&null_char_, sizeof(null_char_), 1) != 1) return false;
|
||||
if (fp->FRead(&weight_range_, sizeof(weight_range_), 1) != 1) return false;
|
||||
if (fp->FRead(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false;
|
||||
if (fp->FRead(&momentum_, sizeof(momentum_), 1) != 1) return false;
|
||||
if (IsRecoding()) {
|
||||
if (!recoder_.DeSerialize(swap, fp)) return false;
|
||||
RecodedCharID code;
|
||||
recoder_.EncodeUnichar(UNICHAR_SPACE, &code);
|
||||
if (code(0) != UNICHAR_SPACE) {
|
||||
tprintf("Space was garbled in recoding!!\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// TODO(rays) swaps!
|
||||
network_->SetRandomizer(&randomizer_);
|
||||
network_->CacheXScaleFactor(network_->XScaleFactor());
|
||||
return true;
|
||||
}
|
||||
|
||||
// Loads the dictionary if possible from the traineddata file.
|
||||
// Prints a warning message, and returns false but otherwise fails silently
|
||||
// and continues to work without it if loading fails.
|
||||
// Note that dictionary load is independent from DeSerialize, but dependent
|
||||
// on the unicharset matching. This enables training to deserialize a model
|
||||
// from checkpoint or restore without having to go back and reload the
|
||||
// dictionary.
|
||||
bool LSTMRecognizer::LoadDictionary(const char* data_file_name,
|
||||
const char* lang) {
|
||||
delete dict_;
|
||||
dict_ = new Dict(&ccutil_);
|
||||
dict_->SetupForLoad(Dict::GlobalDawgCache());
|
||||
dict_->LoadLSTM(data_file_name, lang);
|
||||
if (dict_->FinishLoad()) return true; // Success.
|
||||
tprintf("Failed to load any lstm-specific dictionaries for lang %s!!\n",
|
||||
lang);
|
||||
delete dict_;
|
||||
dict_ = NULL;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Recognizes the line image, contained within image_data, returning the
|
||||
// ratings matrix and matching box_word for each WERD_RES in the output.
|
||||
void LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
|
||||
bool debug, double worst_dict_cert,
|
||||
bool use_alternates,
|
||||
const UNICHARSET* target_unicharset,
|
||||
const TBOX& line_box, float score_ratio,
|
||||
bool one_word,
|
||||
PointerVector<WERD_RES>* words) {
|
||||
NetworkIO outputs;
|
||||
float label_threshold = use_alternates ? 0.75f : 0.0f;
|
||||
float scale_factor;
|
||||
NetworkIO inputs;
|
||||
if (!RecognizeLine(image_data, invert, debug, false, label_threshold,
|
||||
&scale_factor, &inputs, &outputs))
|
||||
return;
|
||||
if (IsRecoding()) {
|
||||
if (search_ == NULL) {
|
||||
search_ =
|
||||
new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
|
||||
}
|
||||
search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, NULL);
|
||||
search_->ExtractBestPathAsWords(line_box, scale_factor, debug,
|
||||
&GetUnicharset(), words);
|
||||
} else {
|
||||
GenericVector<int> label_coords;
|
||||
GenericVector<int> labels;
|
||||
LabelsFromOutputs(outputs, label_threshold, &labels, &label_coords);
|
||||
WordsFromOutputs(outputs, labels, label_coords, line_box, debug,
|
||||
use_alternates, one_word, score_ratio, scale_factor,
|
||||
target_unicharset, words);
|
||||
}
|
||||
}
|
||||
|
||||
// Builds a set of tesseract-compatible WERD_RESs aligned to line_box,
|
||||
// corresponding to the network output in outputs, labels, label_coords.
|
||||
// one_word generates a single word output, that may include spaces inside.
|
||||
// use_alternates generates alternative BLOB_CHOICEs and segmentation paths.
|
||||
// If not NULL, we attempt to translate the output to target_unicharset, but do
|
||||
// not guarantee success, due to mismatches. In that case the output words are
|
||||
// marked with our UNICHARSET, not the caller's.
|
||||
void LSTMRecognizer::WordsFromOutputs(
|
||||
const NetworkIO& outputs, const GenericVector<int>& labels,
|
||||
const GenericVector<int> label_coords, const TBOX& line_box, bool debug,
|
||||
bool use_alternates, bool one_word, float score_ratio, float scale_factor,
|
||||
const UNICHARSET* target_unicharset, PointerVector<WERD_RES>* words) {
|
||||
// Convert labels to unichar-ids.
|
||||
int word_end = 0;
|
||||
float prev_space_cert = 0.0f;
|
||||
for (int i = 0; i < labels.size(); i = word_end) {
|
||||
word_end = i + 1;
|
||||
if (labels[i] == null_char_ || labels[i] == UNICHAR_SPACE) {
|
||||
continue;
|
||||
}
|
||||
float space_cert = 0.0f;
|
||||
if (one_word) {
|
||||
word_end = labels.size();
|
||||
} else {
|
||||
// Find the end of the word at the first null_char_ that leads to the
|
||||
// first UNICHAR_SPACE.
|
||||
while (word_end < labels.size() && labels[word_end] != UNICHAR_SPACE)
|
||||
++word_end;
|
||||
if (word_end < labels.size()) {
|
||||
float rating;
|
||||
outputs.ScoresOverRange(label_coords[word_end],
|
||||
label_coords[word_end] + 1, UNICHAR_SPACE,
|
||||
null_char_, &rating, &space_cert);
|
||||
}
|
||||
while (word_end > i && labels[word_end - 1] == null_char_) --word_end;
|
||||
}
|
||||
ASSERT_HOST(word_end > i);
|
||||
// Create a WERD_RES for the output word.
|
||||
if (debug)
|
||||
tprintf("Creating word from outputs over [%d,%d)\n", i, word_end);
|
||||
WERD_RES* word =
|
||||
WordFromOutput(line_box, outputs, i, word_end, score_ratio,
|
||||
MIN(prev_space_cert, space_cert), debug,
|
||||
use_alternates && !SimpleTextOutput(), target_unicharset,
|
||||
labels, label_coords, scale_factor);
|
||||
if (word == NULL && target_unicharset != NULL) {
|
||||
// Unicharset translation failed - use decoder_ instead, and disable
|
||||
// the segmentation search on output, as it won't understand the encoding.
|
||||
word = WordFromOutput(line_box, outputs, i, word_end, score_ratio,
|
||||
MIN(prev_space_cert, space_cert), debug, false,
|
||||
NULL, labels, label_coords, scale_factor);
|
||||
}
|
||||
prev_space_cert = space_cert;
|
||||
words->push_back(word);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper computes min and mean best results in the output.
|
||||
void LSTMRecognizer::OutputStats(const NetworkIO& outputs, float* min_output,
|
||||
float* mean_output, float* sd) {
|
||||
const int kOutputScale = MAX_INT8;
|
||||
STATS stats(0, kOutputScale + 1);
|
||||
for (int t = 0; t < outputs.Width(); ++t) {
|
||||
int best_label = outputs.BestLabel(t, NULL);
|
||||
if (best_label != null_char_ || t == 0) {
|
||||
float best_output = outputs.f(t)[best_label];
|
||||
stats.add(static_cast<int>(kOutputScale * best_output), 1);
|
||||
}
|
||||
}
|
||||
*min_output = static_cast<float>(stats.min_bucket()) / kOutputScale;
|
||||
*mean_output = stats.mean() / kOutputScale;
|
||||
*sd = stats.sd() / kOutputScale;
|
||||
}
|
||||
|
||||
// Recognizes the image_data, returning the labels,
|
||||
// scores, and corresponding pairs of start, end x-coords in coords.
|
||||
// If label_threshold is positive, uses it for making the labels, otherwise
|
||||
// uses standard ctc.
|
||||
bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
|
||||
bool debug, bool re_invert,
|
||||
float label_threshold, float* scale_factor,
|
||||
NetworkIO* inputs, NetworkIO* outputs) {
|
||||
// Maximum width of image to train on.
|
||||
const int kMaxImageWidth = 2048;
|
||||
// This ensures consistent recognition results.
|
||||
SetRandomSeed();
|
||||
int min_width = network_->XScaleFactor();
|
||||
Pix* pix = Input::PrepareLSTMInputs(image_data, network_, min_width,
|
||||
&randomizer_, scale_factor);
|
||||
if (pix == NULL) {
|
||||
tprintf("Line cannot be recognized!!\n");
|
||||
return false;
|
||||
}
|
||||
if (network_->training() && pixGetWidth(pix) > kMaxImageWidth) {
|
||||
tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix),
|
||||
pixGetHeight(pix));
|
||||
pixDestroy(&pix);
|
||||
return false;
|
||||
}
|
||||
// Reduction factor from image to coords.
|
||||
*scale_factor = min_width / *scale_factor;
|
||||
inputs->set_int_mode(IsIntMode());
|
||||
SetRandomSeed();
|
||||
Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, inputs);
|
||||
network_->Forward(debug, *inputs, NULL, &scratch_space_, outputs);
|
||||
// Check for auto inversion.
|
||||
float pos_min, pos_mean, pos_sd;
|
||||
OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd);
|
||||
if (invert && pos_min < 0.5) {
|
||||
// Run again inverted and see if it is any better.
|
||||
float inv_scale;
|
||||
NetworkIO inv_inputs, inv_outputs;
|
||||
inv_inputs.set_int_mode(IsIntMode());
|
||||
SetRandomSeed();
|
||||
pixInvert(pix, pix);
|
||||
Input::PreparePixInput(network_->InputShape(), pix, &randomizer_,
|
||||
&inv_inputs);
|
||||
network_->Forward(debug, inv_inputs, NULL, &scratch_space_, &inv_outputs);
|
||||
float inv_min, inv_mean, inv_sd;
|
||||
OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd);
|
||||
if (inv_min > pos_min && inv_mean > pos_mean && inv_sd < pos_sd) {
|
||||
// Inverted did better. Use inverted data.
|
||||
if (debug) {
|
||||
tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n",
|
||||
pos_min, pos_mean, pos_sd, inv_min, inv_mean, inv_sd);
|
||||
}
|
||||
*outputs = inv_outputs;
|
||||
*inputs = inv_inputs;
|
||||
} else if (re_invert) {
|
||||
// Inverting was not an improvement, so undo and run again, so the
|
||||
// outputs match the best forward result.
|
||||
SetRandomSeed();
|
||||
network_->Forward(debug, *inputs, NULL, &scratch_space_, outputs);
|
||||
}
|
||||
}
|
||||
pixDestroy(&pix);
|
||||
if (debug) {
|
||||
GenericVector<int> labels, coords;
|
||||
LabelsFromOutputs(*outputs, label_threshold, &labels, &coords);
|
||||
DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_);
|
||||
DebugActivationPath(*outputs, labels, coords);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns a tesseract-compatible WERD_RES from the line recognizer outputs.
|
||||
// line_box should be the bounding box of the line image in the main image,
|
||||
// outputs the output of the network,
|
||||
// [word_start, word_end) the interval over which to convert,
|
||||
// score_ratio for choosing alternate classifier choices,
|
||||
// use_alternates to control generation of alternative segmentations,
|
||||
// labels, label_coords, scale_factor from RecognizeLine above.
|
||||
// If target_unicharset is not NULL, attempts to translate the internal
|
||||
// unichar_ids to the target_unicharset, but falls back to untranslated ids
|
||||
// if the translation should fail.
|
||||
WERD_RES* LSTMRecognizer::WordFromOutput(
|
||||
const TBOX& line_box, const NetworkIO& outputs, int word_start,
|
||||
int word_end, float score_ratio, float space_certainty, bool debug,
|
||||
bool use_alternates, const UNICHARSET* target_unicharset,
|
||||
const GenericVector<int>& labels, const GenericVector<int>& label_coords,
|
||||
float scale_factor) {
|
||||
WERD_RES* word_res = InitializeWord(
|
||||
line_box, word_start, word_end, space_certainty, use_alternates,
|
||||
target_unicharset, labels, label_coords, scale_factor);
|
||||
int max_blob_run = word_res->ratings->bandwidth();
|
||||
for (int width = 1; width <= max_blob_run; ++width) {
|
||||
int col = 0;
|
||||
for (int i = word_start; i + width <= word_end; ++i) {
|
||||
if (labels[i] != null_char_) {
|
||||
// Starting at i, use width labels, but stop at the next null_char_.
|
||||
// This forms all combinations of blobs between regions of null_char_.
|
||||
int j = i + 1;
|
||||
while (j - i < width && labels[j] != null_char_) ++j;
|
||||
if (j - i == width) {
|
||||
// Make the blob choices.
|
||||
int end_coord = label_coords[j];
|
||||
if (j < word_end && labels[j] == null_char_)
|
||||
end_coord = label_coords[j + 1];
|
||||
BLOB_CHOICE_LIST* choices = GetBlobChoices(
|
||||
col, col + width - 1, debug, outputs, target_unicharset,
|
||||
label_coords[i], end_coord, score_ratio);
|
||||
if (choices == NULL) {
|
||||
delete word_res;
|
||||
return NULL;
|
||||
}
|
||||
word_res->ratings->put(col, col + width - 1, choices);
|
||||
}
|
||||
++col;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (use_alternates) {
|
||||
// Merge adjacent single results over null_char boundaries.
|
||||
int col = 0;
|
||||
for (int i = word_start; i + 2 < word_end; ++i) {
|
||||
if (labels[i] != null_char_ && labels[i + 1] == null_char_ &&
|
||||
labels[i + 2] != null_char_ &&
|
||||
(i == word_start || labels[i - 1] == null_char_) &&
|
||||
(i + 3 == word_end || labels[i + 3] == null_char_)) {
|
||||
int end_coord = label_coords[i + 3];
|
||||
if (i + 3 < word_end && labels[i + 3] == null_char_)
|
||||
end_coord = label_coords[i + 4];
|
||||
BLOB_CHOICE_LIST* choices =
|
||||
GetBlobChoices(col, col + 1, debug, outputs, target_unicharset,
|
||||
label_coords[i], end_coord, score_ratio);
|
||||
if (choices == NULL) {
|
||||
delete word_res;
|
||||
return NULL;
|
||||
}
|
||||
word_res->ratings->put(col, col + 1, choices);
|
||||
}
|
||||
if (labels[i] != null_char_) ++col;
|
||||
}
|
||||
} else {
|
||||
word_res->FakeWordFromRatings(TOP_CHOICE_PERM);
|
||||
}
|
||||
return word_res;
|
||||
}
|
||||
|
||||
// Sets up a word with the ratings matrix and fake blobs with boxes in the
|
||||
// right places.
|
||||
WERD_RES* LSTMRecognizer::InitializeWord(const TBOX& line_box, int word_start,
|
||||
int word_end, float space_certainty,
|
||||
bool use_alternates,
|
||||
const UNICHARSET* target_unicharset,
|
||||
const GenericVector<int>& labels,
|
||||
const GenericVector<int>& label_coords,
|
||||
float scale_factor) {
|
||||
// Make a fake blob for each non-zero label.
|
||||
C_BLOB_LIST blobs;
|
||||
C_BLOB_IT b_it(&blobs);
|
||||
// num_blobs is the length of the diagonal of the ratings matrix.
|
||||
int num_blobs = 0;
|
||||
// max_blob_run is the diagonal width of the ratings matrix
|
||||
int max_blob_run = 0;
|
||||
int blob_run = 0;
|
||||
for (int i = word_start; i < word_end; ++i) {
|
||||
if (IsRecoding() && !recoder_.IsValidFirstCode(labels[i])) continue;
|
||||
if (labels[i] != null_char_) {
|
||||
// Make a fake blob.
|
||||
TBOX box(label_coords[i], 0, label_coords[i + 1], line_box.height());
|
||||
box.scale(scale_factor);
|
||||
box.move(ICOORD(line_box.left(), line_box.bottom()));
|
||||
box.set_top(line_box.top());
|
||||
b_it.add_after_then_move(C_BLOB::FakeBlob(box));
|
||||
++num_blobs;
|
||||
++blob_run;
|
||||
}
|
||||
if (labels[i] == null_char_ || i + 1 == word_end) {
|
||||
if (blob_run > max_blob_run)
|
||||
max_blob_run = blob_run;
|
||||
}
|
||||
}
|
||||
if (!use_alternates) max_blob_run = 1;
|
||||
ASSERT_HOST(label_coords.size() >= word_end);
|
||||
// Make a fake word from the blobs.
|
||||
WERD* word = new WERD(&blobs, word_start > 1 ? 1 : 0, NULL);
|
||||
// Make a WERD_RES from the word.
|
||||
WERD_RES* word_res = new WERD_RES(word);
|
||||
word_res->uch_set =
|
||||
target_unicharset != NULL ? target_unicharset : &GetUnicharset();
|
||||
word_res->combination = true; // Give it ownership of the word.
|
||||
word_res->space_certainty = space_certainty;
|
||||
word_res->ratings = new MATRIX(num_blobs, max_blob_run);
|
||||
return word_res;
|
||||
}
|
||||
|
||||
// Converts an array of labels to utf-8, whether or not the labels are
|
||||
// augmented with character boundaries.
|
||||
STRING LSTMRecognizer::DecodeLabels(const GenericVector<int>& labels) {
|
||||
STRING result;
|
||||
int end = 1;
|
||||
for (int start = 0; start < labels.size(); start = end) {
|
||||
if (labels[start] == null_char_) {
|
||||
end = start + 1;
|
||||
} else {
|
||||
result += DecodeLabel(labels, start, &end, NULL);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Displays the forward results in a window with the characters and
|
||||
// boundaries as determined by the labels and label_coords.
|
||||
void LSTMRecognizer::DisplayForward(const NetworkIO& inputs,
|
||||
const GenericVector<int>& labels,
|
||||
const GenericVector<int>& label_coords,
|
||||
const char* window_name,
|
||||
ScrollView** window) {
|
||||
#ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
|
||||
int x_scale = network_->XScaleFactor();
|
||||
Pix* input_pix = inputs.ToPix();
|
||||
Network::ClearWindow(false, window_name, pixGetWidth(input_pix),
|
||||
pixGetHeight(input_pix), window);
|
||||
int line_height = Network::DisplayImage(input_pix, *window);
|
||||
DisplayLSTMOutput(labels, label_coords, line_height, *window);
|
||||
#endif // GRAPHICS_DISABLED
|
||||
}
|
||||
|
||||
// Displays the labels and cuts at the corresponding xcoords.
|
||||
// Size of labels should match xcoords.
|
||||
void LSTMRecognizer::DisplayLSTMOutput(const GenericVector<int>& labels,
|
||||
const GenericVector<int>& xcoords,
|
||||
int height, ScrollView* window) {
|
||||
#ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
|
||||
int x_scale = network_->XScaleFactor();
|
||||
window->TextAttributes("Arial", height / 4, false, false, false);
|
||||
int end = 1;
|
||||
for (int start = 0; start < labels.size(); start = end) {
|
||||
int xpos = xcoords[start] * x_scale;
|
||||
if (labels[start] == null_char_) {
|
||||
end = start + 1;
|
||||
window->Pen(ScrollView::RED);
|
||||
} else {
|
||||
window->Pen(ScrollView::GREEN);
|
||||
const char* str = DecodeLabel(labels, start, &end, NULL);
|
||||
if (*str == '\\') str = "\\\\";
|
||||
xpos = xcoords[(start + end) / 2] * x_scale;
|
||||
window->Text(xpos, height, str);
|
||||
}
|
||||
window->Line(xpos, 0, xpos, height * 3 / 2);
|
||||
}
|
||||
window->Update();
|
||||
#endif // GRAPHICS_DISABLED
|
||||
}
|
||||
|
||||
// Prints debug output detailing the activation path that is implied by the
|
||||
// label_coords.
|
||||
void LSTMRecognizer::DebugActivationPath(const NetworkIO& outputs,
|
||||
const GenericVector<int>& labels,
|
||||
const GenericVector<int>& xcoords) {
|
||||
if (xcoords[0] > 0)
|
||||
DebugActivationRange(outputs, "<null>", null_char_, 0, xcoords[0]);
|
||||
int end = 1;
|
||||
for (int start = 0; start < labels.size(); start = end) {
|
||||
if (labels[start] == null_char_) {
|
||||
end = start + 1;
|
||||
DebugActivationRange(outputs, "<null>", null_char_, xcoords[start],
|
||||
xcoords[end]);
|
||||
continue;
|
||||
} else {
|
||||
int decoded;
|
||||
const char* label = DecodeLabel(labels, start, &end, &decoded);
|
||||
DebugActivationRange(outputs, label, labels[start], xcoords[start],
|
||||
xcoords[start + 1]);
|
||||
for (int i = start + 1; i < end; ++i) {
|
||||
DebugActivationRange(outputs, DecodeSingleLabel(labels[i]), labels[i],
|
||||
xcoords[i], xcoords[i + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prints debug output detailing activations and 2nd choice over a range
|
||||
// of positions.
|
||||
void LSTMRecognizer::DebugActivationRange(const NetworkIO& outputs,
|
||||
const char* label, int best_choice,
|
||||
int x_start, int x_end) {
|
||||
tprintf("%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end);
|
||||
double max_score = 0.0;
|
||||
double mean_score = 0.0;
|
||||
int width = x_end - x_start;
|
||||
for (int x = x_start; x < x_end; ++x) {
|
||||
const float* line = outputs.f(x);
|
||||
double score = line[best_choice] * 100.0;
|
||||
if (score > max_score) max_score = score;
|
||||
mean_score += score / width;
|
||||
int best_c = 0;
|
||||
double best_score = 0.0;
|
||||
for (int c = 0; c < outputs.NumFeatures(); ++c) {
|
||||
if (c != best_choice && line[c] > best_score) {
|
||||
best_c = c;
|
||||
best_score = line[c];
|
||||
}
|
||||
}
|
||||
tprintf(" %.3g(%s=%d=%.3g)", score, DecodeSingleLabel(best_c), best_c,
|
||||
best_score * 100.0);
|
||||
}
|
||||
tprintf(", Mean=%g, max=%g\n", mean_score, max_score);
|
||||
}
|
||||
|
||||
// Helper returns true if the null_char is the winner at t, and it beats the
|
||||
// null_threshold, or the next choice is space, in which case we will use the
|
||||
// null anyway.
|
||||
static bool NullIsBest(const NetworkIO& output, float null_thr,
|
||||
int null_char, int t) {
|
||||
if (output.f(t)[null_char] >= null_thr) return true;
|
||||
if (output.BestLabel(t, null_char, null_char, NULL) != UNICHAR_SPACE)
|
||||
return false;
|
||||
return output.f(t)[null_char] > output.f(t)[UNICHAR_SPACE];
|
||||
}
|
||||
|
||||
// Converts the network output to a sequence of labels. Outputs labels, scores
|
||||
// and start xcoords of each char, and each null_char_, with an additional
|
||||
// final xcoord for the end of the output.
|
||||
// The conversion method is determined by internal state.
|
||||
void LSTMRecognizer::LabelsFromOutputs(const NetworkIO& outputs, float null_thr,
|
||||
GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords) {
|
||||
if (SimpleTextOutput()) {
|
||||
LabelsViaSimpleText(outputs, labels, xcoords);
|
||||
} else if (IsRecoding()) {
|
||||
LabelsViaReEncode(outputs, labels, xcoords);
|
||||
} else if (null_thr <= 0.0) {
|
||||
LabelsViaCTC(outputs, labels, xcoords);
|
||||
} else {
|
||||
LabelsViaThreshold(outputs, null_thr, labels, xcoords);
|
||||
}
|
||||
}
|
||||
|
||||
// Converts the network output to a sequence of labels, using a threshold
|
||||
// on the null_char_ to determine character boundaries. Outputs labels, scores
|
||||
// and start xcoords of each char, and each null_char_, with an additional
|
||||
// final xcoord for the end of the output.
|
||||
// The label output is the one with the highest score in the interval between
|
||||
// null_chars_.
|
||||
void LSTMRecognizer::LabelsViaThreshold(const NetworkIO& output,
|
||||
float null_thr,
|
||||
GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords) {
|
||||
labels->truncate(0);
|
||||
xcoords->truncate(0);
|
||||
int width = output.Width();
|
||||
int t = 0;
|
||||
// Skip any initial non-char.
|
||||
int label = null_char_;
|
||||
while (t < width && NullIsBest(output, null_thr, null_char_, t)) {
|
||||
++t;
|
||||
}
|
||||
while (t < width) {
|
||||
ASSERT_HOST(!isnan(output.f(t)[null_char_]));
|
||||
int label = output.BestLabel(t, null_char_, null_char_, NULL);
|
||||
int char_start = t++;
|
||||
while (t < width && !NullIsBest(output, null_thr, null_char_, t) &&
|
||||
label == output.BestLabel(t, null_char_, null_char_, NULL)) {
|
||||
++t;
|
||||
}
|
||||
int char_end = t;
|
||||
labels->push_back(label);
|
||||
xcoords->push_back(char_start);
|
||||
// Find the end of the non-char, and compute its score.
|
||||
while (t < width && NullIsBest(output, null_thr, null_char_, t)) {
|
||||
++t;
|
||||
}
|
||||
if (t > char_end) {
|
||||
labels->push_back(null_char_);
|
||||
xcoords->push_back(char_end);
|
||||
}
|
||||
}
|
||||
xcoords->push_back(width);
|
||||
}
|
||||
|
||||
// Converts the network output to a sequence of labels, with scores and
|
||||
// start x-coords of the character labels. Retains the null_char_ as the
|
||||
// end x-coord, where already present, otherwise the start of the next
|
||||
// character is the end.
|
||||
// The number of labels, scores, and xcoords is always matched, except that
|
||||
// there is always an additional xcoord for the last end position.
|
||||
void LSTMRecognizer::LabelsViaCTC(const NetworkIO& output,
|
||||
GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords) {
|
||||
labels->truncate(0);
|
||||
xcoords->truncate(0);
|
||||
int width = output.Width();
|
||||
int t = 0;
|
||||
while (t < width) {
|
||||
float score = 0.0f;
|
||||
int label = output.BestLabel(t, &score);
|
||||
labels->push_back(label);
|
||||
xcoords->push_back(t);
|
||||
while (++t < width && output.BestLabel(t, NULL) == label) {
|
||||
}
|
||||
}
|
||||
xcoords->push_back(width);
|
||||
}
|
||||
|
||||
// As LabelsViaCTC except that this function constructs the best path that
|
||||
// contains only legal sequences of subcodes for CJK.
|
||||
void LSTMRecognizer::LabelsViaReEncode(const NetworkIO& output,
|
||||
GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords) {
|
||||
if (search_ == NULL) {
|
||||
search_ =
|
||||
new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
|
||||
}
|
||||
search_->Decode(output, 1.0, 0.0, RecodeBeamSearch::kMinCertainty, NULL);
|
||||
search_->ExtractBestPathAsLabels(labels, xcoords);
|
||||
}
|
||||
|
||||
// Converts the network output to a sequence of labels, with scores, using
|
||||
// the simple character model (each position is a char, and the null_char_ is
|
||||
// mainly intended for tail padding.)
|
||||
void LSTMRecognizer::LabelsViaSimpleText(const NetworkIO& output,
|
||||
GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords) {
|
||||
labels->truncate(0);
|
||||
xcoords->truncate(0);
|
||||
int width = output.Width();
|
||||
for (int t = 0; t < width; ++t) {
|
||||
float score = 0.0f;
|
||||
int label = output.BestLabel(t, &score);
|
||||
if (label != null_char_) {
|
||||
labels->push_back(label);
|
||||
xcoords->push_back(t);
|
||||
}
|
||||
}
|
||||
xcoords->push_back(width);
|
||||
}
|
||||
|
||||
// Helper returns a BLOB_CHOICE_LIST for the choices in a given x-range.
|
||||
// Handles either LSTM labels or direct unichar-ids.
|
||||
// Score ratio determines the worst ratio between top choice and remainder.
|
||||
// If target_unicharset is not NULL, attempts to translate to the target
|
||||
// unicharset, returning NULL on failure.
|
||||
BLOB_CHOICE_LIST* LSTMRecognizer::GetBlobChoices(
|
||||
int col, int row, bool debug, const NetworkIO& output,
|
||||
const UNICHARSET* target_unicharset, int x_start, int x_end,
|
||||
float score_ratio) {
|
||||
int width = x_end - x_start;
|
||||
float rating = 0.0f, certainty = 0.0f;
|
||||
int label = output.BestChoiceOverRange(x_start, x_end, UNICHAR_SPACE,
|
||||
null_char_, &rating, &certainty);
|
||||
int unichar_id = label == null_char_ ? UNICHAR_SPACE : label;
|
||||
if (debug) {
|
||||
tprintf("Best choice over range %d,%d=unichar%d=%s r = %g, cert=%g\n",
|
||||
x_start, x_end, unichar_id, DecodeSingleLabel(label), rating,
|
||||
certainty);
|
||||
}
|
||||
BLOB_CHOICE_LIST* choices = new BLOB_CHOICE_LIST;
|
||||
BLOB_CHOICE_IT bc_it(choices);
|
||||
if (!AddBlobChoices(unichar_id, rating, certainty, col, row,
|
||||
target_unicharset, &bc_it)) {
|
||||
delete choices;
|
||||
return NULL;
|
||||
}
|
||||
// Get the other choices.
|
||||
double best_cert = certainty;
|
||||
for (int c = 0; c < output.NumFeatures(); ++c) {
|
||||
if (c == label || c == UNICHAR_SPACE || c == null_char_) continue;
|
||||
// Compute the score over the range.
|
||||
output.ScoresOverRange(x_start, x_end, c, null_char_, &rating, &certainty);
|
||||
int unichar_id = c == null_char_ ? UNICHAR_SPACE : c;
|
||||
if (certainty >= best_cert - score_ratio &&
|
||||
!AddBlobChoices(unichar_id, rating, certainty, col, row,
|
||||
target_unicharset, &bc_it)) {
|
||||
delete choices;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
choices->sort(&BLOB_CHOICE::SortByRating);
|
||||
if (bc_it.length() > kMaxChoices) {
|
||||
bc_it.move_to_first();
|
||||
for (int i = 0; i < kMaxChoices; ++i)
|
||||
bc_it.forward();
|
||||
while (!bc_it.at_first()) {
|
||||
delete bc_it.extract();
|
||||
bc_it.forward();
|
||||
}
|
||||
}
|
||||
return choices;
|
||||
}
|
||||
|
||||
// Adds to the given iterator, the blob choices for the target_unicharset
|
||||
// that correspond to the given LSTM unichar_id.
|
||||
// Returns false if unicharset translation failed.
|
||||
bool LSTMRecognizer::AddBlobChoices(int unichar_id, float rating,
|
||||
float certainty, int col, int row,
|
||||
const UNICHARSET* target_unicharset,
|
||||
BLOB_CHOICE_IT* bc_it) {
|
||||
int target_id = unichar_id;
|
||||
if (target_unicharset != NULL) {
|
||||
const char* utf8 = GetUnicharset().id_to_unichar(unichar_id);
|
||||
if (target_unicharset->contains_unichar(utf8)) {
|
||||
target_id = target_unicharset->unichar_to_id(utf8);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
BLOB_CHOICE* choice = new BLOB_CHOICE(target_id, rating, certainty, -1, 1.0f,
|
||||
static_cast<float>(MAX_INT16), 0.0f,
|
||||
BCC_STATIC_CLASSIFIER);
|
||||
choice->set_matrix_cell(col, row);
|
||||
bc_it->add_after_then_move(choice);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns a string corresponding to the label starting at start. Sets *end
|
||||
// to the next start and if non-null, *decoded to the unichar id.
|
||||
const char* LSTMRecognizer::DecodeLabel(const GenericVector<int>& labels,
|
||||
int start, int* end, int* decoded) {
|
||||
*end = start + 1;
|
||||
if (IsRecoding()) {
|
||||
// Decode labels via recoder_.
|
||||
RecodedCharID code;
|
||||
if (labels[start] == null_char_) {
|
||||
if (decoded != NULL) {
|
||||
code.Set(0, null_char_);
|
||||
*decoded = recoder_.DecodeUnichar(code);
|
||||
}
|
||||
return "<null>";
|
||||
}
|
||||
int index = start;
|
||||
while (index < labels.size() &&
|
||||
code.length() < RecodedCharID::kMaxCodeLen) {
|
||||
code.Set(code.length(), labels[index++]);
|
||||
while (index < labels.size() && labels[index] == null_char_) ++index;
|
||||
int uni_id = recoder_.DecodeUnichar(code);
|
||||
// If the next label isn't a valid first code, then we need to continue
|
||||
// extending even if we have a valid uni_id from this prefix.
|
||||
if (uni_id != INVALID_UNICHAR_ID &&
|
||||
(index == labels.size() ||
|
||||
code.length() == RecodedCharID::kMaxCodeLen ||
|
||||
recoder_.IsValidFirstCode(labels[index]))) {
|
||||
*end = index;
|
||||
if (decoded != NULL) *decoded = uni_id;
|
||||
if (uni_id == UNICHAR_SPACE) return " ";
|
||||
return GetUnicharset().get_normed_unichar(uni_id);
|
||||
}
|
||||
}
|
||||
return "<Undecodable>";
|
||||
} else {
|
||||
if (decoded != NULL) *decoded = labels[start];
|
||||
if (labels[start] == null_char_) return "<null>";
|
||||
if (labels[start] == UNICHAR_SPACE) return " ";
|
||||
return GetUnicharset().get_normed_unichar(labels[start]);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a string corresponding to a given single label id, falling back to
|
||||
// a default of ".." for part of a multi-label unichar-id.
|
||||
const char* LSTMRecognizer::DecodeSingleLabel(int label) {
|
||||
if (label == null_char_) return "<null>";
|
||||
if (IsRecoding()) {
|
||||
// Decode label via recoder_.
|
||||
RecodedCharID code;
|
||||
code.Set(0, label);
|
||||
label = recoder_.DecodeUnichar(code);
|
||||
if (label == INVALID_UNICHAR_ID) return ".."; // Part of a bigger code.
|
||||
}
|
||||
if (label == UNICHAR_SPACE) return " ";
|
||||
return GetUnicharset().get_normed_unichar(label);
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
392
lstm/lstmrecognizer.h
Normal file
392
lstm/lstmrecognizer.h
Normal file
@ -0,0 +1,392 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: lstmrecognizer.h
|
||||
// Description: Top-level line recognizer class for LSTM-based networks.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu May 02 08:57:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
|
||||
#define TESSERACT_LSTM_LSTMRECOGNIZER_H_
|
||||
|
||||
#include "ccutil.h"
|
||||
#include "helpers.h"
|
||||
#include "imagedata.h"
|
||||
#include "matrix.h"
|
||||
#include "network.h"
|
||||
#include "networkscratch.h"
|
||||
#include "recodebeam.h"
|
||||
#include "series.h"
|
||||
#include "strngs.h"
|
||||
#include "unicharcompress.h"
|
||||
|
||||
class BLOB_CHOICE_IT;
|
||||
struct Pix;
|
||||
class ROW_RES;
|
||||
class ScrollView;
|
||||
class TBOX;
|
||||
class WERD_RES;
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
class Dict;
|
||||
class ImageData;
|
||||
|
||||
// Enum indicating training mode control flags.
|
||||
enum TrainingFlags {
|
||||
TF_INT_MODE = 1,
|
||||
TF_AUTO_HARDEN = 2,
|
||||
TF_ROUND_ROBIN_TRAINING = 16,
|
||||
TF_COMPRESS_UNICHARSET = 64,
|
||||
};
|
||||
|
||||
// Top-level line recognizer class for LSTM-based networks.
|
||||
// Note that a sub-class, LSTMTrainer is used for training.
|
||||
class LSTMRecognizer {
|
||||
public:
|
||||
LSTMRecognizer();
|
||||
~LSTMRecognizer();
|
||||
|
||||
int NumOutputs() const {
|
||||
return network_->NumOutputs();
|
||||
}
|
||||
int training_iteration() const {
|
||||
return training_iteration_;
|
||||
}
|
||||
int sample_iteration() const {
|
||||
return sample_iteration_;
|
||||
}
|
||||
double learning_rate() const {
|
||||
return learning_rate_;
|
||||
}
|
||||
bool IsHardening() const {
|
||||
return (training_flags_ & TF_AUTO_HARDEN) != 0;
|
||||
}
|
||||
LossType OutputLossType() const {
|
||||
if (network_ == nullptr) return LT_NONE;
|
||||
StaticShape shape;
|
||||
shape = network_->OutputShape(shape);
|
||||
return shape.loss_type();
|
||||
}
|
||||
bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; }
|
||||
bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; }
|
||||
// True if recoder_ is active to re-encode text to a smaller space.
|
||||
bool IsRecoding() const {
|
||||
return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0;
|
||||
}
|
||||
// Returns the cache strategy for the DocumentCache.
|
||||
CachingStrategy CacheStrategy() const {
|
||||
return training_flags_ & TF_ROUND_ROBIN_TRAINING ? CS_ROUND_ROBIN
|
||||
: CS_SEQUENTIAL;
|
||||
}
|
||||
// Returns true if the network is a TensorFlow network.
|
||||
bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
|
||||
// Returns a vector of layer ids that can be passed to other layer functions
|
||||
// to access a specific layer.
|
||||
GenericVector<STRING> EnumerateLayers() const {
|
||||
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
|
||||
Series* series = reinterpret_cast<Series*>(network_);
|
||||
GenericVector<STRING> layers;
|
||||
series->EnumerateLayers(NULL, &layers);
|
||||
return layers;
|
||||
}
|
||||
// Returns a specific layer from its id (from EnumerateLayers).
|
||||
Network* GetLayer(const STRING& id) const {
|
||||
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
|
||||
ASSERT_HOST(id.length() > 1 && id[0] == ':');
|
||||
Series* series = reinterpret_cast<Series*>(network_);
|
||||
return series->GetLayer(&id[1]);
|
||||
}
|
||||
// Returns the learning rate of the layer from its id.
|
||||
float GetLayerLearningRate(const STRING& id) const {
|
||||
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
|
||||
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
|
||||
ASSERT_HOST(id.length() > 1 && id[0] == ':');
|
||||
Series* series = reinterpret_cast<Series*>(network_);
|
||||
return series->LayerLearningRate(&id[1]);
|
||||
} else {
|
||||
return learning_rate_;
|
||||
}
|
||||
}
|
||||
// Multiplies the all the learning rate(s) by the given factor.
|
||||
void ScaleLearningRate(double factor) {
|
||||
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
|
||||
learning_rate_ *= factor;
|
||||
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
|
||||
GenericVector<STRING> layers = EnumerateLayers();
|
||||
for (int i = 0; i < layers.size(); ++i) {
|
||||
ScaleLayerLearningRate(layers[i], factor);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Multiplies the learning rate of the layer with id, by the given factor.
|
||||
void ScaleLayerLearningRate(const STRING& id, double factor) {
|
||||
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
|
||||
ASSERT_HOST(id.length() > 1 && id[0] == ':');
|
||||
Series* series = reinterpret_cast<Series*>(network_);
|
||||
series->ScaleLayerLearningRate(&id[1], factor);
|
||||
}
|
||||
|
||||
// True if the network is using adagrad to train.
|
||||
bool IsUsingAdaGrad() const { return network_->TestFlag(NF_ADA_GRAD); }
|
||||
// Provides access to the UNICHARSET that this classifier works with.
|
||||
const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
|
||||
// Sets the sample iteration to the given value. The sample_iteration_
|
||||
// determines the seed for the random number generator. The training
|
||||
// iteration is incremented only by a successful training iteration.
|
||||
void SetIteration(int iteration) {
|
||||
sample_iteration_ = iteration;
|
||||
}
|
||||
// Accessors for textline image normalization.
|
||||
int NumInputs() const {
|
||||
return network_->NumInputs();
|
||||
}
|
||||
int null_char() const { return null_char_; }
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool Serialize(TFile* fp) const;
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool DeSerialize(bool swap, TFile* fp);
|
||||
// Loads the dictionary if possible from the traineddata file.
|
||||
// Prints a warning message, and returns false but otherwise fails silently
|
||||
// and continues to work without it if loading fails.
|
||||
// Note that dictionary load is independent from DeSerialize, but dependent
|
||||
// on the unicharset matching. This enables training to deserialize a model
|
||||
// from checkpoint or restore without having to go back and reload the
|
||||
// dictionary.
|
||||
bool LoadDictionary(const char* data_file_name, const char* lang);
|
||||
|
||||
// Recognizes the line image, contained within image_data, returning the
|
||||
// ratings matrix and matching box_word for each WERD_RES in the output.
|
||||
// If invert, tries inverted as well if the normal interpretation doesn't
|
||||
// produce a good enough result. If use_alternates, the ratings matrix is
|
||||
// filled with segmentation and classifier alternatives that may be searched
|
||||
// using the standard beam search, otherwise, just a diagonal and prebuilt
|
||||
// best_choice. The line_box is used for computing the box_word in the
|
||||
// output words. Score_ratio is used to determine the classifier alternates.
|
||||
// If one_word, then a single WERD_RES is formed, regardless of the spaces
|
||||
// found during recognition.
|
||||
// If not NULL, we attempt to translate the output to target_unicharset, but
|
||||
// do not guarantee success, due to mismatches. In that case the output words
|
||||
// are marked with our UNICHARSET, not the caller's.
|
||||
void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
|
||||
double worst_dict_cert, bool use_alternates,
|
||||
const UNICHARSET* target_unicharset, const TBOX& line_box,
|
||||
float score_ratio, bool one_word,
|
||||
PointerVector<WERD_RES>* words);
|
||||
// Builds a set of tesseract-compatible WERD_RESs aligned to line_box,
|
||||
// corresponding to the network output in outputs, labels, label_coords.
|
||||
// one_word generates a single word output, that may include spaces inside.
|
||||
// use_alternates generates alternative BLOB_CHOICEs and segmentation paths,
|
||||
// with cut-offs determined by scale_factor.
|
||||
// If not NULL, we attempt to translate the output to target_unicharset, but
|
||||
// do not guarantee success, due to mismatches. In that case the output words
|
||||
// are marked with our UNICHARSET, not the caller's.
|
||||
void WordsFromOutputs(const NetworkIO& outputs,
|
||||
const GenericVector<int>& labels,
|
||||
const GenericVector<int> label_coords,
|
||||
const TBOX& line_box, bool debug, bool use_alternates,
|
||||
bool one_word, float score_ratio, float scale_factor,
|
||||
const UNICHARSET* target_unicharset,
|
||||
PointerVector<WERD_RES>* words);
|
||||
|
||||
// Helper computes min and mean best results in the output.
|
||||
void OutputStats(const NetworkIO& outputs,
|
||||
float* min_output, float* mean_output, float* sd);
|
||||
// Recognizes the image_data, returning the labels,
|
||||
// scores, and corresponding pairs of start, end x-coords in coords.
|
||||
// If label_threshold is positive, uses it for making the labels, otherwise
|
||||
// uses standard ctc. Returned in scale_factor is the reduction factor
|
||||
// between the image and the output coords, for computing bounding boxes.
|
||||
// If re_invert is true, the input is inverted back to its orginal
|
||||
// photometric interpretation if inversion is attempted but fails to
|
||||
// improve the results. This ensures that outputs contains the correct
|
||||
// forward outputs for the best photometric interpretation.
|
||||
// inputs is filled with the used inputs to the network, and if not null,
|
||||
// target boxes is filled with scaled truth boxes if present in image_data.
|
||||
bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
|
||||
bool re_invert, float label_threshold, float* scale_factor,
|
||||
NetworkIO* inputs, NetworkIO* outputs);
|
||||
// Returns a tesseract-compatible WERD_RES from the line recognizer outputs.
|
||||
// line_box should be the bounding box of the line image in the main image,
|
||||
// outputs the output of the network,
|
||||
// [word_start, word_end) the interval over which to convert,
|
||||
// score_ratio for choosing alternate classifier choices,
|
||||
// use_alternates to control generation of alternative segmentations,
|
||||
// labels, label_coords, scale_factor from RecognizeLine above.
|
||||
// If target_unicharset is not NULL, attempts to translate the internal
|
||||
// unichar_ids to the target_unicharset, but falls back to untranslated ids
|
||||
// if the translation should fail.
|
||||
WERD_RES* WordFromOutput(const TBOX& line_box, const NetworkIO& outputs,
|
||||
int word_start, int word_end, float score_ratio,
|
||||
float space_certainty, bool debug,
|
||||
bool use_alternates,
|
||||
const UNICHARSET* target_unicharset,
|
||||
const GenericVector<int>& labels,
|
||||
const GenericVector<int>& label_coords,
|
||||
float scale_factor);
|
||||
// Sets up a word with the ratings matrix and fake blobs with boxes in the
|
||||
// right places.
|
||||
WERD_RES* InitializeWord(const TBOX& line_box, int word_start, int word_end,
|
||||
float space_certainty, bool use_alternates,
|
||||
const UNICHARSET* target_unicharset,
|
||||
const GenericVector<int>& labels,
|
||||
const GenericVector<int>& label_coords,
|
||||
float scale_factor);
|
||||
|
||||
// Converts an array of labels to utf-8, whether or not the labels are
|
||||
// augmented with character boundaries.
|
||||
STRING DecodeLabels(const GenericVector<int>& labels);
|
||||
|
||||
// Displays the forward results in a window with the characters and
|
||||
// boundaries as determined by the labels and label_coords.
|
||||
void DisplayForward(const NetworkIO& inputs,
|
||||
const GenericVector<int>& labels,
|
||||
const GenericVector<int>& label_coords,
|
||||
const char* window_name,
|
||||
ScrollView** window);
|
||||
|
||||
protected:
|
||||
// Sets the random seed from the sample_iteration_;
|
||||
void SetRandomSeed() {
|
||||
inT64 seed = static_cast<inT64>(sample_iteration_) * 0x10000001;
|
||||
randomizer_.set_seed(seed);
|
||||
randomizer_.IntRand();
|
||||
}
|
||||
|
||||
// Displays the labels and cuts at the corresponding xcoords.
|
||||
// Size of labels should match xcoords.
|
||||
void DisplayLSTMOutput(const GenericVector<int>& labels,
|
||||
const GenericVector<int>& xcoords,
|
||||
int height, ScrollView* window);
|
||||
|
||||
// Prints debug output detailing the activation path that is implied by the
|
||||
// xcoords.
|
||||
void DebugActivationPath(const NetworkIO& outputs,
|
||||
const GenericVector<int>& labels,
|
||||
const GenericVector<int>& xcoords);
|
||||
|
||||
// Prints debug output detailing activations and 2nd choice over a range
|
||||
// of positions.
|
||||
void DebugActivationRange(const NetworkIO& outputs, const char* label,
|
||||
int best_choice, int x_start, int x_end);
|
||||
|
||||
// Converts the network output to a sequence of labels. Outputs labels, scores
|
||||
// and start xcoords of each char, and each null_char_, with an additional
|
||||
// final xcoord for the end of the output.
|
||||
// The conversion method is determined by internal state.
|
||||
void LabelsFromOutputs(const NetworkIO& outputs, float null_thr,
|
||||
GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords);
|
||||
// Converts the network output to a sequence of labels, using a threshold
|
||||
// on the null_char_ to determine character boundaries. Outputs labels, scores
|
||||
// and start xcoords of each char, and each null_char_, with an additional
|
||||
// final xcoord for the end of the output.
|
||||
// The label output is the one with the highest score in the interval between
|
||||
// null_chars_.
|
||||
void LabelsViaThreshold(const NetworkIO& output,
|
||||
float null_threshold,
|
||||
GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords);
|
||||
// Converts the network output to a sequence of labels, with scores and
|
||||
// start x-coords of the character labels. Retains the null_char_ character as
|
||||
// the end x-coord, where already present, otherwise the start of the next
|
||||
// character is the end.
|
||||
// The number of labels, scores, and xcoords is always matched, except that
|
||||
// there is always an additional xcoord for the last end position.
|
||||
void LabelsViaCTC(const NetworkIO& output,
|
||||
GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords);
|
||||
// As LabelsViaCTC except that this function constructs the best path that
|
||||
// contains only legal sequences of subcodes for recoder_.
|
||||
void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords);
|
||||
// Converts the network output to a sequence of labels, with scores, using
|
||||
// the simple character model (each position is a char, and the null_char_ is
|
||||
// mainly intended for tail padding.)
|
||||
void LabelsViaSimpleText(const NetworkIO& output,
|
||||
GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords);
|
||||
|
||||
// Helper returns a BLOB_CHOICE_LIST for the choices in a given x-range.
|
||||
// Handles either LSTM labels or direct unichar-ids.
|
||||
// Score ratio determines the worst ratio between top choice and remainder.
|
||||
// If target_unicharset is not NULL, attempts to translate to the target
|
||||
// unicharset, returning NULL on failure.
|
||||
BLOB_CHOICE_LIST* GetBlobChoices(int col, int row, bool debug,
|
||||
const NetworkIO& output,
|
||||
const UNICHARSET* target_unicharset,
|
||||
int x_start, int x_end, float score_ratio);
|
||||
|
||||
// Adds to the given iterator, the blob choices for the target_unicharset
|
||||
// that correspond to the given LSTM unichar_id.
|
||||
// Returns false if unicharset translation failed.
|
||||
bool AddBlobChoices(int unichar_id, float rating, float certainty, int col,
|
||||
int row, const UNICHARSET* target_unicharset,
|
||||
BLOB_CHOICE_IT* bc_it);
|
||||
|
||||
// Returns a string corresponding to the label starting at start. Sets *end
|
||||
// to the next start and if non-null, *decoded to the unichar id.
|
||||
const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
|
||||
int* decoded);
|
||||
|
||||
// Returns a string corresponding to a given single label id, falling back to
|
||||
// a default of ".." for part of a multi-label unichar-id.
|
||||
const char* DecodeSingleLabel(int label);
|
||||
|
||||
protected:
|
||||
// The network hierarchy.
|
||||
Network* network_;
|
||||
// The unicharset. Only the unicharset element is serialized.
|
||||
// Has to be a CCUtil, so Dict can point to it.
|
||||
CCUtil ccutil_;
|
||||
// For backward compatability, recoder_ is serialized iff
|
||||
// training_flags_ & TF_COMPRESS_UNICHARSET.
|
||||
// Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
|
||||
UnicharCompress recoder_;
|
||||
|
||||
// ==Training parameters that are serialized to provide a record of them.==
|
||||
STRING network_str_;
|
||||
// Flags used to determine the training method of the network.
|
||||
// See enum TrainingFlags above.
|
||||
inT32 training_flags_;
|
||||
// Number of actual backward training steps used.
|
||||
inT32 training_iteration_;
|
||||
// Index into training sample set. sample_iteration >= training_iteration_.
|
||||
inT32 sample_iteration_;
|
||||
// Index in softmax of null character. May take the value UNICHAR_BROKEN or
|
||||
// ccutil_.unicharset.size().
|
||||
inT32 null_char_;
|
||||
// Range used for the initial random numbers in the weights.
|
||||
float weight_range_;
|
||||
// Learning rate and momentum multipliers of deltas in backprop.
|
||||
float learning_rate_;
|
||||
float momentum_;
|
||||
|
||||
// === NOT SERIALIZED.
|
||||
TRand randomizer_;
|
||||
NetworkScratch scratch_space_;
|
||||
// Language model (optional) to use with the beam search.
|
||||
Dict* dict_;
|
||||
// Beam search held between uses to optimize memory allocation/use.
|
||||
RecodeBeamSearch* search_;
|
||||
|
||||
// == Debugging parameters.==
|
||||
// Recognition debug display window.
|
||||
ScrollView* debug_win_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_
|
1331
lstm/lstmtrainer.cpp
Normal file
1331
lstm/lstmtrainer.cpp
Normal file
File diff suppressed because it is too large
Load Diff
477
lstm/lstmtrainer.h
Normal file
477
lstm/lstmtrainer.h
Normal file
@ -0,0 +1,477 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: lstmtrainer.h
|
||||
// Description: Top-level line trainer class for LSTM-based networks.
|
||||
// Author: Ray Smith
|
||||
// Created: Fri May 03 09:07:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_LSTMTRAINER_H_
|
||||
#define TESSERACT_LSTM_LSTMTRAINER_H_
|
||||
|
||||
#include "imagedata.h"
|
||||
#include "lstmrecognizer.h"
|
||||
#include "rect.h"
|
||||
#include "tesscallback.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
class LSTM;
|
||||
class LSTMTrainer;
|
||||
class Parallel;
|
||||
class Reversed;
|
||||
class Softmax;
|
||||
class Series;
|
||||
|
||||
// Enum for the types of errors that are counted.
|
||||
enum ErrorTypes {
|
||||
ET_RMS, // RMS activation error.
|
||||
ET_DELTA, // Number of big errors in deltas.
|
||||
ET_WORD_RECERR, // Output text string word recall error.
|
||||
ET_CHAR_ERROR, // Output text string total char error.
|
||||
ET_SKIP_RATIO, // Fraction of samples skipped.
|
||||
ET_COUNT // For array sizing.
|
||||
};
|
||||
|
||||
// Enum for the trainability_ flags.
|
||||
enum Trainability {
|
||||
TRAINABLE, // Non-zero delta error.
|
||||
PERFECT, // Zero delta error.
|
||||
UNENCODABLE, // Not trainable due to coding/alignment trouble.
|
||||
HI_PRECISION_ERR, // Hi confidence disagreement.
|
||||
NOT_BOXED, // Early in training and has no character boxes.
|
||||
};
|
||||
|
||||
// Enum to define the amount of data to get serialized.
|
||||
enum SerializeAmount {
|
||||
LIGHT, // Minimal data for remote training.
|
||||
NO_BEST_TRAINER, // Save an empty vector in place of best_trainer_.
|
||||
FULL, // All data including best_trainer_.
|
||||
};
|
||||
|
||||
// Enum to indicate how the sub_trainer_ training went.
|
||||
enum SubTrainerResult {
|
||||
STR_NONE, // Did nothing as not good enough.
|
||||
STR_UPDATED, // Subtrainer was updated, but didn't replace *this.
|
||||
STR_REPLACED // Subtrainer replaced *this.
|
||||
};
|
||||
|
||||
class LSTMTrainer;
|
||||
// Function to restore the trainer state from a given checkpoint.
|
||||
// Returns false on failure.
|
||||
typedef TessResultCallback2<bool, const GenericVector<char>&, LSTMTrainer*>*
|
||||
CheckPointReader;
|
||||
// Function to save a checkpoint of the current trainer state.
|
||||
// Returns false on failure. SerializeAmount determines the amount of the
|
||||
// trainer to serialize, typically used for saving the best state.
|
||||
typedef TessResultCallback3<bool, SerializeAmount, const LSTMTrainer*,
|
||||
GenericVector<char>*>* CheckPointWriter;
|
||||
// Function to compute and record error rates on some external test set(s).
|
||||
// Args are: iteration, mean errors, model, training stage.
|
||||
// Returns a STRING containing logging information about the tests.
|
||||
typedef TessResultCallback4<STRING, int, const double*,
|
||||
const GenericVector<char>&, int>* TestCallback;
|
||||
|
||||
// Trainer class for LSTM networks. Most of the effort is in creating the
|
||||
// ideal target outputs from the transcription. A box file is used if it is
|
||||
// available, otherwise estimates of the char widths from the unicharset are
|
||||
// used to guide a DP search for the best fit to the transcription.
|
||||
class LSTMTrainer : public LSTMRecognizer {
|
||||
public:
|
||||
LSTMTrainer();
|
||||
// Callbacks may be null, in which case defaults are used.
|
||||
LSTMTrainer(FileReader file_reader, FileWriter file_writer,
|
||||
CheckPointReader checkpoint_reader,
|
||||
CheckPointWriter checkpoint_writer,
|
||||
const char* model_base, const char* checkpoint_name,
|
||||
int debug_interval, inT64 max_memory);
|
||||
virtual ~LSTMTrainer();
|
||||
|
||||
// Tries to deserialize a trainer from the given file and silently returns
|
||||
// false in case of failure.
|
||||
bool TryLoadingCheckpoint(const char* filename);
|
||||
|
||||
// Initializes the character set encode/decode mechanism.
|
||||
// train_flags control training behavior according to the TrainingFlags
|
||||
// enum, including character set encoding.
|
||||
// script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided,
|
||||
// fully initializes the unicharset from the universal unicharsets.
|
||||
// Note: Call before InitNetwork!
|
||||
void InitCharSet(const UNICHARSET& unicharset, const STRING& script_dir,
|
||||
int train_flags);
|
||||
// Initializes the character set encode/decode mechanism directly from a
|
||||
// previously setup UNICHARSET and UnicharCompress.
|
||||
// ctc_mode controls how the truth text is mapped to the network targets.
|
||||
// Note: Call before InitNetwork!
|
||||
void InitCharSet(const UNICHARSET& unicharset, const UnicharCompress recoder);
|
||||
|
||||
// Initializes the trainer with a network_spec in the network description
|
||||
// net_flags control network behavior according to the NetworkFlags enum.
|
||||
// There isn't really much difference between them - only where the effects
|
||||
// are implemented.
|
||||
// For other args see NetworkBuilder::InitNetwork.
|
||||
// Note: Be sure to call InitCharSet before InitNetwork!
|
||||
bool InitNetwork(const STRING& network_spec, int append_index, int net_flags,
|
||||
float weight_range, float learning_rate, float momentum);
|
||||
// Initializes a trainer from a serialized TFNetworkModel proto.
|
||||
// Returns the global step of TensorFlow graph or 0 if failed.
|
||||
// Building a compatible TF graph: See tfnetwork.proto.
|
||||
int InitTensorFlowNetwork(const std::string& tf_proto);
|
||||
|
||||
// Accessors.
|
||||
double ActivationError() const {
|
||||
return error_rates_[ET_DELTA];
|
||||
}
|
||||
double CharError() const { return error_rates_[ET_CHAR_ERROR]; }
|
||||
const double* error_rates() const {
|
||||
return error_rates_;
|
||||
}
|
||||
double best_error_rate() const {
|
||||
return best_error_rate_;
|
||||
}
|
||||
int best_iteration() const {
|
||||
return best_iteration_;
|
||||
}
|
||||
int learning_iteration() const { return learning_iteration_; }
|
||||
int improvement_steps() const { return improvement_steps_; }
|
||||
void set_perfect_delay(int delay) { perfect_delay_ = delay; }
|
||||
const GenericVector<char>& best_trainer() const { return best_trainer_; }
|
||||
// Returns the error that was just calculated by PrepareForBackward.
|
||||
double NewSingleError(ErrorTypes type) const {
|
||||
return error_buffers_[type][training_iteration() % kRollingBufferSize_];
|
||||
}
|
||||
// Returns the error that was just calculated by TrainOnLine. Since
|
||||
// TrainOnLine rolls the error buffers, this is one further back than
|
||||
// NewSingleError.
|
||||
double LastSingleError(ErrorTypes type) const {
|
||||
return error_buffers_[type]
|
||||
[(training_iteration() + kRollingBufferSize_ - 1) %
|
||||
kRollingBufferSize_];
|
||||
}
|
||||
const DocumentCache& training_data() const {
|
||||
return training_data_;
|
||||
}
|
||||
DocumentCache* mutable_training_data() { return &training_data_; }
|
||||
|
||||
// If the training sample is usable, grid searches for the optimal
|
||||
// dict_ratio/cert_offset, and returns the results in a string of space-
|
||||
// separated triplets of ratio,offset=worderr.
|
||||
Trainability GridSearchDictParams(
|
||||
const ImageData* trainingdata, int iteration, double min_dict_ratio,
|
||||
double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
|
||||
double cert_offset_step, double max_cert_offset, STRING* results);
|
||||
|
||||
void SetSerializeMode(SerializeAmount serialize_amount) const {
|
||||
serialize_amount_ = serialize_amount;
|
||||
}
|
||||
|
||||
// Provides output on the distribution of weight values.
|
||||
void DebugNetwork();
|
||||
|
||||
// Loads a set of lstmf files that were created using the lstm.train config to
|
||||
// tesseract into memory ready for training. Returns false if nothing was
|
||||
// loaded.
|
||||
bool LoadAllTrainingData(const GenericVector<STRING>& filenames);
|
||||
|
||||
// Keeps track of best and locally worst error rate, using internally computed
|
||||
// values. See MaintainCheckpointsSpecific for more detail.
|
||||
bool MaintainCheckpoints(TestCallback tester, STRING* log_msg);
|
||||
// Keeps track of best and locally worst error_rate (whatever it is) and
|
||||
// launches tests using rec_model, when a new min or max is reached.
|
||||
// Writes checkpoints using train_model at appropriate times and builds and
|
||||
// returns a log message to indicate progress. Returns false if nothing
|
||||
// interesting happened.
|
||||
bool MaintainCheckpointsSpecific(int iteration,
|
||||
const GenericVector<char>* train_model,
|
||||
const GenericVector<char>* rec_model,
|
||||
TestCallback tester, STRING* log_msg);
|
||||
// Builds a string containing a progress message with current error rates.
|
||||
void PrepareLogMsg(STRING* log_msg) const;
|
||||
// Appends <intro_str> iteration learning_iteration()/training_iteration()/
|
||||
// sample_iteration() to the log_msg.
|
||||
void LogIterations(const char* intro_str, STRING* log_msg) const;
|
||||
|
||||
// TODO(rays) Add curriculum learning.
|
||||
// Returns true and increments the training_stage_ if the error rate has just
|
||||
// passed through the given threshold for the first time.
|
||||
bool TransitionTrainingStage(float error_threshold);
|
||||
// Returns the current training stage.
|
||||
int CurrentTrainingStage() const { return training_stage_; }
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
virtual bool Serialize(TFile* fp) const;
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
virtual bool DeSerialize(bool swap, TFile* fp);
|
||||
|
||||
// De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
|
||||
// learning rates (by scaling reduction, or layer specific, according to
|
||||
// NF_LAYER_SPECIFIC_LR).
|
||||
void StartSubtrainer(STRING* log_msg);
|
||||
// While the sub_trainer_ is behind the current training iteration and its
|
||||
// training error is at least kSubTrainerMarginFraction better than the
|
||||
// current training error, trains the sub_trainer_, and returns STR_UPDATED if
|
||||
// it did anything. If it catches up, and has a better error rate than the
|
||||
// current best, as well as a margin over the current error rate, then the
|
||||
// trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
|
||||
// returned. STR_NONE is returned if the subtrainer wasn't good enough to
|
||||
// receive any training iterations.
|
||||
SubTrainerResult UpdateSubtrainer(STRING* log_msg);
|
||||
// Reduces network learning rates, either for everything, or for layers
|
||||
// independently, according to NF_LAYER_SPECIFIC_LR.
|
||||
void ReduceLearningRates(LSTMTrainer* samples_trainer, STRING* log_msg);
|
||||
// Considers reducing the learning rate independently for each layer down by
|
||||
// factor(<1), or leaving it the same, by double-training the given number of
|
||||
// samples and minimizing the amount of changing of sign of weight updates.
|
||||
// Even if it looks like all weights should remain the same, an adjustment
|
||||
// will be made to guarantee a different result when reverting to an old best.
|
||||
// Returns the number of layer learning rates that were reduced.
|
||||
int ReduceLayerLearningRates(double factor, int num_samples,
|
||||
LSTMTrainer* samples_trainer);
|
||||
|
||||
// Converts the string to integer class labels, with appropriate null_char_s
|
||||
// in between if not in SimpleTextOutput mode. Returns false on failure.
|
||||
bool EncodeString(const STRING& str, GenericVector<int>* labels) const {
|
||||
return EncodeString(str, GetUnicharset(), IsRecoding() ? &recoder_ : NULL,
|
||||
SimpleTextOutput(), null_char_, labels);
|
||||
}
|
||||
// Static version operates on supplied unicharset, encoder, simple_text.
|
||||
static bool EncodeString(const STRING& str, const UNICHARSET& unicharset,
|
||||
const UnicharCompress* recoder, bool simple_text,
|
||||
int null_char, GenericVector<int>* labels);
|
||||
|
||||
// Converts the network to int if not already.
|
||||
void ConvertToInt() {
|
||||
if ((training_flags_ & TF_INT_MODE) == 0) {
|
||||
network_->ConvertToInt();
|
||||
training_flags_ |= TF_INT_MODE;
|
||||
}
|
||||
}
|
||||
|
||||
// Performs forward-backward on the given trainingdata.
|
||||
// Returns the sample that was used or NULL if the next sample was deemed
|
||||
// unusable. samples_trainer could be this or an alternative trainer that
|
||||
// holds the training samples.
|
||||
const ImageData* TrainOnLine(LSTMTrainer* samples_trainer, bool batch) {
|
||||
int sample_index = sample_iteration();
|
||||
const ImageData* image =
|
||||
samples_trainer->training_data_.GetPageBySerial(sample_index);
|
||||
if (image != NULL) {
|
||||
Trainability trainable = TrainOnLine(image, batch);
|
||||
if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
|
||||
return NULL; // Sample was unusable.
|
||||
}
|
||||
} else {
|
||||
++sample_iteration_;
|
||||
}
|
||||
return image;
|
||||
}
|
||||
Trainability TrainOnLine(const ImageData* trainingdata, bool batch);
|
||||
|
||||
// Prepares the ground truth, runs forward, and prepares the targets.
|
||||
// Returns a Trainability enum to indicate the suitability of the sample.
|
||||
Trainability PrepareForBackward(const ImageData* trainingdata,
|
||||
NetworkIO* fwd_outputs, NetworkIO* targets);
|
||||
|
||||
// Writes the trainer to memory, so that the current training state can be
|
||||
// restored.
|
||||
bool SaveTrainingDump(SerializeAmount serialize_amount,
|
||||
const LSTMTrainer* trainer,
|
||||
GenericVector<char>* data) const;
|
||||
|
||||
// Reads previously saved trainer from memory.
|
||||
bool ReadTrainingDump(const GenericVector<char>& data, LSTMTrainer* trainer);
|
||||
bool ReadSizedTrainingDump(const char* data, int size);
|
||||
|
||||
// Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
|
||||
void SetupCheckpointInfo();
|
||||
|
||||
// Writes the recognizer to memory, so that it can be used for testing later.
|
||||
void SaveRecognitionDump(GenericVector<char>* data) const;
|
||||
|
||||
// Reads and returns a previously saved recognizer from memory.
|
||||
static LSTMRecognizer* ReadRecognitionDump(const GenericVector<char>& data);
|
||||
|
||||
// Writes current best model to a file, unless it has already been written.
|
||||
bool SaveBestModel(FileWriter writer) const;
|
||||
|
||||
// Returns a suitable filename for a training dump, based on the model_base_,
|
||||
// the iteration and the error rates.
|
||||
STRING DumpFilename() const;
|
||||
|
||||
// Fills the whole error buffer of the given type with the given value.
|
||||
void FillErrorBuffer(double new_error, ErrorTypes type);
|
||||
|
||||
protected:
|
||||
// Factored sub-constructor sets up reasonable default values.
|
||||
void EmptyConstructor();
|
||||
|
||||
// Sets the unicharset properties using the given script_dir as a source of
|
||||
// script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets
|
||||
// up the recoder_ to simplify the unicharset.
|
||||
void SetUnicharsetProperties(const STRING& script_dir);
|
||||
|
||||
// Outputs the string and periodically displays the given network inputs
|
||||
// as an image in the given window, and the corresponding labels at the
|
||||
// corresponding x_starts.
|
||||
// Returns false if the truth string is empty.
|
||||
bool DebugLSTMTraining(const NetworkIO& inputs,
|
||||
const ImageData& trainingdata,
|
||||
const NetworkIO& fwd_outputs,
|
||||
const GenericVector<int>& truth_labels,
|
||||
const NetworkIO& outputs);
|
||||
// Displays the network targets as line a line graph.
|
||||
void DisplayTargets(const NetworkIO& targets, const char* window_name,
|
||||
ScrollView** window);
|
||||
|
||||
// Builds a no-compromises target where the first positions should be the
|
||||
// truth labels and the rest is padded with the null_char_.
|
||||
bool ComputeTextTargets(const NetworkIO& outputs,
|
||||
const GenericVector<int>& truth_labels,
|
||||
NetworkIO* targets);
|
||||
|
||||
// Builds a target using standard CTC. truth_labels should be pre-padded with
|
||||
// nulls wherever desired. They don't have to be between all labels.
|
||||
// outputs is input-output, as it gets clipped to minimum probability.
|
||||
bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
|
||||
NetworkIO* outputs, NetworkIO* targets);
|
||||
|
||||
// Computes network errors, and stores the results in the rolling buffers,
|
||||
// along with the supplied text_error.
|
||||
// Returns the delta error of the current sample (not running average.)
|
||||
double ComputeErrorRates(const NetworkIO& deltas, double char_error,
|
||||
double word_error);
|
||||
|
||||
// Computes the network activation RMS error rate.
|
||||
double ComputeRMSError(const NetworkIO& deltas);
|
||||
|
||||
// Computes network activation winner error rate. (Number of values that are
|
||||
// in error by >= 0.5 divided by number of time-steps.) More closely related
|
||||
// to final character error than RMS, but still directly calculable from
|
||||
// just the deltas. Because of the binary nature of the targets, zero winner
|
||||
// error is a sufficient but not necessary condition for zero char error.
|
||||
double ComputeWinnerError(const NetworkIO& deltas);
|
||||
|
||||
// Computes a very simple bag of chars char error rate.
|
||||
double ComputeCharError(const GenericVector<int>& truth_str,
|
||||
const GenericVector<int>& ocr_str);
|
||||
// Computes a very simple bag of words word recall error rate.
|
||||
// NOTE that this is destructive on both input strings.
|
||||
double ComputeWordError(STRING* truth_str, STRING* ocr_str);
|
||||
|
||||
// Updates the error buffer and corresponding mean of the given type with
|
||||
// the new_error.
|
||||
void UpdateErrorBuffer(double new_error, ErrorTypes type);
|
||||
|
||||
// Rolls error buffers and reports the current means.
|
||||
void RollErrorBuffers();
|
||||
|
||||
// Given that error_rate is either a new min or max, updates the best/worst
|
||||
// error rates, and record of progress.
|
||||
STRING UpdateErrorGraph(int iteration, double error_rate,
|
||||
const GenericVector<char>& model_data,
|
||||
TestCallback tester);
|
||||
|
||||
protected:
|
||||
// Alignment display window.
|
||||
ScrollView* align_win_;
|
||||
// CTC target display window.
|
||||
ScrollView* target_win_;
|
||||
// CTC output display window.
|
||||
ScrollView* ctc_win_;
|
||||
// Reconstructed image window.
|
||||
ScrollView* recon_win_;
|
||||
// How often to display a debug image.
|
||||
int debug_interval_;
|
||||
// Iteration at which the last checkpoint was dumped.
|
||||
int checkpoint_iteration_;
|
||||
// Basename of files to save best models to.
|
||||
STRING model_base_;
|
||||
// Checkpoint filename.
|
||||
STRING checkpoint_name_;
|
||||
// Training data.
|
||||
DocumentCache training_data_;
|
||||
// A hack to serialize less data for batch training and record file version.
|
||||
mutable SerializeAmount serialize_amount_;
|
||||
// Name to use when saving best_trainer_.
|
||||
STRING best_model_name_;
|
||||
// Number of available training stages.
|
||||
int num_training_stages_;
|
||||
// Checkpointing callbacks.
|
||||
FileReader file_reader_;
|
||||
FileWriter file_writer_;
|
||||
// TODO(rays) These are pointers, and must be deleted. Switch to unique_ptr
|
||||
// when we can commit to c++11.
|
||||
CheckPointReader checkpoint_reader_;
|
||||
CheckPointWriter checkpoint_writer_;
|
||||
|
||||
// ===Serialized data to ensure that a restart produces the same results.===
|
||||
// These members are only serialized when serialize_amount_ != LIGHT.
|
||||
// Best error rate so far.
|
||||
double best_error_rate_;
|
||||
// Snapshot of all error rates at best_iteration_.
|
||||
double best_error_rates_[ET_COUNT];
|
||||
// Iteration of best_error_rate_.
|
||||
int best_iteration_;
|
||||
// Worst error rate since best_error_rate_.
|
||||
double worst_error_rate_;
|
||||
// Snapshot of all error rates at worst_iteration_.
|
||||
double worst_error_rates_[ET_COUNT];
|
||||
// Iteration of worst_error_rate_.
|
||||
int worst_iteration_;
|
||||
// Iteration at which the process will be thought stalled.
|
||||
int stall_iteration_;
|
||||
// Saved recognition models for computing test error for graph points.
|
||||
GenericVector<char> best_model_data_;
|
||||
GenericVector<char> worst_model_data_;
|
||||
// Saved trainer for reverting back to last known best.
|
||||
GenericVector<char> best_trainer_;
|
||||
// A subsidiary trainer running with a different learning rate until either
|
||||
// *this or sub_trainer_ hits a new best.
|
||||
LSTMTrainer* sub_trainer_;
|
||||
// Error rate at which last best model was dumped.
|
||||
float error_rate_of_last_saved_best_;
|
||||
// Current stage of training.
|
||||
int training_stage_;
|
||||
// History of best error rate against iteration. Used for computing the
|
||||
// number of steps to each 2% improvement.
|
||||
GenericVector<double> best_error_history_;
|
||||
GenericVector<int> best_error_iterations_;
|
||||
// Number of iterations since the best_error_rate_ was 2% more than it is now.
|
||||
int improvement_steps_;
|
||||
// Number of iterations that yielded a non-zero delta error and thus provided
|
||||
// significant learning. learning_iteration_ <= training_iteration_.
|
||||
// learning_iteration_ is used to measure rate of learning progress.
|
||||
int learning_iteration_;
|
||||
// Saved value of sample_iteration_ before looking for the the next sample.
|
||||
int prev_sample_iteration_;
|
||||
// How often to include a PERFECT training sample in backprop.
|
||||
// A PERFECT training sample is used if the current
|
||||
// training_iteration_ > last_perfect_training_iteration_ + perfect_delay_,
|
||||
// so with perfect_delay_ == 0, all samples are used, and with
|
||||
// perfect_delay_ == 4, at most 1 in 5 samples will be perfect.
|
||||
int perfect_delay_;
|
||||
// Value of training_iteration_ at which the last PERFECT training sample
|
||||
// was used in back prop.
|
||||
int last_perfect_training_iteration_;
|
||||
// Rolling buffers storing recent training errors are indexed by
|
||||
// training_iteration % kRollingBufferSize_.
|
||||
static const int kRollingBufferSize_ = 1000;
|
||||
GenericVector<double> error_buffers_[ET_COUNT];
|
||||
// Rounded mean percent trailing training errors in the buffers.
|
||||
double error_rates_[ET_COUNT]; // RMS training error.
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_LSTMTRAINER_H_
|
87
lstm/maxpool.cpp
Normal file
87
lstm/maxpool.cpp
Normal file
@ -0,0 +1,87 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: maxpool.h
|
||||
// Description: Standard Max-Pooling layer.
|
||||
// Author: Ray Smith
|
||||
// Created: Tue Mar 18 16:28:18 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "maxpool.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
Maxpool::Maxpool(const STRING& name, int ni, int x_scale, int y_scale)
|
||||
: Reconfig(name, ni, x_scale, y_scale) {
|
||||
type_ = NT_MAXPOOL;
|
||||
no_ = ni;
|
||||
}
|
||||
|
||||
Maxpool::~Maxpool() {
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool Maxpool::DeSerialize(bool swap, TFile* fp) {
|
||||
bool result = Reconfig::DeSerialize(swap, fp);
|
||||
no_ = ni_;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
void Maxpool::Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output) {
|
||||
output->ResizeScaled(input, x_scale_, y_scale_, no_);
|
||||
maxes_.ResizeNoInit(output->Width(), ni_);
|
||||
back_map_ = input.stride_map();
|
||||
|
||||
StrideMap::Index dest_index(output->stride_map());
|
||||
do {
|
||||
int out_t = dest_index.t();
|
||||
StrideMap::Index src_index(input.stride_map(), dest_index.index(FD_BATCH),
|
||||
dest_index.index(FD_HEIGHT) * y_scale_,
|
||||
dest_index.index(FD_WIDTH) * x_scale_);
|
||||
// Find the max input out of x_scale_ groups of y_scale_ inputs.
|
||||
// Do it independently for each input dimension.
|
||||
int* max_line = maxes_[out_t];
|
||||
int in_t = src_index.t();
|
||||
output->CopyTimeStepFrom(out_t, input, in_t);
|
||||
for (int i = 0; i < ni_; ++i) {
|
||||
max_line[i] = in_t;
|
||||
}
|
||||
for (int x = 0; x < x_scale_; ++x) {
|
||||
for (int y = 0; y < y_scale_; ++y) {
|
||||
StrideMap::Index src_xy(src_index);
|
||||
if (src_xy.AddOffset(x, FD_WIDTH) && src_xy.AddOffset(y, FD_HEIGHT)) {
|
||||
output->MaxpoolTimeStep(out_t, input, src_xy.t(), max_line);
|
||||
}
|
||||
}
|
||||
}
|
||||
} while (dest_index.Increment());
|
||||
}
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
bool Maxpool::Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas) {
|
||||
back_deltas->ResizeToMap(fwd_deltas.int_mode(), back_map_, ni_);
|
||||
back_deltas->MaxpoolBackward(fwd_deltas, maxes_);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
} // namespace tesseract.
|
||||
|
71
lstm/maxpool.h
Normal file
71
lstm/maxpool.h
Normal file
@ -0,0 +1,71 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: maxpool.h
|
||||
// Description: Standard Max-Pooling layer.
|
||||
// Author: Ray Smith
|
||||
// Created: Tue Mar 18 16:28:18 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_MAXPOOL_H_
|
||||
#define TESSERACT_LSTM_MAXPOOL_H_
|
||||
|
||||
#include "reconfig.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Maxpooling reduction. Independently for each input, selects the location
|
||||
// in the rectangle that contains the max value.
|
||||
// Backprop propagates only to the position that was the max.
|
||||
class Maxpool : public Reconfig {
|
||||
public:
|
||||
Maxpool(const STRING& name, int ni, int x_scale, int y_scale);
|
||||
virtual ~Maxpool();
|
||||
|
||||
// Accessors.
|
||||
virtual STRING spec() const {
|
||||
STRING spec;
|
||||
spec.add_str_int("Mp", y_scale_);
|
||||
spec.add_str_int(",", x_scale_);
|
||||
return spec;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
virtual bool DeSerialize(bool swap, TFile* fp);
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual void Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output);
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas);
|
||||
|
||||
private:
|
||||
// Memory of which input was the max.
|
||||
GENERIC_2D_ARRAY<int> maxes_;
|
||||
};
|
||||
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#endif // TESSERACT_LSTM_MAXPOOL_H_
|
||||
|
309
lstm/network.cpp
Normal file
309
lstm/network.cpp
Normal file
@ -0,0 +1,309 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: network.cpp
|
||||
// Description: Base class for neural network implementations.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed May 01 17:25:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "network.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
// This base class needs to know about all its sub-classes because of the
|
||||
// factory deserializing method: CreateFromFile.
|
||||
#include "allheaders.h"
|
||||
#include "convolve.h"
|
||||
#include "fullyconnected.h"
|
||||
#include "input.h"
|
||||
#include "lstm.h"
|
||||
#include "maxpool.h"
|
||||
#include "parallel.h"
|
||||
#include "reconfig.h"
|
||||
#include "reversed.h"
|
||||
#include "scrollview.h"
|
||||
#include "series.h"
|
||||
#include "statistc.h"
|
||||
#ifdef INCLUDE_TENSORFLOW
|
||||
#include "tfnetwork.h"
|
||||
#endif
|
||||
#include "tprintf.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Min and max window sizes.
|
||||
const int kMinWinSize = 500;
|
||||
const int kMaxWinSize = 2000;
|
||||
// Window frame sizes need adding on to make the content fit.
|
||||
const int kXWinFrameSize = 30;
|
||||
const int kYWinFrameSize = 80;
|
||||
|
||||
// String names corresponding to the NetworkType enum. Keep in sync.
|
||||
// Names used in Serialization to allow re-ordering/addition/deletion of
|
||||
// layer types in NetworkType without invalidating existing network files.
|
||||
char const* const Network::kTypeNames[NT_COUNT] = {
|
||||
"Invalid", "Input",
|
||||
"Convolve", "Maxpool",
|
||||
"Parallel", "Replicated",
|
||||
"ParBidiLSTM", "DepParUDLSTM",
|
||||
"Par2dLSTM", "Series",
|
||||
"Reconfig", "RTLReversed",
|
||||
"TTBReversed", "XYTranspose",
|
||||
"LSTM", "SummLSTM",
|
||||
"Logistic", "LinLogistic",
|
||||
"LinTanh", "Tanh",
|
||||
"Relu", "Linear",
|
||||
"Softmax", "SoftmaxNoCTC",
|
||||
"LSTMSoftmax", "LSTMBinarySoftmax",
|
||||
"TensorFlow",
|
||||
};
|
||||
|
||||
Network::Network()
|
||||
: type_(NT_NONE), training_(true), needs_to_backprop_(true),
|
||||
network_flags_(0), ni_(0), no_(0), num_weights_(0),
|
||||
forward_win_(NULL), backward_win_(NULL), randomizer_(NULL) {
|
||||
}
|
||||
Network::Network(NetworkType type, const STRING& name, int ni, int no)
|
||||
: type_(type), training_(true), needs_to_backprop_(true),
|
||||
network_flags_(0), ni_(ni), no_(no), num_weights_(0),
|
||||
name_(name), forward_win_(NULL), backward_win_(NULL), randomizer_(NULL) {
|
||||
}
|
||||
|
||||
Network::~Network() {
|
||||
}
|
||||
|
||||
// Ends training by setting the training_ flag to false. Serialize and
|
||||
// DeSerialize will now only operate on the run-time data.
|
||||
void Network::SetEnableTraining(bool state) {
|
||||
training_ = state;
|
||||
}
|
||||
|
||||
// Sets flags that control the action of the network. See NetworkFlags enum
|
||||
// for bit values.
|
||||
void Network::SetNetworkFlags(uinT32 flags) {
|
||||
network_flags_ = flags;
|
||||
}
|
||||
|
||||
// Sets up the network for training. Initializes weights using weights of
|
||||
// scale `range` picked according to the random number generator `randomizer`.
|
||||
int Network::InitWeights(float range, TRand* randomizer) {
|
||||
randomizer_ = randomizer;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Provides a pointer to a TRand for any networks that care to use it.
|
||||
// Note that randomizer is a borrowed pointer that should outlive the network
|
||||
// and should not be deleted by any of the networks.
|
||||
void Network::SetRandomizer(TRand* randomizer) {
|
||||
randomizer_ = randomizer;
|
||||
}
|
||||
|
||||
// Sets needs_to_backprop_ to needs_backprop and returns true if
|
||||
// needs_backprop || any weights in this network so the next layer forward
|
||||
// can be told to produce backprop for this layer if needed.
|
||||
bool Network::SetupNeedsBackprop(bool needs_backprop) {
|
||||
needs_to_backprop_ = needs_backprop;
|
||||
return needs_backprop || num_weights_ > 0;
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool Network::Serialize(TFile* fp) const {
|
||||
inT8 data = NT_NONE;
|
||||
if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
|
||||
STRING type_name = kTypeNames[type_];
|
||||
if (!type_name.Serialize(fp)) return false;
|
||||
data = training_;
|
||||
if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
|
||||
data = needs_to_backprop_;
|
||||
if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
|
||||
if (fp->FWrite(&network_flags_, sizeof(network_flags_), 1) != 1) return false;
|
||||
if (fp->FWrite(&ni_, sizeof(ni_), 1) != 1) return false;
|
||||
if (fp->FWrite(&no_, sizeof(no_), 1) != 1) return false;
|
||||
if (fp->FWrite(&num_weights_, sizeof(num_weights_), 1) != 1) return false;
|
||||
if (!name_.Serialize(fp)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
// Should be overridden by subclasses, but NOT called by their DeSerialize.
|
||||
bool Network::DeSerialize(bool swap, TFile* fp) {
|
||||
inT8 data = 0;
|
||||
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
|
||||
if (data == NT_NONE) {
|
||||
STRING type_name;
|
||||
if (!type_name.DeSerialize(swap, fp)) return false;
|
||||
for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
|
||||
}
|
||||
if (data == NT_COUNT) {
|
||||
tprintf("Invalid network layer type:%s\n", type_name.string());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
type_ = static_cast<NetworkType>(data);
|
||||
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
|
||||
training_ = data != 0;
|
||||
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
|
||||
needs_to_backprop_ = data != 0;
|
||||
if (fp->FRead(&network_flags_, sizeof(network_flags_), 1) != 1) return false;
|
||||
if (fp->FRead(&ni_, sizeof(ni_), 1) != 1) return false;
|
||||
if (fp->FRead(&no_, sizeof(no_), 1) != 1) return false;
|
||||
if (fp->FRead(&num_weights_, sizeof(num_weights_), 1) != 1) return false;
|
||||
if (!name_.DeSerialize(swap, fp)) return false;
|
||||
if (swap) {
|
||||
ReverseN(&network_flags_, sizeof(network_flags_));
|
||||
ReverseN(&ni_, sizeof(ni_));
|
||||
ReverseN(&no_, sizeof(no_));
|
||||
ReverseN(&num_weights_, sizeof(num_weights_));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns NULL in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
// Determines the type of the serialized class and calls its DeSerialize
|
||||
// on a new object of the appropriate type, which is returned.
|
||||
Network* Network::CreateFromFile(bool swap, TFile* fp) {
|
||||
Network stub;
|
||||
if (!stub.DeSerialize(swap, fp)) return NULL;
|
||||
Network* network = NULL;
|
||||
switch (stub.type_) {
|
||||
case NT_CONVOLVE:
|
||||
network = new Convolve(stub.name_, stub.ni_, 0, 0);
|
||||
break;
|
||||
case NT_INPUT:
|
||||
network = new Input(stub.name_, stub.ni_, stub.no_);
|
||||
break;
|
||||
case NT_LSTM:
|
||||
case NT_LSTM_SOFTMAX:
|
||||
case NT_LSTM_SOFTMAX_ENCODED:
|
||||
case NT_LSTM_SUMMARY:
|
||||
network =
|
||||
new LSTM(stub.name_, stub.ni_, stub.no_, stub.no_, false, stub.type_);
|
||||
break;
|
||||
case NT_MAXPOOL:
|
||||
network = new Maxpool(stub.name_, stub.ni_, 0, 0);
|
||||
break;
|
||||
// All variants of Parallel.
|
||||
case NT_PARALLEL:
|
||||
case NT_REPLICATED:
|
||||
case NT_PAR_RL_LSTM:
|
||||
case NT_PAR_UD_LSTM:
|
||||
case NT_PAR_2D_LSTM:
|
||||
network = new Parallel(stub.name_, stub.type_);
|
||||
break;
|
||||
case NT_RECONFIG:
|
||||
network = new Reconfig(stub.name_, stub.ni_, 0, 0);
|
||||
break;
|
||||
// All variants of reversed.
|
||||
case NT_XREVERSED:
|
||||
case NT_YREVERSED:
|
||||
case NT_XYTRANSPOSE:
|
||||
network = new Reversed(stub.name_, stub.type_);
|
||||
break;
|
||||
case NT_SERIES:
|
||||
network = new Series(stub.name_);
|
||||
break;
|
||||
case NT_TENSORFLOW:
|
||||
#ifdef INCLUDE_TENSORFLOW
|
||||
network = new TFNetwork(stub.name_);
|
||||
#else
|
||||
tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
|
||||
return NULL;
|
||||
#endif
|
||||
break;
|
||||
// All variants of FullyConnected.
|
||||
case NT_SOFTMAX:
|
||||
case NT_SOFTMAX_NO_CTC:
|
||||
case NT_RELU:
|
||||
case NT_TANH:
|
||||
case NT_LINEAR:
|
||||
case NT_LOGISTIC:
|
||||
case NT_POSCLIP:
|
||||
case NT_SYMCLIP:
|
||||
network = new FullyConnected(stub.name_, stub.ni_, stub.no_, stub.type_);
|
||||
break;
|
||||
default:
|
||||
return NULL;
|
||||
}
|
||||
network->training_ = stub.training_;
|
||||
network->needs_to_backprop_ = stub.needs_to_backprop_;
|
||||
network->network_flags_ = stub.network_flags_;
|
||||
network->num_weights_ = stub.num_weights_;
|
||||
if (!network->DeSerialize(swap, fp)) {
|
||||
delete network;
|
||||
return NULL;
|
||||
}
|
||||
return network;
|
||||
}
|
||||
|
||||
// Returns a random number in [-range, range].
|
||||
double Network::Random(double range) {
|
||||
ASSERT_HOST(randomizer_ != NULL);
|
||||
return randomizer_->SignedRand(range);
|
||||
}
|
||||
|
||||
#ifndef GRAPHICS_DISABLED
|
||||
// === Debug image display methods. ===
|
||||
// Displays the image of the matrix to the forward window.
|
||||
void Network::DisplayForward(const NetworkIO& matrix) {
|
||||
Pix* image = matrix.ToPix();
|
||||
ClearWindow(false, name_.string(), pixGetWidth(image),
|
||||
pixGetHeight(image), &forward_win_);
|
||||
DisplayImage(image, forward_win_);
|
||||
forward_win_->Update();
|
||||
}
|
||||
|
||||
// Displays the image of the matrix to the backward window.
|
||||
void Network::DisplayBackward(const NetworkIO& matrix) {
|
||||
Pix* image = matrix.ToPix();
|
||||
STRING window_name = name_ + "-back";
|
||||
ClearWindow(false, window_name.string(), pixGetWidth(image),
|
||||
pixGetHeight(image), &backward_win_);
|
||||
DisplayImage(image, backward_win_);
|
||||
backward_win_->Update();
|
||||
}
|
||||
|
||||
// Creates the window if needed, otherwise clears it.
|
||||
void Network::ClearWindow(bool tess_coords, const char* window_name,
|
||||
int width, int height, ScrollView** window) {
|
||||
if (*window == NULL) {
|
||||
int min_size = MIN(width, height);
|
||||
if (min_size < kMinWinSize) {
|
||||
if (min_size < 1) min_size = 1;
|
||||
width = width * kMinWinSize / min_size;
|
||||
height = height * kMinWinSize / min_size;
|
||||
}
|
||||
width += kXWinFrameSize;
|
||||
height += kYWinFrameSize;
|
||||
if (width > kMaxWinSize) width = kMaxWinSize;
|
||||
if (height > kMaxWinSize) height = kMaxWinSize;
|
||||
*window = new ScrollView(window_name, 80, 100, width, height, width, height,
|
||||
tess_coords);
|
||||
tprintf("Created window %s of size %d, %d\n", window_name, width, height);
|
||||
} else {
|
||||
(*window)->Clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Displays the pix in the given window. and returns the height of the pix.
|
||||
// The pix is pixDestroyed.
|
||||
int Network::DisplayImage(Pix* pix, ScrollView* window) {
|
||||
int height = pixGetHeight(pix);
|
||||
window->Image(pix, 0, 0);
|
||||
pixDestroy(&pix);
|
||||
return height;
|
||||
}
|
||||
#endif // GRAPHICS_DISABLED
|
||||
|
||||
} // namespace tesseract.
|
292
lstm/network.h
Normal file
292
lstm/network.h
Normal file
@ -0,0 +1,292 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: network.h
|
||||
// Description: Base class for neural network implementations.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed May 01 16:38:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_NETWORK_H_
|
||||
#define TESSERACT_LSTM_NETWORK_H_
|
||||
|
||||
#include <stdio.h>
|
||||
#include <cmath>
|
||||
|
||||
#include "genericvector.h"
|
||||
#include "helpers.h"
|
||||
#include "matrix.h"
|
||||
#include "networkio.h"
|
||||
#include "serialis.h"
|
||||
#include "static_shape.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
struct Pix;
|
||||
class ScrollView;
|
||||
class TBOX;
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
class ImageData;
|
||||
class NetworkScratch;
|
||||
|
||||
// Enum to store the run-time type of a Network. Keep in sync with kTypeNames.
|
||||
enum NetworkType {
|
||||
NT_NONE, // The naked base class.
|
||||
NT_INPUT, // Inputs from an image.
|
||||
// Plumbing networks combine other networks or rearrange the inputs.
|
||||
NT_CONVOLVE, // Duplicates inputs in a sliding window neighborhood.
|
||||
NT_MAXPOOL, // Chooses the max result from a rectangle.
|
||||
NT_PARALLEL, // Runs networks in parallel.
|
||||
NT_REPLICATED, // Runs identical networks in parallel.
|
||||
NT_PAR_RL_LSTM, // Runs LTR and RTL LSTMs in parallel.
|
||||
NT_PAR_UD_LSTM, // Runs Up and Down LSTMs in parallel.
|
||||
NT_PAR_2D_LSTM, // Runs 4 LSTMs in parallel.
|
||||
NT_SERIES, // Executes a sequence of layers.
|
||||
NT_RECONFIG, // Scales the time/y size but makes the output deeper.
|
||||
NT_XREVERSED, // Reverses the x direction of the inputs/outputs.
|
||||
NT_YREVERSED, // Reverses the y-direction of the inputs/outputs.
|
||||
NT_XYTRANSPOSE, // Transposes x and y (for just a single op).
|
||||
// Functional networks actually calculate stuff.
|
||||
NT_LSTM, // Long-Short-Term-Memory block.
|
||||
NT_LSTM_SUMMARY, // LSTM that only keeps its last output.
|
||||
NT_LOGISTIC, // Fully connected logistic nonlinearity.
|
||||
NT_POSCLIP, // Fully connected rect lin version of logistic.
|
||||
NT_SYMCLIP, // Fully connected rect lin version of tanh.
|
||||
NT_TANH, // Fully connected with tanh nonlinearity.
|
||||
NT_RELU, // Fully connected with rectifier nonlinearity.
|
||||
NT_LINEAR, // Fully connected with no nonlinearity.
|
||||
NT_SOFTMAX, // Softmax uses exponential normalization, with CTC.
|
||||
NT_SOFTMAX_NO_CTC, // Softmax uses exponential normalization, no CTC.
|
||||
// The SOFTMAX LSTMs both have an extra softmax layer on top, but inside, with
|
||||
// the outputs fed back to the input of the LSTM at the next timestep.
|
||||
// The ENCODED version binary encodes the softmax outputs, providing log2 of
|
||||
// the number of outputs as additional inputs, and the other version just
|
||||
// provides all the softmax outputs as additional inputs.
|
||||
NT_LSTM_SOFTMAX, // 1-d LSTM with built-in fully connected softmax.
|
||||
NT_LSTM_SOFTMAX_ENCODED, // 1-d LSTM with built-in binary encoded softmax.
|
||||
// A TensorFlow graph encapsulated as a Tesseract network.
|
||||
NT_TENSORFLOW,
|
||||
|
||||
NT_COUNT // Array size.
|
||||
};
|
||||
|
||||
// Enum of Network behavior flags. Can in theory be set for each individual
|
||||
// network element.
|
||||
enum NetworkFlags {
|
||||
// Network forward/backprop behavior.
|
||||
NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer.
|
||||
NF_ADA_GRAD = 128, // Weight-specific learning rate.
|
||||
};
|
||||
|
||||
// Base class for network types. Not quite an abstract base class, but almost.
|
||||
// Most of the time no isolated Network exists, except prior to
|
||||
// deserialization.
|
||||
class Network {
|
||||
public:
|
||||
Network();
|
||||
Network(NetworkType type, const STRING& name, int ni, int no);
|
||||
virtual ~Network();
|
||||
|
||||
// Accessors.
|
||||
NetworkType type() const {
|
||||
return type_;
|
||||
}
|
||||
bool training() const {
|
||||
return training_;
|
||||
}
|
||||
bool needs_to_backprop() const {
|
||||
return needs_to_backprop_;
|
||||
}
|
||||
int num_weights() const { return num_weights_; }
|
||||
int NumInputs() const {
|
||||
return ni_;
|
||||
}
|
||||
int NumOutputs() const {
|
||||
return no_;
|
||||
}
|
||||
// Returns the required shape input to the network.
|
||||
virtual StaticShape InputShape() const {
|
||||
StaticShape result;
|
||||
return result;
|
||||
}
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
virtual StaticShape OutputShape(const StaticShape& input_shape) const {
|
||||
StaticShape result(input_shape);
|
||||
result.set_depth(no_);
|
||||
return result;
|
||||
}
|
||||
const STRING& name() const {
|
||||
return name_;
|
||||
}
|
||||
virtual STRING spec() const {
|
||||
return "?";
|
||||
}
|
||||
bool TestFlag(NetworkFlags flag) const {
|
||||
return (network_flags_ & flag) != 0;
|
||||
}
|
||||
|
||||
// Initialization and administrative functions that are mostly provided
|
||||
// by Plumbing.
|
||||
// Returns true if the given type is derived from Plumbing, and thus contains
|
||||
// multiple sub-networks that can have their own learning rate.
|
||||
virtual bool IsPlumbingType() const { return false; }
|
||||
|
||||
// Suspends/Enables training by setting the training_ flag. Serialize and
|
||||
// DeSerialize only operate on the run-time data if state is false.
|
||||
virtual void SetEnableTraining(bool state);
|
||||
|
||||
// Sets flags that control the action of the network. See NetworkFlags enum
|
||||
// for bit values.
|
||||
virtual void SetNetworkFlags(uinT32 flags);
|
||||
|
||||
// Sets up the network for training. Initializes weights using weights of
|
||||
// scale `range` picked according to the random number generator `randomizer`.
|
||||
// Note that randomizer is a borrowed pointer that should outlive the network
|
||||
// and should not be deleted by any of the networks.
|
||||
// Returns the number of weights initialized.
|
||||
virtual int InitWeights(float range, TRand* randomizer);
|
||||
|
||||
// Converts a float network to an int network.
|
||||
virtual void ConvertToInt() {}
|
||||
|
||||
// Provides a pointer to a TRand for any networks that care to use it.
|
||||
// Note that randomizer is a borrowed pointer that should outlive the network
|
||||
// and should not be deleted by any of the networks.
|
||||
virtual void SetRandomizer(TRand* randomizer);
|
||||
|
||||
// Sets needs_to_backprop_ to needs_backprop and returns true if
|
||||
// needs_backprop || any weights in this network so the next layer forward
|
||||
// can be told to produce backprop for this layer if needed.
|
||||
virtual bool SetupNeedsBackprop(bool needs_backprop);
|
||||
|
||||
// Returns the most recent reduction factor that the network applied to the
|
||||
// time sequence. Assumes that any 2-d is already eliminated. Used for
|
||||
// scaling bounding boxes of truth data and calculating result bounding boxes.
|
||||
// WARNING: if GlobalMinimax is used to vary the scale, this will return
|
||||
// the last used scale factor. Call it before any forward, and it will return
|
||||
// the minimum scale factor of the paths through the GlobalMinimax.
|
||||
virtual int XScaleFactor() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Provides the (minimum) x scale factor to the network (of interest only to
|
||||
// input units) so they can determine how to scale bounding boxes.
|
||||
virtual void CacheXScaleFactor(int factor) {}
|
||||
|
||||
// Provides debug output on the weights.
|
||||
virtual void DebugWeights() {
|
||||
tprintf("Must override Network::DebugWeights for type %d\n", type_);
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
// Should be overridden by subclasses, but called by their Serialize.
|
||||
virtual bool Serialize(TFile* fp) const;
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
// Should be overridden by subclasses, but NOT called by their DeSerialize.
|
||||
virtual bool DeSerialize(bool swap, TFile* fp);
|
||||
|
||||
// Updates the weights using the given learning rate and momentum.
|
||||
// num_samples is the quotient to be used in the adagrad computation iff
|
||||
// use_ada_grad_ is true.
|
||||
virtual void Update(float learning_rate, float momentum, int num_samples) {}
|
||||
// Sums the products of weight updates in *this and other, splitting into
|
||||
// positive (same direction) in *same and negative (different direction) in
|
||||
// *changed.
|
||||
virtual void CountAlternators(const Network& other, double* same,
|
||||
double* changed) const {}
|
||||
|
||||
// Reads from the given file. Returns NULL in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
// Determines the type of the serialized class and calls its DeSerialize
|
||||
// on a new object of the appropriate type, which is returned.
|
||||
static Network* CreateFromFile(bool swap, TFile* fp);
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// Note that input and output are both 2-d arrays.
|
||||
// The 1st index is the time element. In a 1-d network, it might be the pixel
|
||||
// position on the textline. In a 2-d network, the linearization is defined
|
||||
// by the stride_map. (See networkio.h).
|
||||
// The 2nd index of input is the network inputs/outputs, and the dimension
|
||||
// of the input must match NumInputs() of this network.
|
||||
// The output array will be resized as needed so that its 1st dimension is
|
||||
// always equal to the number of output values, and its second dimension is
|
||||
// always NumOutputs(). Note that all this detail is encapsulated away inside
|
||||
// NetworkIO, as are the internals of the scratch memory space used by the
|
||||
// network. See networkscratch.h for that.
|
||||
// If input_transpose is not NULL, then it contains the transpose of input,
|
||||
// and the caller guarantees that it will still be valid on the next call to
|
||||
// backward. The callee is therefore at liberty to save the pointer and
|
||||
// reference it on a call to backward. This is a bit ugly, but it makes it
|
||||
// possible for a replicating parallel to calculate the input transpose once
|
||||
// instead of all the replicated networks having to do it.
|
||||
virtual void Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output) {
|
||||
tprintf("Must override Network::Forward for type %d\n", type_);
|
||||
}
|
||||
|
||||
// Runs backward propagation of errors on fwdX_deltas.
|
||||
// Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward.
|
||||
// Returns false if back_deltas was not set, due to there being no point in
|
||||
// propagating further backwards. Thus most complete networks will always
|
||||
// return false from Backward!
|
||||
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas) {
|
||||
tprintf("Must override Network::Backward for type %d\n", type_);
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Debug image display methods. ===
|
||||
// Displays the image of the matrix to the forward window.
|
||||
void DisplayForward(const NetworkIO& matrix);
|
||||
// Displays the image of the matrix to the backward window.
|
||||
void DisplayBackward(const NetworkIO& matrix);
|
||||
|
||||
// Creates the window if needed, otherwise clears it.
|
||||
static void ClearWindow(bool tess_coords, const char* window_name,
|
||||
int width, int height, ScrollView** window);
|
||||
|
||||
// Displays the pix in the given window. and returns the height of the pix.
|
||||
// The pix is pixDestroyed.
|
||||
static int DisplayImage(Pix* pix, ScrollView* window);
|
||||
|
||||
protected:
|
||||
// Returns a random number in [-range, range].
|
||||
double Random(double range);
|
||||
|
||||
protected:
|
||||
NetworkType type_; // Type of the derived network class.
|
||||
bool training_; // Are we currently training?
|
||||
bool needs_to_backprop_; // This network needs to output back_deltas.
|
||||
inT32 network_flags_; // Behavior control flags in NetworkFlags.
|
||||
inT32 ni_; // Number of input values.
|
||||
inT32 no_; // Number of output values.
|
||||
inT32 num_weights_; // Number of weights in this and sub-network.
|
||||
STRING name_; // A unique name for this layer.
|
||||
|
||||
// NOT-serialized debug data.
|
||||
ScrollView* forward_win_; // Recognition debug display window.
|
||||
ScrollView* backward_win_; // Training debug display window.
|
||||
TRand* randomizer_; // Random number generator.
|
||||
|
||||
// Static serialized name/type_ mapping. Keep in sync with NetworkType.
|
||||
static char const* const kTypeNames[NT_COUNT];
|
||||
};
|
||||
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_NETWORK_H_
|
488
lstm/networkbuilder.cpp
Normal file
488
lstm/networkbuilder.cpp
Normal file
@ -0,0 +1,488 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: networkbuilder.h
|
||||
// Description: Class to parse the network description language and
|
||||
// build a corresponding network.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Jul 16 18:35:38 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "networkbuilder.h"
|
||||
#include "convolve.h"
|
||||
#include "fullyconnected.h"
|
||||
#include "input.h"
|
||||
#include "lstm.h"
|
||||
#include "maxpool.h"
|
||||
#include "network.h"
|
||||
#include "parallel.h"
|
||||
#include "reconfig.h"
|
||||
#include "reversed.h"
|
||||
#include "series.h"
|
||||
#include "unicharset.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Builds a network with a network_spec in the network description
|
||||
// language, to recognize a character set of num_outputs size.
|
||||
// If append_index is non-negative, then *network must be non-null and the
|
||||
// given network_spec will be appended to *network AFTER append_index, with
|
||||
// the top of the input *network discarded.
|
||||
// Note that network_spec is call by value to allow a non-const char* pointer
|
||||
// into the string for BuildFromString.
|
||||
// net_flags control network behavior according to the NetworkFlags enum.
|
||||
// The resulting network is returned via **network.
|
||||
// Returns false if something failed.
|
||||
bool NetworkBuilder::InitNetwork(int num_outputs, STRING network_spec,
|
||||
int append_index, int net_flags,
|
||||
float weight_range, TRand* randomizer,
|
||||
Network** network) {
|
||||
NetworkBuilder builder(num_outputs);
|
||||
Series* bottom_series = NULL;
|
||||
StaticShape input_shape;
|
||||
if (append_index >= 0) {
|
||||
// Split the current network after the given append_index.
|
||||
ASSERT_HOST(*network != NULL && (*network)->type() == NT_SERIES);
|
||||
Series* series = reinterpret_cast<Series*>(*network);
|
||||
Series* top_series = NULL;
|
||||
series->SplitAt(append_index, &bottom_series, &top_series);
|
||||
if (bottom_series == NULL || top_series == NULL) {
|
||||
tprintf("Yikes! Splitting current network failed!!\n");
|
||||
return false;
|
||||
}
|
||||
input_shape = bottom_series->OutputShape(input_shape);
|
||||
delete top_series;
|
||||
}
|
||||
char* str_ptr = &network_spec[0];
|
||||
*network = builder.BuildFromString(input_shape, &str_ptr);
|
||||
if (*network == NULL) return false;
|
||||
(*network)->SetNetworkFlags(net_flags);
|
||||
(*network)->InitWeights(weight_range, randomizer);
|
||||
(*network)->SetupNeedsBackprop(false);
|
||||
if (bottom_series != NULL) {
|
||||
bottom_series->AppendSeries(*network);
|
||||
*network = bottom_series;
|
||||
}
|
||||
(*network)->CacheXScaleFactor((*network)->XScaleFactor());
|
||||
return true;
|
||||
}
|
||||
|
||||
// Helper skips whitespace.
|
||||
static void SkipWhitespace(char** str) {
|
||||
while (**str == ' ' || **str == '\t' || **str == '\n') ++*str;
|
||||
}
|
||||
|
||||
// Parses the given string and returns a network according to the network
|
||||
// description language in networkbuilder.h
|
||||
Network* NetworkBuilder::BuildFromString(const StaticShape& input_shape,
|
||||
char** str) {
|
||||
SkipWhitespace(str);
|
||||
char code_ch = **str;
|
||||
if (code_ch == '[') {
|
||||
return ParseSeries(input_shape, nullptr, str);
|
||||
}
|
||||
if (input_shape.depth() == 0) {
|
||||
// There must be an input at this point.
|
||||
return ParseInput(str);
|
||||
}
|
||||
switch (code_ch) {
|
||||
case '(':
|
||||
return ParseParallel(input_shape, str);
|
||||
case 'R':
|
||||
return ParseR(input_shape, str);
|
||||
case 'S':
|
||||
return ParseS(input_shape, str);
|
||||
case 'C':
|
||||
return ParseC(input_shape, str);
|
||||
case 'M':
|
||||
return ParseM(input_shape, str);
|
||||
case 'L':
|
||||
return ParseLSTM(input_shape, str);
|
||||
case 'F':
|
||||
return ParseFullyConnected(input_shape, str);
|
||||
case 'O':
|
||||
return ParseOutput(input_shape, str);
|
||||
default:
|
||||
tprintf("Invalid network spec:%s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Parses an input specification and returns the result, which may include a
|
||||
// series.
|
||||
Network* NetworkBuilder::ParseInput(char** str) {
|
||||
// There must be an input at this point.
|
||||
int length = 0;
|
||||
int batch, height, width, depth;
|
||||
int num_converted =
|
||||
sscanf(*str, "%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length);
|
||||
StaticShape shape;
|
||||
shape.SetShape(batch, height, width, depth);
|
||||
// num_converted may or may not include the length.
|
||||
if (num_converted != 4 && num_converted != 5) {
|
||||
tprintf("Must specify an input layer as the first layer, not %s!!\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
*str += length;
|
||||
Input* input = new Input("Input", shape);
|
||||
// We want to allow [<input>rest of net... or <input>[rest of net... so we
|
||||
// have to check explicitly for '[' here.
|
||||
SkipWhitespace(str);
|
||||
if (**str == '[') return ParseSeries(shape, input, str);
|
||||
return input;
|
||||
}
|
||||
|
||||
// Parses a sequential series of networks, defined by [<net><net>...].
|
||||
Network* NetworkBuilder::ParseSeries(const StaticShape& input_shape,
|
||||
Input* input_layer, char** str) {
|
||||
StaticShape shape = input_shape;
|
||||
Series* series = new Series("Series");
|
||||
++*str;
|
||||
if (input_layer != nullptr) {
|
||||
series->AddToStack(input_layer);
|
||||
shape = input_layer->OutputShape(shape);
|
||||
}
|
||||
Network* network = NULL;
|
||||
while (**str != '\0' && **str != ']' &&
|
||||
(network = BuildFromString(shape, str)) != NULL) {
|
||||
shape = network->OutputShape(shape);
|
||||
series->AddToStack(network);
|
||||
}
|
||||
if (**str != ']') {
|
||||
tprintf("Missing ] at end of [Series]!\n");
|
||||
delete series;
|
||||
return NULL;
|
||||
}
|
||||
++*str;
|
||||
return series;
|
||||
}
|
||||
|
||||
// Parses a parallel set of networks, defined by (<net><net>...).
|
||||
Network* NetworkBuilder::ParseParallel(const StaticShape& input_shape,
|
||||
char** str) {
|
||||
Parallel* parallel = new Parallel("Parallel", NT_PARALLEL);
|
||||
++*str;
|
||||
Network* network = NULL;
|
||||
while (**str != '\0' && **str != ')' &&
|
||||
(network = BuildFromString(input_shape, str)) != NULL) {
|
||||
parallel->AddToStack(network);
|
||||
}
|
||||
if (**str != ')') {
|
||||
tprintf("Missing ) at end of (Parallel)!\n");
|
||||
delete parallel;
|
||||
return nullptr;
|
||||
}
|
||||
++*str;
|
||||
return parallel;
|
||||
}
|
||||
|
||||
// Parses a network that begins with 'R'.
|
||||
Network* NetworkBuilder::ParseR(const StaticShape& input_shape, char** str) {
|
||||
char dir = (*str)[1];
|
||||
if (dir == 'x' || dir == 'y') {
|
||||
STRING name = "Reverse";
|
||||
name += dir;
|
||||
*str += 2;
|
||||
Network* network = BuildFromString(input_shape, str);
|
||||
if (network == nullptr) return nullptr;
|
||||
Reversed* rev =
|
||||
new Reversed(name, dir == 'y' ? NT_YREVERSED : NT_XREVERSED);
|
||||
rev->SetNetwork(network);
|
||||
return rev;
|
||||
}
|
||||
int replicas = strtol(*str + 1, str, 10);
|
||||
if (replicas <= 0) {
|
||||
tprintf("Invalid R spec!:%s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
Parallel* parallel = new Parallel("Replicated", NT_REPLICATED);
|
||||
char* str_copy = *str;
|
||||
for (int i = 0; i < replicas; ++i) {
|
||||
str_copy = *str;
|
||||
Network* network = BuildFromString(input_shape, &str_copy);
|
||||
if (network == NULL) {
|
||||
tprintf("Invalid replicated network!\n");
|
||||
delete parallel;
|
||||
return nullptr;
|
||||
}
|
||||
parallel->AddToStack(network);
|
||||
}
|
||||
*str = str_copy;
|
||||
return parallel;
|
||||
}
|
||||
|
||||
// Parses a network that begins with 'S'.
|
||||
Network* NetworkBuilder::ParseS(const StaticShape& input_shape, char** str) {
|
||||
int y = strtol(*str + 1, str, 10);
|
||||
if (**str == ',') {
|
||||
int x = strtol(*str + 1, str, 10);
|
||||
if (y <= 0 || x <= 0) {
|
||||
tprintf("Invalid S spec!:%s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
return new Reconfig("Reconfig", input_shape.depth(), x, y);
|
||||
} else if (**str == '(') {
|
||||
// TODO(rays) Add Generic reshape.
|
||||
tprintf("Generic reshape not yet implemented!!\n");
|
||||
return nullptr;
|
||||
}
|
||||
tprintf("Invalid S spec!:%s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Helper returns the fully-connected type for the character code.
|
||||
static NetworkType NonLinearity(char func) {
|
||||
switch (func) {
|
||||
case 's':
|
||||
return NT_LOGISTIC;
|
||||
case 't':
|
||||
return NT_TANH;
|
||||
case 'r':
|
||||
return NT_RELU;
|
||||
case 'l':
|
||||
return NT_LINEAR;
|
||||
case 'm':
|
||||
return NT_SOFTMAX;
|
||||
case 'p':
|
||||
return NT_POSCLIP;
|
||||
case 'n':
|
||||
return NT_SYMCLIP;
|
||||
default:
|
||||
return NT_NONE;
|
||||
}
|
||||
}
|
||||
|
||||
// Parses a network that begins with 'C'.
|
||||
Network* NetworkBuilder::ParseC(const StaticShape& input_shape, char** str) {
|
||||
NetworkType type = NonLinearity((*str)[1]);
|
||||
if (type == NT_NONE) {
|
||||
tprintf("Invalid nonlinearity on C-spec!: %s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
int y = 0, x = 0, d = 0;
|
||||
if ((y = strtol(*str + 2, str, 10)) <= 0 || **str != ',' ||
|
||||
(x = strtol(*str + 1, str, 10)) <= 0 || **str != ',' ||
|
||||
(d = strtol(*str + 1, str, 10)) <= 0) {
|
||||
tprintf("Invalid C spec!:%s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
if (x == 1 && y == 1) {
|
||||
// No actual convolution. Just a FullyConnected on the current depth, to
|
||||
// be slid over all batch,y,x.
|
||||
return new FullyConnected("Conv1x1", input_shape.depth(), d, type);
|
||||
}
|
||||
Series* series = new Series("ConvSeries");
|
||||
Convolve* convolve =
|
||||
new Convolve("Convolve", input_shape.depth(), x / 2, y / 2);
|
||||
series->AddToStack(convolve);
|
||||
StaticShape fc_input = convolve->OutputShape(input_shape);
|
||||
series->AddToStack(new FullyConnected("ConvNL", fc_input.depth(), d, type));
|
||||
return series;
|
||||
}
|
||||
|
||||
// Parses a network that begins with 'M'.
|
||||
Network* NetworkBuilder::ParseM(const StaticShape& input_shape, char** str) {
|
||||
int y = 0, x = 0;
|
||||
if ((*str)[1] != 'p' || (y = strtol(*str + 2, str, 10)) <= 0 ||
|
||||
**str != ',' || (x = strtol(*str + 1, str, 10)) <= 0) {
|
||||
tprintf("Invalid Mp spec!:%s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
return new Maxpool("Maxpool", input_shape.depth(), x, y);
|
||||
}
|
||||
|
||||
// Parses an LSTM network, either individual, bi- or quad-directional.
|
||||
Network* NetworkBuilder::ParseLSTM(const StaticShape& input_shape, char** str) {
|
||||
bool two_d = false;
|
||||
NetworkType type = NT_LSTM;
|
||||
char* spec_start = *str;
|
||||
int chars_consumed = 1;
|
||||
int num_outputs = 0;
|
||||
char key = (*str)[chars_consumed], dir = 'f', dim = 'x';
|
||||
if (key == 'S') {
|
||||
type = NT_LSTM_SOFTMAX;
|
||||
num_outputs = num_softmax_outputs_;
|
||||
++chars_consumed;
|
||||
} else if (key == 'E') {
|
||||
type = NT_LSTM_SOFTMAX_ENCODED;
|
||||
num_outputs = num_softmax_outputs_;
|
||||
++chars_consumed;
|
||||
} else if (key == '2' && (((*str)[2] == 'x' && (*str)[3] == 'y') ||
|
||||
((*str)[2] == 'y' && (*str)[3] == 'x'))) {
|
||||
chars_consumed = 4;
|
||||
dim = (*str)[3];
|
||||
two_d = true;
|
||||
} else if (key == 'f' || key == 'r' || key == 'b') {
|
||||
dir = key;
|
||||
dim = (*str)[2];
|
||||
if (dim != 'x' && dim != 'y') {
|
||||
tprintf("Invalid dimension (x|y) in L Spec!:%s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
chars_consumed = 3;
|
||||
if ((*str)[chars_consumed] == 's') {
|
||||
++chars_consumed;
|
||||
type = NT_LSTM_SUMMARY;
|
||||
}
|
||||
} else {
|
||||
tprintf("Invalid direction (f|r|b) in L Spec!:%s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
int num_states = strtol(*str + chars_consumed, str, 10);
|
||||
if (num_states <= 0) {
|
||||
tprintf("Invalid number of states in L Spec!:%s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
Network* lstm = nullptr;
|
||||
if (two_d) {
|
||||
lstm = BuildLSTMXYQuad(input_shape.depth(), num_states);
|
||||
} else {
|
||||
if (num_outputs == 0) num_outputs = num_states;
|
||||
STRING name(spec_start, *str - spec_start);
|
||||
lstm = new LSTM(name, input_shape.depth(), num_states, num_outputs, false,
|
||||
type);
|
||||
if (dir != 'f') {
|
||||
Reversed* rev = new Reversed("RevLSTM", NT_XREVERSED);
|
||||
rev->SetNetwork(lstm);
|
||||
lstm = rev;
|
||||
}
|
||||
if (dir == 'b') {
|
||||
name += "LTR";
|
||||
Parallel* parallel = new Parallel("BidiLSTM", NT_PAR_RL_LSTM);
|
||||
parallel->AddToStack(new LSTM(name, input_shape.depth(), num_states,
|
||||
num_outputs, false, type));
|
||||
parallel->AddToStack(lstm);
|
||||
lstm = parallel;
|
||||
}
|
||||
}
|
||||
if (dim == 'y') {
|
||||
Reversed* rev = new Reversed("XYTransLSTM", NT_XYTRANSPOSE);
|
||||
rev->SetNetwork(lstm);
|
||||
lstm = rev;
|
||||
}
|
||||
return lstm;
|
||||
}
|
||||
|
||||
// Builds a set of 4 lstms with x and y reversal, running in true parallel.
|
||||
Network* NetworkBuilder::BuildLSTMXYQuad(int num_inputs, int num_states) {
|
||||
Parallel* parallel = new Parallel("2DLSTMQuad", NT_PAR_2D_LSTM);
|
||||
parallel->AddToStack(new LSTM("L2DLTRDown", num_inputs, num_states,
|
||||
num_states, true, NT_LSTM));
|
||||
Reversed* rev = new Reversed("L2DLTRXRev", NT_XREVERSED);
|
||||
rev->SetNetwork(new LSTM("L2DRTLDown", num_inputs, num_states, num_states,
|
||||
true, NT_LSTM));
|
||||
parallel->AddToStack(rev);
|
||||
rev = new Reversed("L2DRTLYRev", NT_YREVERSED);
|
||||
rev->SetNetwork(
|
||||
new LSTM("L2DRTLUp", num_inputs, num_states, num_states, true, NT_LSTM));
|
||||
Reversed* rev2 = new Reversed("L2DXRevU", NT_XREVERSED);
|
||||
rev2->SetNetwork(rev);
|
||||
parallel->AddToStack(rev2);
|
||||
rev = new Reversed("L2DXRevY", NT_YREVERSED);
|
||||
rev->SetNetwork(new LSTM("L2DLTRDown", num_inputs, num_states, num_states,
|
||||
true, NT_LSTM));
|
||||
parallel->AddToStack(rev);
|
||||
return parallel;
|
||||
}
|
||||
|
||||
// Helper builds a truly (0-d) fully connected layer of the given type.
|
||||
static Network* BuildFullyConnected(const StaticShape& input_shape,
|
||||
NetworkType type, const STRING& name,
|
||||
int depth) {
|
||||
if (input_shape.height() == 0 || input_shape.width() == 0) {
|
||||
tprintf("Fully connected requires positive height and width, had %d,%d\n",
|
||||
input_shape.height(), input_shape.width());
|
||||
return nullptr;
|
||||
}
|
||||
int input_size = input_shape.height() * input_shape.width();
|
||||
int input_depth = input_size * input_shape.depth();
|
||||
Network* fc = new FullyConnected(name, input_depth, depth, type);
|
||||
if (input_size > 1) {
|
||||
Series* series = new Series("FCSeries");
|
||||
series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(),
|
||||
input_shape.width(), input_shape.height()));
|
||||
series->AddToStack(fc);
|
||||
fc = series;
|
||||
}
|
||||
return fc;
|
||||
}
|
||||
|
||||
// Parses a Fully connected network.
|
||||
Network* NetworkBuilder::ParseFullyConnected(const StaticShape& input_shape,
|
||||
char** str) {
|
||||
char* spec_start = *str;
|
||||
NetworkType type = NonLinearity((*str)[1]);
|
||||
if (type == NT_NONE) {
|
||||
tprintf("Invalid nonlinearity on F-spec!: %s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
int depth = strtol(*str + 1, str, 10);
|
||||
if (depth <= 0) {
|
||||
tprintf("Invalid F spec!:%s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
STRING name(spec_start, *str - spec_start);
|
||||
return BuildFullyConnected(input_shape, type, name, depth);
|
||||
}
|
||||
|
||||
// Parses an Output spec.
|
||||
Network* NetworkBuilder::ParseOutput(const StaticShape& input_shape,
|
||||
char** str) {
|
||||
char dims_ch = (*str)[1];
|
||||
if (dims_ch != '0' && dims_ch != '1' && dims_ch != '2') {
|
||||
tprintf("Invalid dims (2|1|0) in output spec!:%s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
char type_ch = (*str)[2];
|
||||
if (type_ch != 'l' && type_ch != 's' && type_ch != 'c') {
|
||||
tprintf("Invalid output type (l|s|c) in output spec!:%s\n", *str);
|
||||
return nullptr;
|
||||
}
|
||||
int depth = strtol(*str + 3, str, 10);
|
||||
if (depth != num_softmax_outputs_) {
|
||||
tprintf("Warning: given outputs %d not equal to unicharset of %d.\n", depth,
|
||||
num_softmax_outputs_);
|
||||
depth = num_softmax_outputs_;
|
||||
}
|
||||
NetworkType type = NT_SOFTMAX;
|
||||
if (type_ch == 'l')
|
||||
type = NT_LOGISTIC;
|
||||
else if (type_ch == 's')
|
||||
type = NT_SOFTMAX_NO_CTC;
|
||||
if (dims_ch == '0') {
|
||||
// Same as standard fully connected.
|
||||
return BuildFullyConnected(input_shape, type, "Output", depth);
|
||||
} else if (dims_ch == '2') {
|
||||
// We don't care if x and/or y are variable.
|
||||
return new FullyConnected("Output2d", input_shape.depth(), depth, type);
|
||||
}
|
||||
// For 1-d y has to be fixed, and if not 1, moved to depth.
|
||||
if (input_shape.height() == 0) {
|
||||
tprintf("Fully connected requires fixed height!\n");
|
||||
return nullptr;
|
||||
}
|
||||
int input_size = input_shape.height();
|
||||
int input_depth = input_size * input_shape.depth();
|
||||
Network* fc = new FullyConnected("Output", input_depth, depth, type);
|
||||
if (input_size > 1) {
|
||||
Series* series = new Series("FCSeries");
|
||||
series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(), 1,
|
||||
input_shape.height()));
|
||||
series->AddToStack(fc);
|
||||
fc = series;
|
||||
}
|
||||
return fc;
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
||||
|
160
lstm/networkbuilder.h
Normal file
160
lstm/networkbuilder.h
Normal file
@ -0,0 +1,160 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: networkbuilder.h
|
||||
// Description: Class to parse the network description language and
|
||||
// build a corresponding network.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Jul 16 18:35:38 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_NETWORKBUILDER_H_
|
||||
#define TESSERACT_LSTM_NETWORKBUILDER_H_
|
||||
|
||||
#include "static_shape.h"
|
||||
#include "stridemap.h"
|
||||
|
||||
class STRING;
|
||||
class UNICHARSET;
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
class Input;
|
||||
class Network;
|
||||
class Parallel;
|
||||
class TRand;
|
||||
|
||||
class NetworkBuilder {
|
||||
public:
|
||||
explicit NetworkBuilder(int num_softmax_outputs)
|
||||
: num_softmax_outputs_(num_softmax_outputs) {}
|
||||
|
||||
// Builds a network with a network_spec in the network description
|
||||
// language, to recognize a character set of num_outputs size.
|
||||
// If append_index is non-negative, then *network must be non-null and the
|
||||
// given network_spec will be appended to *network AFTER append_index, with
|
||||
// the top of the input *network discarded.
|
||||
// Note that network_spec is call by value to allow a non-const char* pointer
|
||||
// into the string for BuildFromString.
|
||||
// net_flags control network behavior according to the NetworkFlags enum.
|
||||
// The resulting network is returned via **network.
|
||||
// Returns false if something failed.
|
||||
static bool InitNetwork(int num_outputs, STRING network_spec,
|
||||
int append_index, int net_flags, float weight_range,
|
||||
TRand* randomizer, Network** network);
|
||||
|
||||
// Parses the given string and returns a network according to the following
|
||||
// language:
|
||||
// ============ Syntax of description below: ============
|
||||
// <d> represents a number.
|
||||
// <net> represents any single network element, including (recursively) a
|
||||
// [...] series or (...) parallel construct.
|
||||
// (s|t|r|l|m) (regex notation) represents a single required letter.
|
||||
// NOTE THAT THROUGHOUT, x and y are REVERSED from conventional mathematics,
|
||||
// to use the same convention as Tensor Flow. The reason TF adopts this
|
||||
// convention is to eliminate the need to transpose images on input, since
|
||||
// adjacent memory locations in images increase x and then y, while adjacent
|
||||
// memory locations in tensors in TF, and NetworkIO in tesseract increase the
|
||||
// rightmost index first, then the next-left and so-on, like C arrays.
|
||||
// ============ INPUTS ============
|
||||
// <b>,<h>,<w>,<d> A batch of b images with height h, width w, and depth d.
|
||||
// b, h and/or w may be zero, to indicate variable size. Some network layer
|
||||
// (summarizing LSTM) must be used to make a variable h known.
|
||||
// d may be 1 for greyscale, 3 for color.
|
||||
// NOTE that throughout the constructed network, the inputs/outputs are all of
|
||||
// the same [batch,height,width,depth] dimensions, even if a different size.
|
||||
// ============ PLUMBING ============
|
||||
// [...] Execute ... networks in series (layers).
|
||||
// (...) Execute ... networks in parallel, with their output depths added.
|
||||
// R<d><net> Execute d replicas of net in parallel, with their output depths
|
||||
// added.
|
||||
// Rx<net> Execute <net> with x-dimension reversal.
|
||||
// Ry<net> Execute <net> with y-dimension reversal.
|
||||
// S<y>,<x> Rescale 2-D input by shrink factor x,y, rearranging the data by
|
||||
// increasing the depth of the input by factor xy.
|
||||
// Mp<y>,<x> Maxpool the input, reducing the size by an (x,y) rectangle.
|
||||
// ============ FUNCTIONAL UNITS ============
|
||||
// C(s|t|r|l|m)<y>,<x>,<d> Convolves using a (x,y) window, with no shrinkage,
|
||||
// random infill, producing d outputs, then applies a non-linearity:
|
||||
// s: Sigmoid, t: Tanh, r: Relu, l: Linear, m: Softmax.
|
||||
// F(s|t|r|l|m)<d> Truly fully-connected with s|t|r|l|m non-linearity and d
|
||||
// outputs. Connects to every x,y,depth position of the input, reducing
|
||||
// height, width to 1, producing a single <d> vector as the output.
|
||||
// Input height and width must be constant.
|
||||
// For a sliding-window linear or non-linear map that connects just to the
|
||||
// input depth, and leaves the input image size as-is, use a 1x1 convolution
|
||||
// eg. Cr1,1,64 instead of Fr64.
|
||||
// L(f|r|b)(x|y)[s]<n> LSTM cell with n states/outputs.
|
||||
// The LSTM must have one of:
|
||||
// f runs the LSTM forward only.
|
||||
// r runs the LSTM reversed only.
|
||||
// b runs the LSTM bidirectionally.
|
||||
// It will operate on either the x- or y-dimension, treating the other
|
||||
// dimension independently (as if part of the batch).
|
||||
// s (optional) summarizes the output in the requested dimension,
|
||||
// outputting only the final step, collapsing the dimension to a
|
||||
// single element.
|
||||
// LS<n> Forward-only LSTM cell in the x-direction, with built-in Softmax.
|
||||
// LE<n> Forward-only LSTM cell in the x-direction, with built-in softmax,
|
||||
// with binary Encoding.
|
||||
// L2xy<n> Full 2-d LSTM operating in quad-directions (bidi in x and y) and
|
||||
// all the output depths added.
|
||||
// ============ OUTPUTS ============
|
||||
// The network description must finish with an output specification:
|
||||
// O(2|1|0)(l|s|c)<n> output layer with n classes
|
||||
// 2 (heatmap) Output is a 2-d vector map of the input (possibly at
|
||||
// different scale).
|
||||
// 1 (sequence) Output is a 1-d sequence of vector values.
|
||||
// 0 (category) Output is a 0-d single vector value.
|
||||
// l uses a logistic non-linearity on the output, allowing multiple
|
||||
// hot elements in any output vector value.
|
||||
// s uses a softmax non-linearity, with one-hot output in each value.
|
||||
// c uses a softmax with CTC. Can only be used with s (sequence).
|
||||
// NOTE1: Only O1s and O1c are currently supported.
|
||||
// NOTE2: n is totally ignored, and for compatibility purposes only. The
|
||||
// output number of classes is obtained automatically from the
|
||||
// unicharset.
|
||||
Network* BuildFromString(const StaticShape& input_shape, char** str);
|
||||
|
||||
private:
|
||||
// Parses an input specification and returns the result, which may include a
|
||||
// series.
|
||||
Network* ParseInput(char** str);
|
||||
// Parses a sequential series of networks, defined by [<net><net>...].
|
||||
Network* ParseSeries(const StaticShape& input_shape, Input* input_layer,
|
||||
char** str);
|
||||
// Parses a parallel set of networks, defined by (<net><net>...).
|
||||
Network* ParseParallel(const StaticShape& input_shape, char** str);
|
||||
// Parses a network that begins with 'R'.
|
||||
Network* ParseR(const StaticShape& input_shape, char** str);
|
||||
// Parses a network that begins with 'S'.
|
||||
Network* ParseS(const StaticShape& input_shape, char** str);
|
||||
// Parses a network that begins with 'C'.
|
||||
Network* ParseC(const StaticShape& input_shape, char** str);
|
||||
// Parses a network that begins with 'M'.
|
||||
Network* ParseM(const StaticShape& input_shape, char** str);
|
||||
// Parses an LSTM network, either individual, bi- or quad-directional.
|
||||
Network* ParseLSTM(const StaticShape& input_shape, char** str);
|
||||
// Builds a set of 4 lstms with t and y reversal, running in true parallel.
|
||||
static Network* BuildLSTMXYQuad(int num_inputs, int num_states);
|
||||
// Parses a Fully connected network.
|
||||
Network* ParseFullyConnected(const StaticShape& input_shape, char** str);
|
||||
// Parses an Output spec.
|
||||
Network* ParseOutput(const StaticShape& input_shape, char** str);
|
||||
|
||||
private:
|
||||
int num_softmax_outputs_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_NETWORKBUILDER_H_
|
981
lstm/networkio.cpp
Normal file
981
lstm/networkio.cpp
Normal file
@ -0,0 +1,981 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: networkio.cpp
|
||||
// Description: Network input/output data, allowing float/int implementations.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu Jun 19 13:01:31 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "networkio.h"
|
||||
|
||||
#include "allheaders.h"
|
||||
#include "functions.h"
|
||||
#include "statistc.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Minimum value to output for certainty.
|
||||
const float kMinCertainty = -20.0f;
|
||||
// Probability corresponding to kMinCertainty.
|
||||
const float kMinProb = exp(kMinCertainty);
|
||||
|
||||
// Resizes to a specific size as a 2-d temp buffer. No batches, no y-dim.
|
||||
void NetworkIO::Resize2d(bool int_mode, int width, int num_features) {
|
||||
stride_map_ = StrideMap();
|
||||
int_mode_ = int_mode;
|
||||
if (int_mode_) {
|
||||
i_.ResizeNoInit(width, num_features);
|
||||
} else {
|
||||
f_.ResizeNoInit(width, num_features);
|
||||
}
|
||||
}
|
||||
|
||||
// Resizes to a specific stride_map.
|
||||
void NetworkIO::ResizeToMap(bool int_mode, const StrideMap& stride_map,
|
||||
int num_features) {
|
||||
// If this assert fails, it most likely got here through an uninitialized
|
||||
// scratch element, ie call NetworkScratch::IO::Resizexxx() not
|
||||
// NetworkIO::Resizexxx()!!
|
||||
ASSERT_HOST(this != NULL);
|
||||
stride_map_ = stride_map;
|
||||
int_mode_ = int_mode;
|
||||
if (int_mode_) {
|
||||
i_.ResizeNoInit(stride_map.Width(), num_features);
|
||||
} else {
|
||||
f_.ResizeNoInit(stride_map.Width(), num_features);
|
||||
}
|
||||
ZeroInvalidElements();
|
||||
}
|
||||
|
||||
// Shrinks image size by x_scale,y_scale, and use given number of features.
|
||||
void NetworkIO::ResizeScaled(const NetworkIO& src,
|
||||
int x_scale, int y_scale, int num_features) {
|
||||
StrideMap stride_map = src.stride_map_;
|
||||
stride_map.ScaleXY(x_scale, y_scale);
|
||||
ResizeToMap(src.int_mode_, stride_map, num_features);
|
||||
}
|
||||
|
||||
// Resizes to just 1 x-coord, whatever the input.
|
||||
void NetworkIO::ResizeXTo1(const NetworkIO& src, int num_features) {
|
||||
StrideMap stride_map = src.stride_map_;
|
||||
stride_map.ReduceWidthTo1();
|
||||
ResizeToMap(src.int_mode_, stride_map, num_features);
|
||||
}
|
||||
|
||||
// Initialize all the array to zero.
|
||||
void NetworkIO::Zero() {
|
||||
int width = Width();
|
||||
// Zero out the everything. Column-by-column in case it is aligned.
|
||||
for (int t = 0; t < width; ++t) {
|
||||
ZeroTimeStep(t);
|
||||
}
|
||||
}
|
||||
|
||||
// Initializes to zero all elements of the array that do not correspond to
|
||||
// valid image positions. (If a batch of different-sized images are packed
|
||||
// together, then there will be padding pixels.)
|
||||
void NetworkIO::ZeroInvalidElements() {
|
||||
int num_features = NumFeatures();
|
||||
int full_width = stride_map_.Size(FD_WIDTH);
|
||||
int full_height = stride_map_.Size(FD_HEIGHT);
|
||||
StrideMap::Index b_index(stride_map_);
|
||||
do {
|
||||
int end_x = b_index.MaxIndexOfDim(FD_WIDTH) + 1;
|
||||
if (end_x < full_width) {
|
||||
// The width is small, so fill for every valid y.
|
||||
StrideMap::Index y_index(b_index);
|
||||
int fill_size = num_features * (full_width - end_x);
|
||||
do {
|
||||
StrideMap::Index z_index(y_index);
|
||||
z_index.AddOffset(end_x, FD_WIDTH);
|
||||
if (int_mode_) {
|
||||
ZeroVector(fill_size, i_[z_index.t()]);
|
||||
} else {
|
||||
ZeroVector(fill_size, f_[z_index.t()]);
|
||||
}
|
||||
} while (y_index.AddOffset(1, FD_HEIGHT));
|
||||
}
|
||||
int end_y = b_index.MaxIndexOfDim(FD_HEIGHT) + 1;
|
||||
if (end_y < full_height) {
|
||||
// The height is small, so fill in the space in one go.
|
||||
StrideMap::Index y_index(b_index);
|
||||
y_index.AddOffset(end_y, FD_HEIGHT);
|
||||
int fill_size = num_features * full_width * (full_height - end_y);
|
||||
if (int_mode_) {
|
||||
ZeroVector(fill_size, i_[y_index.t()]);
|
||||
} else {
|
||||
ZeroVector(fill_size, f_[y_index.t()]);
|
||||
}
|
||||
}
|
||||
} while (b_index.AddOffset(1, FD_BATCH));
|
||||
}
|
||||
|
||||
// Helper computes a black point and white point to contrast-enhance an image.
|
||||
// The computation is based on the assumption that the image is of a single line
|
||||
// of text, so a horizontal line through the middle of the image passes through
|
||||
// at least some of it, so local minima and maxima are a good proxy for black
|
||||
// and white pixel samples.
|
||||
static void ComputeBlackWhite(Pix* pix, float* black, float* white) {
|
||||
int width = pixGetWidth(pix);
|
||||
int height = pixGetHeight(pix);
|
||||
STATS mins(0, 256), maxes(0, 256);
|
||||
if (width >= 3) {
|
||||
int y = height / 2;
|
||||
const l_uint32* line = pixGetData(pix) + pixGetWpl(pix) * y;
|
||||
int prev = GET_DATA_BYTE(line, 0);
|
||||
int curr = GET_DATA_BYTE(line, 1);
|
||||
for (int x = 1; x + 1 < width; ++x) {
|
||||
int next = GET_DATA_BYTE(line, x + 1);
|
||||
if ((curr < prev && curr <= next) || (curr <= prev && curr < next)) {
|
||||
// Local minimum.
|
||||
mins.add(curr, 1);
|
||||
}
|
||||
if ((curr > prev && curr >= next) || (curr >= prev && curr > next)) {
|
||||
// Local maximum.
|
||||
maxes.add(curr, 1);
|
||||
}
|
||||
prev = curr;
|
||||
curr = next;
|
||||
}
|
||||
}
|
||||
if (mins.get_total() == 0) mins.add(0, 1);
|
||||
if (maxes.get_total() == 0) maxes.add(255, 1);
|
||||
*black = mins.ile(0.25);
|
||||
*white = maxes.ile(0.75);
|
||||
}
|
||||
|
||||
// Sets up the array from the given image, using the currently set int_mode_.
|
||||
// If the image width doesn't match the shape, the image is truncated or padded
|
||||
// with noise to match.
|
||||
void NetworkIO::FromPix(const StaticShape& shape, const Pix* pix,
|
||||
TRand* randomizer) {
|
||||
std::vector<const Pix*> pixes(1, pix);
|
||||
FromPixes(shape, pixes, randomizer);
|
||||
}
|
||||
|
||||
// Sets up the array from the given set of images, using the currently set
|
||||
// int_mode_. If the image width doesn't match the shape, the images are
|
||||
// truncated or padded with noise to match.
|
||||
void NetworkIO::FromPixes(const StaticShape& shape,
|
||||
const std::vector<const Pix*>& pixes,
|
||||
TRand* randomizer) {
|
||||
int target_height = shape.height();
|
||||
int target_width = shape.width();
|
||||
std::vector<std::pair<int, int>> h_w_pairs;
|
||||
for (auto pix : pixes) {
|
||||
Pix* var_pix = const_cast<Pix*>(pix);
|
||||
int width = pixGetWidth(var_pix);
|
||||
if (target_width != 0) width = target_width;
|
||||
int height = pixGetHeight(var_pix);
|
||||
if (target_height != 0) height = target_height;
|
||||
h_w_pairs.emplace_back(height, width);
|
||||
}
|
||||
stride_map_.SetStride(h_w_pairs);
|
||||
ResizeToMap(int_mode(), stride_map_, shape.depth());
|
||||
// Iterate over the images again to copy the data.
|
||||
for (int b = 0; b < pixes.size(); ++b) {
|
||||
Pix* pix = const_cast<Pix*>(pixes[b]);
|
||||
float black = 0.0f, white = 255.0f;
|
||||
if (shape.depth() != 3) ComputeBlackWhite(pix, &black, &white);
|
||||
float contrast = (white - black) / 2.0f;
|
||||
if (contrast <= 0.0f) contrast = 1.0f;
|
||||
if (shape.height() == 1) {
|
||||
Copy1DGreyImage(b, pix, black, contrast, randomizer);
|
||||
} else {
|
||||
Copy2DImage(b, pix, black, contrast, randomizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copies the given pix to *this at the given batch index, stretching and
|
||||
// clipping the pixel values so that [black, black + 2*contrast] maps to the
|
||||
// dynamic range of *this, ie [-1,1] for a float and (-127,127) for int.
|
||||
// This is a 2-d operation in the sense that the output depth is the number
|
||||
// of input channels, the height is the height of the image, and the width
|
||||
// is the width of the image, or truncated/padded with noise if the width
|
||||
// is a fixed size.
|
||||
void NetworkIO::Copy2DImage(int batch, Pix* pix, float black, float contrast,
|
||||
TRand* randomizer) {
|
||||
int width = pixGetWidth(pix);
|
||||
int height = pixGetHeight(pix);
|
||||
int wpl = pixGetWpl(pix);
|
||||
StrideMap::Index index(stride_map_);
|
||||
index.AddOffset(batch, FD_BATCH);
|
||||
int t = index.t();
|
||||
int target_height = stride_map_.Size(FD_HEIGHT);
|
||||
int target_width = stride_map_.Size(FD_WIDTH);
|
||||
int num_features = NumFeatures();
|
||||
bool color = num_features == 3;
|
||||
if (width > target_width) width = target_width;
|
||||
const uinT32* line = pixGetData(pix);
|
||||
for (int y = 0; y < target_height; ++y, line += wpl) {
|
||||
int x = 0;
|
||||
if (y < height) {
|
||||
for (x = 0; x < width; ++x, ++t) {
|
||||
if (color) {
|
||||
int f = 0;
|
||||
for (int c = COLOR_RED; c <= COLOR_BLUE; ++c) {
|
||||
int pixel = GET_DATA_BYTE(line + x, c);
|
||||
SetPixel(t, f++, pixel, black, contrast);
|
||||
}
|
||||
} else {
|
||||
int pixel = GET_DATA_BYTE(line, x);
|
||||
SetPixel(t, 0, pixel, black, contrast);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (; x < target_width; ++x) Randomize(t++, 0, num_features, randomizer);
|
||||
}
|
||||
}
|
||||
|
||||
// Copies the given pix to *this at the given batch index, as Copy2DImage
|
||||
// above, except that the output depth is the height of the input image, the
|
||||
// output height is 1, and the output width as for Copy2DImage.
|
||||
// The image is thus treated as a 1-d set of vertical pixel strips.
|
||||
void NetworkIO::Copy1DGreyImage(int batch, Pix* pix, float black,
|
||||
float contrast, TRand* randomizer) {
|
||||
int width = pixGetWidth(pix);
|
||||
int height = pixGetHeight(pix);
|
||||
ASSERT_HOST(height == NumFeatures());
|
||||
int wpl = pixGetWpl(pix);
|
||||
StrideMap::Index index(stride_map_);
|
||||
index.AddOffset(batch, FD_BATCH);
|
||||
int t = index.t();
|
||||
int target_width = stride_map_.Size(FD_WIDTH);
|
||||
if (width > target_width) width = target_width;
|
||||
int x;
|
||||
for (x = 0; x < width; ++x, ++t) {
|
||||
for (int y = 0; y < height; ++y) {
|
||||
const uinT32* line = pixGetData(pix) + wpl * y;
|
||||
int pixel = GET_DATA_BYTE(line, x);
|
||||
SetPixel(t, y, pixel, black, contrast);
|
||||
}
|
||||
}
|
||||
for (; x < target_width; ++x) Randomize(t++, 0, height, randomizer);
|
||||
}
|
||||
|
||||
// Helper stores the pixel value in i_ or f_ according to int_mode_.
|
||||
// t: is the index from the StrideMap corresponding to the current
|
||||
// [batch,y,x] position
|
||||
// f: is the index into the depth/channel
|
||||
// pixel: the value of the pixel from the image (in one channel)
|
||||
// black: the pixel value to map to the lowest of the range of *this
|
||||
// contrast: the range of pixel values to stretch to half the range of *this.
|
||||
void NetworkIO::SetPixel(int t, int f, int pixel, float black, float contrast) {
|
||||
float float_pixel = (pixel - black) / contrast - 1.0f;
|
||||
if (int_mode_) {
|
||||
i_[t][f] = ClipToRange(IntCastRounded((MAX_INT8 + 1) * float_pixel),
|
||||
-MAX_INT8, MAX_INT8);
|
||||
} else {
|
||||
f_[t][f] = float_pixel;
|
||||
}
|
||||
}
|
||||
|
||||
// Converts the array to a Pix. Must be pixDestroyed after use.
|
||||
Pix* NetworkIO::ToPix() const {
|
||||
// Count the width of the image, and find the max multiplication factor.
|
||||
int im_width = stride_map_.Size(FD_WIDTH);
|
||||
int im_height = stride_map_.Size(FD_HEIGHT);
|
||||
int num_features = NumFeatures();
|
||||
int feature_factor = 1;
|
||||
if (num_features == 3) {
|
||||
// Special hack for color.
|
||||
num_features = 1;
|
||||
feature_factor = 3;
|
||||
}
|
||||
Pix* pix = pixCreate(im_width, im_height * num_features, 32);
|
||||
StrideMap::Index index(stride_map_);
|
||||
do {
|
||||
int im_x = index.index(FD_WIDTH);
|
||||
int top_im_y = index.index(FD_HEIGHT);
|
||||
int im_y = top_im_y;
|
||||
int t = index.t();
|
||||
if (int_mode_) {
|
||||
const inT8* features = i_[t];
|
||||
for (int y = 0; y < num_features; ++y, im_y += im_height) {
|
||||
int pixel = features[y * feature_factor];
|
||||
// 1 or 2 features use greyscale.
|
||||
int red = ClipToRange(pixel + 128, 0, 255);
|
||||
int green = red, blue = red;
|
||||
if (feature_factor == 3) {
|
||||
// With 3 features assume RGB color.
|
||||
green = ClipToRange(features[y * feature_factor + 1] + 128, 0, 255);
|
||||
blue = ClipToRange(features[y * feature_factor + 2] + 128, 0, 255);
|
||||
} else if (num_features > 3) {
|
||||
// More than 3 features use false yellow/blue color, assuming a signed
|
||||
// input in the range [-1,1].
|
||||
red = abs(pixel) * 2;
|
||||
if (pixel >= 0) {
|
||||
green = red;
|
||||
blue = 0;
|
||||
} else {
|
||||
blue = red;
|
||||
green = red = 0;
|
||||
}
|
||||
}
|
||||
pixSetPixel(pix, im_x, im_y, (red << L_RED_SHIFT) |
|
||||
(green << L_GREEN_SHIFT) |
|
||||
(blue << L_BLUE_SHIFT));
|
||||
}
|
||||
} else {
|
||||
const float* features = f_[t];
|
||||
for (int y = 0; y < num_features; ++y, im_y += im_height) {
|
||||
float pixel = features[y * feature_factor];
|
||||
// 1 or 2 features use greyscale.
|
||||
int red = ClipToRange(IntCastRounded((pixel + 1.0f) * 127.5f), 0, 255);
|
||||
int green = red, blue = red;
|
||||
if (feature_factor == 3) {
|
||||
// With 3 features assume RGB color.
|
||||
pixel = features[y * feature_factor + 1];
|
||||
green = ClipToRange(IntCastRounded((pixel + 1.0f) * 127.5f), 0, 255);
|
||||
pixel = features[y * feature_factor + 2];
|
||||
blue = ClipToRange(IntCastRounded((pixel + 1.0f) * 127.5f), 0, 255);
|
||||
} else if (num_features > 3) {
|
||||
// More than 3 features use false yellow/blue color, assuming a signed
|
||||
// input in the range [-1,1].
|
||||
red = ClipToRange(IntCastRounded(fabs(pixel) * 255), 0, 255);
|
||||
if (pixel >= 0) {
|
||||
green = red;
|
||||
blue = 0;
|
||||
} else {
|
||||
blue = red;
|
||||
green = red = 0;
|
||||
}
|
||||
}
|
||||
pixSetPixel(pix, im_x, im_y, (red << L_RED_SHIFT) |
|
||||
(green << L_GREEN_SHIFT) |
|
||||
(blue << L_BLUE_SHIFT));
|
||||
}
|
||||
}
|
||||
} while (index.Increment());
|
||||
return pix;
|
||||
}
|
||||
|
||||
// Prints the first and last num timesteps of the array for each feature.
|
||||
void NetworkIO::Print(int num) const {
|
||||
int num_features = NumFeatures();
|
||||
for (int y = 0; y < num_features; ++y) {
|
||||
for (int t = 0; t < Width(); ++t) {
|
||||
if (num == 0 || t < num || t + num >= Width()) {
|
||||
if (int_mode_) {
|
||||
tprintf(" %g", static_cast<float>(i_[t][y]) / MAX_INT8);
|
||||
} else {
|
||||
tprintf(" %g", f_[t][y]);
|
||||
}
|
||||
}
|
||||
}
|
||||
tprintf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// Copies a single time step from src.
|
||||
void NetworkIO::CopyTimeStepFrom(int dest_t, const NetworkIO& src, int src_t) {
|
||||
ASSERT_HOST(int_mode_ == src.int_mode_);
|
||||
if (int_mode_) {
|
||||
memcpy(i_[dest_t], src.i_[src_t], i_.dim2() * sizeof(i_[0][0]));
|
||||
} else {
|
||||
memcpy(f_[dest_t], src.f_[src_t], f_.dim2() * sizeof(f_[0][0]));
|
||||
}
|
||||
}
|
||||
|
||||
// Copies a part of single time step from src.
|
||||
void NetworkIO::CopyTimeStepGeneral(int dest_t, int dest_offset,
|
||||
int num_features, const NetworkIO& src,
|
||||
int src_t, int src_offset) {
|
||||
ASSERT_HOST(int_mode_ == src.int_mode_);
|
||||
if (int_mode_) {
|
||||
memcpy(i_[dest_t] + dest_offset, src.i_[src_t] + src_offset,
|
||||
num_features * sizeof(i_[0][0]));
|
||||
} else {
|
||||
memcpy(f_[dest_t] + dest_offset, src.f_[src_t] + src_offset,
|
||||
num_features * sizeof(f_[0][0]));
|
||||
}
|
||||
}
|
||||
|
||||
// Zeroes a single time step.
|
||||
void NetworkIO::ZeroTimeStepGeneral(int t, int offset, int num_features) {
|
||||
if (int_mode_) {
|
||||
ZeroVector(num_features, i_[t] + offset);
|
||||
} else {
|
||||
ZeroVector(num_features, f_[t] + offset);
|
||||
}
|
||||
}
|
||||
|
||||
// Sets the given range to random values.
|
||||
void NetworkIO::Randomize(int t, int offset, int num_features,
|
||||
TRand* randomizer) {
|
||||
if (int_mode_) {
|
||||
inT8* line = i_[t] + offset;
|
||||
for (int i = 0; i < num_features; ++i)
|
||||
line[i] = IntCastRounded(randomizer->SignedRand(MAX_INT8));
|
||||
} else {
|
||||
// float mode.
|
||||
float* line = f_[t] + offset;
|
||||
for (int i = 0; i < num_features; ++i)
|
||||
line[i] = randomizer->SignedRand(1.0);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper returns the label and score of the best choice over a range.
|
||||
int NetworkIO::BestChoiceOverRange(int t_start, int t_end, int not_this,
|
||||
int null_ch, float* rating,
|
||||
float* certainty) const {
|
||||
if (t_end <= t_start) return -1;
|
||||
int max_char = -1;
|
||||
float min_score = 0.0f;
|
||||
for (int c = 0; c < NumFeatures(); ++c) {
|
||||
if (c == not_this || c == null_ch) continue;
|
||||
ScoresOverRange(t_start, t_end, c, null_ch, rating, certainty);
|
||||
if (max_char < 0 || *rating < min_score) {
|
||||
min_score = *rating;
|
||||
max_char = c;
|
||||
}
|
||||
}
|
||||
ScoresOverRange(t_start, t_end, max_char, null_ch, rating, certainty);
|
||||
return max_char;
|
||||
}
|
||||
|
||||
// Helper returns the rating and certainty of the choice over a range in output.
|
||||
void NetworkIO::ScoresOverRange(int t_start, int t_end, int choice, int null_ch,
|
||||
float* rating, float* certainty) const {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
*rating = 0.0f;
|
||||
*certainty = 0.0f;
|
||||
if (t_end <= t_start || t_end <= 0) return;
|
||||
float ratings[3] = {0.0f, 0.0f, 0.0f};
|
||||
float certs[3] = {0.0f, 0.0f, 0.0f};
|
||||
for (int t = t_start; t < t_end; ++t) {
|
||||
const float* line = f_[t];
|
||||
float score = ProbToCertainty(line[choice]);
|
||||
float zero = ProbToCertainty(line[null_ch]);
|
||||
if (t == t_start) {
|
||||
ratings[2] = MAX_FLOAT32;
|
||||
ratings[1] = -score;
|
||||
certs[1] = score;
|
||||
} else {
|
||||
for (int i = 2; i >= 1; --i) {
|
||||
if (ratings[i] > ratings[i - 1]) {
|
||||
ratings[i] = ratings[i - 1];
|
||||
certs[i] = certs[i - 1];
|
||||
}
|
||||
}
|
||||
ratings[2] -= zero;
|
||||
if (zero < certs[2]) certs[2] = zero;
|
||||
ratings[1] -= score;
|
||||
if (score < certs[1]) certs[1] = score;
|
||||
}
|
||||
ratings[0] -= zero;
|
||||
if (zero < certs[0]) certs[0] = zero;
|
||||
}
|
||||
int best_i = ratings[2] < ratings[1] ? 2 : 1;
|
||||
*rating = ratings[best_i] + t_end - t_start;
|
||||
*certainty = certs[best_i];
|
||||
}
|
||||
|
||||
// Returns the index (label) of the best value at the given timestep,
|
||||
// excluding not_this and not_that, and if not null, sets the score to the
|
||||
// log of the corresponding value.
|
||||
int NetworkIO::BestLabel(int t, int not_this, int not_that,
|
||||
float* score) const {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
int best_index = -1;
|
||||
float best_score = -MAX_FLOAT32;
|
||||
const float* line = f_[t];
|
||||
for (int i = 0; i < f_.dim2(); ++i) {
|
||||
if (line[i] > best_score && i != not_this && i != not_that) {
|
||||
best_score = line[i];
|
||||
best_index = i;
|
||||
}
|
||||
}
|
||||
if (score != NULL) *score = ProbToCertainty(best_score);
|
||||
return best_index;
|
||||
}
|
||||
|
||||
// Returns the best start position out of [start, end) (into which all labels
|
||||
// must fit) to obtain the highest cumulative score for the given labels.
|
||||
int NetworkIO::PositionOfBestMatch(const GenericVector<int>& labels, int start,
|
||||
int end) const {
|
||||
int length = labels.size();
|
||||
int last_start = end - length;
|
||||
int best_start = -1;
|
||||
double best_score = 0.0;
|
||||
for (int s = start; s <= last_start; ++s) {
|
||||
double score = ScoreOfLabels(labels, s);
|
||||
if (score > best_score || best_start < 0) {
|
||||
best_score = score;
|
||||
best_start = s;
|
||||
}
|
||||
}
|
||||
return best_start;
|
||||
}
|
||||
|
||||
// Returns the cumulative score of the given labels starting at start, and
|
||||
// using one label per time-step.
|
||||
double NetworkIO::ScoreOfLabels(const GenericVector<int>& labels,
|
||||
int start) const {
|
||||
int length = labels.size();
|
||||
double score = 0.0;
|
||||
for (int i = 0; i < length; ++i) {
|
||||
score += f_(start + i, labels[i]);
|
||||
}
|
||||
return score;
|
||||
}
|
||||
|
||||
// Helper function sets all the outputs for a single timestep, such that
|
||||
// label has value ok_score, and the other labels share 1 - ok_score.
|
||||
void NetworkIO::SetActivations(int t, int label, float ok_score) {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
int num_classes = NumFeatures();
|
||||
float bad_score = (1.0f - ok_score) / (num_classes - 1);
|
||||
float* targets = f_[t];
|
||||
for (int i = 0; i < num_classes; ++i)
|
||||
targets[i] = bad_score;
|
||||
targets[label] = ok_score;
|
||||
}
|
||||
|
||||
// Modifies the values, only if needed, so that the given label is
|
||||
// the winner at the given time step t.
|
||||
void NetworkIO::EnsureBestLabel(int t, int label) {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
if (BestLabel(t, NULL) != label) {
|
||||
// Output value needs enhancing. Third all the other elements and add the
|
||||
// remainder to best_label.
|
||||
int num_classes = NumFeatures();
|
||||
float* targets = f_[t];
|
||||
float enhancement = (1.0f - targets[label]) / 3.0f;
|
||||
for (int c = 0; c < num_classes; ++c) {
|
||||
if (c == label) {
|
||||
targets[c] += (1.0 - targets[c]) * (2 / 3.0);
|
||||
} else {
|
||||
targets[c] /= 3.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function converts prob to certainty taking the minimum into account.
|
||||
/* static */
|
||||
float NetworkIO::ProbToCertainty(float prob) {
|
||||
return prob > kMinProb ? log(prob) : kMinCertainty;
|
||||
}
|
||||
|
||||
// Returns true if there is any bad value that is suspiciously like a GT
|
||||
// error. Assuming that *this is the difference(gradient) between target
|
||||
// and forward output, returns true if there is a large negative value
|
||||
// (correcting a very confident output) for which there is no corresponding
|
||||
// positive value in an adjacent timestep for the same feature index. This
|
||||
// allows the box-truthed samples to make fine adjustments to position while
|
||||
// stopping other disagreements of confident output with ground truth.
|
||||
bool NetworkIO::AnySuspiciousTruth(float confidence_thr) const {
|
||||
int num_features = NumFeatures();
|
||||
for (int t = 0; t < Width(); ++t) {
|
||||
const float* features = f_[t];
|
||||
for (int y = 0; y < num_features; ++y) {
|
||||
float grad = features[y];
|
||||
if (grad < -confidence_thr) {
|
||||
// Correcting strong output. Check for movement.
|
||||
if ((t == 0 || f_[t - 1][y] < confidence_thr / 2) &&
|
||||
(t + 1 == Width() || f_[t + 1][y] < confidence_thr / 2)) {
|
||||
return true; // No strong positive on either side.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Reads a single timestep to floats in the range [-1, 1].
|
||||
void NetworkIO::ReadTimeStep(int t, double* output) const {
|
||||
if (int_mode_) {
|
||||
const inT8* line = i_[t];
|
||||
for (int i = 0; i < i_.dim2(); ++i) {
|
||||
output[i] = static_cast<double>(line[i]) / MAX_INT8;
|
||||
}
|
||||
} else {
|
||||
const float* line = f_[t];
|
||||
for (int i = 0; i < f_.dim2(); ++i) {
|
||||
output[i] = static_cast<double>(line[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Adds a single timestep to floats.
|
||||
void NetworkIO::AddTimeStep(int t, double* inout) const {
|
||||
int num_features = NumFeatures();
|
||||
if (int_mode_) {
|
||||
const inT8* line = i_[t];
|
||||
for (int i = 0; i < num_features; ++i) {
|
||||
inout[i] += static_cast<double>(line[i]) / MAX_INT8;
|
||||
}
|
||||
} else {
|
||||
const float* line = f_[t];
|
||||
for (int i = 0; i < num_features; ++i) {
|
||||
inout[i] += line[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Adds part of a single timestep to floats.
|
||||
void NetworkIO::AddTimeStepPart(int t, int offset, int num_features,
|
||||
float* inout) const {
|
||||
if (int_mode_) {
|
||||
const inT8* line = i_[t] + offset;
|
||||
for (int i = 0; i < num_features; ++i) {
|
||||
inout[i] += static_cast<float>(line[i]) / MAX_INT8;
|
||||
}
|
||||
} else {
|
||||
const float* line = f_[t] + offset;
|
||||
for (int i = 0; i < num_features; ++i) {
|
||||
inout[i] += line[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Writes a single timestep from floats in the range [-1, 1].
|
||||
void NetworkIO::WriteTimeStep(int t, const double* input) {
|
||||
WriteTimeStepPart(t, 0, NumFeatures(), input);
|
||||
}
|
||||
|
||||
// Writes a single timestep from floats in the range [-1, 1] writing only
|
||||
// num_features elements of input to (*this)[t], starting at offset.
|
||||
void NetworkIO::WriteTimeStepPart(int t, int offset, int num_features,
|
||||
const double* input) {
|
||||
if (int_mode_) {
|
||||
inT8* line = i_[t] + offset;
|
||||
for (int i = 0; i < num_features; ++i) {
|
||||
line[i] = ClipToRange(IntCastRounded(input[i] * MAX_INT8),
|
||||
-MAX_INT8, MAX_INT8);
|
||||
}
|
||||
} else {
|
||||
float* line = f_[t] + offset;
|
||||
for (int i = 0; i < num_features; ++i) {
|
||||
line[i] = static_cast<float>(input[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Maxpools a single time step from src.
|
||||
void NetworkIO::MaxpoolTimeStep(int dest_t, const NetworkIO& src, int src_t,
|
||||
int* max_line) {
|
||||
ASSERT_HOST(int_mode_ == src.int_mode_);
|
||||
if (int_mode_) {
|
||||
int dim = i_.dim2();
|
||||
inT8* dest_line = i_[dest_t];
|
||||
const inT8* src_line = src.i_[src_t];
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
if (dest_line[i] < src_line[i]) {
|
||||
dest_line[i] = src_line[i];
|
||||
max_line[i] = src_t;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int dim = f_.dim2();
|
||||
float* dest_line = f_[dest_t];
|
||||
const float* src_line = src.f_[src_t];
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
if (dest_line[i] < src_line[i]) {
|
||||
dest_line[i] = src_line[i];
|
||||
max_line[i] = src_t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Runs maxpool backward, using maxes to index timesteps in *this.
|
||||
void NetworkIO::MaxpoolBackward(const NetworkIO& fwd,
|
||||
const GENERIC_2D_ARRAY<int>& maxes) {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
int width = fwd.Width();
|
||||
Zero();
|
||||
StrideMap::Index index(fwd.stride_map_);
|
||||
do {
|
||||
int t = index.t();
|
||||
const int* max_line = maxes[t];
|
||||
const float* fwd_line = fwd.f_[t];
|
||||
int num_features = fwd.f_.dim2();
|
||||
for (int i = 0; i < num_features; ++i) {
|
||||
f_[max_line[i]][i] = fwd_line[i];
|
||||
}
|
||||
} while (index.Increment());
|
||||
}
|
||||
|
||||
// Returns the min over time of the maxes over features of the outputs.
|
||||
float NetworkIO::MinOfMaxes() const {
|
||||
float min_max = 0.0f;
|
||||
int width = Width();
|
||||
int num_features = NumFeatures();
|
||||
for (int t = 0; t < width; ++t) {
|
||||
float max_value = -MAX_FLOAT32;
|
||||
if (int_mode_) {
|
||||
const inT8* column = i_[t];
|
||||
for (int i = 0; i < num_features; ++i) {
|
||||
if (column[i] > max_value) max_value = column[i];
|
||||
}
|
||||
} else {
|
||||
const float* column = f_[t];
|
||||
for (int i = 0; i < num_features; ++i) {
|
||||
if (column[i] > max_value) max_value = column[i];
|
||||
}
|
||||
}
|
||||
if (t == 0 || max_value < min_max) min_max = max_value;
|
||||
}
|
||||
return min_max;
|
||||
}
|
||||
|
||||
// Computes combined results for a combiner that chooses between an existing
|
||||
// input and itself, with an additional output to indicate the choice.
|
||||
void NetworkIO::CombineOutputs(const NetworkIO& base_output,
|
||||
const NetworkIO& combiner_output) {
|
||||
int no = base_output.NumFeatures();
|
||||
ASSERT_HOST(combiner_output.NumFeatures() == no + 1);
|
||||
Resize(base_output, no);
|
||||
int width = Width();
|
||||
if (int_mode_) {
|
||||
// Number of outputs from base and final result.
|
||||
for (int t = 0; t < width; ++t) {
|
||||
inT8* out_line = i_[t];
|
||||
const inT8* base_line = base_output.i_[t];
|
||||
const inT8* comb_line = combiner_output.i_[t];
|
||||
float base_weight = static_cast<float>(comb_line[no]) / MAX_INT8;
|
||||
float boost_weight = 1.0f - base_weight;
|
||||
for (int i = 0; i < no; ++i) {
|
||||
out_line[i] = IntCastRounded(base_line[i] * base_weight +
|
||||
comb_line[i] * boost_weight);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int t = 0; t < width; ++t) {
|
||||
float* out_line = f_[t];
|
||||
const float* base_line = base_output.f_[t];
|
||||
const float* comb_line = combiner_output.f_[t];
|
||||
float base_weight = comb_line[no];
|
||||
float boost_weight = 1.0f - base_weight;
|
||||
for (int i = 0; i < no; ++i) {
|
||||
out_line[i] = base_line[i] * base_weight + comb_line[i] * boost_weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Computes deltas for a combiner that chooses between 2 sets of inputs.
|
||||
void NetworkIO::ComputeCombinerDeltas(const NetworkIO& fwd_deltas,
|
||||
const NetworkIO& base_output) {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
// Compute the deltas for the combiner.
|
||||
int width = Width();
|
||||
int no = NumFeatures() - 1;
|
||||
ASSERT_HOST(fwd_deltas.NumFeatures() == no);
|
||||
ASSERT_HOST(base_output.NumFeatures() == no);
|
||||
// Number of outputs from base and final result.
|
||||
for (int t = 0; t < width; ++t) {
|
||||
const float* delta_line = fwd_deltas.f_[t];
|
||||
const float* base_line = base_output.f_[t];
|
||||
float* comb_line = f_[t];
|
||||
float base_weight = comb_line[no];
|
||||
float boost_weight = 1.0f - base_weight;
|
||||
float max_base_delta = 0.0;
|
||||
for (int i = 0; i < no; ++i) {
|
||||
// What did the combiner actually produce?
|
||||
float output = base_line[i] * base_weight + comb_line[i] * boost_weight;
|
||||
// Reconstruct the target from the delta.
|
||||
float comb_target = delta_line[i] + output;
|
||||
comb_line[i] = comb_target - comb_line[i];
|
||||
float base_delta = fabs(comb_target - base_line[i]);
|
||||
if (base_delta > max_base_delta) max_base_delta = base_delta;
|
||||
}
|
||||
if (max_base_delta >= 0.5) {
|
||||
// The base network got it wrong. The combiner should output the right
|
||||
// answer and 0 for the base network.
|
||||
comb_line[no] = 0.0 - base_weight;
|
||||
} else {
|
||||
// The base network was right. The combiner should flag that.
|
||||
for (int i = 0; i < no; ++i) {
|
||||
// All other targets are 0.
|
||||
if (comb_line[i] > 0.0) comb_line[i] -= 1.0;
|
||||
}
|
||||
comb_line[no] = 1.0 - base_weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copies the array checking that the types match.
|
||||
void NetworkIO::CopyAll(const NetworkIO& src) {
|
||||
ASSERT_HOST(src.int_mode_ == int_mode_);
|
||||
f_ = src.f_;
|
||||
}
|
||||
|
||||
// Checks that both are floats and adds the src array to *this.
|
||||
void NetworkIO::AddAllToFloat(const NetworkIO& src) {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
ASSERT_HOST(!src.int_mode_);
|
||||
f_ += src.f_;
|
||||
}
|
||||
|
||||
// Subtracts the array from a float array. src must also be float.
|
||||
void NetworkIO::SubtractAllFromFloat(const NetworkIO& src) {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
ASSERT_HOST(!src.int_mode_);
|
||||
f_ -= src.f_;
|
||||
}
|
||||
|
||||
// Copies src to *this, with maxabs normalization to match scale.
|
||||
void NetworkIO::CopyWithNormalization(const NetworkIO& src,
|
||||
const NetworkIO& scale) {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
ASSERT_HOST(!src.int_mode_);
|
||||
ASSERT_HOST(!scale.int_mode_);
|
||||
float src_max = src.f_.MaxAbs();
|
||||
ASSERT_HOST(std::isfinite(src_max));
|
||||
float scale_max = scale.f_.MaxAbs();
|
||||
ASSERT_HOST(std::isfinite(scale_max));
|
||||
if (src_max > 0.0f) {
|
||||
float factor = scale_max / src_max;
|
||||
for (int t = 0; t < src.Width(); ++t) {
|
||||
const float* src_ptr = src.f_[t];
|
||||
float* dest_ptr = f_[t];
|
||||
for (int i = 0; i < src.f_.dim2(); ++i) dest_ptr[i] = src_ptr[i] * factor;
|
||||
}
|
||||
} else {
|
||||
f_.Clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Copies src to *this with independent reversal of the y dimension.
|
||||
void NetworkIO::CopyWithYReversal(const NetworkIO& src) {
|
||||
int num_features = src.NumFeatures();
|
||||
Resize(src, num_features);
|
||||
StrideMap::Index b_index(src.stride_map_);
|
||||
do {
|
||||
int width = b_index.MaxIndexOfDim(FD_WIDTH) + 1;
|
||||
StrideMap::Index fwd_index(b_index);
|
||||
StrideMap::Index rev_index(b_index);
|
||||
rev_index.AddOffset(rev_index.MaxIndexOfDim(FD_HEIGHT), FD_HEIGHT);
|
||||
do {
|
||||
int fwd_t = fwd_index.t();
|
||||
int rev_t = rev_index.t();
|
||||
for (int x = 0; x < width; ++x) CopyTimeStepFrom(rev_t++, src, fwd_t++);
|
||||
} while (fwd_index.AddOffset(1, FD_HEIGHT) &&
|
||||
rev_index.AddOffset(-1, FD_HEIGHT));
|
||||
} while (b_index.AddOffset(1, FD_BATCH));
|
||||
}
|
||||
|
||||
// Copies src to *this with independent reversal of the x dimension.
|
||||
void NetworkIO::CopyWithXReversal(const NetworkIO& src) {
|
||||
int num_features = src.NumFeatures();
|
||||
Resize(src, num_features);
|
||||
StrideMap::Index b_index(src.stride_map_);
|
||||
do {
|
||||
StrideMap::Index y_index(b_index);
|
||||
do {
|
||||
StrideMap::Index fwd_index(y_index);
|
||||
StrideMap::Index rev_index(y_index);
|
||||
rev_index.AddOffset(rev_index.MaxIndexOfDim(FD_WIDTH), FD_WIDTH);
|
||||
do {
|
||||
CopyTimeStepFrom(rev_index.t(), src, fwd_index.t());
|
||||
} while (fwd_index.AddOffset(1, FD_WIDTH) &&
|
||||
rev_index.AddOffset(-1, FD_WIDTH));
|
||||
} while (y_index.AddOffset(1, FD_HEIGHT));
|
||||
} while (b_index.AddOffset(1, FD_BATCH));
|
||||
}
|
||||
|
||||
// Copies src to *this with independent transpose of the x and y dimensions.
|
||||
void NetworkIO::CopyWithXYTranspose(const NetworkIO& src) {
|
||||
int num_features = src.NumFeatures();
|
||||
stride_map_ = src.stride_map_;
|
||||
stride_map_.TransposeXY();
|
||||
ResizeToMap(src.int_mode(), stride_map_, num_features);
|
||||
StrideMap::Index src_b_index(src.stride_map_);
|
||||
StrideMap::Index dest_b_index(stride_map_);
|
||||
do {
|
||||
StrideMap::Index src_y_index(src_b_index);
|
||||
StrideMap::Index dest_x_index(dest_b_index);
|
||||
do {
|
||||
StrideMap::Index src_x_index(src_y_index);
|
||||
StrideMap::Index dest_y_index(dest_x_index);
|
||||
do {
|
||||
CopyTimeStepFrom(dest_y_index.t(), src, src_x_index.t());
|
||||
} while (src_x_index.AddOffset(1, FD_WIDTH) &&
|
||||
dest_y_index.AddOffset(1, FD_HEIGHT));
|
||||
} while (src_y_index.AddOffset(1, FD_HEIGHT) &&
|
||||
dest_x_index.AddOffset(1, FD_WIDTH));
|
||||
} while (src_b_index.AddOffset(1, FD_BATCH) &&
|
||||
dest_b_index.AddOffset(1, FD_BATCH));
|
||||
}
|
||||
|
||||
// Copies src to *this, at the given feature_offset, returning the total
|
||||
// feature offset after the copy. Multiple calls will stack outputs from
|
||||
// multiple sources in feature space.
|
||||
int NetworkIO::CopyPacking(const NetworkIO& src, int feature_offset) {
|
||||
ASSERT_HOST(int_mode_ == src.int_mode_);
|
||||
int width = src.Width();
|
||||
ASSERT_HOST(width <= Width());
|
||||
int num_features = src.NumFeatures();
|
||||
ASSERT_HOST(num_features + feature_offset <= NumFeatures());
|
||||
if (int_mode_) {
|
||||
for (int t = 0; t < width; ++t) {
|
||||
memcpy(i_[t] + feature_offset, src.i_[t],
|
||||
num_features * sizeof(i_[t][0]));
|
||||
}
|
||||
for (int t = width; t < i_.dim1(); ++t) {
|
||||
memset(i_[t], 0, num_features * sizeof(i_[t][0]));
|
||||
}
|
||||
} else {
|
||||
for (int t = 0; t < width; ++t) {
|
||||
memcpy(f_[t] + feature_offset, src.f_[t],
|
||||
num_features * sizeof(f_[t][0]));
|
||||
}
|
||||
for (int t = width; t < f_.dim1(); ++t) {
|
||||
memset(f_[t], 0, num_features * sizeof(f_[t][0]));
|
||||
}
|
||||
}
|
||||
return num_features + feature_offset;
|
||||
}
|
||||
|
||||
// Opposite of CopyPacking, fills *this with a part of src, starting at
|
||||
// feature_offset, and picking num_features.
|
||||
void NetworkIO::CopyUnpacking(const NetworkIO& src, int feature_offset,
|
||||
int num_features) {
|
||||
Resize(src, num_features);
|
||||
int width = src.Width();
|
||||
ASSERT_HOST(num_features + feature_offset <= src.NumFeatures());
|
||||
if (int_mode_) {
|
||||
for (int t = 0; t < width; ++t) {
|
||||
memcpy(i_[t], src.i_[t] + feature_offset,
|
||||
num_features * sizeof(i_[t][0]));
|
||||
}
|
||||
} else {
|
||||
for (int t = 0; t < width; ++t) {
|
||||
memcpy(f_[t], src.f_[t] + feature_offset,
|
||||
num_features * sizeof(f_[t][0]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Transposes the float part of *this into dest.
|
||||
void NetworkIO::Transpose(TransposedArray* dest) const {
|
||||
int width = Width();
|
||||
dest->ResizeNoInit(NumFeatures(), width);
|
||||
for (int t = 0; t < width; ++t) dest->WriteStrided(t, f_[t]);
|
||||
}
|
||||
|
||||
// Clips the content of a single time-step to +/-range.
|
||||
void NetworkIO::ClipVector(int t, float range) {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
float* v = f_[t];
|
||||
int dim = f_.dim2();
|
||||
for (int i = 0; i < dim; ++i)
|
||||
v[i] = ClipToRange(v[i], -range, range);
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
341
lstm/networkio.h
Normal file
341
lstm/networkio.h
Normal file
@ -0,0 +1,341 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: networkio.h
|
||||
// Description: Network input/output data, allowing float/int implementations.
|
||||
// Author: Ray Smith
|
||||
// Created: Tue Jun 17 08:43:11 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_NETWORKIO_H_
|
||||
#define TESSERACT_LSTM_NETWORKIO_H_
|
||||
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
|
||||
#include "genericvector.h"
|
||||
#include "helpers.h"
|
||||
#include "static_shape.h"
|
||||
#include "stridemap.h"
|
||||
#include "weightmatrix.h"
|
||||
|
||||
struct Pix;
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Class to contain all the input/output of a network, allowing for fixed or
|
||||
// variable-strided 2d to 1d mapping, and float or inT8 values. Provides
|
||||
// enough calculating functions to hide the detail of the implementation.
|
||||
class NetworkIO {
|
||||
public:
|
||||
NetworkIO() : int_mode_(false) {}
|
||||
// Resizes the array (and stride), avoiding realloc if possible, to the given
|
||||
// size from various size specs:
|
||||
// Same stride size, but given number of features.
|
||||
void Resize(const NetworkIO& src, int num_features) {
|
||||
ResizeToMap(src.int_mode(), src.stride_map(), num_features);
|
||||
}
|
||||
// Resizes to a specific size as a 2-d temp buffer. No batches, no y-dim.
|
||||
void Resize2d(bool int_mode, int width, int num_features);
|
||||
// Resizes forcing a float representation with the stridemap of src and the
|
||||
// given number of features.
|
||||
void ResizeFloat(const NetworkIO& src, int num_features) {
|
||||
ResizeToMap(false, src.stride_map(), num_features);
|
||||
}
|
||||
// Resizes to a specific stride_map.
|
||||
void ResizeToMap(bool int_mode, const StrideMap& stride_map,
|
||||
int num_features);
|
||||
// Shrinks image size by x_scale,y_scale, and use given number of features.
|
||||
void ResizeScaled(const NetworkIO& src, int x_scale, int y_scale,
|
||||
int num_features);
|
||||
// Resizes to just 1 x-coord, whatever the input.
|
||||
void ResizeXTo1(const NetworkIO& src, int num_features);
|
||||
// Initialize all the array to zero.
|
||||
void Zero();
|
||||
// Initializes to zero all elements of the array that do not correspond to
|
||||
// valid image positions. (If a batch of different-sized images are packed
|
||||
// together, then there will be padding pixels.)
|
||||
void ZeroInvalidElements();
|
||||
// Sets up the array from the given image, using the currently set int_mode_.
|
||||
// If the image width doesn't match the shape, the image is truncated or
|
||||
// padded with noise to match.
|
||||
void FromPix(const StaticShape& shape, const Pix* pix, TRand* randomizer);
|
||||
// Sets up the array from the given set of images, using the currently set
|
||||
// int_mode_. If the image width doesn't match the shape, the images are
|
||||
// truncated or padded with noise to match.
|
||||
void FromPixes(const StaticShape& shape, const std::vector<const Pix*>& pixes,
|
||||
TRand* randomizer);
|
||||
// Copies the given pix to *this at the given batch index, stretching and
|
||||
// clipping the pixel values so that [black, black + 2*contrast] maps to the
|
||||
// dynamic range of *this, ie [-1,1] for a float and (-127,127) for int.
|
||||
// This is a 2-d operation in the sense that the output depth is the number
|
||||
// of input channels, the height is the height of the image, and the width
|
||||
// is the width of the image, or truncated/padded with noise if the width
|
||||
// is a fixed size.
|
||||
void Copy2DImage(int batch, Pix* pix, float black, float contrast,
|
||||
TRand* randomizer);
|
||||
// Copies the given pix to *this at the given batch index, as Copy2DImage
|
||||
// above, except that the output depth is the height of the input image, the
|
||||
// output height is 1, and the output width as for Copy2DImage.
|
||||
// The image is thus treated as a 1-d set of vertical pixel strips.
|
||||
void Copy1DGreyImage(int batch, Pix* pix, float black, float contrast,
|
||||
TRand* randomizer);
|
||||
// Helper stores the pixel value in i_ or f_ according to int_mode_.
|
||||
// t: is the index from the StrideMap corresponding to the current
|
||||
// [batch,y,x] position
|
||||
// f: is the index into the depth/channel
|
||||
// pixel: the value of the pixel from the image (in one channel)
|
||||
// black: the pixel value to map to the lowest of the range of *this
|
||||
// contrast: the range of pixel values to stretch to half the range of *this.
|
||||
void SetPixel(int t, int f, int pixel, float black, float contrast);
|
||||
// Converts the array to a Pix. Must be pixDestroyed after use.
|
||||
Pix* ToPix() const;
|
||||
// Prints the first and last num timesteps of the array for each feature.
|
||||
void Print(int num) const;
|
||||
|
||||
// Returns the timestep width.
|
||||
int Width() const {
|
||||
return int_mode_ ? i_.dim1() : f_.dim1();
|
||||
}
|
||||
// Returns the number of features.
|
||||
int NumFeatures() const {
|
||||
return int_mode_ ? i_.dim2() : f_.dim2();
|
||||
}
|
||||
// Accessor to a timestep of the float matrix.
|
||||
float* f(int t) {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
return f_[t];
|
||||
}
|
||||
const float* f(int t) const {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
return f_[t];
|
||||
}
|
||||
const inT8* i(int t) const {
|
||||
ASSERT_HOST(int_mode_);
|
||||
return i_[t];
|
||||
}
|
||||
bool int_mode() const {
|
||||
return int_mode_;
|
||||
}
|
||||
void set_int_mode(bool is_quantized) {
|
||||
int_mode_ = is_quantized;
|
||||
}
|
||||
const StrideMap& stride_map() const {
|
||||
return stride_map_;
|
||||
}
|
||||
void set_stride_map(const StrideMap& map) {
|
||||
stride_map_ = map;
|
||||
}
|
||||
const GENERIC_2D_ARRAY<float>& float_array() const { return f_; }
|
||||
GENERIC_2D_ARRAY<float>* mutable_float_array() { return &f_; }
|
||||
|
||||
// Copies a single time step from src.
|
||||
void CopyTimeStepFrom(int dest_t, const NetworkIO& src, int src_t);
|
||||
// Copies a part of single time step from src.
|
||||
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features,
|
||||
const NetworkIO& src, int src_t, int src_offset);
|
||||
// Zeroes a single time step.
|
||||
void ZeroTimeStep(int t) { ZeroTimeStepGeneral(t, 0, NumFeatures()); }
|
||||
void ZeroTimeStepGeneral(int t, int offset, int num_features);
|
||||
// Sets the given range to random values.
|
||||
void Randomize(int t, int offset, int num_features, TRand* randomizer);
|
||||
|
||||
// Helper returns the label and score of the best choice over a range.
|
||||
int BestChoiceOverRange(int t_start, int t_end, int not_this, int null_ch,
|
||||
float* rating, float* certainty) const;
|
||||
// Helper returns the rating and certainty of the choice over a range in t.
|
||||
void ScoresOverRange(int t_start, int t_end, int choice, int null_ch,
|
||||
float* rating, float* certainty) const;
|
||||
// Returns the index (label) of the best value at the given timestep,
|
||||
// and if not null, sets the score to the log of the corresponding value.
|
||||
int BestLabel(int t, float* score) const {
|
||||
return BestLabel(t, -1, -1, score);
|
||||
}
|
||||
// Returns the index (label) of the best value at the given timestep,
|
||||
// excluding not_this and not_that, and if not null, sets the score to the
|
||||
// log of the corresponding value.
|
||||
int BestLabel(int t, int not_this, int not_that, float* score) const;
|
||||
// Returns the best start position out of range (into which both start and end
|
||||
// must fit) to obtain the highest cumulative score for the given labels.
|
||||
int PositionOfBestMatch(const GenericVector<int>& labels, int start,
|
||||
int end) const;
|
||||
// Returns the cumulative score of the given labels starting at start, and
|
||||
// using one label per time-step.
|
||||
double ScoreOfLabels(const GenericVector<int>& labels, int start) const;
|
||||
// Helper function sets all the outputs for a single timestep, such that
|
||||
// label has value ok_score, and the other labels share 1 - ok_score.
|
||||
// Assumes float mode.
|
||||
void SetActivations(int t, int label, float ok_score);
|
||||
// Modifies the values, only if needed, so that the given label is
|
||||
// the winner at the given time step t.
|
||||
// Assumes float mode.
|
||||
void EnsureBestLabel(int t, int label);
|
||||
// Helper function converts prob to certainty taking the minimum into account.
|
||||
static float ProbToCertainty(float prob);
|
||||
// Returns true if there is any bad value that is suspiciously like a GT
|
||||
// error. Assuming that *this is the difference(gradient) between target
|
||||
// and forward output, returns true if there is a large negative value
|
||||
// (correcting a very confident output) for which there is no corresponding
|
||||
// positive value in an adjacent timestep for the same feature index. This
|
||||
// allows the box-truthed samples to make fine adjustments to position while
|
||||
// stopping other disagreements of confident output with ground truth.
|
||||
bool AnySuspiciousTruth(float confidence_thr) const;
|
||||
|
||||
// Reads a single timestep to floats in the range [-1, 1].
|
||||
void ReadTimeStep(int t, double* output) const;
|
||||
// Adds a single timestep to floats.
|
||||
void AddTimeStep(int t, double* inout) const;
|
||||
// Adds part of a single timestep to floats.
|
||||
void AddTimeStepPart(int t, int offset, int num_features, float* inout) const;
|
||||
// Writes a single timestep from floats in the range [-1, 1].
|
||||
void WriteTimeStep(int t, const double* input);
|
||||
// Writes a single timestep from floats in the range [-1, 1] writing only
|
||||
// num_features elements of input to (*this)[t], starting at offset.
|
||||
void WriteTimeStepPart(int t, int offset, int num_features,
|
||||
const double* input);
|
||||
// Maxpools a single time step from src.
|
||||
void MaxpoolTimeStep(int dest_t, const NetworkIO& src, int src_t,
|
||||
int* max_line);
|
||||
// Runs maxpool backward, using maxes to index timesteps in *this.
|
||||
void MaxpoolBackward(const NetworkIO& fwd,
|
||||
const GENERIC_2D_ARRAY<int>& maxes);
|
||||
// Returns the min over time of the maxes over features of the outputs.
|
||||
float MinOfMaxes() const;
|
||||
// Returns the min over time.
|
||||
float Max() const { return int_mode_ ? i_.Max() : f_.Max(); }
|
||||
// Computes combined results for a combiner that chooses between an existing
|
||||
// input and itself, with an additional output to indicate the choice.
|
||||
void CombineOutputs(const NetworkIO& base_output,
|
||||
const NetworkIO& combiner_output);
|
||||
// Computes deltas for a combiner that chooses between 2 sets of inputs.
|
||||
void ComputeCombinerDeltas(const NetworkIO& fwd_deltas,
|
||||
const NetworkIO& base_output);
|
||||
|
||||
// Copies the array checking that the types match.
|
||||
void CopyAll(const NetworkIO& src);
|
||||
// Adds the array to a float array, with scaling to [-1, 1] if the src is int.
|
||||
void AddAllToFloat(const NetworkIO& src);
|
||||
// Subtracts the array from a float array. src must also be float.
|
||||
void SubtractAllFromFloat(const NetworkIO& src);
|
||||
|
||||
// Copies src to *this, with maxabs normalization to match scale.
|
||||
void CopyWithNormalization(const NetworkIO& src, const NetworkIO& scale);
|
||||
// Multiplies the float data by the given factor.
|
||||
void ScaleFloatBy(float factor) { f_ *= factor; }
|
||||
// Copies src to *this with independent reversal of the y dimension.
|
||||
void CopyWithYReversal(const NetworkIO& src);
|
||||
// Copies src to *this with independent reversal of the x dimension.
|
||||
void CopyWithXReversal(const NetworkIO& src);
|
||||
// Copies src to *this with independent transpose of the x and y dimensions.
|
||||
void CopyWithXYTranspose(const NetworkIO& src);
|
||||
// Copies src to *this, at the given feature_offset, returning the total
|
||||
// feature offset after the copy. Multiple calls will stack outputs from
|
||||
// multiple sources in feature space.
|
||||
int CopyPacking(const NetworkIO& src, int feature_offset);
|
||||
// Opposite of CopyPacking, fills *this with a part of src, starting at
|
||||
// feature_offset, and picking num_features. Resizes *this to match.
|
||||
void CopyUnpacking(const NetworkIO& src, int feature_offset,
|
||||
int num_features);
|
||||
// Transposes the float part of *this into dest.
|
||||
void Transpose(TransposedArray* dest) const;
|
||||
|
||||
// Clips the content of a single time-step to +/-range.
|
||||
void ClipVector(int t, float range);
|
||||
|
||||
// Applies Func to timestep t of *this (u) and multiplies the result by v
|
||||
// component-wise, putting the product in *product.
|
||||
// *this and v may be int or float, but must match. The outputs are double.
|
||||
template <class Func>
|
||||
void FuncMultiply(const NetworkIO& v_io, int t, double* product) {
|
||||
Func f;
|
||||
ASSERT_HOST(!int_mode_);
|
||||
ASSERT_HOST(!v_io.int_mode_);
|
||||
int dim = f_.dim2();
|
||||
if (int_mode_) {
|
||||
const inT8* u = i_[t];
|
||||
const inT8* v = v_io.i_[t];
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
product[i] = f(u[i] / static_cast<double>(MAX_INT8)) * v[i] /
|
||||
static_cast<double>(MAX_INT8);
|
||||
}
|
||||
} else {
|
||||
const float* u = f_[t];
|
||||
const float* v = v_io.f_[t];
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
product[i] = f(u[i]) * v[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
// Applies Func to *this (u) at u_t, and multiplies the result by v[v_t] * w,
|
||||
// component-wise, putting the product in *product.
|
||||
// All NetworkIOs are assumed to be float.
|
||||
template <class Func>
|
||||
void FuncMultiply3(int u_t, const NetworkIO& v_io, int v_t, const double* w,
|
||||
double* product) const {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
ASSERT_HOST(!v_io.int_mode_);
|
||||
Func f;
|
||||
const float* u = f_[u_t];
|
||||
const float* v = v_io.f_[v_t];
|
||||
int dim = f_.dim2();
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
product[i] = f(u[i]) * v[i] * w[i];
|
||||
}
|
||||
}
|
||||
// Applies Func to *this (u) at u_t, and multiplies the result by v[v_t] * w,
|
||||
// component-wise, adding the product to *product.
|
||||
// All NetworkIOs are assumed to be float.
|
||||
template <class Func>
|
||||
void FuncMultiply3Add(const NetworkIO& v_io, int t, const double* w,
|
||||
double* product) const {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
ASSERT_HOST(!v_io.int_mode_);
|
||||
Func f;
|
||||
const float* u = f_[t];
|
||||
const float* v = v_io.f_[t];
|
||||
int dim = f_.dim2();
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
product[i] += f(u[i]) * v[i] * w[i];
|
||||
}
|
||||
}
|
||||
// Applies Func1 to *this (u), Func2 to v, and multiplies the result by w,
|
||||
// component-wise, putting the product in product, all at timestep t, except
|
||||
// w, which is a simple array. All NetworkIOs are assumed to be float.
|
||||
template <class Func1, class Func2>
|
||||
void Func2Multiply3(const NetworkIO& v_io, int t, const double* w,
|
||||
double* product) const {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
ASSERT_HOST(!v_io.int_mode_);
|
||||
Func1 f;
|
||||
Func2 g;
|
||||
const float* u = f_[t];
|
||||
const float* v = v_io.f_[t];
|
||||
int dim = f_.dim2();
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
product[i] = f(u[i]) * g(v[i]) * w[i];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Choice of float vs 8 bit int for data.
|
||||
GENERIC_2D_ARRAY<float> f_;
|
||||
GENERIC_2D_ARRAY<inT8> i_;
|
||||
// Which of f_ and i_ are we actually using.
|
||||
bool int_mode_;
|
||||
// Stride for 2d input data.
|
||||
StrideMap stride_map_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_NETWORKIO_H_
|
257
lstm/networkscratch.h
Normal file
257
lstm/networkscratch.h
Normal file
@ -0,0 +1,257 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: networkscratch.h
|
||||
// Description: Scratch space for Network layers that hides distinction
|
||||
// between float/int implementations.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu Jun 19 10:50:29 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_NETWORKSCRATCH_H_
|
||||
#define TESSERACT_LSTM_NETWORKSCRATCH_H_
|
||||
|
||||
#include "genericvector.h"
|
||||
#include "matrix.h"
|
||||
#include "networkio.h"
|
||||
#include "svutil.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Generic scratch space for network layers. Provides NetworkIO that can store
|
||||
// a complete set (over time) of intermediates, and GenericVector<float>
|
||||
// scratch space that auto-frees after use. The aim here is to provide a set
|
||||
// of temporary buffers to network layers that can be reused between layers
|
||||
// and don't have to be reallocated on each call.
|
||||
class NetworkScratch {
|
||||
public:
|
||||
NetworkScratch() : int_mode_(false) {}
|
||||
~NetworkScratch() {}
|
||||
|
||||
// Sets the network representation. If the representation is integer, then
|
||||
// default (integer) NetworkIOs are separated from the always-float variety.
|
||||
// This saves memory by having separate int-specific and float-specific
|
||||
// stacks. If the network representation is float, then all NetworkIOs go
|
||||
// to the float stack.
|
||||
void set_int_mode(bool int_mode) {
|
||||
int_mode_ = int_mode;
|
||||
}
|
||||
|
||||
// Class that acts like a NetworkIO (by having an implicit cast operator),
|
||||
// yet actually holds a pointer to NetworkIOs in the source NetworkScratch,
|
||||
// and knows how to unstack the borrowed pointers on destruction.
|
||||
class IO {
|
||||
public:
|
||||
// The NetworkIO should be sized after construction.
|
||||
IO(const NetworkIO& src, NetworkScratch* scratch)
|
||||
: int_mode_(scratch->int_mode_ && src.int_mode()),
|
||||
scratch_space_(scratch) {
|
||||
network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
|
||||
: scratch_space_->float_stack_.Borrow();
|
||||
}
|
||||
// Default constructor for arrays. Use one of the Resize functions
|
||||
// below to initialize and size.
|
||||
IO() : int_mode_(false), network_io_(NULL), scratch_space_(NULL) {}
|
||||
|
||||
~IO() {
|
||||
if (scratch_space_ == NULL) {
|
||||
ASSERT_HOST(network_io_ == NULL);
|
||||
} else if (int_mode_) {
|
||||
scratch_space_->int_stack_.Return(network_io_);
|
||||
} else {
|
||||
scratch_space_->float_stack_.Return(network_io_);
|
||||
}
|
||||
}
|
||||
// Resizes the array (and stride), avoiding realloc if possible, to the
|
||||
// size from various size specs:
|
||||
// Same time size, given number of features.
|
||||
void Resize(const NetworkIO& src, int num_features,
|
||||
NetworkScratch* scratch) {
|
||||
if (scratch_space_ == NULL) {
|
||||
int_mode_ = scratch->int_mode_ && src.int_mode();
|
||||
scratch_space_ = scratch;
|
||||
network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
|
||||
: scratch_space_->float_stack_.Borrow();
|
||||
}
|
||||
network_io_->Resize(src, num_features);
|
||||
}
|
||||
// Resizes to a specific size as a temp buffer. No batches, no y-dim.
|
||||
void Resize2d(bool int_mode, int width, int num_features,
|
||||
NetworkScratch* scratch) {
|
||||
if (scratch_space_ == NULL) {
|
||||
int_mode_ = scratch->int_mode_ && int_mode;
|
||||
scratch_space_ = scratch;
|
||||
network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
|
||||
: scratch_space_->float_stack_.Borrow();
|
||||
}
|
||||
network_io_->Resize2d(int_mode, width, num_features);
|
||||
}
|
||||
// Resize forcing a float representation with the width of src and the given
|
||||
// number of features.
|
||||
void ResizeFloat(const NetworkIO& src, int num_features,
|
||||
NetworkScratch* scratch) {
|
||||
if (scratch_space_ == NULL) {
|
||||
int_mode_ = false;
|
||||
scratch_space_ = scratch;
|
||||
network_io_ = scratch_space_->float_stack_.Borrow();
|
||||
}
|
||||
network_io_->ResizeFloat(src, num_features);
|
||||
}
|
||||
|
||||
// Returns a ref to a NetworkIO that enables *this to be treated as if
|
||||
// it were just a NetworkIO*.
|
||||
NetworkIO& operator*() {
|
||||
return *network_io_;
|
||||
}
|
||||
NetworkIO* operator->() {
|
||||
return network_io_;
|
||||
}
|
||||
operator NetworkIO*() {
|
||||
return network_io_;
|
||||
}
|
||||
|
||||
private:
|
||||
// True if this is from the always-float stack, otherwise the default stack.
|
||||
bool int_mode_;
|
||||
// The NetworkIO that we have borrowed from the scratch_space_.
|
||||
NetworkIO* network_io_;
|
||||
// The source scratch_space_. Borrowed pointer, used to free the
|
||||
// NetworkIO. Don't delete!
|
||||
NetworkScratch* scratch_space_;
|
||||
}; // class IO.
|
||||
|
||||
// Class that acts like a fixed array of float, yet actually uses space
|
||||
// from a GenericVector<float> in the source NetworkScratch, and knows how
|
||||
// to unstack the borrowed vector on destruction.
|
||||
class FloatVec {
|
||||
public:
|
||||
// The array will have size elements in it, uninitialized.
|
||||
FloatVec(int size, NetworkScratch* scratch)
|
||||
: vec_(NULL), scratch_space_(scratch) {
|
||||
Init(size, scratch);
|
||||
}
|
||||
// Default constructor is for arrays. Use Init to setup.
|
||||
FloatVec() : vec_(NULL), data_(NULL), scratch_space_(NULL) {}
|
||||
~FloatVec() {
|
||||
if (scratch_space_ != NULL) scratch_space_->vec_stack_.Return(vec_);
|
||||
}
|
||||
|
||||
void Init(int size, NetworkScratch* scratch) {
|
||||
if (scratch_space_ != NULL && vec_ != NULL)
|
||||
scratch_space_->vec_stack_.Return(vec_);
|
||||
scratch_space_ = scratch;
|
||||
vec_ = scratch_space_->vec_stack_.Borrow();
|
||||
vec_->resize_no_init(size);
|
||||
data_ = &(*vec_)[0];
|
||||
}
|
||||
|
||||
// Use the cast operator instead of operator[] so the FloatVec can be used
|
||||
// as a double* argument to a function call.
|
||||
operator double*() const { return data_; }
|
||||
double* get() { return data_; }
|
||||
|
||||
private:
|
||||
// Vector borrowed from the scratch space. Use Return to free it.
|
||||
GenericVector<double>* vec_;
|
||||
// Short-cut pointer to the underlying array.
|
||||
double* data_;
|
||||
// The source scratch_space_. Borrowed pointer, used to free the
|
||||
// vector. Don't delete!
|
||||
NetworkScratch* scratch_space_;
|
||||
}; // class FloatVec
|
||||
|
||||
// Class that acts like a 2-D array of double, yet actually uses space
|
||||
// from the source NetworkScratch, and knows how to unstack the borrowed
|
||||
// array on destruction.
|
||||
class GradientStore {
|
||||
public:
|
||||
// Default constructor is for arrays. Use Init to setup.
|
||||
GradientStore() : array_(NULL), scratch_space_(NULL) {}
|
||||
~GradientStore() {
|
||||
if (scratch_space_ != NULL) scratch_space_->array_stack_.Return(array_);
|
||||
}
|
||||
|
||||
void Init(int size1, int size2, NetworkScratch* scratch) {
|
||||
if (scratch_space_ != NULL && array_ != NULL)
|
||||
scratch_space_->array_stack_.Return(array_);
|
||||
scratch_space_ = scratch;
|
||||
array_ = scratch_space_->array_stack_.Borrow();
|
||||
array_->Resize(size1, size2, 0.0);
|
||||
}
|
||||
|
||||
// Accessors to get to the underlying TransposedArray.
|
||||
TransposedArray* get() const { return array_; }
|
||||
const TransposedArray& operator*() const { return *array_; }
|
||||
|
||||
private:
|
||||
// Array borrowed from the scratch space. Use Return to free it.
|
||||
TransposedArray* array_;
|
||||
// The source scratch_space_. Borrowed pointer, used to free the
|
||||
// vector. Don't delete!
|
||||
NetworkScratch* scratch_space_;
|
||||
}; // class GradientStore
|
||||
|
||||
// Class that does the work of holding a stack of objects, a stack pointer
|
||||
// and a vector of in-use flags, so objects can be returned out of order.
|
||||
// It is safe to attempt to Borrow/Return in multiple threads.
|
||||
template<typename T> class Stack {
|
||||
public:
|
||||
Stack() : stack_top_(0) {
|
||||
}
|
||||
|
||||
// Lends out the next free item, creating one if none available, sets
|
||||
// the used flags and increments the stack top.
|
||||
T* Borrow() {
|
||||
SVAutoLock lock(&mutex_);
|
||||
if (stack_top_ == stack_.size()) {
|
||||
stack_.push_back(new T);
|
||||
flags_.push_back(false);
|
||||
}
|
||||
flags_[stack_top_] = true;
|
||||
return stack_[stack_top_++];
|
||||
}
|
||||
// Takes back the given item, and marks it free. Item does not have to be
|
||||
// the most recently lent out, but free slots don't get re-used until the
|
||||
// blocking item is returned. The assumption is that there will only be
|
||||
// small, temporary variations from true stack use. (Determined by the order
|
||||
// of destructors within a local scope.)
|
||||
void Return(T* item) {
|
||||
SVAutoLock lock(&mutex_);
|
||||
// Linear search will do.
|
||||
int index = stack_top_ - 1;
|
||||
while (index >= 0 && stack_[index] != item) --index;
|
||||
if (index >= 0) flags_[index] = false;
|
||||
while (stack_top_ > 0 && !flags_[stack_top_ - 1]) --stack_top_;
|
||||
}
|
||||
|
||||
private:
|
||||
PointerVector<T> stack_;
|
||||
GenericVector<bool> flags_;
|
||||
int stack_top_;
|
||||
SVMutex mutex_;
|
||||
}; // class Stack.
|
||||
|
||||
private:
|
||||
// If true, the network weights are inT8, if false, float.
|
||||
bool int_mode_;
|
||||
// Stacks of NetworkIO and GenericVector<float>. Once allocated, they are not
|
||||
// deleted until the NetworkScratch is deleted.
|
||||
Stack<NetworkIO> int_stack_;
|
||||
Stack<NetworkIO> float_stack_;
|
||||
Stack<GenericVector<double> > vec_stack_;
|
||||
Stack<TransposedArray> array_stack_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_NETWORKSCRATCH_H_
|
180
lstm/parallel.cpp
Normal file
180
lstm/parallel.cpp
Normal file
@ -0,0 +1,180 @@
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
// File: parallel.cpp
|
||||
// Description: Runs networks in parallel on the same input.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu May 02 08:06:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "parallel.h"
|
||||
|
||||
#include <omp.h>
|
||||
|
||||
#include "functions.h" // For conditional undef of _OPENMP.
|
||||
#include "networkscratch.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// ni_ and no_ will be set by AddToStack.
|
||||
Parallel::Parallel(const STRING& name, NetworkType type) : Plumbing(name) {
|
||||
type_ = type;
|
||||
}
|
||||
|
||||
Parallel::~Parallel() {
|
||||
}
|
||||
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
StaticShape Parallel::OutputShape(const StaticShape& input_shape) const {
|
||||
StaticShape result = stack_[0]->OutputShape(input_shape);
|
||||
int stack_size = stack_.size();
|
||||
for (int i = 1; i < stack_size; ++i) {
|
||||
StaticShape shape = stack_[i]->OutputShape(input_shape);
|
||||
result.set_depth(result.depth() + shape.depth());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
void Parallel::Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output) {
|
||||
bool parallel_debug = false;
|
||||
// If this parallel is a replicator of convolvers, or holds a 1-d LSTM pair,
|
||||
// or a 2-d LSTM quad, do debug locally, and don't pass the flag on.
|
||||
if (debug && type_ != NT_PARALLEL) {
|
||||
parallel_debug = true;
|
||||
debug = false;
|
||||
}
|
||||
int stack_size = stack_.size();
|
||||
if (type_ == NT_PAR_2D_LSTM) {
|
||||
// Special case, run parallel in parallel.
|
||||
GenericVector<NetworkScratch::IO> results;
|
||||
results.init_to_size(stack_size, NetworkScratch::IO());
|
||||
for (int i = 0; i < stack_size; ++i) {
|
||||
results[i].Resize(input, stack_[i]->NumOutputs(), scratch);
|
||||
}
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for num_threads(stack_size)
|
||||
#endif
|
||||
for (int i = 0; i < stack_size; ++i) {
|
||||
stack_[i]->Forward(debug, input, NULL, scratch, results[i]);
|
||||
}
|
||||
// Now pack all the results (serially) into the output.
|
||||
int out_offset = 0;
|
||||
output->Resize(*results[0], NumOutputs());
|
||||
for (int i = 0; i < stack_size; ++i) {
|
||||
out_offset = output->CopyPacking(*results[i], out_offset);
|
||||
}
|
||||
} else {
|
||||
// Revolving intermediate result.
|
||||
NetworkScratch::IO result(input, scratch);
|
||||
// Source for divided replicated.
|
||||
NetworkScratch::IO source_part;
|
||||
TransposedArray* src_transpose = NULL;
|
||||
if (training() && type_ == NT_REPLICATED) {
|
||||
// Make a transposed copy of the input.
|
||||
input.Transpose(&transposed_input_);
|
||||
src_transpose = &transposed_input_;
|
||||
}
|
||||
// Run each network, putting the outputs into result.
|
||||
int input_offset = 0;
|
||||
int out_offset = 0;
|
||||
for (int i = 0; i < stack_size; ++i) {
|
||||
stack_[i]->Forward(debug, input, src_transpose, scratch, result);
|
||||
// All networks must have the same output width
|
||||
if (i == 0) {
|
||||
output->Resize(*result, NumOutputs());
|
||||
} else {
|
||||
ASSERT_HOST(result->Width() == output->Width());
|
||||
}
|
||||
out_offset = output->CopyPacking(*result, out_offset);
|
||||
}
|
||||
}
|
||||
if (parallel_debug) {
|
||||
DisplayForward(*output);
|
||||
}
|
||||
}
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
bool Parallel::Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas) {
|
||||
// If this parallel is a replicator of convolvers, or holds a 1-d LSTM pair,
|
||||
// or a 2-d LSTM quad, do debug locally, and don't pass the flag on.
|
||||
if (debug && type_ != NT_PARALLEL) {
|
||||
DisplayBackward(fwd_deltas);
|
||||
debug = false;
|
||||
}
|
||||
int stack_size = stack_.size();
|
||||
if (type_ == NT_PAR_2D_LSTM) {
|
||||
// Special case, run parallel in parallel.
|
||||
GenericVector<NetworkScratch::IO> in_deltas, out_deltas;
|
||||
in_deltas.init_to_size(stack_size, NetworkScratch::IO());
|
||||
out_deltas.init_to_size(stack_size, NetworkScratch::IO());
|
||||
// Split the forward deltas for each stack element.
|
||||
int feature_offset = 0;
|
||||
int out_offset = 0;
|
||||
for (int i = 0; i < stack_.size(); ++i) {
|
||||
int num_features = stack_[i]->NumOutputs();
|
||||
in_deltas[i].Resize(fwd_deltas, num_features, scratch);
|
||||
out_deltas[i].Resize(fwd_deltas, stack_[i]->NumInputs(), scratch);
|
||||
in_deltas[i]->CopyUnpacking(fwd_deltas, feature_offset, num_features);
|
||||
feature_offset += num_features;
|
||||
}
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for num_threads(stack_size)
|
||||
#endif
|
||||
for (int i = 0; i < stack_size; ++i) {
|
||||
stack_[i]->Backward(debug, *in_deltas[i], scratch,
|
||||
i == 0 ? back_deltas : out_deltas[i]);
|
||||
}
|
||||
if (needs_to_backprop_) {
|
||||
for (int i = 1; i < stack_size; ++i) {
|
||||
back_deltas->AddAllToFloat(*out_deltas[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Revolving partial deltas.
|
||||
NetworkScratch::IO in_deltas(fwd_deltas, scratch);
|
||||
// The sum of deltas from different sources, which will eventually go into
|
||||
// back_deltas.
|
||||
NetworkScratch::IO out_deltas;
|
||||
int feature_offset = 0;
|
||||
int out_offset = 0;
|
||||
for (int i = 0; i < stack_.size(); ++i) {
|
||||
int num_features = stack_[i]->NumOutputs();
|
||||
in_deltas->CopyUnpacking(fwd_deltas, feature_offset, num_features);
|
||||
feature_offset += num_features;
|
||||
if (stack_[i]->Backward(debug, *in_deltas, scratch, back_deltas)) {
|
||||
if (i == 0) {
|
||||
out_deltas.ResizeFloat(*back_deltas, back_deltas->NumFeatures(),
|
||||
scratch);
|
||||
out_deltas->CopyAll(*back_deltas);
|
||||
} else if (back_deltas->NumFeatures() == out_deltas->NumFeatures()) {
|
||||
// Widths are allowed to be different going back, as we may have
|
||||
// input nets, so only accumulate the deltas if the widths are the
|
||||
// same.
|
||||
out_deltas->AddAllToFloat(*back_deltas);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (needs_to_backprop_) back_deltas->CopyAll(*out_deltas);
|
||||
}
|
||||
if (needs_to_backprop_) back_deltas->ScaleFloatBy(1.0f / stack_size);
|
||||
return needs_to_backprop_;
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
87
lstm/parallel.h
Normal file
87
lstm/parallel.h
Normal file
@ -0,0 +1,87 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: parallel.h
|
||||
// Description: Runs networks in parallel on the same input.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu May 02 08:02:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_PARALLEL_H_
|
||||
#define TESSERACT_LSTM_PARALLEL_H_
|
||||
|
||||
#include "plumbing.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Runs multiple networks in parallel, interlacing their outputs.
|
||||
class Parallel : public Plumbing {
|
||||
public:
|
||||
// ni_ and no_ will be set by AddToStack.
|
||||
Parallel(const STRING& name, NetworkType type);
|
||||
virtual ~Parallel();
|
||||
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
virtual StaticShape OutputShape(const StaticShape& input_shape) const;
|
||||
|
||||
virtual STRING spec() const {
|
||||
STRING spec;
|
||||
if (type_ == NT_PAR_2D_LSTM) {
|
||||
// We have 4 LSTMs operating in parallel here, so the size of each is
|
||||
// the number of outputs/4.
|
||||
spec.add_str_int("L2xy", no_ / 4);
|
||||
} else if (type_ == NT_PAR_RL_LSTM) {
|
||||
// We have 2 LSTMs operating in parallel here, so the size of each is
|
||||
// the number of outputs/2.
|
||||
if (stack_[0]->type() == NT_LSTM_SUMMARY)
|
||||
spec.add_str_int("Lbxs", no_ / 2);
|
||||
else
|
||||
spec.add_str_int("Lbx", no_ / 2);
|
||||
} else {
|
||||
if (type_ == NT_REPLICATED) {
|
||||
spec.add_str_int("R", stack_.size());
|
||||
spec += "(";
|
||||
spec += stack_[0]->spec();
|
||||
} else {
|
||||
spec = "(";
|
||||
for (int i = 0; i < stack_.size(); ++i) spec += stack_[i]->spec();
|
||||
}
|
||||
spec += ")";
|
||||
}
|
||||
return spec;
|
||||
}
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual void Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output);
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas);
|
||||
|
||||
private:
|
||||
// If *this is a NT_REPLICATED, then it feeds a replicated network with
|
||||
// identical inputs, and it would be extremely wasteful for them to each
|
||||
// calculate and store the same transpose of the inputs, so Parallel does it
|
||||
// and passes a pointer to the replicated network, allowing it to use the
|
||||
// transpose on the next call to Backward.
|
||||
TransposedArray transposed_input_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_PARALLEL_H_
|
233
lstm/plumbing.cpp
Normal file
233
lstm/plumbing.cpp
Normal file
@ -0,0 +1,233 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: plumbing.cpp
|
||||
// Description: Base class for networks that organize other networks
|
||||
// eg series or parallel.
|
||||
// Author: Ray Smith
|
||||
// Created: Mon May 12 08:17:34 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "plumbing.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// ni_ and no_ will be set by AddToStack.
|
||||
Plumbing::Plumbing(const STRING& name)
|
||||
: Network(NT_PARALLEL, name, 0, 0) {
|
||||
}
|
||||
|
||||
Plumbing::~Plumbing() {
|
||||
}
|
||||
|
||||
// Suspends/Enables training by setting the training_ flag. Serialize and
|
||||
// DeSerialize only operate on the run-time data if state is false.
|
||||
void Plumbing::SetEnableTraining(bool state) {
|
||||
Network::SetEnableTraining(state);
|
||||
for (int i = 0; i < stack_.size(); ++i)
|
||||
stack_[i]->SetEnableTraining(state);
|
||||
}
|
||||
|
||||
// Sets flags that control the action of the network. See NetworkFlags enum
|
||||
// for bit values.
|
||||
void Plumbing::SetNetworkFlags(uinT32 flags) {
|
||||
Network::SetNetworkFlags(flags);
|
||||
for (int i = 0; i < stack_.size(); ++i)
|
||||
stack_[i]->SetNetworkFlags(flags);
|
||||
}
|
||||
|
||||
// Sets up the network for training. Initializes weights using weights of
|
||||
// scale `range` picked according to the random number generator `randomizer`.
|
||||
// Note that randomizer is a borrowed pointer that should outlive the network
|
||||
// and should not be deleted by any of the networks.
|
||||
// Returns the number of weights initialized.
|
||||
int Plumbing::InitWeights(float range, TRand* randomizer) {
|
||||
num_weights_ = 0;
|
||||
for (int i = 0; i < stack_.size(); ++i)
|
||||
num_weights_ += stack_[i]->InitWeights(range, randomizer);
|
||||
return num_weights_;
|
||||
}
|
||||
|
||||
// Converts a float network to an int network.
|
||||
void Plumbing::ConvertToInt() {
|
||||
for (int i = 0; i < stack_.size(); ++i)
|
||||
stack_[i]->ConvertToInt();
|
||||
}
|
||||
|
||||
// Provides a pointer to a TRand for any networks that care to use it.
|
||||
// Note that randomizer is a borrowed pointer that should outlive the network
|
||||
// and should not be deleted by any of the networks.
|
||||
void Plumbing::SetRandomizer(TRand* randomizer) {
|
||||
for (int i = 0; i < stack_.size(); ++i)
|
||||
stack_[i]->SetRandomizer(randomizer);
|
||||
}
|
||||
|
||||
// Adds the given network to the stack.
|
||||
void Plumbing::AddToStack(Network* network) {
|
||||
if (stack_.empty()) {
|
||||
ni_ = network->NumInputs();
|
||||
no_ = network->NumOutputs();
|
||||
} else if (type_ == NT_SERIES) {
|
||||
// ni is input of first, no output of last, others match output to input.
|
||||
ASSERT_HOST(no_ == network->NumInputs());
|
||||
no_ = network->NumOutputs();
|
||||
} else {
|
||||
// All parallel types. Output is sum of outputs, inputs all match.
|
||||
ASSERT_HOST(ni_ == network->NumInputs());
|
||||
no_ += network->NumOutputs();
|
||||
}
|
||||
stack_.push_back(network);
|
||||
}
|
||||
|
||||
// Sets needs_to_backprop_ to needs_backprop and calls on sub-network
|
||||
// according to needs_backprop || any weights in this network.
|
||||
bool Plumbing::SetupNeedsBackprop(bool needs_backprop) {
|
||||
needs_to_backprop_ = needs_backprop;
|
||||
bool retval = needs_backprop;
|
||||
for (int i = 0; i < stack_.size(); ++i) {
|
||||
if (stack_[i]->SetupNeedsBackprop(needs_backprop))
|
||||
retval = true;
|
||||
}
|
||||
return retval;
|
||||
}
|
||||
|
||||
// Returns an integer reduction factor that the network applies to the
|
||||
// time sequence. Assumes that any 2-d is already eliminated. Used for
|
||||
// scaling bounding boxes of truth data.
|
||||
// WARNING: if GlobalMinimax is used to vary the scale, this will return
|
||||
// the last used scale factor. Call it before any forward, and it will return
|
||||
// the minimum scale factor of the paths through the GlobalMinimax.
|
||||
int Plumbing::XScaleFactor() const {
|
||||
return stack_[0]->XScaleFactor();
|
||||
}
|
||||
|
||||
// Provides the (minimum) x scale factor to the network (of interest only to
|
||||
// input units) so they can determine how to scale bounding boxes.
|
||||
void Plumbing::CacheXScaleFactor(int factor) {
|
||||
for (int i = 0; i < stack_.size(); ++i) {
|
||||
stack_[i]->CacheXScaleFactor(factor);
|
||||
}
|
||||
}
|
||||
|
||||
// Provides debug output on the weights.
|
||||
void Plumbing::DebugWeights() {
|
||||
for (int i = 0; i < stack_.size(); ++i)
|
||||
stack_[i]->DebugWeights();
|
||||
}
|
||||
|
||||
// Returns a set of strings representing the layer-ids of all layers below.
|
||||
void Plumbing::EnumerateLayers(const STRING* prefix,
|
||||
GenericVector<STRING>* layers) const {
|
||||
for (int i = 0; i < stack_.size(); ++i) {
|
||||
STRING layer_name;
|
||||
if (prefix) layer_name = *prefix;
|
||||
layer_name.add_str_int(":", i);
|
||||
if (stack_[i]->IsPlumbingType()) {
|
||||
Plumbing* plumbing = reinterpret_cast<Plumbing*>(stack_[i]);
|
||||
plumbing->EnumerateLayers(&layer_name, layers);
|
||||
} else {
|
||||
layers->push_back(layer_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a pointer to the network layer corresponding to the given id.
|
||||
Network* Plumbing::GetLayer(const char* id) const {
|
||||
char* next_id;
|
||||
int index = strtol(id, &next_id, 10);
|
||||
if (index < 0 || index >= stack_.size()) return NULL;
|
||||
if (stack_[index]->IsPlumbingType()) {
|
||||
Plumbing* plumbing = reinterpret_cast<Plumbing*>(stack_[index]);
|
||||
ASSERT_HOST(*next_id == ':');
|
||||
return plumbing->GetLayer(next_id + 1);
|
||||
}
|
||||
return stack_[index];
|
||||
}
|
||||
|
||||
// Returns a pointer to the learning rate for the given layer id.
|
||||
float* Plumbing::LayerLearningRatePtr(const char* id) const {
|
||||
char* next_id;
|
||||
int index = strtol(id, &next_id, 10);
|
||||
if (index < 0 || index >= stack_.size()) return NULL;
|
||||
if (stack_[index]->IsPlumbingType()) {
|
||||
Plumbing* plumbing = reinterpret_cast<Plumbing*>(stack_[index]);
|
||||
ASSERT_HOST(*next_id == ':');
|
||||
return plumbing->LayerLearningRatePtr(next_id + 1);
|
||||
}
|
||||
if (index < 0 || index >= learning_rates_.size()) return NULL;
|
||||
return &learning_rates_[index];
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool Plumbing::Serialize(TFile* fp) const {
|
||||
if (!Network::Serialize(fp)) return false;
|
||||
inT32 size = stack_.size();
|
||||
// Can't use PointerVector::Serialize here as we need a special DeSerialize.
|
||||
if (fp->FWrite(&size, sizeof(size), 1) != 1) return false;
|
||||
for (int i = 0; i < size; ++i)
|
||||
if (!stack_[i]->Serialize(fp)) return false;
|
||||
if ((network_flags_ & NF_LAYER_SPECIFIC_LR) &&
|
||||
!learning_rates_.Serialize(fp)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool Plumbing::DeSerialize(bool swap, TFile* fp) {
|
||||
stack_.truncate(0);
|
||||
no_ = 0; // We will be modifying this as we AddToStack.
|
||||
inT32 size;
|
||||
if (fp->FRead(&size, sizeof(size), 1) != 1) return false;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
Network* network = CreateFromFile(swap, fp);
|
||||
if (network == NULL) return false;
|
||||
AddToStack(network);
|
||||
}
|
||||
if ((network_flags_ & NF_LAYER_SPECIFIC_LR) &&
|
||||
!learning_rates_.DeSerialize(swap, fp)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Updates the weights using the given learning rate and momentum.
|
||||
// num_samples is the quotient to be used in the adagrad computation iff
|
||||
// use_ada_grad_ is true.
|
||||
void Plumbing::Update(float learning_rate, float momentum, int num_samples) {
|
||||
for (int i = 0; i < stack_.size(); ++i) {
|
||||
if (network_flags_ & NF_LAYER_SPECIFIC_LR) {
|
||||
if (i < learning_rates_.size())
|
||||
learning_rate = learning_rates_[i];
|
||||
else
|
||||
learning_rates_.push_back(learning_rate);
|
||||
}
|
||||
if (stack_[i]->training())
|
||||
stack_[i]->Update(learning_rate, momentum, num_samples);
|
||||
}
|
||||
}
|
||||
|
||||
// Sums the products of weight updates in *this and other, splitting into
|
||||
// positive (same direction) in *same and negative (different direction) in
|
||||
// *changed.
|
||||
void Plumbing::CountAlternators(const Network& other, double* same,
|
||||
double* changed) const {
|
||||
ASSERT_HOST(other.type() == type_);
|
||||
const Plumbing* plumbing = reinterpret_cast<const Plumbing*>(&other);
|
||||
ASSERT_HOST(plumbing->stack_.size() == stack_.size());
|
||||
for (int i = 0; i < stack_.size(); ++i)
|
||||
stack_[i]->CountAlternators(*plumbing->stack_[i], same, changed);
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
||||
|
143
lstm/plumbing.h
Normal file
143
lstm/plumbing.h
Normal file
@ -0,0 +1,143 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: plumbing.h
|
||||
// Description: Base class for networks that organize other networks
|
||||
// eg series or parallel.
|
||||
// Author: Ray Smith
|
||||
// Created: Mon May 12 08:11:36 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_PLUMBING_H_
|
||||
#define TESSERACT_LSTM_PLUMBING_H_
|
||||
|
||||
#include "genericvector.h"
|
||||
#include "matrix.h"
|
||||
#include "network.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Holds a collection of other networks and forwards calls to each of them.
|
||||
class Plumbing : public Network {
|
||||
public:
|
||||
// ni_ and no_ will be set by AddToStack.
|
||||
explicit Plumbing(const STRING& name);
|
||||
virtual ~Plumbing();
|
||||
|
||||
// Returns the required shape input to the network.
|
||||
virtual StaticShape InputShape() const { return stack_[0]->InputShape(); }
|
||||
virtual STRING spec() const {
|
||||
return "Sub-classes of Plumbing must implement spec()!";
|
||||
}
|
||||
|
||||
// Returns true if the given type is derived from Plumbing, and thus contains
|
||||
// multiple sub-networks that can have their own learning rate.
|
||||
virtual bool IsPlumbingType() const { return true; }
|
||||
|
||||
// Suspends/Enables training by setting the training_ flag. Serialize and
|
||||
// DeSerialize only operate on the run-time data if state is false.
|
||||
virtual void SetEnableTraining(bool state);
|
||||
|
||||
// Sets flags that control the action of the network. See NetworkFlags enum
|
||||
// for bit values.
|
||||
virtual void SetNetworkFlags(uinT32 flags);
|
||||
|
||||
// Sets up the network for training. Initializes weights using weights of
|
||||
// scale `range` picked according to the random number generator `randomizer`.
|
||||
// Note that randomizer is a borrowed pointer that should outlive the network
|
||||
// and should not be deleted by any of the networks.
|
||||
// Returns the number of weights initialized.
|
||||
virtual int InitWeights(float range, TRand* randomizer);
|
||||
|
||||
// Converts a float network to an int network.
|
||||
virtual void ConvertToInt();
|
||||
|
||||
// Provides a pointer to a TRand for any networks that care to use it.
|
||||
// Note that randomizer is a borrowed pointer that should outlive the network
|
||||
// and should not be deleted by any of the networks.
|
||||
virtual void SetRandomizer(TRand* randomizer);
|
||||
|
||||
// Adds the given network to the stack.
|
||||
virtual void AddToStack(Network* network);
|
||||
|
||||
// Sets needs_to_backprop_ to needs_backprop and returns true if
|
||||
// needs_backprop || any weights in this network so the next layer forward
|
||||
// can be told to produce backprop for this layer if needed.
|
||||
virtual bool SetupNeedsBackprop(bool needs_backprop);
|
||||
|
||||
// Returns an integer reduction factor that the network applies to the
|
||||
// time sequence. Assumes that any 2-d is already eliminated. Used for
|
||||
// scaling bounding boxes of truth data.
|
||||
// WARNING: if GlobalMinimax is used to vary the scale, this will return
|
||||
// the last used scale factor. Call it before any forward, and it will return
|
||||
// the minimum scale factor of the paths through the GlobalMinimax.
|
||||
virtual int XScaleFactor() const;
|
||||
|
||||
// Provides the (minimum) x scale factor to the network (of interest only to
|
||||
// input units) so they can determine how to scale bounding boxes.
|
||||
virtual void CacheXScaleFactor(int factor);
|
||||
|
||||
// Provides debug output on the weights.
|
||||
virtual void DebugWeights();
|
||||
|
||||
// Returns the current stack.
|
||||
const PointerVector<Network>& stack() const {
|
||||
return stack_;
|
||||
}
|
||||
// Returns a set of strings representing the layer-ids of all layers below.
|
||||
void EnumerateLayers(const STRING* prefix,
|
||||
GenericVector<STRING>* layers) const;
|
||||
// Returns a pointer to the network layer corresponding to the given id.
|
||||
Network* GetLayer(const char* id) const;
|
||||
// Returns the learning rate for a specific layer of the stack.
|
||||
float LayerLearningRate(const char* id) const {
|
||||
const float* lr_ptr = LayerLearningRatePtr(id);
|
||||
ASSERT_HOST(lr_ptr != NULL);
|
||||
return *lr_ptr;
|
||||
}
|
||||
// Scales the learning rate for a specific layer of the stack.
|
||||
void ScaleLayerLearningRate(const char* id, double factor) {
|
||||
float* lr_ptr = LayerLearningRatePtr(id);
|
||||
ASSERT_HOST(lr_ptr != NULL);
|
||||
*lr_ptr *= factor;
|
||||
}
|
||||
// Returns a pointer to the learning rate for the given layer id.
|
||||
float* LayerLearningRatePtr(const char* id) const;
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
virtual bool Serialize(TFile* fp) const;
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
virtual bool DeSerialize(bool swap, TFile* fp);
|
||||
|
||||
// Updates the weights using the given learning rate and momentum.
|
||||
// num_samples is the quotient to be used in the adagrad computation iff
|
||||
// use_ada_grad_ is true.
|
||||
virtual void Update(float learning_rate, float momentum, int num_samples);
|
||||
// Sums the products of weight updates in *this and other, splitting into
|
||||
// positive (same direction) in *same and negative (different direction) in
|
||||
// *changed.
|
||||
virtual void CountAlternators(const Network& other, double* same,
|
||||
double* changed) const;
|
||||
|
||||
protected:
|
||||
// The networks.
|
||||
PointerVector<Network> stack_;
|
||||
// Layer-specific learning rate iff network_flags_ & NF_LAYER_SPECIFIC_LR.
|
||||
// One element for each element of stack_.
|
||||
GenericVector<float> learning_rates_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_PLUMBING_H_
|
||||
|
759
lstm/recodebeam.cpp
Normal file
759
lstm/recodebeam.cpp
Normal file
@ -0,0 +1,759 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: recodebeam.cpp
|
||||
// Description: Beam search to decode from the re-encoded CJK as a sequence of
|
||||
// smaller numbers in place of a single large code.
|
||||
// Author: Ray Smith
|
||||
// Created: Fri Mar 13 09:39:01 PDT 2015
|
||||
//
|
||||
// (C) Copyright 2015, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "recodebeam.h"
|
||||
#include "networkio.h"
|
||||
#include "pageres.h"
|
||||
#include "unicharcompress.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Clipping value for certainty inside Tesseract. Reflects the minimum value
|
||||
// of certainty that will be returned by ExtractBestPathAsUnicharIds.
|
||||
// Supposedly on a uniform scale that can be compared across languages and
|
||||
// engines.
|
||||
const float RecodeBeamSearch::kMinCertainty = -20.0f;
|
||||
|
||||
// The beam width at each code position.
|
||||
const int RecodeBeamSearch::kBeamWidths[RecodedCharID::kMaxCodeLen + 1] = {
|
||||
5, 10, 16, 16, 16, 16, 16, 16, 16, 16,
|
||||
};
|
||||
|
||||
// Borrows the pointer, which is expected to survive until *this is deleted.
|
||||
RecodeBeamSearch::RecodeBeamSearch(const UnicharCompress& recoder,
|
||||
int null_char, bool simple_text, Dict* dict)
|
||||
: recoder_(recoder),
|
||||
dict_(dict),
|
||||
space_delimited_(true),
|
||||
is_simple_text_(simple_text),
|
||||
null_char_(null_char) {
|
||||
if (dict_ != NULL && !dict_->IsSpaceDelimitedLang()) space_delimited_ = false;
|
||||
}
|
||||
|
||||
// Decodes the set of network outputs, storing the lattice internally.
|
||||
void RecodeBeamSearch::Decode(const NetworkIO& output, double dict_ratio,
|
||||
double cert_offset, double worst_dict_cert,
|
||||
const UNICHARSET* charset) {
|
||||
beam_size_ = 0;
|
||||
int width = output.Width();
|
||||
for (int t = 0; t < width; ++t) {
|
||||
ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]);
|
||||
DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert,
|
||||
charset);
|
||||
}
|
||||
}
|
||||
void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY<float>& output,
|
||||
double dict_ratio, double cert_offset,
|
||||
double worst_dict_cert,
|
||||
const UNICHARSET* charset) {
|
||||
beam_size_ = 0;
|
||||
int width = output.dim1();
|
||||
for (int t = 0; t < width; ++t) {
|
||||
ComputeTopN(output[t], output.dim2(), kBeamWidths[0]);
|
||||
DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the best path as labels/scores/xcoords similar to simple CTC.
|
||||
void RecodeBeamSearch::ExtractBestPathAsLabels(
|
||||
GenericVector<int>* labels, GenericVector<int>* xcoords) const {
|
||||
labels->truncate(0);
|
||||
xcoords->truncate(0);
|
||||
GenericVector<const RecodeNode*> best_nodes;
|
||||
ExtractBestPaths(&best_nodes, NULL);
|
||||
// Now just run CTC on the best nodes.
|
||||
int t = 0;
|
||||
int width = best_nodes.size();
|
||||
while (t < width) {
|
||||
int label = best_nodes[t]->code;
|
||||
if (label != null_char_) {
|
||||
labels->push_back(label);
|
||||
xcoords->push_back(t);
|
||||
}
|
||||
while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) {
|
||||
}
|
||||
}
|
||||
xcoords->push_back(width);
|
||||
}
|
||||
|
||||
// Returns the best path as unichar-ids/certs/ratings/xcoords skipping
|
||||
// duplicates, nulls and intermediate parts.
|
||||
void RecodeBeamSearch::ExtractBestPathAsUnicharIds(
|
||||
bool debug, const UNICHARSET* unicharset, GenericVector<int>* unichar_ids,
|
||||
GenericVector<float>* certs, GenericVector<float>* ratings,
|
||||
GenericVector<int>* xcoords) const {
|
||||
GenericVector<const RecodeNode*> best_nodes;
|
||||
ExtractBestPaths(&best_nodes, NULL);
|
||||
ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords);
|
||||
if (debug) {
|
||||
DebugPath(unicharset, best_nodes);
|
||||
DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings,
|
||||
*xcoords);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the best path as a set of WERD_RES.
|
||||
void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box,
|
||||
float scale_factor, bool debug,
|
||||
const UNICHARSET* unicharset,
|
||||
PointerVector<WERD_RES>* words) {
|
||||
words->truncate(0);
|
||||
GenericVector<int> unichar_ids;
|
||||
GenericVector<float> certs;
|
||||
GenericVector<float> ratings;
|
||||
GenericVector<int> xcoords;
|
||||
GenericVector<const RecodeNode*> best_nodes;
|
||||
GenericVector<const RecodeNode*> second_nodes;
|
||||
ExtractBestPaths(&best_nodes, &second_nodes);
|
||||
if (debug) {
|
||||
DebugPath(unicharset, best_nodes);
|
||||
ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings,
|
||||
&xcoords);
|
||||
tprintf("\nSecond choice path:\n");
|
||||
DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings,
|
||||
xcoords);
|
||||
}
|
||||
ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords);
|
||||
int num_ids = unichar_ids.size();
|
||||
if (debug) {
|
||||
DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings,
|
||||
xcoords);
|
||||
}
|
||||
// Convert labels to unichar-ids.
|
||||
int word_end = 0;
|
||||
float prev_space_cert = 0.0f;
|
||||
for (int word_start = 0; word_start < num_ids; word_start = word_end) {
|
||||
for (word_end = word_start + 1; word_end < num_ids; ++word_end) {
|
||||
// A word is terminated when a space character or start_of_word flag is
|
||||
// hit. We also want to force a separate word for every non
|
||||
// space-delimited character when not in a dictionary context.
|
||||
if (unichar_ids[word_end] == UNICHAR_SPACE) break;
|
||||
int index = xcoords[word_end];
|
||||
if (best_nodes[index]->start_of_word) break;
|
||||
if (best_nodes[index]->permuter == TOP_CHOICE_PERM &&
|
||||
(!unicharset->IsSpaceDelimited(unichar_ids[word_end]) ||
|
||||
!unicharset->IsSpaceDelimited(unichar_ids[word_end - 1])))
|
||||
break;
|
||||
}
|
||||
float space_cert = 0.0f;
|
||||
if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE)
|
||||
space_cert = certs[word_end];
|
||||
bool leading_space =
|
||||
word_start > 0 && unichar_ids[word_start - 1] == UNICHAR_SPACE;
|
||||
// Create a WERD_RES for the output word.
|
||||
WERD_RES* word_res = InitializeWord(
|
||||
leading_space, line_box, word_start, word_end,
|
||||
MIN(space_cert, prev_space_cert), unicharset, xcoords, scale_factor);
|
||||
for (int i = word_start; i < word_end; ++i) {
|
||||
BLOB_CHOICE_LIST* choices = new BLOB_CHOICE_LIST;
|
||||
BLOB_CHOICE_IT bc_it(choices);
|
||||
BLOB_CHOICE* choice = new BLOB_CHOICE(
|
||||
unichar_ids[i], ratings[i], certs[i], -1, 1.0f,
|
||||
static_cast<float>(MAX_INT16), 0.0f, BCC_STATIC_CLASSIFIER);
|
||||
int col = i - word_start;
|
||||
choice->set_matrix_cell(col, col);
|
||||
bc_it.add_after_then_move(choice);
|
||||
word_res->ratings->put(col, col, choices);
|
||||
}
|
||||
int index = xcoords[word_end - 1];
|
||||
word_res->FakeWordFromRatings(best_nodes[index]->permuter);
|
||||
words->push_back(word_res);
|
||||
prev_space_cert = space_cert;
|
||||
if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE)
|
||||
++word_end;
|
||||
}
|
||||
}
|
||||
|
||||
// Generates debug output of the content of the beams after a Decode.
|
||||
void RecodeBeamSearch::DebugBeams(const UNICHARSET& unicharset) const {
|
||||
for (int p = 0; p < beam_size_; ++p) {
|
||||
// Print all the best scoring nodes for each unichar found.
|
||||
tprintf("Position %d: Nondict beam\n", p);
|
||||
DebugBeamPos(unicharset, beam_[p]->beams_[0]);
|
||||
tprintf("Position %d: Dict beam\n", p);
|
||||
DebugBeamPos(unicharset, beam_[p]->dawg_beams_[0]);
|
||||
}
|
||||
}
|
||||
|
||||
// Generates debug output of the content of a single beam position.
|
||||
void RecodeBeamSearch::DebugBeamPos(const UNICHARSET& unicharset,
|
||||
const RecodeHeap& heap) const {
|
||||
GenericVector<const RecodeNode*> unichar_bests;
|
||||
unichar_bests.init_to_size(unicharset.size(), NULL);
|
||||
const RecodeNode* null_best = NULL;
|
||||
int heap_size = heap.size();
|
||||
for (int i = 0; i < heap_size; ++i) {
|
||||
const RecodeNode* node = &heap.get(i).data;
|
||||
if (node->unichar_id == INVALID_UNICHAR_ID) {
|
||||
if (null_best == NULL || null_best->score < node->score) null_best = node;
|
||||
} else {
|
||||
if (unichar_bests[node->unichar_id] == NULL ||
|
||||
unichar_bests[node->unichar_id]->score < node->score) {
|
||||
unichar_bests[node->unichar_id] = node;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int u = 0; u < unichar_bests.size(); ++u) {
|
||||
if (unichar_bests[u] != NULL) {
|
||||
const RecodeNode& node = *unichar_bests[u];
|
||||
tprintf("label=%d, uid=%d=%s score=%g, c=%g, s=%d, e=%d, perm=%d\n",
|
||||
node.code, node.unichar_id,
|
||||
unicharset.debug_str(node.unichar_id).string(), node.score,
|
||||
node.certainty, node.start_of_word, node.end_of_word,
|
||||
node.permuter);
|
||||
}
|
||||
}
|
||||
if (null_best != NULL) {
|
||||
tprintf("null_char score=%g, c=%g, s=%d, e=%d, perm=%d\n", null_best->score,
|
||||
null_best->certainty, null_best->start_of_word,
|
||||
null_best->end_of_word, null_best->permuter);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping
|
||||
// duplicates, nulls and intermediate parts.
|
||||
/* static */
|
||||
void RecodeBeamSearch::ExtractPathAsUnicharIds(
|
||||
const GenericVector<const RecodeNode*>& best_nodes,
|
||||
GenericVector<int>* unichar_ids, GenericVector<float>* certs,
|
||||
GenericVector<float>* ratings, GenericVector<int>* xcoords) {
|
||||
unichar_ids->truncate(0);
|
||||
certs->truncate(0);
|
||||
ratings->truncate(0);
|
||||
xcoords->truncate(0);
|
||||
// Backtrack extracting only valid, non-duplicate unichar-ids.
|
||||
int t = 0;
|
||||
int width = best_nodes.size();
|
||||
while (t < width) {
|
||||
double certainty = 0.0;
|
||||
double rating = 0.0;
|
||||
while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) {
|
||||
double cert = best_nodes[t++]->certainty;
|
||||
if (cert < certainty) certainty = cert;
|
||||
rating -= cert;
|
||||
}
|
||||
if (t < width) {
|
||||
int unichar_id = best_nodes[t]->unichar_id;
|
||||
unichar_ids->push_back(unichar_id);
|
||||
xcoords->push_back(t);
|
||||
do {
|
||||
double cert = best_nodes[t++]->certainty;
|
||||
// Special-case NO-PERM space to forget the certainty of the previous
|
||||
// nulls. See long comment in ContinueContext.
|
||||
if (cert < certainty || (unichar_id == UNICHAR_SPACE &&
|
||||
best_nodes[t - 1]->permuter == NO_PERM)) {
|
||||
certainty = cert;
|
||||
}
|
||||
rating -= cert;
|
||||
} while (t < width && best_nodes[t]->duplicate);
|
||||
certs->push_back(certainty);
|
||||
ratings->push_back(rating);
|
||||
} else if (!certs->empty()) {
|
||||
if (certainty < certs->back()) certs->back() = certainty;
|
||||
ratings->back() += rating;
|
||||
}
|
||||
}
|
||||
xcoords->push_back(width);
|
||||
}
|
||||
|
||||
// Sets up a word with the ratings matrix and fake blobs with boxes in the
|
||||
// right places.
|
||||
WERD_RES* RecodeBeamSearch::InitializeWord(bool leading_space,
|
||||
const TBOX& line_box, int word_start,
|
||||
int word_end, float space_certainty,
|
||||
const UNICHARSET* unicharset,
|
||||
const GenericVector<int>& xcoords,
|
||||
float scale_factor) {
|
||||
// Make a fake blob for each non-zero label.
|
||||
C_BLOB_LIST blobs;
|
||||
C_BLOB_IT b_it(&blobs);
|
||||
for (int i = word_start; i < word_end; ++i) {
|
||||
int min_half_width = xcoords[i + 1] - xcoords[i];
|
||||
if (i > 0 && xcoords[i] - xcoords[i - 1] < min_half_width)
|
||||
min_half_width = xcoords[i] - xcoords[i - 1];
|
||||
if (min_half_width < 1) min_half_width = 1;
|
||||
// Make a fake blob.
|
||||
TBOX box(xcoords[i] - min_half_width, 0, xcoords[i] + min_half_width,
|
||||
line_box.height());
|
||||
box.scale(scale_factor);
|
||||
box.move(ICOORD(line_box.left(), line_box.bottom()));
|
||||
box.set_top(line_box.top());
|
||||
b_it.add_after_then_move(C_BLOB::FakeBlob(box));
|
||||
}
|
||||
// Make a fake word from the blobs.
|
||||
WERD* word = new WERD(&blobs, leading_space, NULL);
|
||||
// Make a WERD_RES from the word.
|
||||
WERD_RES* word_res = new WERD_RES(word);
|
||||
word_res->uch_set = unicharset;
|
||||
word_res->combination = true; // Give it ownership of the word.
|
||||
word_res->space_certainty = space_certainty;
|
||||
word_res->ratings = new MATRIX(word_end - word_start, 1);
|
||||
return word_res;
|
||||
}
|
||||
|
||||
// Fills top_n_flags_ with bools that are true iff the corresponding output
|
||||
// is one of the top_n.
|
||||
void RecodeBeamSearch::ComputeTopN(const float* outputs, int num_outputs,
|
||||
int top_n) {
|
||||
top_n_flags_.init_to_size(num_outputs, false);
|
||||
top_heap_.clear();
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key) {
|
||||
TopPair entry(outputs[i], i);
|
||||
top_heap_.Push(&entry);
|
||||
if (top_heap_.size() > top_n) top_heap_.Pop(&entry);
|
||||
}
|
||||
}
|
||||
while (!top_heap_.empty()) {
|
||||
TopPair entry;
|
||||
top_heap_.Pop(&entry);
|
||||
top_n_flags_[entry.data] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Adds the computation for the current time-step to the beam. Call at each
|
||||
// time-step in sequence from left to right. outputs is the activation vector
|
||||
// for the current timestep.
|
||||
void RecodeBeamSearch::DecodeStep(const float* outputs, int t,
|
||||
double dict_ratio, double cert_offset,
|
||||
double worst_dict_cert,
|
||||
const UNICHARSET* charset) {
|
||||
if (t == beam_.size()) beam_.push_back(new RecodeBeam);
|
||||
RecodeBeam* step = beam_[t];
|
||||
beam_size_ = t + 1;
|
||||
step->Clear();
|
||||
if (t == 0) {
|
||||
// The first step can only use singles and initials.
|
||||
ContinueContext(NULL, 0, outputs, false, true, dict_ratio, cert_offset,
|
||||
worst_dict_cert, step);
|
||||
if (dict_ != NULL)
|
||||
ContinueContext(NULL, 0, outputs, true, true, dict_ratio, cert_offset,
|
||||
worst_dict_cert, step);
|
||||
} else {
|
||||
RecodeBeam* prev = beam_[t - 1];
|
||||
if (charset != NULL) {
|
||||
for (int i = prev->dawg_beams_[0].size() - 1; i >= 0; --i) {
|
||||
GenericVector<const RecodeNode*> path;
|
||||
ExtractPath(&prev->dawg_beams_[0].get(i).data, &path);
|
||||
tprintf("Step %d: Dawg beam %d:\n", t, i);
|
||||
DebugPath(charset, path);
|
||||
}
|
||||
}
|
||||
int total_beam = 0;
|
||||
// Try true and then false only if the beam is empty. This enables extending
|
||||
// the context using only the top-n results first, which may have an empty
|
||||
// intersection with the valid codes, so we fall back to the rest if the
|
||||
// beam is empty.
|
||||
for (int flag = 1; flag >= 0 && total_beam == 0; --flag) {
|
||||
for (int length = 0; length <= RecodedCharID::kMaxCodeLen; ++length) {
|
||||
// Working backwards through the heaps doesn't guarantee that we see the
|
||||
// best first, but it comes before a lot of the worst, so it is slightly
|
||||
// more efficient than going forwards.
|
||||
for (int i = prev->dawg_beams_[length].size() - 1; i >= 0; --i) {
|
||||
ContinueContext(&prev->dawg_beams_[length].get(i).data, length,
|
||||
outputs, true, flag, dict_ratio, cert_offset,
|
||||
worst_dict_cert, step);
|
||||
}
|
||||
for (int i = prev->beams_[length].size() - 1; i >= 0; --i) {
|
||||
ContinueContext(&prev->beams_[length].get(i).data, length, outputs,
|
||||
false, flag, dict_ratio, cert_offset, worst_dict_cert,
|
||||
step);
|
||||
}
|
||||
}
|
||||
for (int length = 0; length <= RecodedCharID::kMaxCodeLen; ++length) {
|
||||
total_beam += step->beams_[length].size();
|
||||
total_beam += step->dawg_beams_[length].size();
|
||||
}
|
||||
}
|
||||
// Special case for the best initial dawg. Push it on the heap if good
|
||||
// enough, but there is only one, so it doesn't blow up the beam.
|
||||
RecodeHeap* dawg_heap = &step->dawg_beams_[0];
|
||||
if (step->best_initial_dawg_.code >= 0 &&
|
||||
(dawg_heap->size() < kBeamWidths[0] ||
|
||||
step->best_initial_dawg_.score > dawg_heap->PeekTop().data.score)) {
|
||||
RecodePair entry(step->best_initial_dawg_.score,
|
||||
step->best_initial_dawg_);
|
||||
dawg_heap->Push(&entry);
|
||||
if (dawg_heap->size() > kBeamWidths[0]) dawg_heap->Pop(&entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Adds to the appropriate beams the legal (according to recoder)
|
||||
// continuations of context prev, which is of the given length, using the
|
||||
// given network outputs to provide scores to the choices. Uses only those
|
||||
// choices for which top_n_flags[index] == top_n_flag.
|
||||
void RecodeBeamSearch::ContinueContext(const RecodeNode* prev, int length,
|
||||
const float* outputs, bool use_dawgs,
|
||||
bool top_n_flag, double dict_ratio,
|
||||
double cert_offset,
|
||||
double worst_dict_cert,
|
||||
RecodeBeam* step) {
|
||||
RecodedCharID prefix;
|
||||
RecodedCharID full_code;
|
||||
const RecodeNode* previous = prev;
|
||||
for (int p = length - 1; p >= 0; --p, previous = previous->prev) {
|
||||
while (previous != NULL &&
|
||||
(previous->duplicate || previous->code == null_char_)) {
|
||||
previous = previous->prev;
|
||||
}
|
||||
prefix.Set(p, previous->code);
|
||||
full_code.Set(p, previous->code);
|
||||
}
|
||||
if (prev != NULL && !is_simple_text_) {
|
||||
float cert = NetworkIO::ProbToCertainty(outputs[prev->code]) + cert_offset;
|
||||
if ((cert >= kMinCertainty || prev->code == null_char_) &&
|
||||
top_n_flags_[prev->code] == top_n_flag) {
|
||||
if (use_dawgs) {
|
||||
if (cert > worst_dict_cert) {
|
||||
PushDupIfBetter(kBeamWidths[length], cert, prev,
|
||||
&step->dawg_beams_[length]);
|
||||
}
|
||||
} else {
|
||||
PushDupIfBetter(kBeamWidths[length], cert * dict_ratio, prev,
|
||||
&step->beams_[length]);
|
||||
}
|
||||
}
|
||||
if (prev->code != null_char_ && length > 0 &&
|
||||
top_n_flags_[null_char_] == top_n_flag) {
|
||||
// Allow nulls within multi code sequences, as the nulls within are not
|
||||
// explicitly included in the code sequence.
|
||||
cert = NetworkIO::ProbToCertainty(outputs[null_char_]) + cert_offset;
|
||||
if (cert >= kMinCertainty && (!use_dawgs || cert > worst_dict_cert)) {
|
||||
if (use_dawgs) {
|
||||
PushNoDawgIfBetter(kBeamWidths[length], null_char_,
|
||||
INVALID_UNICHAR_ID, NO_PERM, cert, prev,
|
||||
&step->dawg_beams_[length]);
|
||||
} else {
|
||||
PushNoDawgIfBetter(kBeamWidths[length], null_char_,
|
||||
INVALID_UNICHAR_ID, TOP_CHOICE_PERM,
|
||||
cert * dict_ratio, prev, &step->beams_[length]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const GenericVector<int>* final_codes = recoder_.GetFinalCodes(prefix);
|
||||
if (final_codes != NULL) {
|
||||
for (int i = 0; i < final_codes->size(); ++i) {
|
||||
int code = (*final_codes)[i];
|
||||
if (top_n_flags_[code] != top_n_flag) continue;
|
||||
if (prev != NULL && prev->code == code && !is_simple_text_) continue;
|
||||
float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
|
||||
if (cert < kMinCertainty && code != null_char_) continue;
|
||||
full_code.Set(length, code);
|
||||
int unichar_id = recoder_.DecodeUnichar(full_code);
|
||||
// Map the null char to INVALID.
|
||||
if (length == 0 && code == null_char_) unichar_id = INVALID_UNICHAR_ID;
|
||||
if (use_dawgs) {
|
||||
if (cert > worst_dict_cert) {
|
||||
ContinueDawg(kBeamWidths[0], code, unichar_id, cert, prev,
|
||||
&step->dawg_beams_[0], step);
|
||||
}
|
||||
} else {
|
||||
PushNoDawgIfBetter(kBeamWidths[0], code, unichar_id, TOP_CHOICE_PERM,
|
||||
cert * dict_ratio, prev, &step->beams_[0]);
|
||||
if (dict_ != NULL &&
|
||||
((unichar_id == UNICHAR_SPACE && cert > worst_dict_cert) ||
|
||||
!dict_->getUnicharset().IsSpaceDelimited(unichar_id))) {
|
||||
// Any top choice position that can start a new word, ie a space or
|
||||
// any non-space-delimited character, should also be considered
|
||||
// by the dawg search, so push initial dawg to the dawg heap.
|
||||
float dawg_cert = cert;
|
||||
PermuterType permuter = TOP_CHOICE_PERM;
|
||||
// Since we use the space either side of a dictionary word in the
|
||||
// certainty of the word, (to properly handle weak spaces) and the
|
||||
// space is coming from a non-dict word, we need special conditions
|
||||
// to avoid degrading the certainty of the dict word that follows.
|
||||
// With a space we don't multiply the certainty by dict_ratio, and we
|
||||
// flag the space with NO_PERM to indicate that we should not use the
|
||||
// predecessor nulls to generate the confidence for the space, as they
|
||||
// have already been multiplied by dict_ratio, and we can't go back to
|
||||
// insert more entries in any previous heaps.
|
||||
if (unichar_id == UNICHAR_SPACE)
|
||||
permuter = NO_PERM;
|
||||
else
|
||||
dawg_cert *= dict_ratio;
|
||||
PushInitialDawgIfBetter(code, unichar_id, permuter, false, false,
|
||||
dawg_cert, prev, &step->best_initial_dawg_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const GenericVector<int>* next_codes = recoder_.GetNextCodes(prefix);
|
||||
if (next_codes != NULL) {
|
||||
for (int i = 0; i < next_codes->size(); ++i) {
|
||||
int code = (*next_codes)[i];
|
||||
if (top_n_flags_[code] != top_n_flag) continue;
|
||||
if (prev != NULL && prev->code == code && !is_simple_text_) continue;
|
||||
float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
|
||||
if (cert < kMinCertainty && code != null_char_) continue;
|
||||
if (use_dawgs) {
|
||||
if (cert > worst_dict_cert) {
|
||||
ContinueDawg(kBeamWidths[length + 1], code, INVALID_UNICHAR_ID, cert,
|
||||
prev, &step->dawg_beams_[length + 1], step);
|
||||
}
|
||||
} else {
|
||||
PushNoDawgIfBetter(kBeamWidths[length + 1], code, INVALID_UNICHAR_ID,
|
||||
TOP_CHOICE_PERM, cert * dict_ratio, prev,
|
||||
&step->beams_[length + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Adds a RecodeNode composed of the tuple (code, unichar_id, cert, prev,
|
||||
// appropriate-dawg-args, cert) to the given heap (dawg_beam_) if unichar_id
|
||||
// is a valid continuation of whatever is in prev.
|
||||
void RecodeBeamSearch::ContinueDawg(int max_size, int code, int unichar_id,
|
||||
float cert, const RecodeNode* prev,
|
||||
RecodeHeap* heap, RecodeBeam* step) {
|
||||
if (unichar_id == INVALID_UNICHAR_ID) {
|
||||
PushNoDawgIfBetter(max_size, code, unichar_id, NO_PERM, cert, prev, heap);
|
||||
return;
|
||||
}
|
||||
// Avoid dictionary probe if score a total loss.
|
||||
float score = cert;
|
||||
if (prev != NULL) score += prev->score;
|
||||
if (heap->size() >= max_size && score <= heap->PeekTop().data.score) return;
|
||||
const RecodeNode* uni_prev = prev;
|
||||
// Prev may be a partial code, null_char, or duplicate, so scan back to the
|
||||
// last valid unichar_id.
|
||||
while (uni_prev != NULL &&
|
||||
(uni_prev->unichar_id == INVALID_UNICHAR_ID || uni_prev->duplicate))
|
||||
uni_prev = uni_prev->prev;
|
||||
if (unichar_id == UNICHAR_SPACE) {
|
||||
if (uni_prev != NULL && uni_prev->end_of_word) {
|
||||
// Space is good. Push initial state, to the dawg beam and a regular
|
||||
// space to the top choice beam.
|
||||
PushInitialDawgIfBetter(code, unichar_id, uni_prev->permuter, false,
|
||||
false, cert, prev, &step->best_initial_dawg_);
|
||||
PushNoDawgIfBetter(max_size, code, unichar_id, uni_prev->permuter, cert,
|
||||
prev, &step->beams_[0]);
|
||||
}
|
||||
return;
|
||||
} else if (uni_prev != NULL && uni_prev->start_of_dawg &&
|
||||
uni_prev->unichar_id != UNICHAR_SPACE &&
|
||||
dict_->getUnicharset().IsSpaceDelimited(uni_prev->unichar_id) &&
|
||||
dict_->getUnicharset().IsSpaceDelimited(unichar_id)) {
|
||||
return; // Can't break words between space delimited chars.
|
||||
}
|
||||
DawgPositionVector initial_dawgs;
|
||||
DawgPositionVector* updated_dawgs = new DawgPositionVector;
|
||||
DawgArgs dawg_args(&initial_dawgs, updated_dawgs, NO_PERM);
|
||||
bool word_start = false;
|
||||
if (uni_prev == NULL) {
|
||||
// Starting from beginning of line.
|
||||
dict_->default_dawgs(&initial_dawgs, false);
|
||||
word_start = true;
|
||||
} else if (uni_prev->dawgs != NULL) {
|
||||
// Continuing a previous dict word.
|
||||
dawg_args.active_dawgs = uni_prev->dawgs;
|
||||
word_start = uni_prev->start_of_dawg;
|
||||
} else {
|
||||
return; // Can't continue if not a dict word.
|
||||
}
|
||||
PermuterType permuter = static_cast<PermuterType>(
|
||||
dict_->def_letter_is_okay(&dawg_args, unichar_id, false));
|
||||
if (permuter != NO_PERM) {
|
||||
PushHeapIfBetter(max_size, code, unichar_id, permuter, false, word_start,
|
||||
dawg_args.valid_end, false, cert, prev,
|
||||
dawg_args.updated_dawgs, heap);
|
||||
if (dawg_args.valid_end && !space_delimited_) {
|
||||
// We can start another word right away, so push initial state as well,
|
||||
// to the dawg beam, and the regular character to the top choice beam,
|
||||
// since non-dict words can start here too.
|
||||
PushInitialDawgIfBetter(code, unichar_id, permuter, word_start, true,
|
||||
cert, prev, &step->best_initial_dawg_);
|
||||
PushHeapIfBetter(max_size, code, unichar_id, permuter, false, word_start,
|
||||
true, false, cert, prev, NULL, &step->beams_[0]);
|
||||
}
|
||||
} else {
|
||||
delete updated_dawgs;
|
||||
}
|
||||
}
|
||||
|
||||
// Adds a RecodeNode composed of the tuple (code, unichar_id,
|
||||
// initial-dawg-state, prev, cert) to the given heap if/ there is room or if
|
||||
// better than the current worst element if already full.
|
||||
void RecodeBeamSearch::PushInitialDawgIfBetter(int code, int unichar_id,
|
||||
PermuterType permuter,
|
||||
bool start, bool end, float cert,
|
||||
const RecodeNode* prev,
|
||||
RecodeNode* best_initial_dawg) {
|
||||
float score = cert;
|
||||
if (prev != NULL) score += prev->score;
|
||||
if (best_initial_dawg->code < 0 || score > best_initial_dawg->score) {
|
||||
DawgPositionVector* initial_dawgs = new DawgPositionVector;
|
||||
dict_->default_dawgs(initial_dawgs, false);
|
||||
RecodeNode node(code, unichar_id, permuter, true, start, end, false, cert,
|
||||
score, prev, initial_dawgs);
|
||||
*best_initial_dawg = node;
|
||||
}
|
||||
}
|
||||
|
||||
// Adds a copy of the given prev as a duplicate of and successor to prev, if
|
||||
// there is room or if better than the current worst element if already full.
|
||||
/* static */
|
||||
void RecodeBeamSearch::PushDupIfBetter(int max_size, float cert,
|
||||
const RecodeNode* prev,
|
||||
RecodeHeap* heap) {
|
||||
PushHeapIfBetter(max_size, prev->code, prev->unichar_id, prev->permuter,
|
||||
false, false, false, true, cert, prev, NULL, heap);
|
||||
}
|
||||
|
||||
// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
|
||||
// false, false, false, false, cert, prev, NULL) to heap if there is room
|
||||
// or if better than the current worst element if already full.
|
||||
/* static */
|
||||
void RecodeBeamSearch::PushNoDawgIfBetter(int max_size, int code,
|
||||
int unichar_id, PermuterType permuter,
|
||||
float cert, const RecodeNode* prev,
|
||||
RecodeHeap* heap) {
|
||||
float score = cert;
|
||||
if (prev != NULL) score += prev->score;
|
||||
if (heap->size() < max_size || score > heap->PeekTop().data.score) {
|
||||
RecodeNode node(code, unichar_id, permuter, false, false, false, false,
|
||||
cert, score, prev, NULL);
|
||||
RecodePair entry(score, node);
|
||||
heap->Push(&entry);
|
||||
if (heap->size() > max_size) heap->Pop(&entry);
|
||||
}
|
||||
}
|
||||
|
||||
// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
|
||||
// dawg_start, word_start, end, dup, cert, prev, d) to heap if there is room
|
||||
// or if better than the current worst element if already full.
|
||||
/* static */
|
||||
void RecodeBeamSearch::PushHeapIfBetter(int max_size, int code, int unichar_id,
|
||||
PermuterType permuter, bool dawg_start,
|
||||
bool word_start, bool end, bool dup,
|
||||
float cert, const RecodeNode* prev,
|
||||
DawgPositionVector* d,
|
||||
RecodeHeap* heap) {
|
||||
float score = cert;
|
||||
if (prev != NULL) score += prev->score;
|
||||
if (heap->size() < max_size || score > heap->PeekTop().data.score) {
|
||||
RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end,
|
||||
dup, cert, score, prev, d);
|
||||
RecodePair entry(score, node);
|
||||
heap->Push(&entry);
|
||||
ASSERT_HOST(entry.data.dawgs == NULL);
|
||||
if (heap->size() > max_size) heap->Pop(&entry);
|
||||
} else {
|
||||
delete d;
|
||||
}
|
||||
}
|
||||
|
||||
// Backtracks to extract the best path through the lattice that was built
|
||||
// during Decode. On return the best_nodes vector essentially contains the set
|
||||
// of code, score pairs that make the optimal path with the constraint that
|
||||
// the recoder can decode the code sequence back to a sequence of unichar-ids.
|
||||
void RecodeBeamSearch::ExtractBestPaths(
|
||||
GenericVector<const RecodeNode*>* best_nodes,
|
||||
GenericVector<const RecodeNode*>* second_nodes) const {
|
||||
// Scan both beams to extract the best and second best paths.
|
||||
const RecodeNode* best_node = NULL;
|
||||
const RecodeNode* second_best_node = NULL;
|
||||
const RecodeBeam* last_beam = beam_[beam_size_ - 1];
|
||||
int heap_size = last_beam->beams_[0].size();
|
||||
for (int i = 0; i < heap_size; ++i) {
|
||||
const RecodeNode* node = &last_beam->beams_[0].get(i).data;
|
||||
if (best_node == NULL || node->score > best_node->score) {
|
||||
second_best_node = best_node;
|
||||
best_node = node;
|
||||
} else if (second_best_node == NULL ||
|
||||
node->score > second_best_node->score) {
|
||||
second_best_node = node;
|
||||
}
|
||||
}
|
||||
// Scan the entire dawg heap for the best *valid* nodes, if any.
|
||||
int dawg_size = last_beam->dawg_beams_[0].size();
|
||||
for (int i = 0; i < dawg_size; ++i) {
|
||||
const RecodeNode* dawg_node = &last_beam->dawg_beams_[0].get(i).data;
|
||||
// dawg_node may be a null_char, or duplicate, so scan back to the last
|
||||
// valid unichar_id.
|
||||
const RecodeNode* back_dawg_node = dawg_node;
|
||||
while (back_dawg_node != NULL &&
|
||||
(back_dawg_node->unichar_id == INVALID_UNICHAR_ID ||
|
||||
back_dawg_node->duplicate))
|
||||
back_dawg_node = back_dawg_node->prev;
|
||||
if (back_dawg_node != NULL &&
|
||||
(back_dawg_node->end_of_word ||
|
||||
back_dawg_node->unichar_id == UNICHAR_SPACE)) {
|
||||
// Dawg node is valid. Use it in preference to back_dawg_node, as the
|
||||
// score comparison is fair that way.
|
||||
if (best_node == NULL || dawg_node->score > best_node->score) {
|
||||
second_best_node = best_node;
|
||||
best_node = dawg_node;
|
||||
} else if (second_best_node == NULL ||
|
||||
dawg_node->score > second_best_node->score) {
|
||||
second_best_node = dawg_node;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (second_nodes != NULL) ExtractPath(second_best_node, second_nodes);
|
||||
ExtractPath(best_node, best_nodes);
|
||||
}
|
||||
|
||||
// Helper backtracks through the lattice from the given node, storing the
|
||||
// path and reversing it.
|
||||
void RecodeBeamSearch::ExtractPath(
|
||||
const RecodeNode* node, GenericVector<const RecodeNode*>* path) const {
|
||||
path->truncate(0);
|
||||
while (node != NULL) {
|
||||
path->push_back(node);
|
||||
node = node->prev;
|
||||
}
|
||||
path->reverse();
|
||||
}
|
||||
|
||||
// Helper prints debug information on the given lattice path.
|
||||
void RecodeBeamSearch::DebugPath(
|
||||
const UNICHARSET* unicharset,
|
||||
const GenericVector<const RecodeNode*>& path) const {
|
||||
for (int c = 0; c < path.size(); ++c) {
|
||||
const RecodeNode& node = *path[c];
|
||||
tprintf("%d %d=%s score=%g, c=%g, s=%d, e=%d, perm=%d\n", c,
|
||||
node.unichar_id, unicharset->debug_str(node.unichar_id).string(),
|
||||
node.score, node.certainty, node.start_of_word, node.end_of_word,
|
||||
node.permuter);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper prints debug information on the given unichar path.
|
||||
void RecodeBeamSearch::DebugUnicharPath(
|
||||
const UNICHARSET* unicharset, const GenericVector<const RecodeNode*>& path,
|
||||
const GenericVector<int>& unichar_ids, const GenericVector<float>& certs,
|
||||
const GenericVector<float>& ratings,
|
||||
const GenericVector<int>& xcoords) const {
|
||||
int num_ids = unichar_ids.size();
|
||||
double total_rating = 0.0;
|
||||
for (int c = 0; c < num_ids; ++c) {
|
||||
int coord = xcoords[c];
|
||||
tprintf("%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c],
|
||||
unicharset->debug_str(unichar_ids[c]).string(), ratings[c],
|
||||
certs[c], path[coord]->start_of_word, path[coord]->end_of_word,
|
||||
path[coord]->permuter);
|
||||
total_rating += ratings[c];
|
||||
}
|
||||
tprintf("Path total rating = %g\n", total_rating);
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
304
lstm/recodebeam.h
Normal file
304
lstm/recodebeam.h
Normal file
@ -0,0 +1,304 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: recodebeam.h
|
||||
// Description: Beam search to decode from the re-encoded CJK as a sequence of
|
||||
// smaller numbers in place of a single large code.
|
||||
// Author: Ray Smith
|
||||
// Created: Fri Mar 13 09:12:01 PDT 2015
|
||||
//
|
||||
// (C) Copyright 2015, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
|
||||
#define THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
|
||||
|
||||
#include "dawg.h"
|
||||
#include "dict.h"
|
||||
#include "genericheap.h"
|
||||
#include "kdpair.h"
|
||||
#include "networkio.h"
|
||||
#include "ratngs.h"
|
||||
#include "unicharcompress.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Lattice element for Re-encode beam search.
|
||||
struct RecodeNode {
|
||||
RecodeNode()
|
||||
: code(-1),
|
||||
unichar_id(INVALID_UNICHAR_ID),
|
||||
permuter(TOP_CHOICE_PERM),
|
||||
start_of_dawg(false),
|
||||
start_of_word(false),
|
||||
end_of_word(false),
|
||||
duplicate(false),
|
||||
certainty(0.0f),
|
||||
score(0.0f),
|
||||
prev(NULL),
|
||||
dawgs(NULL) {}
|
||||
RecodeNode(int c, int uni_id, PermuterType perm, bool dawg_start,
|
||||
bool word_start, bool end, bool dup, float cert, float s,
|
||||
const RecodeNode* p, DawgPositionVector* d)
|
||||
: code(c),
|
||||
unichar_id(uni_id),
|
||||
permuter(perm),
|
||||
start_of_dawg(dawg_start),
|
||||
start_of_word(word_start),
|
||||
end_of_word(end),
|
||||
duplicate(dup),
|
||||
certainty(cert),
|
||||
score(s),
|
||||
prev(p),
|
||||
dawgs(d) {}
|
||||
// NOTE: If we could use C++11, then this would be a move constructor.
|
||||
// Instead we have copy constructor that does a move!! This is because we
|
||||
// don't want to copy the whole DawgPositionVector each time, and true
|
||||
// copying isn't necessary for this struct. It does get moved around a lot
|
||||
// though inside the heap and during heap push, hence the move semantics.
|
||||
RecodeNode(RecodeNode& src) : dawgs(NULL) {
|
||||
*this = src;
|
||||
ASSERT_HOST(src.dawgs == NULL);
|
||||
}
|
||||
RecodeNode& operator=(RecodeNode& src) {
|
||||
delete dawgs;
|
||||
memcpy(this, &src, sizeof(src));
|
||||
src.dawgs = NULL;
|
||||
return *this;
|
||||
}
|
||||
~RecodeNode() { delete dawgs; }
|
||||
|
||||
// The re-encoded code here = index to network output.
|
||||
int code;
|
||||
// The decoded unichar_id is only valid for the final code of a sequence.
|
||||
int unichar_id;
|
||||
// The type of permuter active at this point. Intervals between start_of_word
|
||||
// and end_of_word make valid words of type given by permuter where
|
||||
// end_of_word is true. These aren't necessarily delimited by spaces.
|
||||
PermuterType permuter;
|
||||
// True if this is the initial dawg state. May be attached to a space or,
|
||||
// in a non-space-delimited lang, the end of the previous word.
|
||||
bool start_of_dawg;
|
||||
// True if this is the first node in a dictionary word.
|
||||
bool start_of_word;
|
||||
// True if this represents a valid candidate end of word position. Does not
|
||||
// necessarily mark the end of a word, since a word can be extended beyond a
|
||||
// candidiate end by a continuation, eg 'the' continues to 'these'.
|
||||
bool end_of_word;
|
||||
// True if this is a duplicate of prev in all respects. Some training modes
|
||||
// allow the network to output duplicate characters and crush them with CTC,
|
||||
// but that would mess up the decoding, so we just smash them together on the
|
||||
// fly using the duplicate flag.
|
||||
bool duplicate;
|
||||
// Certainty (log prob) of (just) this position.
|
||||
float certainty;
|
||||
// Total certainty of the path to this position.
|
||||
float score;
|
||||
// The previous node in this chain. Borrowed pointer.
|
||||
const RecodeNode* prev;
|
||||
// The currently active dawgs at this position. Owned pointer.
|
||||
DawgPositionVector* dawgs;
|
||||
};
|
||||
|
||||
typedef KDPairInc<double, RecodeNode> RecodePair;
|
||||
typedef GenericHeap<RecodePair> RecodeHeap;
|
||||
|
||||
// Class that holds the entire beam search for recognition of a text line.
|
||||
class RecodeBeamSearch {
|
||||
public:
|
||||
// Borrows the pointer, which is expected to survive until *this is deleted.
|
||||
RecodeBeamSearch(const UnicharCompress& recoder, int null_char,
|
||||
bool simple_text, Dict* dict);
|
||||
|
||||
// Decodes the set of network outputs, storing the lattice internally.
|
||||
// If charset is not null, it enables detailed debugging of the beam search.
|
||||
void Decode(const NetworkIO& output, double dict_ratio, double cert_offset,
|
||||
double worst_dict_cert, const UNICHARSET* charset);
|
||||
void Decode(const GENERIC_2D_ARRAY<float>& output, double dict_ratio,
|
||||
double cert_offset, double worst_dict_cert,
|
||||
const UNICHARSET* charset);
|
||||
|
||||
// Returns the best path as labels/scores/xcoords similar to simple CTC.
|
||||
void ExtractBestPathAsLabels(GenericVector<int>* labels,
|
||||
GenericVector<int>* xcoords) const;
|
||||
// Returns the best path as unichar-ids/certs/ratings/xcoords skipping
|
||||
// duplicates, nulls and intermediate parts.
|
||||
void ExtractBestPathAsUnicharIds(bool debug, const UNICHARSET* unicharset,
|
||||
GenericVector<int>* unichar_ids,
|
||||
GenericVector<float>* certs,
|
||||
GenericVector<float>* ratings,
|
||||
GenericVector<int>* xcoords) const;
|
||||
|
||||
// Returns the best path as a set of WERD_RES.
|
||||
void ExtractBestPathAsWords(const TBOX& line_box, float scale_factor,
|
||||
bool debug, const UNICHARSET* unicharset,
|
||||
PointerVector<WERD_RES>* words);
|
||||
|
||||
// Generates debug output of the content of the beams after a Decode.
|
||||
void DebugBeams(const UNICHARSET& unicharset) const;
|
||||
|
||||
// Clipping value for certainty inside Tesseract. Reflects the minimum value
|
||||
// of certainty that will be returned by ExtractBestPathAsUnicharIds.
|
||||
// Supposedly on a uniform scale that can be compared across languages and
|
||||
// engines.
|
||||
static const float kMinCertainty;
|
||||
|
||||
private:
|
||||
// Struct for the Re-encode beam search. This struct holds the data for
|
||||
// a single time-step position of the output. Use a PointerVector<RecodeBeam>
|
||||
// to hold all the timesteps and prevent reallocation of the individual heaps.
|
||||
struct RecodeBeam {
|
||||
// Resets to the initial state without deleting all the memory.
|
||||
void Clear() {
|
||||
for (int i = 0; i <= RecodedCharID::kMaxCodeLen; ++i) {
|
||||
beams_[i].clear();
|
||||
dawg_beams_[i].clear();
|
||||
}
|
||||
RecodeNode empty;
|
||||
best_initial_dawg_ = empty;
|
||||
}
|
||||
// A separate beam for each code position. Since there aren't that many
|
||||
// code positions, this allows the beam to be quite narrow, and yet still
|
||||
// have a low chance of losing the best path.
|
||||
// Each heap is stored with the WORST result at the top, so we can quickly
|
||||
// get the top-n values.
|
||||
RecodeHeap beams_[RecodedCharID::kMaxCodeLen + 1];
|
||||
// Although, we can only use complete codes in the dawg, we have to separate
|
||||
// partial code paths that lead back to a mid-dawg word from paths that are
|
||||
// not part of a dawg word, as they have a different score. Since a dawg
|
||||
// word can dead-end at any point, we need to keep the non dawg path going
|
||||
// so the dawg beams_ are totally separate set with a heap for each length
|
||||
// just like the non-dawg beams.
|
||||
RecodeHeap dawg_beams_[RecodedCharID::kMaxCodeLen + 1];
|
||||
// While the language model is only a single word dictionary, we can use
|
||||
// word starts as a choke point in the beam, and keep only a single dict
|
||||
// start node at each step, so we find the best one here and push it on
|
||||
// the heap, if it qualifies, after processing all of the step.
|
||||
RecodeNode best_initial_dawg_;
|
||||
};
|
||||
typedef KDPairInc<float, int> TopPair;
|
||||
|
||||
// Generates debug output of the content of a single beam position.
|
||||
void DebugBeamPos(const UNICHARSET& unicharset, const RecodeHeap& heap) const;
|
||||
|
||||
// Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping
|
||||
// duplicates, nulls and intermediate parts.
|
||||
static void ExtractPathAsUnicharIds(
|
||||
const GenericVector<const RecodeNode*>& best_nodes,
|
||||
GenericVector<int>* unichar_ids, GenericVector<float>* certs,
|
||||
GenericVector<float>* ratings, GenericVector<int>* xcoords);
|
||||
|
||||
// Sets up a word with the ratings matrix and fake blobs with boxes in the
|
||||
// right places.
|
||||
WERD_RES* InitializeWord(bool leading_space, const TBOX& line_box,
|
||||
int word_start, int word_end, float space_certainty,
|
||||
const UNICHARSET* unicharset,
|
||||
const GenericVector<int>& xcoords,
|
||||
float scale_factor);
|
||||
|
||||
// Fills top_n_flags_ with bools that are true iff the corresponding output
|
||||
// is one of the top_n.
|
||||
void ComputeTopN(const float* outputs, int num_outputs, int top_n);
|
||||
|
||||
// Adds the computation for the current time-step to the beam. Call at each
|
||||
// time-step in sequence from left to right. outputs is the activation vector
|
||||
// for the current timestep.
|
||||
void DecodeStep(const float* outputs, int t, double dict_ratio,
|
||||
double cert_offset, double worst_dict_cert,
|
||||
const UNICHARSET* charset);
|
||||
|
||||
// Adds to the appropriate beams the legal (according to recoder)
|
||||
// continuations of context prev, which is of the given length, using the
|
||||
// given network outputs to provide scores to the choices. Uses only those
|
||||
// choices for which top_n_flags[index] == top_n_flag.
|
||||
void ContinueContext(const RecodeNode* prev, int length, const float* outputs,
|
||||
bool use_dawgs, bool top_n_flag, double dict_ratio,
|
||||
double cert_offset, double worst_dict_cert,
|
||||
RecodeBeam* step);
|
||||
// Adds a RecodeNode composed of the tuple (code, unichar_id, cert, prev,
|
||||
// appropriate-dawg-args, cert) to the given heap (dawg_beam_) if unichar_id
|
||||
// is a valid continuation of whatever is in prev.
|
||||
void ContinueDawg(int max_size, int code, int unichar_id, float cert,
|
||||
const RecodeNode* prev, RecodeHeap* heap, RecodeBeam* step);
|
||||
// Adds a RecodeNode composed of the tuple (code, unichar_id,
|
||||
// initial-dawg-state, prev, cert) to the given heap if/ there is room or if
|
||||
// better than the current worst element if already full.
|
||||
void PushInitialDawgIfBetter(int code, int unichar_id, PermuterType permuter,
|
||||
bool start, bool end, float cert,
|
||||
const RecodeNode* prev,
|
||||
RecodeNode* best_initial_dawg);
|
||||
// Adds a copy of the given prev as a duplicate of and successor to prev, if
|
||||
// there is room or if better than the current worst element if already full.
|
||||
static void PushDupIfBetter(int max_size, float cert, const RecodeNode* prev,
|
||||
RecodeHeap* heap);
|
||||
// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
|
||||
// false, false, false, false, cert, prev, NULL) to heap if there is room
|
||||
// or if better than the current worst element if already full.
|
||||
static void PushNoDawgIfBetter(int max_size, int code, int unichar_id,
|
||||
PermuterType permuter, float cert,
|
||||
const RecodeNode* prev, RecodeHeap* heap);
|
||||
// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
|
||||
// dawg_start, word_start, end, dup, cert, prev, d) to heap if there is room
|
||||
// or if better than the current worst element if already full.
|
||||
static void PushHeapIfBetter(int max_size, int code, int unichar_id,
|
||||
PermuterType permuter, bool dawg_start,
|
||||
bool word_start, bool end, bool dup, float cert,
|
||||
const RecodeNode* prev, DawgPositionVector* d,
|
||||
RecodeHeap* heap);
|
||||
// Backtracks to extract the best path through the lattice that was built
|
||||
// during Decode. On return the best_nodes vector essentially contains the set
|
||||
// of code, score pairs that make the optimal path with the constraint that
|
||||
// the recoder can decode the code sequence back to a sequence of unichar-ids.
|
||||
void ExtractBestPaths(GenericVector<const RecodeNode*>* best_nodes,
|
||||
GenericVector<const RecodeNode*>* second_nodes) const;
|
||||
// Helper backtracks through the lattice from the given node, storing the
|
||||
// path and reversing it.
|
||||
void ExtractPath(const RecodeNode* node,
|
||||
GenericVector<const RecodeNode*>* path) const;
|
||||
// Helper prints debug information on the given lattice path.
|
||||
void DebugPath(const UNICHARSET* unicharset,
|
||||
const GenericVector<const RecodeNode*>& path) const;
|
||||
// Helper prints debug information on the given unichar path.
|
||||
void DebugUnicharPath(const UNICHARSET* unicharset,
|
||||
const GenericVector<const RecodeNode*>& path,
|
||||
const GenericVector<int>& unichar_ids,
|
||||
const GenericVector<float>& certs,
|
||||
const GenericVector<float>& ratings,
|
||||
const GenericVector<int>& xcoords) const;
|
||||
|
||||
static const int kBeamWidths[RecodedCharID::kMaxCodeLen + 1];
|
||||
|
||||
// The encoder/decoder that we will be using.
|
||||
const UnicharCompress& recoder_;
|
||||
// The beam for each timestep in the output.
|
||||
PointerVector<RecodeBeam> beam_;
|
||||
// The number of timesteps valid in beam_;
|
||||
int beam_size_;
|
||||
// A flag to indicate which outputs are the top-n choices. Current timestep
|
||||
// only.
|
||||
GenericVector<bool> top_n_flags_;
|
||||
// Heap used to compute the top_n_flags_.
|
||||
GenericHeap<TopPair> top_heap_;
|
||||
// Borrowed pointer to the dictionary to use in the search.
|
||||
Dict* dict_;
|
||||
// True if the language is space-delimited, which is true for most languages
|
||||
// except chi*, jpn, tha.
|
||||
bool space_delimited_;
|
||||
// True if the input is simple text, ie adjacent equal chars are not to be
|
||||
// eliminated.
|
||||
bool is_simple_text_;
|
||||
// The encoded (class label) of the null/reject character.
|
||||
int null_char_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
|
128
lstm/reconfig.cpp
Normal file
128
lstm/reconfig.cpp
Normal file
@ -0,0 +1,128 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: reconfig.cpp
|
||||
// Description: Network layer that reconfigures the scaling vs feature
|
||||
// depth.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Feb 26 15:42:25 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
#include "reconfig.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
Reconfig::Reconfig(const STRING& name, int ni, int x_scale, int y_scale)
|
||||
: Network(NT_RECONFIG, name, ni, ni * x_scale * y_scale),
|
||||
x_scale_(x_scale), y_scale_(y_scale) {
|
||||
}
|
||||
|
||||
Reconfig::~Reconfig() {
|
||||
}
|
||||
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
StaticShape Reconfig::OutputShape(const StaticShape& input_shape) const {
|
||||
StaticShape result = input_shape;
|
||||
result.set_height(result.height() / y_scale_);
|
||||
result.set_width(result.width() / x_scale_);
|
||||
if (type_ != NT_MAXPOOL)
|
||||
result.set_depth(result.depth() * y_scale_ * x_scale_);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns an integer reduction factor that the network applies to the
|
||||
// time sequence. Assumes that any 2-d is already eliminated. Used for
|
||||
// scaling bounding boxes of truth data.
|
||||
// WARNING: if GlobalMinimax is used to vary the scale, this will return
|
||||
// the last used scale factor. Call it before any forward, and it will return
|
||||
// the minimum scale factor of the paths through the GlobalMinimax.
|
||||
int Reconfig::XScaleFactor() const {
|
||||
return x_scale_;
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool Reconfig::Serialize(TFile* fp) const {
|
||||
if (!Network::Serialize(fp)) return false;
|
||||
if (fp->FWrite(&x_scale_, sizeof(x_scale_), 1) != 1) return false;
|
||||
if (fp->FWrite(&y_scale_, sizeof(y_scale_), 1) != 1) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool Reconfig::DeSerialize(bool swap, TFile* fp) {
|
||||
if (fp->FRead(&x_scale_, sizeof(x_scale_), 1) != 1) return false;
|
||||
if (fp->FRead(&y_scale_, sizeof(y_scale_), 1) != 1) return false;
|
||||
if (swap) {
|
||||
ReverseN(&x_scale_, sizeof(x_scale_));
|
||||
ReverseN(&y_scale_, sizeof(y_scale_));
|
||||
}
|
||||
no_ = ni_ * x_scale_ * y_scale_;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
void Reconfig::Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output) {
|
||||
output->ResizeScaled(input, x_scale_, y_scale_, no_);
|
||||
back_map_ = input.stride_map();
|
||||
StrideMap::Index dest_index(output->stride_map());
|
||||
do {
|
||||
int out_t = dest_index.t();
|
||||
StrideMap::Index src_index(input.stride_map(), dest_index.index(FD_BATCH),
|
||||
dest_index.index(FD_HEIGHT) * y_scale_,
|
||||
dest_index.index(FD_WIDTH) * x_scale_);
|
||||
// Stack x_scale_ groups of y_scale_ inputs together.
|
||||
for (int x = 0; x < x_scale_; ++x) {
|
||||
for (int y = 0; y < y_scale_; ++y) {
|
||||
StrideMap::Index src_xy(src_index);
|
||||
if (src_xy.AddOffset(x, FD_WIDTH) && src_xy.AddOffset(y, FD_HEIGHT)) {
|
||||
output->CopyTimeStepGeneral(out_t, (x * y_scale_ + y) * ni_, ni_,
|
||||
input, src_xy.t(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
} while (dest_index.Increment());
|
||||
}
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
bool Reconfig::Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas) {
|
||||
back_deltas->ResizeToMap(fwd_deltas.int_mode(), back_map_, ni_);
|
||||
StrideMap::Index src_index(fwd_deltas.stride_map());
|
||||
do {
|
||||
int in_t = src_index.t();
|
||||
StrideMap::Index dest_index(back_deltas->stride_map(),
|
||||
src_index.index(FD_BATCH),
|
||||
src_index.index(FD_HEIGHT) * y_scale_,
|
||||
src_index.index(FD_WIDTH) * x_scale_);
|
||||
// Unstack x_scale_ groups of y_scale_ inputs that are together.
|
||||
for (int x = 0; x < x_scale_; ++x) {
|
||||
for (int y = 0; y < y_scale_; ++y) {
|
||||
StrideMap::Index dest_xy(dest_index);
|
||||
if (dest_xy.AddOffset(x, FD_WIDTH) && dest_xy.AddOffset(y, FD_HEIGHT)) {
|
||||
back_deltas->CopyTimeStepGeneral(dest_xy.t(), 0, ni_, fwd_deltas,
|
||||
in_t, (x * y_scale_ + y) * ni_);
|
||||
}
|
||||
}
|
||||
}
|
||||
} while (src_index.Increment());
|
||||
return needs_to_backprop_;
|
||||
}
|
||||
|
||||
|
||||
} // namespace tesseract.
|
86
lstm/reconfig.h
Normal file
86
lstm/reconfig.h
Normal file
@ -0,0 +1,86 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: reconfig.h
|
||||
// Description: Network layer that reconfigures the scaling vs feature
|
||||
// depth.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Feb 26 15:37:42 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
#ifndef TESSERACT_LSTM_RECONFIG_H_
|
||||
#define TESSERACT_LSTM_RECONFIG_H_
|
||||
|
||||
|
||||
#include "genericvector.h"
|
||||
#include "matrix.h"
|
||||
#include "network.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Reconfigures (Shrinks) the inputs by concatenating an x_scale by y_scale tile
|
||||
// of inputs together, producing a single, deeper output per tile.
|
||||
// Note that fractional parts are truncated for efficiency, so make sure the
|
||||
// input stride is a multiple of the y_scale factor!
|
||||
class Reconfig : public Network {
|
||||
public:
|
||||
Reconfig(const STRING& name, int ni, int x_scale, int y_scale);
|
||||
virtual ~Reconfig();
|
||||
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
virtual StaticShape OutputShape(const StaticShape& input_shape) const;
|
||||
|
||||
virtual STRING spec() const {
|
||||
STRING spec;
|
||||
spec.add_str_int("S", y_scale_);
|
||||
spec.add_str_int(",", x_scale_);
|
||||
return spec;
|
||||
}
|
||||
|
||||
// Returns an integer reduction factor that the network applies to the
|
||||
// time sequence. Assumes that any 2-d is already eliminated. Used for
|
||||
// scaling bounding boxes of truth data.
|
||||
// WARNING: if GlobalMinimax is used to vary the scale, this will return
|
||||
// the last used scale factor. Call it before any forward, and it will return
|
||||
// the minimum scale factor of the paths through the GlobalMinimax.
|
||||
virtual int XScaleFactor() const;
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
virtual bool Serialize(TFile* fp) const;
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
virtual bool DeSerialize(bool swap, TFile* fp);
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual void Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output);
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas);
|
||||
|
||||
protected:
|
||||
// Non-serialized data used to store parameters between forward and back.
|
||||
StrideMap back_map_;
|
||||
// Serialized data.
|
||||
inT32 x_scale_;
|
||||
inT32 y_scale_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
|
||||
#endif // TESSERACT_LSTM_SUBSAMPLE_H_
|
91
lstm/reversed.cpp
Normal file
91
lstm/reversed.cpp
Normal file
@ -0,0 +1,91 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: reversed.cpp
|
||||
// Description: Runs a single network on time-reversed input, reversing output.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu May 02 08:42:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "reversed.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include "networkscratch.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
Reversed::Reversed(const STRING& name, NetworkType type) : Plumbing(name) {
|
||||
type_ = type;
|
||||
}
|
||||
Reversed::~Reversed() {
|
||||
}
|
||||
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
StaticShape Reversed::OutputShape(const StaticShape& input_shape) const {
|
||||
if (type_ == NT_XYTRANSPOSE) {
|
||||
StaticShape x_shape(input_shape);
|
||||
x_shape.set_width(input_shape.height());
|
||||
x_shape.set_height(input_shape.width());
|
||||
x_shape = stack_[0]->OutputShape(x_shape);
|
||||
x_shape.SetShape(x_shape.batch(), x_shape.width(), x_shape.height(),
|
||||
x_shape.depth());
|
||||
return x_shape;
|
||||
}
|
||||
return stack_[0]->OutputShape(input_shape);
|
||||
}
|
||||
|
||||
// Takes ownership of the given network to make it the reversed one.
|
||||
void Reversed::SetNetwork(Network* network) {
|
||||
stack_.clear();
|
||||
AddToStack(network);
|
||||
}
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
void Reversed::Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output) {
|
||||
NetworkScratch::IO rev_input(input, scratch);
|
||||
ReverseData(input, rev_input);
|
||||
NetworkScratch::IO rev_output(input, scratch);
|
||||
stack_[0]->Forward(debug, *rev_input, NULL, scratch, rev_output);
|
||||
ReverseData(*rev_output, output);
|
||||
}
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
bool Reversed::Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas) {
|
||||
NetworkScratch::IO rev_input(fwd_deltas, scratch);
|
||||
ReverseData(fwd_deltas, rev_input);
|
||||
NetworkScratch::IO rev_output(fwd_deltas, scratch);
|
||||
if (stack_[0]->Backward(debug, *rev_input, scratch, rev_output)) {
|
||||
ReverseData(*rev_output, back_deltas);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Copies src to *dest with the reversal according to type_.
|
||||
void Reversed::ReverseData(const NetworkIO& src, NetworkIO* dest) const {
|
||||
if (type_ == NT_XREVERSED)
|
||||
dest->CopyWithXReversal(src);
|
||||
else if (type_ == NT_YREVERSED)
|
||||
dest->CopyWithYReversal(src);
|
||||
else
|
||||
dest->CopyWithXYTranspose(src);
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
89
lstm/reversed.h
Normal file
89
lstm/reversed.h
Normal file
@ -0,0 +1,89 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: reversed.h
|
||||
// Description: Runs a single network on time-reversed input, reversing output.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu May 02 08:38:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_REVERSED_H_
|
||||
#define TESSERACT_LSTM_REVERSED_H_
|
||||
|
||||
#include "matrix.h"
|
||||
#include "plumbing.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// C++ Implementation of the Reversed class from lstm.py.
|
||||
class Reversed : public Plumbing {
|
||||
public:
|
||||
explicit Reversed(const STRING& name, NetworkType type);
|
||||
virtual ~Reversed();
|
||||
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
virtual StaticShape OutputShape(const StaticShape& input_shape) const;
|
||||
|
||||
virtual STRING spec() const {
|
||||
STRING spec(type_ == NT_XREVERSED ? "Rx"
|
||||
: (type_ == NT_YREVERSED ? "Ry" : "Txy"));
|
||||
// For most simple cases, we will output Rx<net> or Ry<net> where <net> is
|
||||
// the network in stack_[0], but in the special case that <net> is an
|
||||
// LSTM, we will just output the LSTM's spec modified to take the reversal
|
||||
// into account. This is because when the user specified Lfy64, we actually
|
||||
// generated TxyLfx64, and if the user specified Lrx64 we actually
|
||||
// generated RxLfx64, and we want to display what the user asked for.
|
||||
STRING net_spec = stack_[0]->spec();
|
||||
if (net_spec[0] == 'L') {
|
||||
// Setup a from and to character according to the type of the reversal
|
||||
// such that the LSTM spec gets modified to the spec that the user
|
||||
// asked for
|
||||
char from = 'f';
|
||||
char to = 'r';
|
||||
if (type_ == NT_XYTRANSPOSE) {
|
||||
from = 'x';
|
||||
to = 'y';
|
||||
}
|
||||
// Change the from char to the to char.
|
||||
for (int i = 0; i < net_spec.length(); ++i) {
|
||||
if (net_spec[i] == from) net_spec[i] = to;
|
||||
}
|
||||
return net_spec;
|
||||
}
|
||||
spec += net_spec;
|
||||
return spec;
|
||||
}
|
||||
|
||||
// Takes ownership of the given network to make it the reversed one.
|
||||
void SetNetwork(Network* network);
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual void Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output);
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas);
|
||||
|
||||
private:
|
||||
// Copies src to *dest with the reversal according to type_.
|
||||
void ReverseData(const NetworkIO& src, NetworkIO* dest) const;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_REVERSED_H_
|
188
lstm/series.cpp
Normal file
188
lstm/series.cpp
Normal file
@ -0,0 +1,188 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: series.cpp
|
||||
// Description: Runs networks in series on the same input.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu May 02 08:26:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "series.h"
|
||||
|
||||
#include "fullyconnected.h"
|
||||
#include "networkscratch.h"
|
||||
#include "scrollview.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// ni_ and no_ will be set by AddToStack.
|
||||
Series::Series(const STRING& name) : Plumbing(name) {
|
||||
type_ = NT_SERIES;
|
||||
}
|
||||
|
||||
Series::~Series() {
|
||||
}
|
||||
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
StaticShape Series::OutputShape(const StaticShape& input_shape) const {
|
||||
StaticShape result(input_shape);
|
||||
int stack_size = stack_.size();
|
||||
for (int i = 0; i < stack_size; ++i) {
|
||||
result = stack_[i]->OutputShape(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Sets up the network for training. Initializes weights using weights of
|
||||
// scale `range` picked according to the random number generator `randomizer`.
|
||||
// Note that series has its own implementation just for debug purposes.
|
||||
int Series::InitWeights(float range, TRand* randomizer) {
|
||||
num_weights_ = 0;
|
||||
tprintf("Num outputs,weights in serial:\n");
|
||||
for (int i = 0; i < stack_.size(); ++i) {
|
||||
int weights = stack_[i]->InitWeights(range, randomizer);
|
||||
tprintf(" %s:%d, %d\n",
|
||||
stack_[i]->spec().string(), stack_[i]->NumOutputs(), weights);
|
||||
num_weights_ += weights;
|
||||
}
|
||||
tprintf("Total weights = %d\n", num_weights_);
|
||||
return num_weights_;
|
||||
}
|
||||
|
||||
// Sets needs_to_backprop_ to needs_backprop and returns true if
|
||||
// needs_backprop || any weights in this network so the next layer forward
|
||||
// can be told to produce backprop for this layer if needed.
|
||||
bool Series::SetupNeedsBackprop(bool needs_backprop) {
|
||||
needs_to_backprop_ = needs_backprop;
|
||||
for (int i = 0; i < stack_.size(); ++i)
|
||||
needs_backprop = stack_[i]->SetupNeedsBackprop(needs_backprop);
|
||||
return needs_backprop;
|
||||
}
|
||||
|
||||
// Returns an integer reduction factor that the network applies to the
|
||||
// time sequence. Assumes that any 2-d is already eliminated. Used for
|
||||
// scaling bounding boxes of truth data.
|
||||
// WARNING: if GlobalMinimax is used to vary the scale, this will return
|
||||
// the last used scale factor. Call it before any forward, and it will return
|
||||
// the minimum scale factor of the paths through the GlobalMinimax.
|
||||
int Series::XScaleFactor() const {
|
||||
int factor = 1;
|
||||
for (int i = 0; i < stack_.size(); ++i)
|
||||
factor *= stack_[i]->XScaleFactor();
|
||||
return factor;
|
||||
}
|
||||
|
||||
// Provides the (minimum) x scale factor to the network (of interest only to
|
||||
// input units) so they can determine how to scale bounding boxes.
|
||||
void Series::CacheXScaleFactor(int factor) {
|
||||
stack_[0]->CacheXScaleFactor(factor);
|
||||
}
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
void Series::Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output) {
|
||||
int stack_size = stack_.size();
|
||||
ASSERT_HOST(stack_size > 1);
|
||||
// Revolving intermediate buffers.
|
||||
NetworkScratch::IO buffer1(input, scratch);
|
||||
NetworkScratch::IO buffer2(input, scratch);
|
||||
// Run each network in turn, giving the output of n as the input to n + 1,
|
||||
// with the final network providing the real output.
|
||||
stack_[0]->Forward(debug, input, input_transpose, scratch, buffer1);
|
||||
for (int i = 1; i < stack_size; i += 2) {
|
||||
stack_[i]->Forward(debug, *buffer1, NULL, scratch,
|
||||
i + 1 < stack_size ? buffer2 : output);
|
||||
if (i + 1 == stack_size) return;
|
||||
stack_[i + 1]->Forward(debug, *buffer2, NULL, scratch,
|
||||
i + 2 < stack_size ? buffer1 : output);
|
||||
}
|
||||
}
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
bool Series::Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas) {
|
||||
if (!training()) return false;
|
||||
int stack_size = stack_.size();
|
||||
ASSERT_HOST(stack_size > 1);
|
||||
// Revolving intermediate buffers.
|
||||
NetworkScratch::IO buffer1(fwd_deltas, scratch);
|
||||
NetworkScratch::IO buffer2(fwd_deltas, scratch);
|
||||
// Run each network in reverse order, giving the back_deltas output of n as
|
||||
// the fwd_deltas input to n-1, with the 0 network providing the real output.
|
||||
if (!stack_.back()->training() ||
|
||||
!stack_.back()->Backward(debug, fwd_deltas, scratch, buffer1))
|
||||
return false;
|
||||
for (int i = stack_size - 2; i >= 0; i -= 2) {
|
||||
if (!stack_[i]->training() ||
|
||||
!stack_[i]->Backward(debug, *buffer1, scratch,
|
||||
i > 0 ? buffer2 : back_deltas))
|
||||
return false;
|
||||
if (i == 0) return needs_to_backprop_;
|
||||
if (!stack_[i - 1]->training() ||
|
||||
!stack_[i - 1]->Backward(debug, *buffer2, scratch,
|
||||
i > 1 ? buffer1 : back_deltas))
|
||||
return false;
|
||||
}
|
||||
return needs_to_backprop_;
|
||||
}
|
||||
|
||||
// Splits the series after the given index, returning the two parts and
|
||||
// deletes itself. The first part, upto network with index last_start, goes
|
||||
// into start, and the rest goes into end.
|
||||
void Series::SplitAt(int last_start, Series** start, Series** end) {
|
||||
*start = NULL;
|
||||
*end = NULL;
|
||||
if (last_start < 0 || last_start >= stack_.size()) {
|
||||
tprintf("Invalid split index %d must be in range [0,%d]!\n",
|
||||
last_start, stack_.size() - 1);
|
||||
return;
|
||||
}
|
||||
Series* master_series = new Series("MasterSeries");
|
||||
Series* boosted_series = new Series("BoostedSeries");
|
||||
for (int s = 0; s <= last_start; ++s) {
|
||||
if (s + 1 == stack_.size() && stack_[s]->type() == NT_SOFTMAX) {
|
||||
// Change the softmax to a tanh.
|
||||
FullyConnected* fc = reinterpret_cast<FullyConnected*>(stack_[s]);
|
||||
fc->ChangeType(NT_TANH);
|
||||
}
|
||||
master_series->AddToStack(stack_[s]);
|
||||
stack_[s] = NULL;
|
||||
}
|
||||
for (int s = last_start + 1; s < stack_.size(); ++s) {
|
||||
boosted_series->AddToStack(stack_[s]);
|
||||
stack_[s] = NULL;
|
||||
}
|
||||
*start = master_series;
|
||||
*end = boosted_series;
|
||||
delete this;
|
||||
}
|
||||
|
||||
// Appends the elements of the src series to this, removing from src and
|
||||
// deleting it.
|
||||
void Series::AppendSeries(Network* src) {
|
||||
ASSERT_HOST(src->type() == NT_SERIES);
|
||||
Series* src_series = reinterpret_cast<Series*>(src);
|
||||
for (int s = 0; s < src_series->stack_.size(); ++s) {
|
||||
AddToStack(src_series->stack_[s]);
|
||||
src_series->stack_[s] = NULL;
|
||||
}
|
||||
delete src;
|
||||
}
|
||||
|
||||
|
||||
} // namespace tesseract.
|
91
lstm/series.h
Normal file
91
lstm/series.h
Normal file
@ -0,0 +1,91 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: series.h
|
||||
// Description: Runs networks in series on the same input.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu May 02 08:20:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_SERIES_H_
|
||||
#define TESSERACT_LSTM_SERIES_H_
|
||||
|
||||
#include "plumbing.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Runs two or more networks in series (layers) on the same input.
|
||||
class Series : public Plumbing {
|
||||
public:
|
||||
// ni_ and no_ will be set by AddToStack.
|
||||
explicit Series(const STRING& name);
|
||||
virtual ~Series();
|
||||
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
virtual StaticShape OutputShape(const StaticShape& input_shape) const;
|
||||
|
||||
virtual STRING spec() const {
|
||||
STRING spec("[");
|
||||
for (int i = 0; i < stack_.size(); ++i)
|
||||
spec += stack_[i]->spec();
|
||||
spec += "]";
|
||||
return spec;
|
||||
}
|
||||
|
||||
// Sets up the network for training. Initializes weights using weights of
|
||||
// scale `range` picked according to the random number generator `randomizer`.
|
||||
// Returns the number of weights initialized.
|
||||
virtual int InitWeights(float range, TRand* randomizer);
|
||||
|
||||
// Sets needs_to_backprop_ to needs_backprop and returns true if
|
||||
// needs_backprop || any weights in this network so the next layer forward
|
||||
// can be told to produce backprop for this layer if needed.
|
||||
virtual bool SetupNeedsBackprop(bool needs_backprop);
|
||||
|
||||
// Returns an integer reduction factor that the network applies to the
|
||||
// time sequence. Assumes that any 2-d is already eliminated. Used for
|
||||
// scaling bounding boxes of truth data.
|
||||
// WARNING: if GlobalMinimax is used to vary the scale, this will return
|
||||
// the last used scale factor. Call it before any forward, and it will return
|
||||
// the minimum scale factor of the paths through the GlobalMinimax.
|
||||
virtual int XScaleFactor() const;
|
||||
|
||||
// Provides the (minimum) x scale factor to the network (of interest only to
|
||||
// input units) so they can determine how to scale bounding boxes.
|
||||
virtual void CacheXScaleFactor(int factor);
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual void Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output);
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas);
|
||||
|
||||
// Splits the series after the given index, returning the two parts and
|
||||
// deletes itself. The first part, upto network with index last_start, goes
|
||||
// into start, and the rest goes into end.
|
||||
void SplitAt(int last_start, Series** start, Series** end);
|
||||
|
||||
// Appends the elements of the src series to this, removing from src and
|
||||
// deleting it.
|
||||
void AppendSeries(Network* src);
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_SERIES_H_
|
80
lstm/static_shape.h
Normal file
80
lstm/static_shape.h
Normal file
@ -0,0 +1,80 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: static_shape.h
|
||||
// Description: Defines the size of the 4-d tensor input/output from a network.
|
||||
// Author: Ray Smith
|
||||
// Created: Fri Oct 14 09:07:31 PST 2016
|
||||
//
|
||||
// (C) Copyright 2016, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
#ifndef TESSERACT_LSTM_STATIC_SHAPE_H_
|
||||
#define TESSERACT_LSTM_STATIC_SHAPE_H_
|
||||
|
||||
#include "tprintf.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Enum describing the loss function to apply during training and/or the
|
||||
// decoding method to apply at runtime.
|
||||
enum LossType {
|
||||
LT_NONE, // Undefined.
|
||||
LT_CTC, // Softmax with standard CTC for training/decoding.
|
||||
LT_SOFTMAX, // Outputs sum to 1 in fixed positions.
|
||||
LT_LOGISTIC, // Logistic outputs with independent values.
|
||||
};
|
||||
|
||||
// Simple class to hold the tensor shape that is known at network build time
|
||||
// and the LossType of the loss funtion.
|
||||
class StaticShape {
|
||||
public:
|
||||
StaticShape()
|
||||
: batch_(0), height_(0), width_(0), depth_(0), loss_type_(LT_NONE) {}
|
||||
int batch() const { return batch_; }
|
||||
void set_batch(int value) { batch_ = value; }
|
||||
int height() const { return height_; }
|
||||
void set_height(int value) { height_ = value; }
|
||||
int width() const { return width_; }
|
||||
void set_width(int value) { width_ = value; }
|
||||
int depth() const { return depth_; }
|
||||
void set_depth(int value) { depth_ = value; }
|
||||
LossType loss_type() const { return loss_type_; }
|
||||
void set_loss_type(LossType value) { loss_type_ = value; }
|
||||
void SetShape(int batch, int height, int width, int depth) {
|
||||
batch_ = batch;
|
||||
height_ = height;
|
||||
width_ = width;
|
||||
depth_ = depth;
|
||||
}
|
||||
|
||||
void Print() const {
|
||||
tprintf("Batch=%d, Height=%d, Width=%d, Depth=%d, loss=%d\n", batch_,
|
||||
height_, width_, depth_, loss_type_);
|
||||
}
|
||||
|
||||
private:
|
||||
// Size of the 4-D tensor input/output to a network. A value of zero is
|
||||
// allowed for all except depth_ and means to be determined at runtime, and
|
||||
// regarded as variable.
|
||||
// Number of elements in a batch, or number of frames in a video stream.
|
||||
int batch_;
|
||||
// Height of the image.
|
||||
int height_;
|
||||
// Width of the image.
|
||||
int width_;
|
||||
// Depth of the image. (Number of "nodes").
|
||||
int depth_;
|
||||
// How to train/interpret the output.
|
||||
LossType loss_type_;
|
||||
};
|
||||
|
||||
} // namespace tesseract
|
||||
|
||||
#endif // TESSERACT_LSTM_STATIC_SHAPE_H_
|
173
lstm/stridemap.cpp
Normal file
173
lstm/stridemap.cpp
Normal file
@ -0,0 +1,173 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: stridemap.cpp
|
||||
// Description: Indexing into a 4-d tensor held in a 2-d Array.
|
||||
// Author: Ray Smith
|
||||
// Created: Fri Sep 20 15:30:31 PST 2016
|
||||
//
|
||||
// (C) Copyright 2016, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "stridemap.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Returns true if *this is a valid index.
|
||||
bool StrideMap::Index::IsValid() const {
|
||||
// Cheap check first.
|
||||
for (int d = 0; d < FD_DIMSIZE; ++d) {
|
||||
if (indices_[d] < 0) return false;
|
||||
}
|
||||
for (int d = 0; d < FD_DIMSIZE; ++d) {
|
||||
if (indices_[d] > MaxIndexOfDim(static_cast<FlexDimensions>(d)))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if the index of the given dimension is the last.
|
||||
bool StrideMap::Index::IsLast(FlexDimensions dimension) const {
|
||||
return MaxIndexOfDim(dimension) == indices_[dimension];
|
||||
}
|
||||
|
||||
// Given that the dimensions upto and including dim-1 are valid, returns the
|
||||
// maximum index for dimension dim.
|
||||
int StrideMap::Index::MaxIndexOfDim(FlexDimensions dim) const {
|
||||
int max_index = stride_map_->shape_[dim] - 1;
|
||||
if (dim == FD_BATCH) return max_index;
|
||||
int batch = indices_[FD_BATCH];
|
||||
if (dim == FD_HEIGHT) {
|
||||
if (batch >= stride_map_->heights_.size() ||
|
||||
stride_map_->heights_[batch] > max_index)
|
||||
return max_index;
|
||||
return stride_map_->heights_[batch] - 1;
|
||||
}
|
||||
if (batch >= stride_map_->widths_.size() ||
|
||||
stride_map_->widths_[batch] > max_index)
|
||||
return max_index;
|
||||
return stride_map_->widths_[batch] - 1;
|
||||
}
|
||||
|
||||
// Adds the given offset to the given dimension. Returns true if the result
|
||||
// makes a valid index.
|
||||
bool StrideMap::Index::AddOffset(int offset, FlexDimensions dimension) {
|
||||
indices_[dimension] += offset;
|
||||
SetTFromIndices();
|
||||
return IsValid();
|
||||
}
|
||||
|
||||
// Increments the index in some encapsulated way that guarantees to remain
|
||||
// valid until it returns false, meaning that the iteration is complete.
|
||||
bool StrideMap::Index::Increment() {
|
||||
for (int d = FD_DIMSIZE - 1; d >= 0; --d) {
|
||||
if (!IsLast(static_cast<FlexDimensions>(d))) {
|
||||
t_ += stride_map_->t_increments_[d];
|
||||
++indices_[d];
|
||||
return true;
|
||||
}
|
||||
t_ -= stride_map_->t_increments_[d] * indices_[d];
|
||||
indices_[d] = 0;
|
||||
// Now carry to the next dimension.
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Decrements the index in some encapsulated way that guarantees to remain
|
||||
// valid until it returns false, meaning that the iteration (that started
|
||||
// with InitToLast()) is complete.
|
||||
bool StrideMap::Index::Decrement() {
|
||||
for (int d = FD_DIMSIZE - 1; d >= 0; --d) {
|
||||
if (indices_[d] > 0) {
|
||||
--indices_[d];
|
||||
if (d == FD_BATCH) {
|
||||
// The upper limits of the other dimensions may have changed as a result
|
||||
// of a different batch index, so they have to be reset.
|
||||
InitToLastOfBatch(indices_[FD_BATCH]);
|
||||
} else {
|
||||
t_ -= stride_map_->t_increments_[d];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
indices_[d] = MaxIndexOfDim(static_cast<FlexDimensions>(d));
|
||||
t_ += stride_map_->t_increments_[d] * indices_[d];
|
||||
// Now borrow from the next dimension.
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Initializes the indices to the last valid location in the given batch
|
||||
// index.
|
||||
void StrideMap::Index::InitToLastOfBatch(int batch) {
|
||||
indices_[FD_BATCH] = batch;
|
||||
for (int d = FD_BATCH + 1; d < FD_DIMSIZE; ++d) {
|
||||
indices_[d] = MaxIndexOfDim(static_cast<FlexDimensions>(d));
|
||||
}
|
||||
SetTFromIndices();
|
||||
}
|
||||
|
||||
// Computes and sets t_ from the current indices_.
|
||||
void StrideMap::Index::SetTFromIndices() {
|
||||
t_ = 0;
|
||||
for (int d = 0; d < FD_DIMSIZE; ++d) {
|
||||
t_ += stride_map_->t_increments_[d] * indices_[d];
|
||||
}
|
||||
}
|
||||
|
||||
// Sets up the stride for the given array of height, width pairs.
|
||||
void StrideMap::SetStride(const std::vector<std::pair<int, int>>& h_w_pairs) {
|
||||
int max_height = 0;
|
||||
int max_width = 0;
|
||||
for (const std::pair<int, int>& hw : h_w_pairs) {
|
||||
int height = hw.first;
|
||||
int width = hw.second;
|
||||
heights_.push_back(height);
|
||||
widths_.push_back(width);
|
||||
if (height > max_height) max_height = height;
|
||||
if (width > max_width) max_width = width;
|
||||
}
|
||||
shape_[FD_BATCH] = heights_.size();
|
||||
shape_[FD_HEIGHT] = max_height;
|
||||
shape_[FD_WIDTH] = max_width;
|
||||
ComputeTIncrements();
|
||||
}
|
||||
|
||||
// Scales width and height dimensions by the given factors.
|
||||
void StrideMap::ScaleXY(int x_factor, int y_factor) {
|
||||
for (int& height : heights_) height /= y_factor;
|
||||
for (int& width : widths_) width /= x_factor;
|
||||
shape_[FD_HEIGHT] /= y_factor;
|
||||
shape_[FD_WIDTH] /= x_factor;
|
||||
ComputeTIncrements();
|
||||
}
|
||||
|
||||
// Reduces width to 1, across the batch, whatever the input size.
|
||||
void StrideMap::ReduceWidthTo1() {
|
||||
widths_.assign(widths_.size(), 1);
|
||||
shape_[FD_WIDTH] = 1;
|
||||
ComputeTIncrements();
|
||||
}
|
||||
|
||||
// Transposes the width and height dimensions.
|
||||
void StrideMap::TransposeXY() {
|
||||
std::swap(shape_[FD_HEIGHT], shape_[FD_WIDTH]);
|
||||
std::swap(heights_, widths_);
|
||||
ComputeTIncrements();
|
||||
}
|
||||
|
||||
// Computes t_increments_ from shape_.
|
||||
void StrideMap::ComputeTIncrements() {
|
||||
t_increments_[FD_DIMSIZE - 1] = 1;
|
||||
for (int d = FD_DIMSIZE - 2; d >= 0; --d) {
|
||||
t_increments_[d] = t_increments_[d + 1] * shape_[d + 1];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tesseract
|
137
lstm/stridemap.h
Normal file
137
lstm/stridemap.h
Normal file
@ -0,0 +1,137 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: stridemap.h
|
||||
// Description: Indexing into a 4-d tensor held in a 2-d Array.
|
||||
// Author: Ray Smith
|
||||
// Created: Fri Sep 20 16:00:31 PST 2016
|
||||
//
|
||||
// (C) Copyright 2016, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
#ifndef TESSERACT_LSTM_STRIDEMAP_H_
|
||||
#define TESSERACT_LSTM_STRIDEMAP_H_
|
||||
|
||||
#include <string.h>
|
||||
#include <vector>
|
||||
#include "tprintf.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Enum describing the dimensions of the 'Tensor' in a NetworkIO.
|
||||
// A NetworkIO is analogous to a TF Tensor, except that the number of dimensions
|
||||
// is fixed (4), and they always have the same meaning. The underlying
|
||||
// representation is a 2-D array, for which the product batch*height*width
|
||||
// is always dim1 and depth is always dim2. FlexDimensions is used only for
|
||||
// batch, height, width with the StrideMap, and therefore represents the runtime
|
||||
// shape. The build-time shape is defined by StaticShape.
|
||||
enum FlexDimensions {
|
||||
FD_BATCH, // Index of multiple images.
|
||||
FD_HEIGHT, // y-coordinate in image.
|
||||
FD_WIDTH, // x-coordinate in image.
|
||||
FD_DIMSIZE, // Number of flexible non-depth dimensions.
|
||||
};
|
||||
|
||||
// Encapsulation of information relating to the mapping from [batch][y][x] to
|
||||
// the first index into the 2-d array underlying a NetworkIO.
|
||||
class StrideMap {
|
||||
public:
|
||||
// Class holding the non-depth indices.
|
||||
class Index {
|
||||
public:
|
||||
explicit Index(const StrideMap& stride_map) : stride_map_(&stride_map) {
|
||||
InitToFirst();
|
||||
}
|
||||
Index(const StrideMap& stride_map, int batch, int y, int x)
|
||||
: stride_map_(&stride_map) {
|
||||
indices_[FD_BATCH] = batch;
|
||||
indices_[FD_HEIGHT] = y;
|
||||
indices_[FD_WIDTH] = x;
|
||||
SetTFromIndices();
|
||||
}
|
||||
// Accesses the index to the underlying array.
|
||||
int t() const { return t_; }
|
||||
int index(FlexDimensions dimension) const { return indices_[dimension]; }
|
||||
// Initializes the indices to the first valid location.
|
||||
void InitToFirst() {
|
||||
memset(indices_, 0, sizeof(indices_));
|
||||
t_ = 0;
|
||||
}
|
||||
// Initializes the indices to the last valid location.
|
||||
void InitToLast() { InitToLastOfBatch(MaxIndexOfDim(FD_BATCH)); }
|
||||
// Returns true if *this is a valid index.
|
||||
bool IsValid() const;
|
||||
// Returns true if the index of the given dimension is the last.
|
||||
bool IsLast(FlexDimensions dimension) const;
|
||||
// Given that the dimensions upto and including dim-1 are valid, returns the
|
||||
// maximum index for dimension dim.
|
||||
int MaxIndexOfDim(FlexDimensions dim) const;
|
||||
// Adds the given offset to the given dimension. Returns true if the result
|
||||
// makes a valid index.
|
||||
bool AddOffset(int offset, FlexDimensions dimension);
|
||||
// Increments the index in some encapsulated way that guarantees to remain
|
||||
// valid until it returns false, meaning that the iteration is complete.
|
||||
bool Increment();
|
||||
// Decrements the index in some encapsulated way that guarantees to remain
|
||||
// valid until it returns false, meaning that the iteration (that started
|
||||
// with InitToLast()) is complete.
|
||||
bool Decrement();
|
||||
|
||||
private:
|
||||
// Initializes the indices to the last valid location in the given batch
|
||||
// index.
|
||||
void InitToLastOfBatch(int batch);
|
||||
// Computes and sets t_ from the current indices_.
|
||||
void SetTFromIndices();
|
||||
|
||||
// Map into which *this is an index.
|
||||
const StrideMap* stride_map_;
|
||||
// Index to the first dimension of the underlying array.
|
||||
int t_;
|
||||
// Indices into the individual dimensions.
|
||||
int indices_[FD_DIMSIZE];
|
||||
};
|
||||
|
||||
StrideMap() {
|
||||
memset(shape_, 0, sizeof(shape_));
|
||||
memset(t_increments_, 0, sizeof(t_increments_));
|
||||
}
|
||||
// Default copy constructor and operator= are OK to use here!
|
||||
|
||||
// Sets up the stride for the given array of height, width pairs.
|
||||
void SetStride(const std::vector<std::pair<int, int>>& h_w_pairs);
|
||||
// Scales width and height dimensions by the given factors.
|
||||
void ScaleXY(int x_factor, int y_factor);
|
||||
// Reduces width to 1, across the batch, whatever the input size.
|
||||
void ReduceWidthTo1();
|
||||
// Transposes the width and height dimensions.
|
||||
void TransposeXY();
|
||||
// Returns the size of the given dimension.
|
||||
int Size(FlexDimensions dimension) const { return shape_[dimension]; }
|
||||
// Returns the total width required.
|
||||
int Width() const { return t_increments_[FD_BATCH] * shape_[FD_BATCH]; }
|
||||
|
||||
private:
|
||||
// Computes t_increments_ from shape_.
|
||||
void ComputeTIncrements();
|
||||
|
||||
// The size of each non-depth dimension.
|
||||
int shape_[FD_DIMSIZE];
|
||||
// Precomputed 't' increments for each dimension. This is the value of
|
||||
// the given dimension in the packed 3-d array that the shape_ represents.
|
||||
int t_increments_[FD_DIMSIZE];
|
||||
// Vector of size shape_[FD_BATCH] holds the height of each image in a batch.
|
||||
std::vector<int> heights_;
|
||||
// Vector of size shape_[FD_BATCH] holds the width of each image in a batch.
|
||||
std::vector<int> widths_;
|
||||
};
|
||||
|
||||
} // namespace tesseract
|
||||
|
||||
#endif // TESSERACT_LSTM_STRIDEMAP_H_
|
146
lstm/tfnetwork.cpp
Normal file
146
lstm/tfnetwork.cpp
Normal file
@ -0,0 +1,146 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: tfnetwork.h
|
||||
// Description: Encapsulation of an entire tensorflow graph as a
|
||||
// Tesseract Network.
|
||||
// Author: Ray Smith
|
||||
// Created: Fri Feb 26 09:35:29 PST 2016
|
||||
//
|
||||
// (C) Copyright 2016, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
#ifdef INCLUDE_TENSORFLOW
|
||||
|
||||
#include "tfnetwork.h"
|
||||
|
||||
#include "allheaders.h"
|
||||
#include "input.h"
|
||||
#include "networkscratch.h"
|
||||
|
||||
using tensorflow::Status;
|
||||
using tensorflow::Tensor;
|
||||
using tensorflow::TensorShape;
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
TFNetwork::TFNetwork(const STRING& name) : Network(NT_TENSORFLOW, name, 0, 0) {}
|
||||
|
||||
TFNetwork::~TFNetwork() {}
|
||||
|
||||
int TFNetwork::InitFromProtoStr(const string& proto_str) {
|
||||
if (!model_proto_.ParseFromString(proto_str)) return 0;
|
||||
return InitFromProto();
|
||||
}
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
// Should be overridden by subclasses, but called by their Serialize.
|
||||
bool TFNetwork::Serialize(TFile* fp) const {
|
||||
if (!Network::Serialize(fp)) return false;
|
||||
string proto_str;
|
||||
model_proto_.SerializeToString(&proto_str);
|
||||
GenericVector<char> data;
|
||||
data.init_to_size(proto_str.size(), 0);
|
||||
memcpy(&data[0], proto_str.data(), proto_str.size());
|
||||
if (!data.Serialize(fp)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
// Should be overridden by subclasses, but NOT called by their DeSerialize.
|
||||
bool TFNetwork::DeSerialize(bool swap, TFile* fp) {
|
||||
GenericVector<char> data;
|
||||
if (!data.DeSerialize(swap, fp)) return false;
|
||||
if (!model_proto_.ParseFromArray(&data[0], data.size())) {
|
||||
return false;
|
||||
}
|
||||
return InitFromProto();
|
||||
}
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
void TFNetwork::Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output) {
|
||||
vector<std::pair<string, Tensor>> tf_inputs;
|
||||
int depth = input_shape_.depth();
|
||||
ASSERT_HOST(depth == input.NumFeatures());
|
||||
// TODO(rays) Allow batching. For now batch_size = 1.
|
||||
const StrideMap& stride_map = input.stride_map();
|
||||
// TF requires a tensor of shape float[batch, height, width, depth].
|
||||
TensorShape shape{1, stride_map.Size(FD_HEIGHT), stride_map.Size(FD_WIDTH),
|
||||
depth};
|
||||
Tensor input_tensor(tensorflow::DT_FLOAT, shape);
|
||||
// The flat() member gives a 1d array, with a data() member to get the data.
|
||||
auto eigen_tensor = input_tensor.flat<float>();
|
||||
memcpy(eigen_tensor.data(), input.f(0),
|
||||
input.Width() * depth * sizeof(input.f(0)[0]));
|
||||
// Add the tensor to the vector of inputs.
|
||||
tf_inputs.emplace_back(model_proto_.image_input(), input_tensor);
|
||||
|
||||
// Provide tensors giving the width and/or height of the image if they are
|
||||
// required. Some tf ops require a separate tensor with knowledge of the
|
||||
// size of the input as they cannot obtain it from the input tensor. This is
|
||||
// usually true in the case of ops that process a batch of variable-sized
|
||||
// objects.
|
||||
if (!model_proto_.image_widths().empty()) {
|
||||
TensorShape size_shape{1};
|
||||
Tensor width_tensor(tensorflow::DT_INT32, size_shape);
|
||||
auto eigen_wtensor = width_tensor.flat<int32>();
|
||||
*eigen_wtensor.data() = stride_map.Size(FD_WIDTH);
|
||||
tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor);
|
||||
}
|
||||
if (!model_proto_.image_heights().empty()) {
|
||||
TensorShape size_shape{1};
|
||||
Tensor height_tensor(tensorflow::DT_INT32, size_shape);
|
||||
auto eigen_htensor = height_tensor.flat<int32>();
|
||||
*eigen_htensor.data() = stride_map.Size(FD_HEIGHT);
|
||||
tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor);
|
||||
}
|
||||
vector<string> target_layers = {model_proto_.output_layer()};
|
||||
vector<Tensor> outputs;
|
||||
Status s = session_->Run(tf_inputs, target_layers, {}, &outputs);
|
||||
ASSERT_HOST(s.ok());
|
||||
ASSERT_HOST(outputs.size() == 1);
|
||||
const Tensor& output_tensor = outputs[0];
|
||||
// Check the dimensions of the output.
|
||||
ASSERT_HOST(output_tensor.shape().dims() == 2);
|
||||
int output_dim0 = output_tensor.shape().dim_size(0);
|
||||
int output_dim1 = output_tensor.shape().dim_size(1);
|
||||
ASSERT_HOST(output_dim1 == output_shape_.depth());
|
||||
output->Resize2d(false, output_dim0, output_dim1);
|
||||
auto eigen_output = output_tensor.flat<float>();
|
||||
memcpy(output->f(0), eigen_output.data(),
|
||||
output_dim0 * output_dim1 * sizeof(output->f(0)[0]));
|
||||
}
|
||||
|
||||
int TFNetwork::InitFromProto() {
|
||||
spec_ = model_proto_.spec();
|
||||
input_shape_.SetShape(
|
||||
model_proto_.batch_size(), std::max(0, model_proto_.y_size()),
|
||||
std::max(0, model_proto_.x_size()), model_proto_.depth());
|
||||
output_shape_.SetShape(model_proto_.batch_size(), 1, 0,
|
||||
model_proto_.num_classes());
|
||||
output_shape_.set_loss_type(model_proto_.using_ctc() ? LT_CTC : LT_SOFTMAX);
|
||||
ni_ = input_shape_.height();
|
||||
no_ = output_shape_.depth();
|
||||
// Initialize the session_ with the graph. Since we can't get the graph
|
||||
// back from the session_, we have to keep the proto as well
|
||||
tensorflow::SessionOptions options;
|
||||
session_.reset(NewSession(options));
|
||||
Status s = session_->Create(model_proto_.graph());
|
||||
if (s.ok()) return model_proto_.global_step();
|
||||
tprintf("Session_->Create returned '%s'\n", s.error_message().c_str());
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace tesseract
|
||||
|
||||
#endif // ifdef INCLUDE_TENSORFLOW
|
91
lstm/tfnetwork.h
Normal file
91
lstm/tfnetwork.h
Normal file
@ -0,0 +1,91 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: tfnetwork.h
|
||||
// Description: Encapsulation of an entire tensorflow graph as a
|
||||
// Tesseract Network.
|
||||
// Author: Ray Smith
|
||||
// Created: Fri Feb 26 09:35:29 PST 2016
|
||||
//
|
||||
// (C) Copyright 2016, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_TFNETWORK_H_
|
||||
#define TESSERACT_LSTM_TFNETWORK_H_
|
||||
|
||||
#ifdef INCLUDE_TENSORFLOW
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "network.h"
|
||||
#include "static_shape.h"
|
||||
#include "tfnetwork.proto.h"
|
||||
#include "third_party/tensorflow/core/framework/graph.pb.h"
|
||||
#include "third_party/tensorflow/core/public/session.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
class TFNetwork : public Network {
|
||||
public:
|
||||
explicit TFNetwork(const STRING& name);
|
||||
virtual ~TFNetwork();
|
||||
|
||||
// Returns the required shape input to the network.
|
||||
virtual StaticShape InputShape() const { return input_shape_; }
|
||||
// Returns the shape output from the network given an input shape (which may
|
||||
// be partially unknown ie zero).
|
||||
virtual StaticShape OutputShape(const StaticShape& input_shape) const {
|
||||
return output_shape_;
|
||||
}
|
||||
|
||||
virtual STRING spec() const { return spec_.c_str(); }
|
||||
|
||||
// Deserializes *this from a serialized TFNetwork proto. Returns 0 if failed,
|
||||
// otherwise the global step of the serialized graph.
|
||||
int InitFromProtoStr(const string& proto_str);
|
||||
// The number of classes in this network should be equal to those in the
|
||||
// recoder_ in LSTMRecognizer.
|
||||
int num_classes() const { return output_shape_.depth(); }
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
// Should be overridden by subclasses, but called by their Serialize.
|
||||
virtual bool Serialize(TFile* fp) const;
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
// Should be overridden by subclasses, but NOT called by their DeSerialize.
|
||||
virtual bool DeSerialize(bool swap, TFile* fp);
|
||||
|
||||
// Runs forward propagation of activations on the input line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
virtual void Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output);
|
||||
|
||||
private:
|
||||
int InitFromProto();
|
||||
|
||||
// The original network definition for reference.
|
||||
string spec_;
|
||||
// Input tensor parameters.
|
||||
StaticShape input_shape_;
|
||||
// Output tensor parameters.
|
||||
StaticShape output_shape_;
|
||||
// The tensor flow graph is contained in here.
|
||||
std::unique_ptr<tensorflow::Session> session_;
|
||||
// The serialized graph is also contained in here.
|
||||
TFNetworkModel model_proto_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // ifdef INCLUDE_TENSORFLOW
|
||||
|
||||
#endif // TESSERACT_TENSORFLOW_TFNETWORK_H_
|
61
lstm/tfnetwork.proto
Normal file
61
lstm/tfnetwork.proto
Normal file
@ -0,0 +1,61 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tesseract;
|
||||
|
||||
// TODO(rays) How to make this usable both in Google and open source?
|
||||
import "third_party/tensorflow/core/framework/graph.proto";
|
||||
|
||||
// This proto is the interface between a python TF graph builder/trainer and
|
||||
// the C++ world. The writer of this proto must provide fields as documented
|
||||
// by the comments below.
|
||||
// The graph must have a placeholder for NetworkIO, Widths and Heights. The
|
||||
// following python code creates the appropriate placeholders:
|
||||
//
|
||||
// input_layer = tf.placeholder(tf.float32,
|
||||
// shape=[batch_size, xsize, ysize, depth_dim],
|
||||
// name='NetworkIO')
|
||||
// widths = tf.placeholder(tf.int32, shape=[batch_size], name='Widths')
|
||||
// heights = tf.placeholder(tf.int32, shape=[batch_size], name='Heights')
|
||||
// # Flip x and y to the TF convention.
|
||||
// input_layer = tf.transpose(input_layer, [0, 2, 1, 3])
|
||||
//
|
||||
// The widths and heights will be set to indicate the post-scaling size of the
|
||||
// input image(s).
|
||||
// For now batch_size is ignored and set to 1.
|
||||
// The graph should return a 2-dimensional float32 tensor called 'softmax' of
|
||||
// shape [sequence_length, num_classes], where sequence_length is allowed to
|
||||
// be variable, given by the tensor itself.
|
||||
// TODO(rays) determine whether it is worth providing for batch_size >1 and if
|
||||
// so, how.
|
||||
message TFNetworkModel {
|
||||
// The TF graph definition. Required.
|
||||
tensorflow.GraphDef graph = 1;
|
||||
// The training index. Required to be > 0.
|
||||
int64 global_step = 2;
|
||||
// The original network definition for reference. Optional
|
||||
string spec = 3;
|
||||
// Input tensor parameters.
|
||||
// Values per pixel. Required to be 1 or 3. Inputs assumed to be float32.
|
||||
int32 depth = 4;
|
||||
// Image size. Required. Zero implies flexible sizes, fixed if non-zero.
|
||||
// If x_size > 0, images will be cropped/padded to the given size, after
|
||||
// any scaling required by the y_size.
|
||||
// If y_size > 0, images will be scaled isotropically to the given height.
|
||||
int32 x_size = 5;
|
||||
int32 y_size = 6;
|
||||
// Number of images in a batch. Optional.
|
||||
int32 batch_size = 8;
|
||||
// Output tensor parameters.
|
||||
// Number of output classes. Required to match the depth of the softmax.
|
||||
int32 num_classes = 9;
|
||||
// True if this network needs CTC-like decoding, dropping duplicated labels.
|
||||
// The decoder always drops the null character.
|
||||
bool using_ctc = 10;
|
||||
// Name of input image tensor.
|
||||
string image_input = 11;
|
||||
// Name of image height and width tensors.
|
||||
string image_widths = 12;
|
||||
string image_heights = 13;
|
||||
// Name of output (softmax) tensor.
|
||||
string output_layer = 14;
|
||||
}
|
443
lstm/weightmatrix.cpp
Normal file
443
lstm/weightmatrix.cpp
Normal file
@ -0,0 +1,443 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: weightmatrix.h
|
||||
// Description: Hides distinction between float/int implementations.
|
||||
// Author: Ray Smith
|
||||
// Created: Tue Jun 17 11:46:20 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "weightmatrix.h"
|
||||
|
||||
#undef NONX86_BUILD
|
||||
#if defined(ANDROID_BUILD) or defined(__PPC__) or defined(_ARCH_PPC64)
|
||||
#define NONX86_BUILD 1
|
||||
#endif
|
||||
|
||||
#ifndef NONX86_BUILD
|
||||
#include <cpuid.h>
|
||||
#endif
|
||||
#include "dotproductavx.h"
|
||||
#include "dotproductsse.h"
|
||||
#include "statistc.h"
|
||||
#include "svutil.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Architecture detector. Add code here to detect any other architectures for
|
||||
// SIMD-based faster dot product functions. Intended to be a single static
|
||||
// object, but it does no real harm to have more than one.
|
||||
class SIMDDetect {
|
||||
public:
|
||||
SIMDDetect()
|
||||
: arch_tested_(false), avx_available_(false), sse_available_(false) {}
|
||||
|
||||
// Returns true if AVX is available on this system.
|
||||
bool IsAVXAvailable() {
|
||||
if (!arch_tested_) TestArchitecture();
|
||||
return avx_available_;
|
||||
}
|
||||
// Returns true if SSE4.1 is available on this system.
|
||||
bool IsSSEAvailable() {
|
||||
if (!arch_tested_) TestArchitecture();
|
||||
return sse_available_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Tests the architecture in a system-dependent way to detect AVX, SSE and
|
||||
// any other available SIMD equipment.
|
||||
void TestArchitecture() {
|
||||
SVAutoLock lock(&arch_mutex_);
|
||||
if (arch_tested_) return;
|
||||
#if defined(__linux__) && !defined(NONX86_BUILD)
|
||||
if (__get_cpuid_max(0, NULL) >= 1) {
|
||||
unsigned int eax, ebx, ecx, edx;
|
||||
__get_cpuid(1, &eax, &ebx, &ecx, &edx);
|
||||
sse_available_ = (ecx & 0x00080000) != 0;
|
||||
avx_available_ = (ecx & 0x10000000) != 0;
|
||||
}
|
||||
#endif
|
||||
if (avx_available_) tprintf("Found AVX\n");
|
||||
if (sse_available_) tprintf("Found SSE\n");
|
||||
arch_tested_ = true;
|
||||
}
|
||||
|
||||
private:
|
||||
// Detect architecture in only a single thread.
|
||||
SVMutex arch_mutex_;
|
||||
// Flag set to true after TestArchitecture has been called.
|
||||
bool arch_tested_;
|
||||
// If true, then AVX has been detected.
|
||||
bool avx_available_;
|
||||
// If true, then SSe4.1 has been detected.
|
||||
bool sse_available_;
|
||||
};
|
||||
|
||||
static SIMDDetect detector;
|
||||
|
||||
// Copies the whole input transposed, converted to double, into *this.
|
||||
void TransposedArray::Transpose(const GENERIC_2D_ARRAY<double>& input) {
|
||||
int width = input.dim1();
|
||||
int num_features = input.dim2();
|
||||
ResizeNoInit(num_features, width);
|
||||
for (int t = 0; t < width; ++t) WriteStrided(t, input[t]);
|
||||
}
|
||||
|
||||
// Sets up the network for training. Initializes weights using weights of
|
||||
// scale `range` picked according to the random number generator `randomizer`.
|
||||
int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad,
|
||||
float weight_range, TRand* randomizer) {
|
||||
int_mode_ = false;
|
||||
use_ada_grad_ = ada_grad;
|
||||
if (use_ada_grad_) dw_sq_sum_.Resize(no, ni, 0.0);
|
||||
wf_.Resize(no, ni, 0.0);
|
||||
if (randomizer != NULL) {
|
||||
for (int i = 0; i < no; ++i) {
|
||||
for (int j = 0; j < ni; ++j) {
|
||||
wf_[i][j] = randomizer->SignedRand(weight_range);
|
||||
}
|
||||
}
|
||||
}
|
||||
InitBackward();
|
||||
return ni * no;
|
||||
}
|
||||
|
||||
// Converts a float network to an int network. Each set of input weights that
|
||||
// corresponds to a single output weight is converted independently:
|
||||
// Compute the max absolute value of the weight set.
|
||||
// Scale so the max absolute value becomes MAX_INT8.
|
||||
// Round to integer.
|
||||
// Store a multiplicative scale factor (as a double) that will reproduce
|
||||
// the original value, subject to rounding errors.
|
||||
void WeightMatrix::ConvertToInt() {
|
||||
wi_.ResizeNoInit(wf_.dim1(), wf_.dim2());
|
||||
scales_.init_to_size(wi_.dim1(), 0.0);
|
||||
int dim2 = wi_.dim2();
|
||||
for (int t = 0; t < wi_.dim1(); ++t) {
|
||||
double* f_line = wf_[t];
|
||||
inT8* i_line = wi_[t];
|
||||
double max_abs = 0.0;
|
||||
for (int f = 0; f < dim2; ++f) {
|
||||
double abs_val = fabs(f_line[f]);
|
||||
if (abs_val > max_abs) max_abs = abs_val;
|
||||
}
|
||||
double scale = max_abs / MAX_INT8;
|
||||
scales_[t] = scale;
|
||||
if (scale == 0.0) scale = 1.0;
|
||||
for (int f = 0; f < dim2; ++f) {
|
||||
i_line[f] = IntCastRounded(f_line[f] / scale);
|
||||
}
|
||||
}
|
||||
wf_.Resize(1, 1, 0.0);
|
||||
int_mode_ = true;
|
||||
}
|
||||
|
||||
// Allocates any needed memory for running Backward, and zeroes the deltas,
|
||||
// thus eliminating any existing momentum.
|
||||
void WeightMatrix::InitBackward() {
|
||||
int no = int_mode_ ? wi_.dim1() : wf_.dim1();
|
||||
int ni = int_mode_ ? wi_.dim2() : wf_.dim2();
|
||||
dw_.Resize(no, ni, 0.0);
|
||||
updates_.Resize(no, ni, 0.0);
|
||||
wf_t_.Transpose(wf_);
|
||||
}
|
||||
|
||||
// Flag on mode to indicate that this weightmatrix uses inT8.
|
||||
const int kInt8Flag = 1;
|
||||
// Flag on mode to indicate that this weightmatrix uses ada grad.
|
||||
const int kAdaGradFlag = 4;
|
||||
// Flag on mode to indicate that this weightmatrix uses double. Set
|
||||
// independently of kInt8Flag as even in int mode the scales can
|
||||
// be float or double.
|
||||
const int kDoubleFlag = 128;
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool WeightMatrix::Serialize(bool training, TFile* fp) const {
|
||||
// For backward compatability, add kDoubleFlag to mode to indicate the doubles
|
||||
// format, without errs, so we can detect and read old format weight matrices.
|
||||
uinT8 mode = (int_mode_ ? kInt8Flag : 0) |
|
||||
(use_ada_grad_ ? kAdaGradFlag : 0) | kDoubleFlag;
|
||||
if (fp->FWrite(&mode, sizeof(mode), 1) != 1) return false;
|
||||
if (int_mode_) {
|
||||
if (!wi_.Serialize(fp)) return false;
|
||||
if (!scales_.Serialize(fp)) return false;
|
||||
} else {
|
||||
if (!wf_.Serialize(fp)) return false;
|
||||
if (training && !updates_.Serialize(fp)) return false;
|
||||
if (training && use_ada_grad_ && !dw_sq_sum_.Serialize(fp)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool WeightMatrix::DeSerialize(bool training, bool swap, TFile* fp) {
|
||||
uinT8 mode = 0;
|
||||
if (fp->FRead(&mode, sizeof(mode), 1) != 1) return false;
|
||||
int_mode_ = (mode & kInt8Flag) != 0;
|
||||
use_ada_grad_ = (mode & kAdaGradFlag) != 0;
|
||||
if ((mode & kDoubleFlag) == 0) return DeSerializeOld(training, swap, fp);
|
||||
if (int_mode_) {
|
||||
if (!wi_.DeSerialize(swap, fp)) return false;
|
||||
if (!scales_.DeSerialize(swap, fp)) return false;
|
||||
} else {
|
||||
if (!wf_.DeSerialize(swap, fp)) return false;
|
||||
if (training) {
|
||||
InitBackward();
|
||||
if (!updates_.DeSerialize(swap, fp)) return false;
|
||||
if (use_ada_grad_ && !dw_sq_sum_.DeSerialize(swap, fp)) return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// As DeSerialize, but reads an old (float) format WeightMatrix for
|
||||
// backward compatability.
|
||||
bool WeightMatrix::DeSerializeOld(bool training, bool swap, TFile* fp) {
|
||||
GENERIC_2D_ARRAY<float> float_array;
|
||||
if (int_mode_) {
|
||||
if (!wi_.DeSerialize(swap, fp)) return false;
|
||||
GenericVector<float> old_scales;
|
||||
if (!old_scales.DeSerialize(swap, fp)) return false;
|
||||
scales_.init_to_size(old_scales.size(), 0.0);
|
||||
for (int i = 0; i < old_scales.size(); ++i) scales_[i] = old_scales[i];
|
||||
} else {
|
||||
if (!float_array.DeSerialize(swap, fp)) return false;
|
||||
FloatToDouble(float_array, &wf_);
|
||||
}
|
||||
if (training) {
|
||||
InitBackward();
|
||||
if (!float_array.DeSerialize(swap, fp)) return false;
|
||||
FloatToDouble(float_array, &updates_);
|
||||
// Errs was only used in int training, which is now dead.
|
||||
if (!float_array.DeSerialize(swap, fp)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Computes matrix.vector v = Wu.
|
||||
// u is of size W.dim2() - 1 and the output v is of size W.dim1().
|
||||
// u is imagined to have an extra element at the end with value 1, to
|
||||
// implement the bias, but it doesn't actually have it.
|
||||
// Asserts that the call matches what we have.
|
||||
void WeightMatrix::MatrixDotVector(const double* u, double* v) const {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
MatrixDotVectorInternal(wf_, true, false, u, v);
|
||||
}
|
||||
|
||||
void WeightMatrix::MatrixDotVector(const inT8* u, double* v) const {
|
||||
ASSERT_HOST(int_mode_);
|
||||
int num_out = wi_.dim1();
|
||||
int num_in = wi_.dim2() - 1;
|
||||
for (int i = 0; i < num_out; ++i) {
|
||||
const inT8* Wi = wi_[i];
|
||||
int total = 0;
|
||||
if (detector.IsSSEAvailable()) {
|
||||
total = IntDotProductSSE(u, Wi, num_in);
|
||||
} else {
|
||||
for (int j = 0; j < num_in; ++j) total += Wi[j] * u[j];
|
||||
}
|
||||
// Add in the bias and correct for integer values.
|
||||
v[i] = (static_cast<double>(total) / MAX_INT8 + Wi[num_in]) * scales_[i];
|
||||
}
|
||||
}
|
||||
|
||||
// MatrixDotVector for peep weights, MultiplyAccumulate adds the
|
||||
// component-wise products of *this[0] and v to inout.
|
||||
void WeightMatrix::MultiplyAccumulate(const double* v, double* inout) {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
ASSERT_HOST(wf_.dim1() == 1);
|
||||
int n = wf_.dim2();
|
||||
const double* u = wf_[0];
|
||||
for (int i = 0; i < n; ++i) {
|
||||
inout[i] += u[i] * v[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Computes vector.matrix v = uW.
|
||||
// u is of size W.dim1() and the output v is of size W.dim2() - 1.
|
||||
// The last result is discarded, as v is assumed to have an imaginary
|
||||
// last value of 1, as with MatrixDotVector.
|
||||
void WeightMatrix::VectorDotMatrix(const double* u, double* v) const {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
MatrixDotVectorInternal(wf_t_, false, true, u, v);
|
||||
}
|
||||
|
||||
// Fills dw_[i][j] with the dot product u[i][] . v[j][], using elements from
|
||||
// u and v. In terms of the neural network, u is the gradients and v is the
|
||||
// inputs.
|
||||
// Note that (matching MatrixDotVector) v[last][] is missing, presumed 1.0.
|
||||
// Runs parallel if requested. Note that u and v must be transposed.
|
||||
void WeightMatrix::SumOuterTransposed(const TransposedArray& u,
|
||||
const TransposedArray& v,
|
||||
bool in_parallel) {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
int num_outputs = dw_.dim1();
|
||||
ASSERT_HOST(u.dim1() == num_outputs);
|
||||
ASSERT_HOST(u.dim2() == v.dim2());
|
||||
int num_inputs = dw_.dim2() - 1;
|
||||
int num_samples = u.dim2();
|
||||
// v is missing the last element in dim1.
|
||||
ASSERT_HOST(v.dim1() == num_inputs);
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for num_threads(4) if (in_parallel)
|
||||
#endif
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
double* dwi = dw_[i];
|
||||
const double* ui = u[i];
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
dwi[j] = DotProduct(ui, v[j], num_samples);
|
||||
}
|
||||
// The last element of v is missing, presumed 1.0f.
|
||||
double total = 0.0;
|
||||
for (int k = 0; k < num_samples; ++k) total += ui[k];
|
||||
dwi[num_inputs] = total;
|
||||
}
|
||||
}
|
||||
|
||||
// Updates the weights using the given learning rate and momentum.
|
||||
// num_samples is the quotient to be used in the adagrad computation iff
|
||||
// use_ada_grad_ is true.
|
||||
void WeightMatrix::Update(double learning_rate, double momentum,
|
||||
int num_samples) {
|
||||
ASSERT_HOST(!int_mode_);
|
||||
if (use_ada_grad_ && num_samples > 0) {
|
||||
dw_sq_sum_.SumSquares(dw_);
|
||||
dw_.AdaGradScaling(dw_sq_sum_, num_samples);
|
||||
}
|
||||
dw_ *= learning_rate;
|
||||
updates_ += dw_;
|
||||
if (momentum > 0.0) wf_ += updates_;
|
||||
if (momentum >= 0.0) updates_ *= momentum;
|
||||
wf_t_.Transpose(wf_);
|
||||
}
|
||||
|
||||
// Adds the dw_ in other to the dw_ is *this.
|
||||
void WeightMatrix::AddDeltas(const WeightMatrix& other) {
|
||||
ASSERT_HOST(dw_.dim1() == other.dw_.dim1());
|
||||
ASSERT_HOST(dw_.dim2() == other.dw_.dim2());
|
||||
dw_ += other.dw_;
|
||||
}
|
||||
|
||||
// Sums the products of weight updates in *this and other, splitting into
|
||||
// positive (same direction) in *same and negative (different direction) in
|
||||
// *changed.
|
||||
void WeightMatrix::CountAlternators(const WeightMatrix& other, double* same,
|
||||
double* changed) const {
|
||||
int num_outputs = updates_.dim1();
|
||||
int num_inputs = updates_.dim2();
|
||||
ASSERT_HOST(num_outputs == other.updates_.dim1());
|
||||
ASSERT_HOST(num_inputs == other.updates_.dim2());
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
const double* this_i = updates_[i];
|
||||
const double* other_i = other.updates_[i];
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
double product = this_i[j] * other_i[j];
|
||||
if (product < 0.0)
|
||||
*changed -= product;
|
||||
else
|
||||
*same += product;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper computes an integer histogram bucket for a weight and adds it
|
||||
// to the histogram.
|
||||
const int kHistogramBuckets = 16;
|
||||
static void HistogramWeight(double weight, STATS* histogram) {
|
||||
int bucket = kHistogramBuckets - 1;
|
||||
if (weight != 0.0) {
|
||||
double logval = -log2(fabs(weight));
|
||||
bucket = ClipToRange(IntCastRounded(logval), 0, kHistogramBuckets - 1);
|
||||
}
|
||||
histogram->add(bucket, 1);
|
||||
}
|
||||
|
||||
void WeightMatrix::Debug2D(const char* msg) {
|
||||
STATS histogram(0, kHistogramBuckets);
|
||||
if (int_mode_) {
|
||||
for (int i = 0; i < wi_.dim1(); ++i) {
|
||||
for (int j = 0; j < wi_.dim2(); ++j) {
|
||||
HistogramWeight(wi_[i][j] * scales_[i], &histogram);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < wf_.dim1(); ++i) {
|
||||
for (int j = 0; j < wf_.dim2(); ++j) {
|
||||
HistogramWeight(wf_[i][j], &histogram);
|
||||
}
|
||||
}
|
||||
}
|
||||
tprintf("%s\n", msg);
|
||||
histogram.print();
|
||||
}
|
||||
|
||||
// Computes and returns the dot product of the two n-vectors u and v.
|
||||
/* static */
|
||||
double WeightMatrix::DotProduct(const double* u, const double* v, int n) {
|
||||
// Note: because the order of addition is different among the 3 DotProduct
|
||||
// functions, the results can (and do) vary slightly (although they agree
|
||||
// to within about 4e-15). This produces different results when running
|
||||
// training, despite all random inputs being precisely equal.
|
||||
// To get consistent results, use just one of these DotProduct functions.
|
||||
// On a test multi-layer network, serial is 57% slower than sse, and avx
|
||||
// is about 8% faster than sse. This suggests that the time is memory
|
||||
// bandwidth constrained and could benefit from holding the reused vector
|
||||
// in AVX registers.
|
||||
if (detector.IsAVXAvailable()) return DotProductAVX(u, v, n);
|
||||
if (detector.IsSSEAvailable()) return DotProductSSE(u, v, n);
|
||||
double total = 0.0;
|
||||
for (int k = 0; k < n; ++k) total += u[k] * v[k];
|
||||
return total;
|
||||
}
|
||||
|
||||
// Utility function converts an array of float to the corresponding array
|
||||
// of double.
|
||||
/* static */
|
||||
void WeightMatrix::FloatToDouble(const GENERIC_2D_ARRAY<float>& wf,
|
||||
GENERIC_2D_ARRAY<double>* wd) {
|
||||
int dim1 = wf.dim1();
|
||||
int dim2 = wf.dim2();
|
||||
wd->ResizeNoInit(dim1, dim2);
|
||||
for (int i = 0; i < dim1; ++i) {
|
||||
const float* wfi = wf[i];
|
||||
double* wdi = (*wd)[i];
|
||||
for (int j = 0; j < dim2; ++j) wdi[j] = static_cast<double>(wfi[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// Computes matrix.vector v = Wu.
|
||||
// u is of size W.dim2() - add_bias_fwd and the output v is of size
|
||||
// W.dim1() - skip_bias_back.
|
||||
// If add_bias_fwd, u is imagined to have an extra element at the end with value
|
||||
// 1, to implement the bias, weight.
|
||||
// If skip_bias_back, we are actullay performing the backwards product on a
|
||||
// transposed matrix, so we need to drop the v output corresponding to the last
|
||||
// element in dim1.
|
||||
void WeightMatrix::MatrixDotVectorInternal(const GENERIC_2D_ARRAY<double>& w,
|
||||
bool add_bias_fwd,
|
||||
bool skip_bias_back, const double* u,
|
||||
double* v) {
|
||||
int num_results = w.dim1() - skip_bias_back;
|
||||
int extent = w.dim2() - add_bias_fwd;
|
||||
for (int i = 0; i < num_results; ++i) {
|
||||
const double* wi = w[i];
|
||||
double total = DotProduct(wi, u, extent);
|
||||
if (add_bias_fwd) total += wi[extent]; // The bias value.
|
||||
v[i] = total;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#undef NONX86_BUILD
|
183
lstm/weightmatrix.h
Normal file
183
lstm/weightmatrix.h
Normal file
@ -0,0 +1,183 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: weightmatrix.h
|
||||
// Description: Hides distinction between float/int implementations.
|
||||
// Author: Ray Smith
|
||||
// Created: Tue Jun 17 09:05:39 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_WEIGHTMATRIX_H_
|
||||
#define TESSERACT_LSTM_WEIGHTMATRIX_H_
|
||||
|
||||
#include "genericvector.h"
|
||||
#include "matrix.h"
|
||||
#include "tprintf.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// Convenience instantiation of GENERIC_2D_ARRAY<double> with additional
|
||||
// operations to write a strided vector, so the transposed form of the input
|
||||
// is memory-contiguous.
|
||||
class TransposedArray : public GENERIC_2D_ARRAY<double> {
|
||||
public:
|
||||
// Copies the whole input transposed, converted to double, into *this.
|
||||
void Transpose(const GENERIC_2D_ARRAY<double>& input);
|
||||
// Writes a vector of data representing a timestep (gradients or sources).
|
||||
// The data is assumed to be of size1 in size (the strided dimension).
|
||||
void WriteStrided(int t, const float* data) {
|
||||
int size1 = dim1();
|
||||
for (int i = 0; i < size1; ++i) put(i, t, data[i]);
|
||||
}
|
||||
void WriteStrided(int t, const double* data) {
|
||||
int size1 = dim1();
|
||||
for (int i = 0; i < size1; ++i) put(i, t, data[i]);
|
||||
}
|
||||
// Prints the first and last num elements of the un-transposed array.
|
||||
void PrintUnTransposed(int num) {
|
||||
int num_features = dim1();
|
||||
int width = dim2();
|
||||
for (int y = 0; y < num_features; ++y) {
|
||||
for (int t = 0; t < width; ++t) {
|
||||
if (num == 0 || t < num || t + num >= width) {
|
||||
tprintf(" %g", (*this)(y, t));
|
||||
}
|
||||
}
|
||||
tprintf("\n");
|
||||
}
|
||||
}
|
||||
}; // class TransposedArray
|
||||
|
||||
// Generic weight matrix for network layers. Can store the matrix as either
|
||||
// an array of floats or inT8. Provides functions to compute the forward and
|
||||
// backward steps with the matrix and updates to the weights.
|
||||
class WeightMatrix {
|
||||
public:
|
||||
WeightMatrix() : int_mode_(false), use_ada_grad_(false) {}
|
||||
// Sets up the network for training. Initializes weights using weights of
|
||||
// scale `range` picked according to the random number generator `randomizer`.
|
||||
// Note the order is outputs, inputs, as this is the order of indices to
|
||||
// the matrix, so the adjacent elements are multiplied by the input during
|
||||
// a forward operation.
|
||||
int InitWeightsFloat(int no, int ni, bool ada_grad, float weight_range,
|
||||
TRand* randomizer);
|
||||
|
||||
// Converts a float network to an int network. Each set of input weights that
|
||||
// corresponds to a single output weight is converted independently:
|
||||
// Compute the max absolute value of the weight set.
|
||||
// Scale so the max absolute value becomes MAX_INT8.
|
||||
// Round to integer.
|
||||
// Store a multiplicative scale factor (as a float) that will reproduce
|
||||
// the original value, subject to rounding errors.
|
||||
void ConvertToInt();
|
||||
|
||||
// Accessors.
|
||||
bool is_int_mode() const {
|
||||
return int_mode_;
|
||||
}
|
||||
int NumOutputs() const { return int_mode_ ? wi_.dim1() : wf_.dim1(); }
|
||||
// Provides one set of weights. Only used by peep weight maxpool.
|
||||
const double* GetWeights(int index) const { return wf_[index]; }
|
||||
// Provides access to the deltas (dw_).
|
||||
double GetDW(int i, int j) const { return dw_(i, j); }
|
||||
|
||||
// Allocates any needed memory for running Backward, and zeroes the deltas,
|
||||
// thus eliminating any existing momentum.
|
||||
void InitBackward();
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool Serialize(bool training, TFile* fp) const;
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// If swap is true, assumes a big/little-endian swap is needed.
|
||||
bool DeSerialize(bool training, bool swap, TFile* fp);
|
||||
// As DeSerialize, but reads an old (float) format WeightMatrix for
|
||||
// backward compatability.
|
||||
bool DeSerializeOld(bool training, bool swap, TFile* fp);
|
||||
|
||||
// Computes matrix.vector v = Wu.
|
||||
// u is of size W.dim2() - 1 and the output v is of size W.dim1().
|
||||
// u is imagined to have an extra element at the end with value 1, to
|
||||
// implement the bias, but it doesn't actually have it.
|
||||
// Asserts that the call matches what we have.
|
||||
void MatrixDotVector(const double* u, double* v) const;
|
||||
void MatrixDotVector(const inT8* u, double* v) const;
|
||||
// MatrixDotVector for peep weights, MultiplyAccumulate adds the
|
||||
// component-wise products of *this[0] and v to inout.
|
||||
void MultiplyAccumulate(const double* v, double* inout);
|
||||
// Computes vector.matrix v = uW.
|
||||
// u is of size W.dim1() and the output v is of size W.dim2() - 1.
|
||||
// The last result is discarded, as v is assumed to have an imaginary
|
||||
// last value of 1, as with MatrixDotVector.
|
||||
void VectorDotMatrix(const double* u, double* v) const;
|
||||
// Fills dw_[i][j] with the dot product u[i][] . v[j][], using elements
|
||||
// from u and v, starting with u[i][offset] and v[j][offset].
|
||||
// Note that (matching MatrixDotVector) v[last][] is missing, presumed 1.0.
|
||||
// Runs parallel if requested. Note that inputs must be transposed.
|
||||
void SumOuterTransposed(const TransposedArray& u, const TransposedArray& v,
|
||||
bool parallel);
|
||||
// Updates the weights using the given learning rate and momentum.
|
||||
// num_samples is the quotient to be used in the adagrad computation iff
|
||||
// use_ada_grad_ is true.
|
||||
void Update(double learning_rate, double momentum, int num_samples);
|
||||
// Adds the dw_ in other to the dw_ is *this.
|
||||
void AddDeltas(const WeightMatrix& other);
|
||||
// Sums the products of weight updates in *this and other, splitting into
|
||||
// positive (same direction) in *same and negative (different direction) in
|
||||
// *changed.
|
||||
void CountAlternators(const WeightMatrix& other, double* same,
|
||||
double* changed) const;
|
||||
|
||||
void Debug2D(const char* msg);
|
||||
|
||||
// Computes and returns the dot product of the two n-vectors u and v.
|
||||
static double DotProduct(const double* u, const double* v, int n);
|
||||
// Utility function converts an array of float to the corresponding array
|
||||
// of double.
|
||||
static void FloatToDouble(const GENERIC_2D_ARRAY<float>& wf,
|
||||
GENERIC_2D_ARRAY<double>* wd);
|
||||
|
||||
private:
|
||||
// Computes matrix.vector v = Wu.
|
||||
// u is of size starts.back()+extents.back() and the output v is of size
|
||||
// starts.size().
|
||||
// The weight matrix w, is of size starts.size()xMAX(extents)+add_bias_fwd.
|
||||
// If add_bias_fwd, an extra element at the end of w[i] is the bias weight
|
||||
// and is added to v[i].
|
||||
static void MatrixDotVectorInternal(const GENERIC_2D_ARRAY<double>& w,
|
||||
bool add_bias_fwd, bool skip_bias_back,
|
||||
const double* u, double* v);
|
||||
|
||||
private:
|
||||
// Choice between float and 8 bit int implementations.
|
||||
GENERIC_2D_ARRAY<double> wf_;
|
||||
GENERIC_2D_ARRAY<inT8> wi_;
|
||||
// Transposed copy of wf_, used only for Backward, and set with each Update.
|
||||
TransposedArray wf_t_;
|
||||
// Which of wf_ and wi_ are we actually using.
|
||||
bool int_mode_;
|
||||
// True if we are running adagrad in this weight matrix.
|
||||
bool use_ada_grad_;
|
||||
// If we are using wi_, then scales_ is a factor to restore the row product
|
||||
// with a vector to the correct range.
|
||||
GenericVector<double> scales_;
|
||||
// Weight deltas. dw_ is the new delta, and updates_ the momentum-decaying
|
||||
// amount to be added to wf_/wi_.
|
||||
GENERIC_2D_ARRAY<double> dw_;
|
||||
GENERIC_2D_ARRAY<double> updates_;
|
||||
// Iff use_ada_grad_, the sum of squares of dw_. The number of samples is
|
||||
// given to Update(). Serialized iff use_ada_grad_.
|
||||
GENERIC_2D_ARRAY<double> dw_sq_sum_;
|
||||
};
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_WEIGHTMATRIX_H_
|
@ -850,7 +850,8 @@ void BaselineDetect::ComputeBaselineSplinesAndXheights(const ICOORD& page_tr,
|
||||
Pix* pix_spline = pix_debug_ ? pixConvertTo32(pix_debug_) : NULL;
|
||||
for (int i = 0; i < blocks_.size(); ++i) {
|
||||
BaselineBlock* bl_block = blocks_[i];
|
||||
bl_block->PrepareForSplineFitting(page_tr, remove_noise);
|
||||
if (enable_splines)
|
||||
bl_block->PrepareForSplineFitting(page_tr, remove_noise);
|
||||
bl_block->FitBaselineSplines(enable_splines, show_final_rows, textord);
|
||||
if (pix_spline) {
|
||||
bl_block->DrawPixSpline(pix_spline);
|
||||
|
@ -1632,6 +1632,10 @@ TO_BLOCK* ColPartition::MakeBlock(const ICOORD& bleft, const ICOORD& tright,
|
||||
ColPartition_LIST* used_parts) {
|
||||
if (block_parts->empty())
|
||||
return NULL; // Nothing to do.
|
||||
// If the block_parts are not in reading order, then it will make an invalid
|
||||
// block polygon and bounding_box, so sort by bounding box now just to make
|
||||
// sure.
|
||||
block_parts->sort(&ColPartition::SortByBBox);
|
||||
ColPartition_IT it(block_parts);
|
||||
ColPartition* part = it.data();
|
||||
PolyBlockType type = part->type();
|
||||
|
@ -704,6 +704,25 @@ class ColPartition : public ELIST2_LINK {
|
||||
// doing a SideSearch when you want things in the same page column.
|
||||
bool IsInSameColumnAs(const ColPartition& part) const;
|
||||
|
||||
// Sort function to sort by bounding box.
|
||||
static int SortByBBox(const void* p1, const void* p2) {
|
||||
const ColPartition* part1 =
|
||||
*reinterpret_cast<const ColPartition* const*>(p1);
|
||||
const ColPartition* part2 =
|
||||
*reinterpret_cast<const ColPartition* const*>(p2);
|
||||
int mid_y1 = part1->bounding_box_.y_middle();
|
||||
int mid_y2 = part2->bounding_box_.y_middle();
|
||||
if ((part2->bounding_box_.bottom() <= mid_y1 &&
|
||||
mid_y1 <= part2->bounding_box_.top()) ||
|
||||
(part1->bounding_box_.bottom() <= mid_y2 &&
|
||||
mid_y2 <= part1->bounding_box_.top())) {
|
||||
// Sort by increasing x.
|
||||
return part1->bounding_box_.x_middle() - part2->bounding_box_.x_middle();
|
||||
}
|
||||
// Sort by decreasing y.
|
||||
return mid_y2 - mid_y1;
|
||||
}
|
||||
|
||||
// Sets the column bounds. Primarily used in testing.
|
||||
void set_first_column(int column) {
|
||||
first_column_ = column;
|
||||
|
@ -251,6 +251,7 @@ void Textord::filter_blobs(ICOORD page_tr, // top right
|
||||
&block->noise_blobs,
|
||||
&block->small_blobs,
|
||||
&block->large_blobs);
|
||||
if (block->line_size == 0) block->line_size = 1;
|
||||
block->line_spacing = block->line_size *
|
||||
(tesseract::CCStruct::kDescenderFraction +
|
||||
tesseract::CCStruct::kXHeightFraction +
|
||||
@ -769,6 +770,7 @@ void Textord::TransferDiacriticsToBlockGroups(BLOBNBOX_LIST* diacritic_blobs,
|
||||
PointerVector<WordWithBox> word_ptrs;
|
||||
for (int g = 0; g < groups.size(); ++g) {
|
||||
const BlockGroup* group = groups[g];
|
||||
if (group->bounding_box.null_box()) continue;
|
||||
WordGrid word_grid(group->min_xheight, group->bounding_box.botleft(),
|
||||
group->bounding_box.topright());
|
||||
for (int b = 0; b < group->blocks.size(); ++b) {
|
||||
|
@ -1323,9 +1323,10 @@ BOOL8 Textord::make_a_word_break(
|
||||
we may need to set PARTICULAR spaces to fuzzy or not. The values will ONLY
|
||||
be used if the function returns TRUE - ie the word is to be broken.
|
||||
*/
|
||||
blanks = (uinT8) (current_gap / row->space_size);
|
||||
if (blanks < 1)
|
||||
blanks = 1;
|
||||
int num_blanks = current_gap;
|
||||
if (row->space_size > 1.0f)
|
||||
num_blanks = IntCastRounded(current_gap / row->space_size);
|
||||
blanks = static_cast<uinT8>(ClipToRange(num_blanks, 1, MAX_UINT8));
|
||||
fuzzy_sp = FALSE;
|
||||
fuzzy_non = FALSE;
|
||||
/*
|
||||
|
@ -3,6 +3,7 @@ AM_CPPFLAGS += \
|
||||
-DUSE_STD_NAMESPACE -DPANGO_ENABLE_ENGINE\
|
||||
-I$(top_srcdir)/ccmain -I$(top_srcdir)/api \
|
||||
-I$(top_srcdir)/ccutil -I$(top_srcdir)/ccstruct \
|
||||
-I$(top_srcdir)/lstm -I$(top_srcdir)/arch \
|
||||
-I$(top_srcdir)/viewer \
|
||||
-I$(top_srcdir)/textord -I$(top_srcdir)/dict \
|
||||
-I$(top_srcdir)/classify -I$(top_srcdir)/display \
|
||||
@ -45,7 +46,7 @@ libtesseract_tessopt_la_SOURCES = \
|
||||
tessopt.cpp
|
||||
|
||||
bin_PROGRAMS = ambiguous_words classifier_tester cntraining combine_tessdata \
|
||||
dawg2wordlist mftraining set_unicharset_properties shapeclustering \
|
||||
dawg2wordlist lstmtraining mftraining set_unicharset_properties shapeclustering \
|
||||
text2image unicharset_extractor wordlist2dawg
|
||||
|
||||
ambiguous_words_SOURCES = ambiguous_words.cpp
|
||||
@ -58,6 +59,9 @@ ambiguous_words_LDADD += \
|
||||
../textord/libtesseract_textord.la \
|
||||
../classify/libtesseract_classify.la \
|
||||
../dict/libtesseract_dict.la \
|
||||
../arch/libtesseract_avx.la \
|
||||
../arch/libtesseract_sse.la \
|
||||
../lstm/libtesseract_lstm.la \
|
||||
../ccstruct/libtesseract_ccstruct.la \
|
||||
../cutil/libtesseract_cutil.la \
|
||||
../viewer/libtesseract_viewer.la \
|
||||
@ -82,6 +86,9 @@ classifier_tester_LDADD += \
|
||||
../textord/libtesseract_textord.la \
|
||||
../classify/libtesseract_classify.la \
|
||||
../dict/libtesseract_dict.la \
|
||||
../arch/libtesseract_avx.la \
|
||||
../arch/libtesseract_sse.la \
|
||||
../lstm/libtesseract_lstm.la \
|
||||
../ccstruct/libtesseract_ccstruct.la \
|
||||
../cutil/libtesseract_cutil.la \
|
||||
../viewer/libtesseract_viewer.la \
|
||||
@ -115,6 +122,9 @@ cntraining_LDADD += \
|
||||
../textord/libtesseract_textord.la \
|
||||
../classify/libtesseract_classify.la \
|
||||
../dict/libtesseract_dict.la \
|
||||
../arch/libtesseract_avx.la \
|
||||
../arch/libtesseract_sse.la \
|
||||
../lstm/libtesseract_lstm.la \
|
||||
../ccstruct/libtesseract_ccstruct.la \
|
||||
../cutil/libtesseract_cutil.la \
|
||||
../viewer/libtesseract_viewer.la \
|
||||
@ -136,6 +146,9 @@ if USING_MULTIPLELIBS
|
||||
dawg2wordlist_LDADD += \
|
||||
../classify/libtesseract_classify.la \
|
||||
../dict/libtesseract_dict.la \
|
||||
../arch/libtesseract_avx.la \
|
||||
../arch/libtesseract_sse.la \
|
||||
../lstm/libtesseract_lstm.la \
|
||||
../ccstruct/libtesseract_ccstruct.la \
|
||||
../cutil/libtesseract_cutil.la \
|
||||
../viewer/libtesseract_viewer.la \
|
||||
@ -150,6 +163,33 @@ dawg2wordlist_LDADD += \
|
||||
../api/libtesseract.la
|
||||
endif
|
||||
|
||||
lstmtraining_SOURCES = lstmtraining.cpp
|
||||
#lstmtraining_LDFLAGS = -static
|
||||
lstmtraining_LDADD = \
|
||||
libtesseract_training.la \
|
||||
libtesseract_tessopt.la \
|
||||
$(libicu)
|
||||
if USING_MULTIPLELIBS
|
||||
lstmtraining_LDADD += \
|
||||
../textord/libtesseract_textord.la \
|
||||
../classify/libtesseract_classify.la \
|
||||
../dict/libtesseract_dict.la \
|
||||
../arch/libtesseract_avx.la \
|
||||
../arch/libtesseract_sse.la \
|
||||
../lstm/libtesseract_lstm.la \
|
||||
../ccstruct/libtesseract_ccstruct.la \
|
||||
../cutil/libtesseract_cutil.la \
|
||||
../viewer/libtesseract_viewer.la \
|
||||
../ccmain/libtesseract_main.la \
|
||||
../cube/libtesseract_cube.la \
|
||||
../neural_networks/runtime/libtesseract_neural.la \
|
||||
../wordrec/libtesseract_wordrec.la \
|
||||
../ccutil/libtesseract_ccutil.la
|
||||
else
|
||||
lstmtraining_LDADD += \
|
||||
../api/libtesseract.la
|
||||
endif
|
||||
|
||||
mftraining_SOURCES = mftraining.cpp mergenf.cpp
|
||||
#mftraining_LDFLAGS = -static
|
||||
mftraining_LDADD = \
|
||||
@ -160,6 +200,9 @@ mftraining_LDADD += \
|
||||
../textord/libtesseract_textord.la \
|
||||
../classify/libtesseract_classify.la \
|
||||
../dict/libtesseract_dict.la \
|
||||
../arch/libtesseract_avx.la \
|
||||
../arch/libtesseract_sse.la \
|
||||
../lstm/libtesseract_lstm.la \
|
||||
../ccstruct/libtesseract_ccstruct.la \
|
||||
../cutil/libtesseract_cutil.la \
|
||||
../viewer/libtesseract_viewer.la \
|
||||
@ -185,6 +228,9 @@ set_unicharset_properties_LDADD += \
|
||||
../classify/libtesseract_classify.la \
|
||||
../dict/libtesseract_dict.la \
|
||||
../ccstruct/libtesseract_ccstruct.la \
|
||||
../arch/libtesseract_avx.la \
|
||||
../arch/libtesseract_sse.la \
|
||||
../lstm/libtesseract_lstm.la \
|
||||
../cutil/libtesseract_cutil.la \
|
||||
../viewer/libtesseract_viewer.la \
|
||||
../ccmain/libtesseract_main.la \
|
||||
@ -207,6 +253,9 @@ shapeclustering_LDADD += \
|
||||
../textord/libtesseract_textord.la \
|
||||
../classify/libtesseract_classify.la \
|
||||
../dict/libtesseract_dict.la \
|
||||
../arch/libtesseract_avx.la \
|
||||
../arch/libtesseract_sse.la \
|
||||
../lstm/libtesseract_lstm.la \
|
||||
../ccstruct/libtesseract_ccstruct.la \
|
||||
../cutil/libtesseract_cutil.la \
|
||||
../viewer/libtesseract_viewer.la \
|
||||
@ -230,6 +279,9 @@ text2image_LDADD += \
|
||||
../textord/libtesseract_textord.la \
|
||||
../classify/libtesseract_classify.la \
|
||||
../dict/libtesseract_dict.la \
|
||||
../arch/libtesseract_avx.la \
|
||||
../arch/libtesseract_sse.la \
|
||||
../lstm/libtesseract_lstm.la \
|
||||
../ccstruct/libtesseract_ccstruct.la \
|
||||
../cutil/libtesseract_cutil.la \
|
||||
../viewer/libtesseract_viewer.la \
|
||||
@ -266,6 +318,9 @@ if USING_MULTIPLELIBS
|
||||
wordlist2dawg_LDADD += \
|
||||
../classify/libtesseract_classify.la \
|
||||
../dict/libtesseract_dict.la \
|
||||
../arch/libtesseract_avx.la \
|
||||
../arch/libtesseract_sse.la \
|
||||
../lstm/libtesseract_lstm.la \
|
||||
../ccstruct/libtesseract_ccstruct.la \
|
||||
../cutil/libtesseract_cutil.la \
|
||||
../viewer/libtesseract_viewer.la \
|
||||
|
@ -22,10 +22,36 @@
|
||||
|
||||
#include <stdlib.h>
|
||||
#include "allheaders.h" // from leptonica
|
||||
#include "genericvector.h"
|
||||
#include "helpers.h" // For TRand.
|
||||
#include "rect.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
// A randomized perspective distortion can be applied to synthetic input.
|
||||
// The perspective distortion comes from leptonica, which uses 2 sets of 4
|
||||
// corners to determine the distortion. There are random values for each of
|
||||
// the x numbers x0..x3 and y0..y3, except for x2 and x3 which are instead
|
||||
// defined in terms of a single shear value. This reduces the degrees of
|
||||
// freedom enough to make the distortion more realistic than it would otherwise
|
||||
// be if all 8 coordinates could move independently.
|
||||
// One additional factor is used for the color of the pixels that don't exist
|
||||
// in the source image.
|
||||
// Name for each of the randomizing factors.
|
||||
enum FactorNames {
|
||||
FN_INCOLOR,
|
||||
FN_Y0,
|
||||
FN_Y1,
|
||||
FN_Y2,
|
||||
FN_Y3,
|
||||
FN_X0,
|
||||
FN_X1,
|
||||
FN_SHEAR,
|
||||
// x2 = x1 - shear
|
||||
// x3 = x0 + shear
|
||||
FN_NUM_FACTORS
|
||||
};
|
||||
|
||||
// Rotation is +/- kRotationRange radians.
|
||||
const float kRotationRange = 0.02f;
|
||||
// Number of grey levels to shift by for each exposure step.
|
||||
@ -144,4 +170,141 @@ Pix* DegradeImage(Pix* input, int exposure, TRand* randomizer,
|
||||
return input;
|
||||
}
|
||||
|
||||
// Creates and returns a Pix distorted by various means according to the bool
|
||||
// flags. If boxes is not NULL, the boxes are resized/positioned according to
|
||||
// any spatial distortion and also by the integer reduction factor box_scale
|
||||
// so they will match what the network will output.
|
||||
// Returns NULL on error. The returned Pix must be pixDestroyed.
|
||||
Pix* PrepareDistortedPix(const Pix* pix, bool perspective, bool invert,
|
||||
bool white_noise, bool smooth_noise, bool blur,
|
||||
int box_reduction, TRand* randomizer,
|
||||
GenericVector<TBOX>* boxes) {
|
||||
Pix* distorted = pixCopy(NULL, const_cast<Pix*>(pix));
|
||||
// Things to do to synthetic training data.
|
||||
if (invert && randomizer->SignedRand(1.0) < 0)
|
||||
pixInvert(distorted, distorted);
|
||||
if ((white_noise || smooth_noise) && randomizer->SignedRand(1.0) > 0.0) {
|
||||
// TODO(rays) Cook noise in a more thread-safe manner than rand().
|
||||
// Attempt to make the sequences reproducible.
|
||||
srand(randomizer->IntRand());
|
||||
Pix* pixn = pixAddGaussianNoise(distorted, 8.0);
|
||||
pixDestroy(&distorted);
|
||||
if (smooth_noise) {
|
||||
distorted = pixBlockconv(pixn, 1, 1);
|
||||
pixDestroy(&pixn);
|
||||
} else {
|
||||
distorted = pixn;
|
||||
}
|
||||
}
|
||||
if (blur && randomizer->SignedRand(1.0) > 0.0) {
|
||||
Pix* blurred = pixBlockconv(distorted, 1, 1);
|
||||
pixDestroy(&distorted);
|
||||
distorted = blurred;
|
||||
}
|
||||
if (perspective)
|
||||
GeneratePerspectiveDistortion(0, 0, randomizer, &distorted, boxes);
|
||||
if (boxes != NULL) {
|
||||
for (int b = 0; b < boxes->size(); ++b) {
|
||||
(*boxes)[b].scale(1.0f / box_reduction);
|
||||
if ((*boxes)[b].width() <= 0)
|
||||
(*boxes)[b].set_right((*boxes)[b].left() + 1);
|
||||
}
|
||||
}
|
||||
return distorted;
|
||||
}
|
||||
|
||||
// Distorts anything that has a non-null pointer with the same pseudo-random
|
||||
// perspective distortion. Width and height only need to be set if there
|
||||
// is no pix. If there is a pix, then they will be taken from there.
|
||||
void GeneratePerspectiveDistortion(int width, int height, TRand* randomizer,
|
||||
Pix** pix, GenericVector<TBOX>* boxes) {
|
||||
if (pix != NULL && *pix != NULL) {
|
||||
width = pixGetWidth(*pix);
|
||||
height = pixGetHeight(*pix);
|
||||
}
|
||||
float* im_coeffs = NULL;
|
||||
float* box_coeffs = NULL;
|
||||
l_int32 incolor =
|
||||
ProjectiveCoeffs(width, height, randomizer, &im_coeffs, &box_coeffs);
|
||||
if (pix != NULL && *pix != NULL) {
|
||||
// Transform the image.
|
||||
Pix* transformed = pixProjective(*pix, im_coeffs, incolor);
|
||||
if (transformed == NULL) {
|
||||
tprintf("Projective transformation failed!!\n");
|
||||
return;
|
||||
}
|
||||
pixDestroy(pix);
|
||||
*pix = transformed;
|
||||
}
|
||||
if (boxes != NULL) {
|
||||
// Transform the boxes.
|
||||
for (int b = 0; b < boxes->size(); ++b) {
|
||||
int x1, y1, x2, y2;
|
||||
const TBOX& box = (*boxes)[b];
|
||||
projectiveXformSampledPt(box_coeffs, box.left(), height - box.top(), &x1,
|
||||
&y1);
|
||||
projectiveXformSampledPt(box_coeffs, box.right(), height - box.bottom(),
|
||||
&x2, &y2);
|
||||
TBOX new_box1(x1, height - y2, x2, height - y1);
|
||||
projectiveXformSampledPt(box_coeffs, box.left(), height - box.bottom(),
|
||||
&x1, &y1);
|
||||
projectiveXformSampledPt(box_coeffs, box.right(), height - box.top(), &x2,
|
||||
&y2);
|
||||
TBOX new_box2(x1, height - y1, x2, height - y2);
|
||||
(*boxes)[b] = new_box1.bounding_union(new_box2);
|
||||
}
|
||||
}
|
||||
free(im_coeffs);
|
||||
free(box_coeffs);
|
||||
}
|
||||
|
||||
// Computes the coefficients of a randomized projective transformation.
|
||||
// The image transform requires backward transformation coefficient, and the
|
||||
// box transform the forward coefficients.
|
||||
// Returns the incolor arg to pixProjective.
|
||||
int ProjectiveCoeffs(int width, int height, TRand* randomizer,
|
||||
float** im_coeffs, float** box_coeffs) {
|
||||
// Setup "from" points.
|
||||
Pta* src_pts = ptaCreate(4);
|
||||
ptaAddPt(src_pts, 0.0f, 0.0f);
|
||||
ptaAddPt(src_pts, width, 0.0f);
|
||||
ptaAddPt(src_pts, width, height);
|
||||
ptaAddPt(src_pts, 0.0f, height);
|
||||
// Extract factors from pseudo-random sequence.
|
||||
float factors[FN_NUM_FACTORS];
|
||||
float shear = 0.0f; // Shear is signed.
|
||||
for (int i = 0; i < FN_NUM_FACTORS; ++i) {
|
||||
// Everything is squared to make wild values rarer.
|
||||
if (i == FN_SHEAR) {
|
||||
// Shear is signed.
|
||||
shear = randomizer->SignedRand(0.5 / 3.0);
|
||||
shear = shear >= 0.0 ? shear * shear : -shear * shear;
|
||||
// Keep the sheared points within the original rectangle.
|
||||
if (shear < -factors[FN_X0]) shear = -factors[FN_X0];
|
||||
if (shear > factors[FN_X1]) shear = factors[FN_X1];
|
||||
factors[i] = shear;
|
||||
} else if (i != FN_INCOLOR) {
|
||||
factors[i] = fabs(randomizer->SignedRand(1.0));
|
||||
if (i <= FN_Y3)
|
||||
factors[i] *= 5.0 / 8.0;
|
||||
else
|
||||
factors[i] *= 0.5;
|
||||
factors[i] *= factors[i];
|
||||
}
|
||||
}
|
||||
// Setup "to" points.
|
||||
Pta* dest_pts = ptaCreate(4);
|
||||
ptaAddPt(dest_pts, factors[FN_X0] * width, factors[FN_Y0] * height);
|
||||
ptaAddPt(dest_pts, (1.0f - factors[FN_X1]) * width, factors[FN_Y1] * height);
|
||||
ptaAddPt(dest_pts, (1.0f - factors[FN_X1] + shear) * width,
|
||||
(1 - factors[FN_Y2]) * height);
|
||||
ptaAddPt(dest_pts, (factors[FN_X0] + shear) * width,
|
||||
(1 - factors[FN_Y3]) * height);
|
||||
getProjectiveXformCoeffs(dest_pts, src_pts, im_coeffs);
|
||||
getProjectiveXformCoeffs(src_pts, dest_pts, box_coeffs);
|
||||
ptaDestroy(&src_pts);
|
||||
ptaDestroy(&dest_pts);
|
||||
return factors[FN_INCOLOR] > 0.5f ? L_BRING_IN_WHITE : L_BRING_IN_BLACK;
|
||||
}
|
||||
|
||||
} // namespace tesseract
|
||||
|
@ -20,12 +20,13 @@
|
||||
#ifndef TESSERACT_TRAINING_DEGRADEIMAGE_H_
|
||||
#define TESSERACT_TRAINING_DEGRADEIMAGE_H_
|
||||
|
||||
struct Pix;
|
||||
#include "allheaders.h"
|
||||
#include "genericvector.h"
|
||||
#include "helpers.h" // For TRand.
|
||||
#include "rect.h"
|
||||
|
||||
namespace tesseract {
|
||||
|
||||
class TRand;
|
||||
|
||||
// Degrade the pix as if by a print/copy/scan cycle with exposure > 0
|
||||
// corresponding to darkening on the copier and <0 lighter and 0 not copied.
|
||||
// If rotation is not NULL, the clockwise rotation in radians is saved there.
|
||||
@ -34,6 +35,27 @@ class TRand;
|
||||
struct Pix* DegradeImage(struct Pix* input, int exposure, TRand* randomizer,
|
||||
float* rotation);
|
||||
|
||||
// Creates and returns a Pix distorted by various means according to the bool
|
||||
// flags. If boxes is not NULL, the boxes are resized/positioned according to
|
||||
// any spatial distortion and also by the integer reduction factor box_scale
|
||||
// so they will match what the network will output.
|
||||
// Returns NULL on error. The returned Pix must be pixDestroyed.
|
||||
Pix* PrepareDistortedPix(const Pix* pix, bool perspective, bool invert,
|
||||
bool white_noise, bool smooth_noise, bool blur,
|
||||
int box_reduction, TRand* randomizer,
|
||||
GenericVector<TBOX>* boxes);
|
||||
// Distorts anything that has a non-null pointer with the same pseudo-random
|
||||
// perspective distortion. Width and height only need to be set if there
|
||||
// is no pix. If there is a pix, then they will be taken from there.
|
||||
void GeneratePerspectiveDistortion(int width, int height, TRand* randomizer,
|
||||
Pix** pix, GenericVector<TBOX>* boxes);
|
||||
// Computes the coefficients of a randomized projective transformation.
|
||||
// The image transform requires backward transformation coefficient, and the
|
||||
// box transform the forward coefficients.
|
||||
// Returns the incolor arg to pixProjective.
|
||||
int ProjectiveCoeffs(int width, int height, TRand* randomizer,
|
||||
float** im_coeffs, float** box_coeffs);
|
||||
|
||||
} // namespace tesseract
|
||||
|
||||
#endif // TESSERACT_TRAINING_DEGRADEIMAGE_H_
|
||||
|
@ -868,6 +868,9 @@ set_lang_specific_parameters() {
|
||||
AMBIGS_FILTER_DENOMINATOR="100000"
|
||||
LEADING="32"
|
||||
MEAN_COUNT="40" # Default for latin script.
|
||||
# Language to mix with the language for maximum accuracy. Defaults to eng.
|
||||
# If no language is good, set to the base language.
|
||||
MIX_LANG="eng"
|
||||
|
||||
case ${lang} in
|
||||
# Latin languages.
|
||||
@ -959,11 +962,13 @@ set_lang_specific_parameters() {
|
||||
WORD_DAWG_SIZE=1000000
|
||||
test -z "$FONTS" && FONTS=( "${EARLY_LATIN_FONTS[@]}" );;
|
||||
|
||||
# Cyrillic script-based languages.
|
||||
# Cyrillic script-based languages. It is bad to mix Latin with Cyrillic.
|
||||
rus ) test -z "$FONTS" && FONTS=( "${RUSSIAN_FONTS[@]}" )
|
||||
MIX_LANG="rus"
|
||||
NUMBER_DAWG_FACTOR=0.05
|
||||
WORD_DAWG_SIZE=1000000 ;;
|
||||
aze_cyrl | bel | bul | kaz | mkd | srp | tgk | ukr | uzb_cyrl )
|
||||
MIX_LANG="${lang}"
|
||||
test -z "$FONTS" && FONTS=( "${RUSSIAN_FONTS[@]}" ) ;;
|
||||
|
||||
# Special code for performing Cyrillic language-id that is trained on
|
||||
|
185
training/lstmtraining.cpp
Normal file
185
training/lstmtraining.cpp
Normal file
@ -0,0 +1,185 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: lstmtraining.cpp
|
||||
// Description: Training program for LSTM-based networks.
|
||||
// Author: Ray Smith
|
||||
// Created: Fri May 03 11:05:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef USE_STD_NAMESPACE
|
||||
#include "base/commandlineflags.h"
|
||||
#endif
|
||||
#include "commontraining.h"
|
||||
#include "lstmtrainer.h"
|
||||
#include "params.h"
|
||||
#include "strngs.h"
|
||||
#include "tprintf.h"
|
||||
#include "unicharset_training_utils.h"
|
||||
|
||||
INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment.");
|
||||
STRING_PARAM_FLAG(net_spec, "[I1,48Lt1,100O]", "Network specification");
|
||||
INT_PARAM_FLAG(train_mode, 64, "Controls gross training behavior.");
|
||||
INT_PARAM_FLAG(net_mode, 192, "Controls network behavior.");
|
||||
INT_PARAM_FLAG(perfect_sample_delay, 4,
|
||||
"How many imperfect samples between perfect ones.");
|
||||
DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent.");
|
||||
DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights.");
|
||||
DOUBLE_PARAM_FLAG(learning_rate, 1.0e-4, "Weight factor for new deltas.");
|
||||
DOUBLE_PARAM_FLAG(momentum, 0.9, "Decay factor for repeating deltas.");
|
||||
INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images.");
|
||||
STRING_PARAM_FLAG(continue_from, "", "Existing model to extend");
|
||||
STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models");
|
||||
STRING_PARAM_FLAG(script_dir, "",
|
||||
"Required to set unicharset properties or"
|
||||
" use unicharset compression.");
|
||||
BOOL_PARAM_FLAG(stop_training, false,
|
||||
"Just convert the training model to a runtime model.");
|
||||
INT_PARAM_FLAG(append_index, -1, "Index in continue_from Network at which to"
|
||||
" attach the new network defined by net_spec");
|
||||
BOOL_PARAM_FLAG(debug_network, false,
|
||||
"Get info on distribution of weight values");
|
||||
INT_PARAM_FLAG(max_iterations, 0, "If set, exit after this many iterations");
|
||||
DECLARE_STRING_PARAM_FLAG(U);
|
||||
|
||||
// Number of training images to train between calls to MaintainCheckpoints.
|
||||
const int kNumPagesPerBatch = 100;
|
||||
|
||||
// Apart from command-line flags, input is a collection of lstmf files, that
|
||||
// were previously created using tesseract with the lstm.train config file.
|
||||
// The program iterates over the inputs, feeding the data to the network,
|
||||
// until the error rate reaches a specified target or max_iterations is reached.
|
||||
int main(int argc, char **argv) {
|
||||
ParseArguments(&argc, &argv);
|
||||
// Purify the model name in case it is based on the network string.
|
||||
if (FLAGS_model_output.empty()) {
|
||||
tprintf("Must provide a --model_output!\n");
|
||||
return 1;
|
||||
}
|
||||
STRING model_output = FLAGS_model_output.c_str();
|
||||
for (int i = 0; i < model_output.length(); ++i) {
|
||||
if (model_output[i] == '[' || model_output[i] == ']')
|
||||
model_output[i] = '-';
|
||||
if (model_output[i] == '(' || model_output[i] == ')')
|
||||
model_output[i] = '_';
|
||||
}
|
||||
// Setup the trainer.
|
||||
STRING checkpoint_file = FLAGS_model_output.c_str();
|
||||
checkpoint_file += "_checkpoint";
|
||||
STRING checkpoint_bak = checkpoint_file + ".bak";
|
||||
tesseract::LSTMTrainer trainer(
|
||||
NULL, NULL, NULL, NULL, FLAGS_model_output.c_str(),
|
||||
checkpoint_file.c_str(), FLAGS_debug_interval,
|
||||
static_cast<inT64>(FLAGS_max_image_MB) * 1048576);
|
||||
|
||||
// Reading something from an existing model doesn't require many flags,
|
||||
// so do it now and exit.
|
||||
if (FLAGS_stop_training || FLAGS_debug_network) {
|
||||
if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str())) {
|
||||
tprintf("Failed to read continue from: %s\n",
|
||||
FLAGS_continue_from.c_str());
|
||||
return 1;
|
||||
}
|
||||
if (FLAGS_debug_network) {
|
||||
trainer.DebugNetwork();
|
||||
} else {
|
||||
if (FLAGS_train_mode & tesseract::TF_INT_MODE)
|
||||
trainer.ConvertToInt();
|
||||
GenericVector<char> recognizer_data;
|
||||
trainer.SaveRecognitionDump(&recognizer_data);
|
||||
if (!tesseract::SaveDataToFile(recognizer_data,
|
||||
FLAGS_model_output.c_str())) {
|
||||
tprintf("Failed to write recognition model : %s\n",
|
||||
FLAGS_model_output.c_str());
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Get the list of files to process.
|
||||
GenericVector<STRING> filenames;
|
||||
for (int arg = 1; arg < argc; ++arg) {
|
||||
filenames.push_back(STRING(argv[arg]));
|
||||
}
|
||||
|
||||
UNICHARSET unicharset;
|
||||
// Checkpoints always take priority if they are available.
|
||||
if (trainer.TryLoadingCheckpoint(checkpoint_file.string()) ||
|
||||
trainer.TryLoadingCheckpoint(checkpoint_bak.string())) {
|
||||
tprintf("Successfully restored trainer from %s\n",
|
||||
checkpoint_file.string());
|
||||
} else {
|
||||
if (!FLAGS_continue_from.empty()) {
|
||||
// Load a past model file to improve upon.
|
||||
if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str())) {
|
||||
tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str());
|
||||
return 1;
|
||||
}
|
||||
tprintf("Continuing from %s\n", FLAGS_continue_from.c_str());
|
||||
}
|
||||
if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
|
||||
// We need a unicharset to start from scratch or append.
|
||||
string unicharset_str;
|
||||
// Character coding to be used by the classifier.
|
||||
if (!unicharset.load_from_file(FLAGS_U.c_str())) {
|
||||
tprintf("Error: must provide a -U unicharset!\n");
|
||||
return 1;
|
||||
}
|
||||
tesseract::SetupBasicProperties(true, &unicharset);
|
||||
if (FLAGS_append_index >= 0) {
|
||||
tprintf("Appending a new network to an old one!!");
|
||||
if (FLAGS_continue_from.empty()) {
|
||||
tprintf("Must set --continue_from for appending!\n");
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
// We are initializing from scratch.
|
||||
trainer.InitCharSet(unicharset, FLAGS_script_dir.c_str(),
|
||||
FLAGS_train_mode);
|
||||
if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index,
|
||||
FLAGS_net_mode, FLAGS_weight_range,
|
||||
FLAGS_learning_rate, FLAGS_momentum)) {
|
||||
tprintf("Failed to create network from spec: %s\n",
|
||||
FLAGS_net_spec.c_str());
|
||||
return 1;
|
||||
}
|
||||
trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
|
||||
}
|
||||
}
|
||||
if (!trainer.LoadAllTrainingData(filenames)) {
|
||||
tprintf("Load of images failed!!\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
bool best_dumped = true;
|
||||
char* best_model_dump = NULL;
|
||||
size_t best_model_size = 0;
|
||||
STRING best_model_name;
|
||||
do {
|
||||
// Train a few.
|
||||
int iteration = trainer.training_iteration();
|
||||
for (int target_iteration = iteration + kNumPagesPerBatch;
|
||||
iteration < target_iteration;
|
||||
iteration = trainer.training_iteration()) {
|
||||
trainer.TrainOnLine(&trainer, false);
|
||||
}
|
||||
STRING log_str;
|
||||
trainer.MaintainCheckpoints(NULL, &log_str);
|
||||
tprintf("%s\n", log_str.string());
|
||||
} while (trainer.best_error_rate() > FLAGS_target_error_rate &&
|
||||
(trainer.training_iteration() < FLAGS_max_iterations ||
|
||||
FLAGS_max_iterations == 0));
|
||||
tprintf("Finished! Error rate = %g\n", trainer.best_error_rate());
|
||||
return 0;
|
||||
} /* main */
|
||||
|
||||
|
52
training/merge_unicharsets.cpp
Normal file
52
training/merge_unicharsets.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// File: merge_unicharsets.cpp
|
||||
// Description: Simple tool to merge two or more unicharsets.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Sep 30 16:09:01 PDT 2015
|
||||
//
|
||||
// (C) Copyright 2015, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <stdio.h>
|
||||
#include "unicharset.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
// Print usage
|
||||
if (argc < 4) {
|
||||
printf("Usage: %s unicharset-in-1 ... unicharset-in-n unicharset-out\n",
|
||||
argv[0]);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
UNICHARSET input_unicharset, result_unicharset;
|
||||
for (int arg = 1; arg < argc - 1; ++arg) {
|
||||
// Load the input unicharset
|
||||
if (input_unicharset.load_from_file(argv[arg])) {
|
||||
printf("Loaded unicharset of size %d from file %s\n",
|
||||
input_unicharset.size(), argv[arg]);
|
||||
result_unicharset.AppendOtherUnicharset(input_unicharset);
|
||||
} else {
|
||||
printf("Failed to load unicharset from file %s!!\n", argv[arg]);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Save the combined unicharset.
|
||||
if (result_unicharset.save_to_file(argv[argc - 1])) {
|
||||
printf("Wrote unicharset file %s.\n", argv[argc - 1]);
|
||||
} else {
|
||||
printf("Cannot save unicharset file %s.\n", argv[argc - 1]);
|
||||
exit(1);
|
||||
}
|
||||
return 0;
|
||||
}
|
@ -302,6 +302,9 @@ int main (int argc, char **argv) {
|
||||
*shape_table, float_classes,
|
||||
inttemp_file.string(),
|
||||
pffmtable_file.string());
|
||||
for (int c = 0; c < unicharset->size(); ++c) {
|
||||
FreeClassFields(&float_classes[c]);
|
||||
}
|
||||
delete [] float_classes;
|
||||
FreeLabeledClassList(mf_classes);
|
||||
delete trainer;
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user