/////////////////////////////////////////////////////////////////////// // File: segsearch.cpp // Description: Segmentation search functions. // Author: Daria Antonova // Created: Mon Jun 23 11:26:43 PDT 2008 // // (C) Copyright 2009, Google Inc. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // /////////////////////////////////////////////////////////////////////// #include "wordrec.h" #include "associate.h" #include "language_model.h" #include "matrix.h" #include "params.h" #include "lm_pain_points.h" #include "ratngs.h" namespace tesseract { void Wordrec::DoSegSearch(WERD_RES* word_res) { BestChoiceBundle best_choice_bundle(word_res->ratings->dimension()); // Run Segmentation Search. SegSearch(word_res, &best_choice_bundle, nullptr); } void Wordrec::SegSearch(WERD_RES* word_res, BestChoiceBundle* best_choice_bundle, BlamerBundle* blamer_bundle) { LMPainPoints pain_points(segsearch_max_pain_points, segsearch_max_char_wh_ratio, assume_fixed_pitch_char_segment, &getDict(), segsearch_debug_level); // Compute scaling factor that will help us recover blob outline length // from classifier rating and certainty for the blob. float rating_cert_scale = -1.0 * getDict().certainty_scale / rating_scale; GenericVector pending; InitialSegSearch(word_res, &pain_points, &pending, best_choice_bundle, blamer_bundle); if (!SegSearchDone(0)) { // find a better choice if (chop_enable && word_res->chopped_word != nullptr) { improve_by_chopping(rating_cert_scale, word_res, best_choice_bundle, blamer_bundle, &pain_points, &pending); } if (chop_debug) SEAM::PrintSeams("Final seam list:", word_res->seam_array); if (blamer_bundle != nullptr && !blamer_bundle->ChoiceIsCorrect(word_res->best_choice)) { blamer_bundle->SetChopperBlame(word_res, wordrec_debug_blamer); } } // Keep trying to find a better path by fixing the "pain points". MATRIX_COORD pain_point; float pain_point_priority; int num_futile_classifications = 0; STRING blamer_debug; while (wordrec_enable_assoc && (!SegSearchDone(num_futile_classifications) || (blamer_bundle != nullptr && blamer_bundle->GuidedSegsearchStillGoing()))) { // Get the next valid "pain point". bool found_nothing = true; LMPainPointsType pp_type; while ((pp_type = pain_points.Deque(&pain_point, &pain_point_priority)) != LM_PPTYPE_NUM) { if (!pain_point.Valid(*word_res->ratings)) { word_res->ratings->IncreaseBandSize( pain_point.row - pain_point.col + 1); } if (pain_point.Valid(*word_res->ratings) && !word_res->ratings->Classified(pain_point.col, pain_point.row, getDict().WildcardID())) { found_nothing = false; break; } } if (found_nothing) { if (segsearch_debug_level > 0) tprintf("Pain points queue is empty\n"); break; } ProcessSegSearchPainPoint(pain_point_priority, pain_point, LMPainPoints::PainPointDescription(pp_type), &pending, word_res, &pain_points, blamer_bundle); UpdateSegSearchNodes(rating_cert_scale, pain_point.col, &pending, word_res, &pain_points, best_choice_bundle, blamer_bundle); if (!best_choice_bundle->updated) ++num_futile_classifications; if (segsearch_debug_level > 0) { tprintf("num_futile_classifications %d\n", num_futile_classifications); } best_choice_bundle->updated = false; // reset updated // See if it's time to terminate SegSearch or time for starting a guided // search for the true path to find the blame for the incorrect best_choice. if (SegSearchDone(num_futile_classifications) && blamer_bundle != nullptr && blamer_bundle->GuidedSegsearchNeeded(word_res->best_choice)) { InitBlamerForSegSearch(word_res, &pain_points, blamer_bundle, &blamer_debug); } } // end while loop exploring alternative paths if (blamer_bundle != nullptr) { blamer_bundle->FinishSegSearch(word_res->best_choice, wordrec_debug_blamer, &blamer_debug); } if (segsearch_debug_level > 0) { tprintf("Done with SegSearch (AcceptableChoiceFound: %d)\n", language_model_->AcceptableChoiceFound()); } } // Setup and run just the initial segsearch on an established matrix, // without doing any additional chopping or joining. // (Internal factored version that can be used as part of the main SegSearch.) void Wordrec::InitialSegSearch(WERD_RES* word_res, LMPainPoints* pain_points, GenericVector* pending, BestChoiceBundle* best_choice_bundle, BlamerBundle* blamer_bundle) { if (segsearch_debug_level > 0) { tprintf("Starting SegSearch on ratings matrix%s:\n", wordrec_enable_assoc ? " (with assoc)" : ""); word_res->ratings->print(getDict().getUnicharset()); } pain_points->GenerateInitial(word_res); // Compute scaling factor that will help us recover blob outline length // from classifier rating and certainty for the blob. float rating_cert_scale = -1.0 * getDict().certainty_scale / rating_scale; language_model_->InitForWord(prev_word_best_choice_, assume_fixed_pitch_char_segment, segsearch_max_char_wh_ratio, rating_cert_scale); // Initialize blamer-related information: map character boxes recorded in // blamer_bundle->norm_truth_word to the corresponding i,j indices in the // ratings matrix. We expect this step to succeed, since when running the // chopper we checked that the correct chops are present. if (blamer_bundle != nullptr) { blamer_bundle->SetupCorrectSegmentation(word_res->chopped_word, wordrec_debug_blamer); } // pending[col] tells whether there is update work to do to combine // best_choice_bundle->beam[col - 1] with some BLOB_CHOICEs in matrix[col, *]. // As the language model state is updated, pending entries are modified to // minimize duplication of work. It is important that during the update the // children are considered in the non-decreasing order of their column, since // this guarantees that all the parents would be up to date before an update // of a child is done. pending->init_to_size(word_res->ratings->dimension(), SegSearchPending()); // Search the ratings matrix for the initial best path. (*pending)[0].SetColumnClassified(); UpdateSegSearchNodes(rating_cert_scale, 0, pending, word_res, pain_points, best_choice_bundle, blamer_bundle); } void Wordrec::UpdateSegSearchNodes( float rating_cert_scale, int starting_col, GenericVector* pending, WERD_RES *word_res, LMPainPoints *pain_points, BestChoiceBundle *best_choice_bundle, BlamerBundle *blamer_bundle) { MATRIX *ratings = word_res->ratings; ASSERT_HOST(ratings->dimension() == pending->size()); ASSERT_HOST(ratings->dimension() == best_choice_bundle->beam.size()); for (int col = starting_col; col < ratings->dimension(); ++col) { if (!(*pending)[col].WorkToDo()) continue; int first_row = col; int last_row = MIN(ratings->dimension() - 1, col + ratings->bandwidth() - 1); if ((*pending)[col].SingleRow() >= 0) { first_row = last_row = (*pending)[col].SingleRow(); } if (segsearch_debug_level > 0) { tprintf("\n\nUpdateSegSearchNodes: col=%d, rows=[%d,%d], alljust=%d\n", col, first_row, last_row, (*pending)[col].IsRowJustClassified(INT32_MAX)); } // Iterate over the pending list for this column. for (int row = first_row; row <= last_row; ++row) { // Update language model state of this child+parent pair. BLOB_CHOICE_LIST *current_node = ratings->get(col, row); LanguageModelState *parent_node = col == 0 ? nullptr : best_choice_bundle->beam[col - 1]; if (current_node != nullptr && language_model_->UpdateState((*pending)[col].IsRowJustClassified(row), col, row, current_node, parent_node, pain_points, word_res, best_choice_bundle, blamer_bundle) && row + 1 < ratings->dimension()) { // Since the language model state of this entry changed, process all // the child column. (*pending)[row + 1].RevisitWholeColumn(); if (segsearch_debug_level > 0) { tprintf("Added child col=%d to pending\n", row + 1); } } // end if UpdateState. } // end for row. } // end for col. if (best_choice_bundle->best_vse != nullptr) { ASSERT_HOST(word_res->StatesAllValid()); if (best_choice_bundle->best_vse->updated) { pain_points->GenerateFromPath(rating_cert_scale, best_choice_bundle->best_vse, word_res); if (!best_choice_bundle->fixpt.empty()) { pain_points->GenerateFromAmbigs(best_choice_bundle->fixpt, best_choice_bundle->best_vse, word_res); } } } // The segsearch is completed. Reset all updated flags on all VSEs and reset // all pendings. for (int col = 0; col < pending->size(); ++col) { (*pending)[col].Clear(); ViterbiStateEntry_IT vse_it(&best_choice_bundle->beam[col]->viterbi_state_entries); for (vse_it.mark_cycle_pt(); !vse_it.cycled_list(); vse_it.forward()) { vse_it.data()->updated = false; } } } void Wordrec::ProcessSegSearchPainPoint( float pain_point_priority, const MATRIX_COORD &pain_point, const char* pain_point_type, GenericVector* pending, WERD_RES *word_res, LMPainPoints *pain_points, BlamerBundle *blamer_bundle) { if (segsearch_debug_level > 0) { tprintf("Classifying pain point %s priority=%.4f, col=%d, row=%d\n", pain_point_type, pain_point_priority, pain_point.col, pain_point.row); } ASSERT_HOST(pain_points != nullptr); MATRIX *ratings = word_res->ratings; // Classify blob [pain_point.col pain_point.row] if (!pain_point.Valid(*ratings)) { ratings->IncreaseBandSize(pain_point.row + 1 - pain_point.col); } ASSERT_HOST(pain_point.Valid(*ratings)); BLOB_CHOICE_LIST *classified = classify_piece(word_res->seam_array, pain_point.col, pain_point.row, pain_point_type, word_res->chopped_word, blamer_bundle); BLOB_CHOICE_LIST *lst = ratings->get(pain_point.col, pain_point.row); if (lst == nullptr) { ratings->put(pain_point.col, pain_point.row, classified); } else { // We can not delete old BLOB_CHOICEs, since they might contain // ViterbiStateEntries that are parents of other "active" entries. // Thus if the matrix cell already contains classifications we add // the new ones to the beginning of the list. BLOB_CHOICE_IT it(lst); it.add_list_before(classified); delete classified; // safe to delete, since empty after add_list_before() classified = nullptr; } if (segsearch_debug_level > 0) { print_ratings_list("Updated ratings matrix with a new entry:", ratings->get(pain_point.col, pain_point.row), getDict().getUnicharset()); ratings->print(getDict().getUnicharset()); } // Insert initial "pain points" to join the newly classified blob // with its left and right neighbors. if (classified != nullptr && !classified->empty()) { if (pain_point.col > 0) { pain_points->GeneratePainPoint( pain_point.col - 1, pain_point.row, LM_PPTYPE_SHAPE, 0.0, true, segsearch_max_char_wh_ratio, word_res); } if (pain_point.row + 1 < ratings->dimension()) { pain_points->GeneratePainPoint( pain_point.col, pain_point.row + 1, LM_PPTYPE_SHAPE, 0.0, true, segsearch_max_char_wh_ratio, word_res); } } (*pending)[pain_point.col].SetBlobClassified(pain_point.row); } // Resets enough of the results so that the Viterbi search is re-run. // Needed when the n-gram model is enabled, as the multi-length comparison // implementation will re-value existing paths to worse values. void Wordrec::ResetNGramSearch(WERD_RES* word_res, BestChoiceBundle* best_choice_bundle, GenericVector* pending) { // TODO(rays) More refactoring required here. // Delete existing viterbi states. for (int col = 0; col < best_choice_bundle->beam.size(); ++col) { best_choice_bundle->beam[col]->Clear(); } // Reset best_choice_bundle. word_res->ClearWordChoices(); best_choice_bundle->best_vse = nullptr; // Clear out all existing pendings and add a new one for the first column. (*pending)[0].SetColumnClassified(); for (int i = 1; i < pending->size(); ++i) (*pending)[i].Clear(); } void Wordrec::InitBlamerForSegSearch(WERD_RES *word_res, LMPainPoints *pain_points, BlamerBundle *blamer_bundle, STRING *blamer_debug) { pain_points->Clear(); // Clear pain points heap. TessResultCallback2* pp_cb = NewPermanentTessCallback( pain_points, &LMPainPoints::GenerateForBlamer, static_cast(segsearch_max_char_wh_ratio), word_res); blamer_bundle->InitForSegSearch(word_res->best_choice, word_res->ratings, getDict().WildcardID(), wordrec_debug_blamer, blamer_debug, pp_cb); delete pp_cb; } } // namespace tesseract