From 4d3455e1de212a6d70ff86bd486c9511b586a328 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zdenko=20Podobn=C3=BD?= Date: Sun, 16 Jun 2019 19:22:19 +0200 Subject: [PATCH] Integrated Timesteps per symbol into ChoiceIterator Signed-off-by: Noah Metzger # Conflicts: # src/ccmain/ltrresultiterator.cpp --- src/api/hocrrenderer.cpp | 37 +++++++++++++- src/ccmain/ltrresultiterator.cpp | 88 ++++++++++++++++++++++++++++---- src/ccmain/ltrresultiterator.h | 17 ++++++ src/ccmain/resultiterator.cpp | 19 ++++++- src/ccmain/resultiterator.h | 4 ++ src/ccmain/tesseractclass.cpp | 4 +- src/ccmain/tesseractclass.h | 4 +- src/ccstruct/pageres.h | 7 ++- src/cutil/oldlist.cpp | 32 ------------ src/cutil/oldlist.h | 10 ---- src/lstm/recodebeam.cpp | 54 ++++++++++++++------ src/lstm/recodebeam.h | 3 +- 12 files changed, 204 insertions(+), 75 deletions(-) diff --git a/src/api/hocrrenderer.cpp b/src/api/hocrrenderer.cpp index f845d0b59..fe88164a7 100644 --- a/src/api/hocrrenderer.cpp +++ b/src/api/hocrrenderer.cpp @@ -229,11 +229,17 @@ char* TessBaseAPI::GetHOCRText(ETEXT_DESC* monitor, int page_number) { } // Now, process the word... + std::vector>>* rawTimestepMap = + nullptr; std::vector>>* choiceMap = nullptr; + std::vector>>>* + symbolMap = nullptr; if (tesseract_->lstm_choice_mode) { choiceMap = res_it->GetBestLSTMSymbolChoices(); + symbolMap = res_it->GetSegmentedLSTMTimesteps(); + rawTimestepMap = res_it->GetRawLSTMTimesteps(); } hocr_str << "\n lstm_choice_mode == 3 && symbolMap != nullptr) { + for (auto timesteps : *symbolMap) { + hocr_str << "\n "; + for (auto timestep : timesteps) { + hocr_str << "\n "; + for (std::pair conf : timestep) { + hocr_str << "" + << conf.first << ""; + gcnt++; + } + hocr_str << ""; + tcnt++; + } + hocr_str << ""; + scnt++; + } } hocr_str << ""; tcnt = 1; diff --git a/src/ccmain/ltrresultiterator.cpp b/src/ccmain/ltrresultiterator.cpp index 84aafed2f..46f4e4282 100644 --- a/src/ccmain/ltrresultiterator.cpp +++ b/src/ccmain/ltrresultiterator.cpp @@ -357,7 +357,17 @@ bool LTRResultIterator::SymbolIsDropcap() const { ChoiceIterator::ChoiceIterator(const LTRResultIterator& result_it) { ASSERT_HOST(result_it.it_->word() != nullptr); word_res_ = result_it.it_->word(); + oemLSTM_ = word_res_->tesseract->AnyLSTMLang(); + oemLegacy_ = word_res_->tesseract->AnyTessLang(); BLOB_CHOICE_LIST* choices = nullptr; + tstep_index_ = &result_it.blob_index_; + if (oemLSTM_ && !oemLegacy_ && !word_res_->accumulated_timesteps.empty()) { + if (word_res_->leadingSpace) + LSTM_choices_ = &word_res_->accumulated_timesteps[(*tstep_index_) + 1]; + else + LSTM_choices_ = &word_res_->accumulated_timesteps[*tstep_index_]; + filterSpaces(); + } if (word_res_->ratings != nullptr) choices = word_res_->GetBlobChoices(result_it.blob_index_); if (choices != nullptr && !choices->empty()) { @@ -366,23 +376,42 @@ ChoiceIterator::ChoiceIterator(const LTRResultIterator& result_it) { } else { choice_it_ = nullptr; } + if (LSTM_choices_ != nullptr && !LSTM_choices_->empty()) { + LSTM_mode_ = true; + LSTM_choice_it_ = LSTM_choices_->begin(); + } } ChoiceIterator::~ChoiceIterator() { delete choice_it_; } // Moves to the next choice for the symbol and returns false if there // are none left. bool ChoiceIterator::Next() { - if (choice_it_ == nullptr) return false; - choice_it_->forward(); - return !choice_it_->cycled_list(); + if (LSTM_mode_) { + if (LSTM_choice_it_ != LSTM_choices_->end() && + next(LSTM_choice_it_) == LSTM_choices_->end()) { + return false; + } else { + ++LSTM_choice_it_; + return true; + } + } else { + if (choice_it_ == nullptr) return false; + choice_it_->forward(); + return !choice_it_->cycled_list(); + } } // Returns the null terminated UTF-8 encoded text string for the current // choice. Do NOT use delete [] to free after use. const char* ChoiceIterator::GetUTF8Text() const { - if (choice_it_ == nullptr) return nullptr; - UNICHAR_ID id = choice_it_->data()->unichar_id(); - return word_res_->uch_set->id_to_unichar_ext(id); + if (LSTM_mode_) { + std::pair choice = *LSTM_choice_it_; + return choice.first; + } else { + if (choice_it_ == nullptr) return nullptr; + UNICHAR_ID id = choice_it_->data()->unichar_id(); + return word_res_->uch_set->id_to_unichar_ext(id); + } } // Returns the confidence of the current choice depending on the used language @@ -392,10 +421,47 @@ const char* ChoiceIterator::GetUTF8Text() const { // interpreted as a percent probability. (0.0f-100.0f) In this case probabilities // won't add up to 100. Each one stands on its own. float ChoiceIterator::Confidence() const { - if (choice_it_ == nullptr) return 0.0f; - float confidence = 100 + 5 * choice_it_->data()->certainty(); - if (confidence < 0.0f) confidence = 0.0f; - if (confidence > 100.0f) confidence = 100.0f; - return confidence; + if (LSTM_mode_) { + std::pair choice = *LSTM_choice_it_; + return choice.second; + } else { + if (choice_it_ == nullptr) return 0.0f; + float confidence = 100 + 5 * choice_it_->data()->certainty(); + if (confidence < 0.0f) confidence = 0.0f; + if (confidence > 100.0f) confidence = 100.0f; + return confidence; + } +} + +// Returns the set of timesteps which belong to the current symbol +std::vector>>* +ChoiceIterator::Timesteps() const { + if (word_res_->symbol_steps.empty() || !LSTM_mode_) return nullptr; + if (word_res_->leadingSpace) { + return &word_res_->symbol_steps[*(tstep_index_) + 1]; + } else { + return &word_res_->symbol_steps[*tstep_index_]; + } +} + +void ChoiceIterator::filterSpaces() { + if (LSTM_choices_->empty()) return; + std::vector>::iterator it; + bool found_space = false; + float sum = 0; + for (it = LSTM_choices_->begin(); it != LSTM_choices_->end();) { + if (!strcmp(it->first, " ")) { + it = LSTM_choices_->erase(it); + found_space = true; + } else { + sum += it->second; + ++it; + } + } + if (found_space) { + for (it = LSTM_choices_->begin(); it != LSTM_choices_->end(); ++it) { + it->second /= sum; + } + } } } // namespace tesseract. diff --git a/src/ccmain/ltrresultiterator.h b/src/ccmain/ltrresultiterator.h index 3764b1662..65550a437 100644 --- a/src/ccmain/ltrresultiterator.h +++ b/src/ccmain/ltrresultiterator.h @@ -216,11 +216,28 @@ class ChoiceIterator { // probabilities won't add up to 100. Each one stands on its own. float Confidence() const; + // Returns a vector containing all timesteps, which belong to the currently + // selected symbol. A timestep is a vector containing pairs of symbols and + // floating point numbers. The number states the probability for the + // corresponding symbol. + std::vector>>* Timesteps() const; + private: + //clears the remaining spaces out of the results and adapt the probabilities + void filterSpaces(); // Pointer to the WERD_RES object owned by the API. WERD_RES* word_res_; // Iterator over the blob choices. BLOB_CHOICE_IT* choice_it_; + std::vector>* LSTM_choices_ = nullptr; + std::vector>::iterator LSTM_choice_it_; + + const int* tstep_index_; + bool LSTM_mode_ = false; + //true when there is lstm engine related trained data + bool oemLSTM_; + // true when there is legacy engine related trained data + bool oemLegacy_; }; } // namespace tesseract. diff --git a/src/ccmain/resultiterator.cpp b/src/ccmain/resultiterator.cpp index a9e08fac3..c4cdf86b7 100644 --- a/src/ccmain/resultiterator.cpp +++ b/src/ccmain/resultiterator.cpp @@ -604,11 +604,28 @@ char* ResultIterator::GetUTF8Text(PageIteratorLevel level) const { strncpy(result, text.string(), length); return result; } +std::vector>>* +ResultIterator::GetRawLSTMTimesteps() const { + if (it_->word() != nullptr) { + return &it_->word()->raw_timesteps; + } else { + return nullptr; + } +} std::vector>>* ResultIterator::GetBestLSTMSymbolChoices() const { if (it_->word() != nullptr) { - return &it_->word()->timesteps; + return &it_->word()->accumulated_timesteps; + } else { + return nullptr; + } +} + +std::vector>>>* + ResultIterator::GetSegmentedLSTMTimesteps() const { + if (it_->word() != nullptr) { + return &it_->word()->symbol_steps; } else { return nullptr; } diff --git a/src/ccmain/resultiterator.h b/src/ccmain/resultiterator.h index 4333012b6..0be62dd6d 100644 --- a/src/ccmain/resultiterator.h +++ b/src/ccmain/resultiterator.h @@ -100,8 +100,12 @@ class TESS_API ResultIterator : public LTRResultIterator { /** * Returns the LSTM choices for every LSTM timestep for the current word. */ + virtual std::vector>>* + GetRawLSTMTimesteps() const; virtual std::vector>>* GetBestLSTMSymbolChoices() const; + virtual std::vector>>>* + GetSegmentedLSTMTimesteps() const; /** * Return whether the current paragraph's dominant reading direction diff --git a/src/ccmain/tesseractclass.cpp b/src/ccmain/tesseractclass.cpp index ca3f2fa0e..809e111e4 100644 --- a/src/ccmain/tesseractclass.cpp +++ b/src/ccmain/tesseractclass.cpp @@ -526,7 +526,9 @@ Tesseract::Tesseract() "Valid input values are 0, 1, 2 and 3. 0 is the default value. " "With 1 the alternative symbol choices per timestep are included. " "With 2 the alternative symbol choices are accumulated per " - "character. ", + "character. " + "With 3 the alternative symbol choices per timestep are included " + "and separated by the suggested segmentation of Tesseract", this->params()), backup_config_file_(nullptr), diff --git a/src/ccmain/tesseractclass.h b/src/ccmain/tesseractclass.h index db4ed43be..dd723cc8a 100644 --- a/src/ccmain/tesseractclass.h +++ b/src/ccmain/tesseractclass.h @@ -1091,7 +1091,9 @@ class Tesseract : public Wordrec { "Valid input values are 0, 1, 2 and 3. 0 is the default value. " "With 1 the alternative symbol choices per timestep are included. " "With 2 the alternative symbol choices are accumulated per " - "character. "); + "character. " + "With 3 the alternative symbol choices per timestep are included " + "and separated by the suggested segmentation of Tesseract"); //// ambigsrecog.cpp ///////////////////////////////////////////////////////// FILE* init_recog_training(const STRING& fname); diff --git a/src/ccstruct/pageres.h b/src/ccstruct/pageres.h index f1f1e4197..23c8c5799 100644 --- a/src/ccstruct/pageres.h +++ b/src/ccstruct/pageres.h @@ -220,7 +220,12 @@ class WERD_RES : public ELIST_LINK { // blob i and blob i+1. GenericVector blob_gaps; // Stores the lstm choices of every timestep - std::vector>> timesteps; + std::vector>> raw_timesteps; + std::vector>> accumulated_timesteps; + std::vector>>> + symbol_steps; + //Stores if the timestep vector starts with a space + bool leadingSpace = false; // Ratings matrix contains classifier choices for each classified combination // of blobs. The dimension is the same as the number of blobs in chopped_word // and the leading diagonal corresponds to classifier results of the blobs diff --git a/src/cutil/oldlist.cpp b/src/cutil/oldlist.cpp index 9882d2f34..27c25e558 100644 --- a/src/cutil/oldlist.cpp +++ b/src/cutil/oldlist.cpp @@ -151,25 +151,6 @@ void destroy_nodes(LIST list, void_dest destructor) { } } -/********************************************************************** - * i n s e r t - * - * Create a list element and rearrange the pointers so that the first - * element in the list is the second argument. - **********************************************************************/ -void insert(LIST list, void *node) { - LIST element; - - if (list != NIL_LIST) { - element = push(NIL_LIST, node); - set_rest(element, list_rest(list)); - set_rest(list, element); - node = first_node(list); - list->node = first_node(list_rest(list)); - list->next->node = (LIST)node; - } -} - /********************************************************************** * l a s t * @@ -228,19 +209,6 @@ LIST push_last(LIST list, void *item) { return (push(NIL_LIST, item)); } -/********************************************************************** - * r e v e r s e - * - * Create a new list with the elements reversed. The old list is not - * destroyed. - **********************************************************************/ -LIST reverse(LIST list) { - LIST newlist = NIL_LIST; - - iterate(list) copy_first(list, newlist); - return (newlist); -} - /********************************************************************** * s e a r c h * diff --git a/src/cutil/oldlist.h b/src/cutil/oldlist.h index 48e60fce3..a11498991 100644 --- a/src/cutil/oldlist.h +++ b/src/cutil/oldlist.h @@ -38,10 +38,6 @@ * ----------------- * iterate - Macro to create a for loop to visit each cell. * - * COPYING: - * ----------------- - * reverse - (Deprecated) Creates a backwards copy of the input list. - * * LIST CELL COUNTS: * ----------------- * count - Returns the number of list cells in the list. @@ -50,8 +46,6 @@ * TRANSFORMS: (Note: These functions all modify the input list.) * ---------- * delete_d - Removes the requested elements from the list. - * insert - (Deprecated) Add a new element into this spot in a list. - (not NIL_LIST) * push_last - Add a new element onto the end of a list. * * SETS: @@ -130,8 +124,6 @@ LIST destroy(LIST list); void destroy_nodes(LIST list, void_dest destructor); -void insert(LIST list, void *node); - LIST last(LIST var_list); LIST pop(LIST list); @@ -140,8 +132,6 @@ LIST push(LIST list, void* element); LIST push_last(LIST list, void* item); -LIST reverse(LIST list); - LIST search(LIST list, void* key, int_compare is_equal); #endif diff --git a/src/lstm/recodebeam.cpp b/src/lstm/recodebeam.cpp index d8871bee8..61f130ab2 100644 --- a/src/lstm/recodebeam.cpp +++ b/src/lstm/recodebeam.cpp @@ -187,6 +187,7 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box, GenericVector best_nodes; GenericVector second_nodes; std::deque> best_choices; + std::deque> best_choices_acc; ExtractBestPaths(&best_nodes, &second_nodes); if (debug) { DebugPath(unicharset, best_nodes); @@ -196,15 +197,18 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box, DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings, xcoords); } - int timestepEnd= 0; + int timestepEndRaw = 0; + int timestepEnd = 0; + int timestepEnd_acc = 0; //if lstm choice mode is required in granularity level 2 it stores the x //Coordinates of every chosen character to match the alternative choices to it - if (lstm_choice_mode == 2) { + if (lstm_choice_mode) { ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, - &xcoords, &best_choices); + &xcoords, &best_choices, &best_choices_acc); if (best_choices.size() > 0) { timestepEnd = std::get<1>(best_choices.front()); - best_choices.pop_front(); + timestepEnd_acc = std::get<1>(best_choices_acc.front()); + best_choices_acc.pop_front(); } } else { ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, @@ -240,23 +244,22 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box, WERD_RES* word_res = InitializeWord( leading_space, line_box, word_start, word_end, std::min(space_cert, prev_space_cert), unicharset, xcoords, scale_factor); - if (lstm_choice_mode == 1) { - for (size_t i = timestepEnd; i < xcoords[word_end]; i++) { - word_res->timesteps.push_back(timesteps[i]); + if (lstm_choice_mode) { + for (size_t i = timestepEndRaw; i < xcoords[word_end]; i++) { + word_res->raw_timesteps.push_back(timesteps[i]); } - timestepEnd = xcoords[word_end]; - } else if (lstm_choice_mode == 2){ + timestepEndRaw = xcoords[word_end]; // Accumulated Timesteps (choice mode 2 processing) float sum = 0; std::vector> choice_pairs; - for (size_t i = timestepEnd; i < xcoords[word_end]; i++) { + for (size_t i = timestepEnd_acc; i < xcoords[word_end]; i++) { for (std::pair choice : timesteps[i]) { if (std::strcmp(choice.first, "")) { sum += choice.second; choice_pairs.push_back(choice); } } - if ((best_choices.size() > 0 && i == std::get<1>(best_choices.front()) - 1) + if ((best_choices_acc.size() > 0 && i == std::get<1>(best_choices_acc.front()) - 1) || i == xcoords[word_end]-1) { std::map summed_propabilities; for (auto & choice_pair : choice_pairs) { @@ -275,14 +278,32 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box, std::pair(summed_propability.first, summed_propability.second)); } - if (best_choices.size() > 0) { - best_choices.pop_front(); + if (best_choices_acc.size() > 0) { + best_choices_acc.pop_front(); } choice_pairs.clear(); - word_res->timesteps.push_back(accumulated_timestep); + word_res->accumulated_timesteps.push_back(accumulated_timestep); sum = 0; } } + timestepEnd_acc = xcoords[word_end]; + //Symbol Step (choice mode 3 processing) + std::vector>> currentSymbol; + for (size_t i = timestepEnd; i < xcoords[word_end]; i++) { + if (i == std::get<1>(best_choices.front())) { + if (currentSymbol.size() > 0) { + word_res->symbol_steps.push_back(currentSymbol); + currentSymbol.clear(); + } + const char* leadCharacter = + unicharset->id_to_unichar_ext(std::get<0>(best_choices.front())); + if (!strcmp(leadCharacter, " ")) + word_res->leadingSpace = true; + if(best_choices.size()>1) best_choices.pop_front(); + } + currentSymbol.push_back(timesteps[i]); + } + word_res->symbol_steps.push_back(currentSymbol); timestepEnd = xcoords[word_end]; } for (int i = word_start; i < word_end; ++i) { @@ -358,7 +379,8 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds( const GenericVector& best_nodes, GenericVector* unichar_ids, GenericVector* certs, GenericVector* ratings, GenericVector* xcoords, - std::deque>* best_choices) { + std::deque>* best_choices, + std::deque>* best_choices_acc) { unichar_ids->truncate(0); certs->truncate(0); ratings->truncate(0); @@ -412,6 +434,8 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds( if (best_choices != nullptr) { best_choices->push_back( std::tuple(id, tposition)); + best_choices_acc->push_back( + std::tuple(id, tposition)); } } xcoords->push_back(width); diff --git a/src/lstm/recodebeam.h b/src/lstm/recodebeam.h index fc78842a1..697878653 100644 --- a/src/lstm/recodebeam.h +++ b/src/lstm/recodebeam.h @@ -282,7 +282,8 @@ class RecodeBeamSearch { const GenericVector& best_nodes, GenericVector* unichar_ids, GenericVector* certs, GenericVector* ratings, GenericVector* xcoords, - std::deque>* best_choices = nullptr); + std::deque>* best_choices = nullptr, + std::deque>* best_choices_acc = nullptr); // Sets up a word with the ratings matrix and fake blobs with boxes in the // right places.