Merge pull request #2416 from noahmetzger/4.1

Readded parts of the lstm_choice_mode functionality
This commit is contained in:
zdenop 2019-04-30 21:43:46 +02:00 committed by GitHub
commit 0963daad19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 218 additions and 21 deletions

View File

@ -133,7 +133,7 @@ char* TessBaseAPI::GetHOCRText(ETEXT_DESC* monitor, int page_number) {
if (tesseract_ == nullptr || (page_res_ == nullptr && Recognize(monitor) < 0))
return nullptr;
int lcnt = 1, bcnt = 1, pcnt = 1, wcnt = 1, scnt = 1, gcnt = 1;
int lcnt = 1, bcnt = 1, pcnt = 1, wcnt = 1, scnt = 1, tcnt = 1, gcnt = 1;
int page_id = page_number + 1; // hOCR uses 1-based page numbers.
bool para_is_ltr = true; // Default direction is LTR
const char* paragraph_lang = nullptr;
@ -216,6 +216,12 @@ char* TessBaseAPI::GetHOCRText(ETEXT_DESC* monitor, int page_number) {
}
// Now, process the word...
std::vector<std::vector<std::pair<const char*, float>>>* choiceMap =
nullptr;
if (tesseract_->lstm_choice_mode) {
choiceMap = res_it->GetBestLSTMSymbolChoices();
}
hocr_str << "\n <span class='ocrx_word'"
<< " id='"
<< "word_" << page_id << "_" << wcnt << "'";
@ -278,8 +284,48 @@ char* TessBaseAPI::GetHOCRText(ETEXT_DESC* monitor, int page_number) {
res_it->Next(RIL_SYMBOL);
} while (!res_it->Empty(RIL_BLOCK) && !res_it->IsAtBeginningOf(RIL_WORD));
if (italic) hocr_str << "</em>";
if (bold) hocr_str << "</strong>";
if (bold) hocr_str << "</strong>";
// If the lstm choice mode is required it is added here
if (tesseract_->lstm_choice_mode == 1 && choiceMap != nullptr) {
for (auto timestep : *choiceMap) {
hocr_str << "\n <span class='ocrx_cinfo'"
<< " id='"
<< "timestep_" << page_id << "_" << wcnt << "_" << tcnt << "'"
<< ">";
for (std::pair<const char*, float> conf : timestep) {
hocr_str << "<span class='ocr_glyph'"
<< " id='"
<< "choice_" << page_id << "_" << wcnt << "_" << gcnt << "'"
<< " title='x_confs " << int(conf.second * 100) << "'>"
<< conf.first << "</span>";
gcnt++;
}
hocr_str << "</span>";
tcnt++;
}
} else if (tesseract_->lstm_choice_mode == 2 && choiceMap != nullptr) {
for (auto timestep : *choiceMap) {
if (timestep.size() > 0) {
hocr_str << "\n <span class='ocrx_cinfo'"
<< " id='"
<< "lstm_choices_" << page_id << "_" << wcnt << "_" << tcnt
<< "'>";
for (auto & j : timestep) {
hocr_str << "<span class='ocr_glyph'"
<< " id='"
<< "choice_" << page_id << "_" << wcnt << "_" << gcnt
<< "'"
<< " title='x_confs " << int(j.second * 100)
<< "'>" << j.first << "</span>";
gcnt++;
}
hocr_str << "</span>";
tcnt++;
}
}
}
hocr_str << "</span>";
tcnt = 1;
gcnt = 1;
wcnt++;
// Close any ending block/paragraph/textline.

View File

@ -239,7 +239,7 @@ void Tesseract::LSTMRecognizeWord(const BLOCK& block, ROW *row, WERD_RES *word,
if (im_data == nullptr) return;
lstm_recognizer_->RecognizeLine(*im_data, true, classify_debug_level > 0,
kWorstDictCertainty / kCertaintyScale,
word_box, words);
word_box, words, lstm_choice_mode);
delete im_data;
SearchWords(words);
}

View File

@ -216,12 +216,6 @@ class ChoiceIterator {
// probabilities won't add up to 100. Each one stands on its own.
float Confidence() const;
// Returns a vector containing all timesteps, which belong to the currently
// selected symbol. A timestep is a vector containing pairs of symbols and
// floating point numbers. The number states the probability for the
// corresponding symbol.
std::vector<std::vector<std::pair<const char*, float>>>* Timesteps() const;
private:
// Pointer to the WERD_RES object owned by the API.
WERD_RES* word_res_;

View File

@ -27,6 +27,8 @@
#include "tesseractclass.h"
#include "unicharset.h"
#include "unicodes.h"
#include <set>
#include <vector>
namespace tesseract {
@ -603,6 +605,15 @@ char* ResultIterator::GetUTF8Text(PageIteratorLevel level) const {
return result;
}
std::vector<std::vector<std::pair<const char*, float>>>*
ResultIterator::GetBestLSTMSymbolChoices() const {
if (it_->word() != nullptr) {
return &it_->word()->timesteps;
} else {
return nullptr;
}
}
void ResultIterator::AppendUTF8WordText(STRING *text) const {
if (!it_->word()) return;
ASSERT_HOST(it_->word()->best_choice != nullptr);

View File

@ -22,6 +22,8 @@
#ifndef TESSERACT_CCMAIN_RESULT_ITERATOR_H_
#define TESSERACT_CCMAIN_RESULT_ITERATOR_H_
#include <set> // for std::pair
#include <vector> // for std::vector
#include "ltrresultiterator.h" // for LTRResultIterator
#include "platform.h" // for TESS_API, TESS_LOCAL
#include "publictypes.h" // for PageIteratorLevel
@ -95,6 +97,12 @@ class TESS_API ResultIterator : public LTRResultIterator {
*/
virtual char* GetUTF8Text(PageIteratorLevel level) const;
/**
* Returns the LSTM choices for every LSTM timestep for the current word.
*/
virtual std::vector<std::vector<std::pair<const char*, float>>>*
GetBestLSTMSymbolChoices() const;
/**
* Return whether the current paragraph's dominant reading direction
* is left-to-right (as opposed to right-to-left).

View File

@ -521,6 +521,13 @@ Tesseract::Tesseract()
STRING_MEMBER(page_separator, "\f",
"Page separator (default is form feed control character)",
this->params()),
INT_MEMBER(lstm_choice_mode, 0,
"Allows to include alternative symbols choices in the hOCR output. "
"Valid input values are 0, 1, 2 and 3. 0 is the default value. "
"With 1 the alternative symbol choices per timestep are included. "
"With 2 the alternative symbol choices are accumulated per "
"character. ",
this->params()),
backup_config_file_(nullptr),
pix_binary_(nullptr),

View File

@ -1085,6 +1085,13 @@ class Tesseract : public Wordrec {
"Preserve multiple interword spaces");
STRING_VAR_H(page_separator, "\f",
"Page separator (default is form feed control character)");
INT_VAR_H(lstm_choice_mode, 0,
"Allows to include alternative symbols choices in the hOCR "
"output. "
"Valid input values are 0, 1, 2 and 3. 0 is the default value. "
"With 1 the alternative symbol choices per timestep are included. "
"With 2 the alternative symbol choices are accumulated per "
"character. ");
//// ambigsrecog.cpp /////////////////////////////////////////////////////////
FILE* init_recog_training(const STRING& fname);

View File

@ -20,6 +20,8 @@
#define PAGERES_H
#include <cstdint> // for int32_t, int16_t
#include <set> // for std::pair
#include <vector> // for std::vector
#include <sys/types.h> // for int8_t
#include "blamer.h" // for BlamerBundle (ptr only), IRR_NUM_REASONS
#include "clst.h" // for CLIST_ITERATOR, CLISTIZEH
@ -217,6 +219,8 @@ class WERD_RES : public ELIST_LINK {
// Gaps between blobs in chopped_word. blob_gaps[i] is the gap between
// blob i and blob i+1.
GenericVector<int> blob_gaps;
// Stores the lstm choices of every timestep
std::vector<std::vector<std::pair<const char*, float>>> timesteps;
// Ratings matrix contains classifier choices for each classified combination
// of blobs. The dimension is the same as the number of blobs in chopped_word
// and the leading diagonal corresponds to classifier results of the blobs

View File

@ -179,7 +179,8 @@ bool LSTMRecognizer::LoadDictionary(const ParamsVectors* params,
void LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
bool debug, double worst_dict_cert,
const TBOX& line_box,
PointerVector<WERD_RES>* words) {
PointerVector<WERD_RES>* words,
int lstm_choice_mode) {
NetworkIO outputs;
float scale_factor;
NetworkIO inputs;
@ -191,9 +192,9 @@ void LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
}
search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert,
&GetUnicharset());
&GetUnicharset(), lstm_choice_mode);
search_->ExtractBestPathAsWords(line_box, scale_factor, debug,
&GetUnicharset(), words);
&GetUnicharset(), words, lstm_choice_mode);
}
// Helper computes min and mean best results in the output.

View File

@ -175,7 +175,7 @@ class LSTMRecognizer {
// will be used in a dictionary word.
void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
double worst_dict_cert, const TBOX& line_box,
PointerVector<WERD_RES>* words);
PointerVector<WERD_RES>* words, int lstm_choice_mode = 0);
// Helper computes min and mean best results in the output.
void OutputStats(const NetworkIO& outputs, float* min_output,

View File

@ -21,6 +21,12 @@
#include "networkio.h"
#include "pageres.h"
#include "unicharcompress.h"
#include <deque>
#include <map>
#include <set>
#include <tuple>
#include <vector>
#include <algorithm>
namespace tesseract {
@ -75,13 +81,18 @@ RecodeBeamSearch::RecodeBeamSearch(const UnicharCompress& recoder,
// Decodes the set of network outputs, storing the lattice internally.
void RecodeBeamSearch::Decode(const NetworkIO& output, double dict_ratio,
double cert_offset, double worst_dict_cert,
const UNICHARSET* charset) {
const UNICHARSET* charset, int lstm_choice_mode) {
beam_size_ = 0;
int width = output.Width();
if (lstm_choice_mode)
timesteps.clear();
for (int t = 0; t < width; ++t) {
ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]);
DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert,
charset);
if (lstm_choice_mode) {
SaveMostCertainChoices(output.f(t), output.NumFeatures(), charset, t);
}
}
}
void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY<float>& output,
@ -96,6 +107,34 @@ void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY<float>& output,
}
}
void RecodeBeamSearch::SaveMostCertainChoices(const float* outputs,
int num_outputs,
const UNICHARSET* charset,
int xCoord) {
std::vector<std::pair<const char*, float>> choices;
for (int i = 0; i < num_outputs; ++i) {
if (outputs[i] >= 0.01f) {
const char* character;
if (i + 2 >= num_outputs) {
character = "";
} else if (i > 0) {
character = charset->id_to_unichar_ext(i + 2);
} else {
character = charset->id_to_unichar_ext(i);
}
size_t pos = 0;
//order the possible choices within one timestep
//beginning with the most likely
while (choices.size() > pos && choices[pos].second > outputs[i]) {
pos++;
}
choices.insert(choices.begin() + pos,
std::pair<const char*, float>(character, outputs[i]));
}
}
timesteps.push_back(choices);
}
// Returns the best path as labels/scores/xcoords similar to simple CTC.
void RecodeBeamSearch::ExtractBestPathAsLabels(
GenericVector<int>* labels, GenericVector<int>* xcoords) const {
@ -138,7 +177,8 @@ void RecodeBeamSearch::ExtractBestPathAsUnicharIds(
void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box,
float scale_factor, bool debug,
const UNICHARSET* unicharset,
PointerVector<WERD_RES>* words) {
PointerVector<WERD_RES>* words,
int lstm_choice_mode) {
words->truncate(0);
GenericVector<int> unichar_ids;
GenericVector<float> certs;
@ -146,6 +186,7 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box,
GenericVector<int> xcoords;
GenericVector<const RecodeNode*> best_nodes;
GenericVector<const RecodeNode*> second_nodes;
std::deque<std::tuple<int, int>> best_choices;
ExtractBestPaths(&best_nodes, &second_nodes);
if (debug) {
DebugPath(unicharset, best_nodes);
@ -155,10 +196,20 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box,
DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings,
xcoords);
}
int timestepEnd= 0;
//if lstm choice mode is required in granularity level 2 it stores the x
//Coordinates of every chosen character to match the alternative choices to it
ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings,
&xcoords);
if (lstm_choice_mode == 2) {
ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings,
&xcoords, &best_choices);
if (best_choices.size() > 0) {
timestepEnd = std::get<1>(best_choices.front());
best_choices.pop_front();
}
} else {
ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings,
&xcoords);
}
int num_ids = unichar_ids.size();
if (debug) {
DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings,
@ -189,6 +240,51 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box,
WERD_RES* word_res = InitializeWord(
leading_space, line_box, word_start, word_end,
std::min(space_cert, prev_space_cert), unicharset, xcoords, scale_factor);
if (lstm_choice_mode == 1) {
for (size_t i = timestepEnd; i < xcoords[word_end]; i++) {
word_res->timesteps.push_back(timesteps[i]);
}
timestepEnd = xcoords[word_end];
} else if (lstm_choice_mode == 2){
// Accumulated Timesteps (choice mode 2 processing)
float sum = 0;
std::vector<std::pair<const char*, float>> choice_pairs;
for (size_t i = timestepEnd; i < xcoords[word_end]; i++) {
for (std::pair<const char*, float> choice : timesteps[i]) {
if (std::strcmp(choice.first, "")) {
sum += choice.second;
choice_pairs.push_back(choice);
}
}
if ((best_choices.size() > 0 && i == std::get<1>(best_choices.front()) - 1)
|| i == xcoords[word_end]-1) {
std::map<const char*, float> summed_propabilities;
for (auto & choice_pair : choice_pairs) {
summed_propabilities[choice_pair.first] += choice_pair.second;
}
std::vector<std::pair<const char*, float>> accumulated_timestep;
for (auto& summed_propability : summed_propabilities) {
if(sum == 0) break;
summed_propability.second/=sum;
size_t pos = 0;
while (accumulated_timestep.size() > pos
&& accumulated_timestep[pos].second > summed_propability.second) {
pos++;
}
accumulated_timestep.insert(accumulated_timestep.begin() + pos,
std::pair<const char*,float>(summed_propability.first,
summed_propability.second));
}
if (best_choices.size() > 0) {
best_choices.pop_front();
}
choice_pairs.clear();
word_res->timesteps.push_back(accumulated_timestep);
sum = 0;
}
}
timestepEnd = xcoords[word_end];
}
for (int i = word_start; i < word_end; ++i) {
auto* choices = new BLOB_CHOICE_LIST;
BLOB_CHOICE_IT bc_it(choices);
@ -261,7 +357,8 @@ void RecodeBeamSearch::DebugBeamPos(const UNICHARSET& unicharset,
void RecodeBeamSearch::ExtractPathAsUnicharIds(
const GenericVector<const RecodeNode*>& best_nodes,
GenericVector<int>* unichar_ids, GenericVector<float>* certs,
GenericVector<float>* ratings, GenericVector<int>* xcoords) {
GenericVector<float>* ratings, GenericVector<int>* xcoords,
std::deque<std::tuple<int, int>>* best_choices) {
unichar_ids->truncate(0);
certs->truncate(0);
ratings->truncate(0);
@ -292,6 +389,10 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds(
}
unichar_ids->push_back(unichar_id);
xcoords->push_back(t);
if (best_choices != nullptr) {
tposition = t;
id = unichar_id;
}
do {
double cert = best_nodes[t++]->certainty;
// Special-case NO-PERM space to forget the certainty of the previous
@ -308,6 +409,10 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds(
if (certainty < certs->back()) certs->back() = certainty;
ratings->back() += rating;
}
if (best_choices != nullptr) {
best_choices->push_back(
std::tuple<int, int>(id, tposition));
}
}
xcoords->push_back(width);
}

View File

@ -27,6 +27,10 @@
#include "networkio.h"
#include "ratngs.h"
#include "unicharcompress.h"
#include <deque>
#include <set>
#include <tuple>
#include <vector>
namespace tesseract {
@ -181,7 +185,8 @@ class RecodeBeamSearch {
// Decodes the set of network outputs, storing the lattice internally.
// If charset is not null, it enables detailed debugging of the beam search.
void Decode(const NetworkIO& output, double dict_ratio, double cert_offset,
double worst_dict_cert, const UNICHARSET* charset);
double worst_dict_cert, const UNICHARSET* charset,
int lstm_choice_mode = 0);
void Decode(const GENERIC_2D_ARRAY<float>& output, double dict_ratio,
double cert_offset, double worst_dict_cert,
const UNICHARSET* charset);
@ -200,11 +205,16 @@ class RecodeBeamSearch {
// Returns the best path as a set of WERD_RES.
void ExtractBestPathAsWords(const TBOX& line_box, float scale_factor,
bool debug, const UNICHARSET* unicharset,
PointerVector<WERD_RES>* words);
PointerVector<WERD_RES>* words,
int lstm_choice_mode = 0);
// Generates debug output of the content of the beams after a Decode.
void DebugBeams(const UNICHARSET& unicharset) const;
// Stores the alternative characters of every timestep together with their
// probability.
std::vector< std::vector<std::pair<const char*, float>>> timesteps;
// Clipping value for certainty inside Tesseract. Reflects the minimum value
// of certainty that will be returned by ExtractBestPathAsUnicharIds.
// Supposedly on a uniform scale that can be compared across languages and
@ -271,7 +281,8 @@ class RecodeBeamSearch {
static void ExtractPathAsUnicharIds(
const GenericVector<const RecodeNode*>& best_nodes,
GenericVector<int>* unichar_ids, GenericVector<float>* certs,
GenericVector<float>* ratings, GenericVector<int>* xcoords);
GenericVector<float>* ratings, GenericVector<int>* xcoords,
std::deque<std::tuple<int, int>>* best_choices = nullptr);
// Sets up a word with the ratings matrix and fake blobs with boxes in the
// right places.
@ -292,6 +303,9 @@ class RecodeBeamSearch {
double cert_offset, double worst_dict_cert,
const UNICHARSET* charset, bool debug = false);
//Saves the most certain choices for the current time-step
void SaveMostCertainChoices(const float* outputs, int num_outputs, const UNICHARSET* charset, int xCoord);
// Adds to the appropriate beams the legal (according to recoder)
// continuations of context prev, which is from the given index to beams_,
// using the given network outputs to provide scores to the choices. Uses only