/********************************************************************** * File: beam_search.cpp * Description: Class to implement Beam Word Search Algorithm * Author: Ahmad Abdulkader * Created: 2007 * * (C) Copyright 2008, Google Inc. ** Licensed under the Apache License, Version 2.0 (the "License"); ** you may not use this file except in compliance with the License. ** You may obtain a copy of the License at ** http://www.apache.org/licenses/LICENSE-2.0 ** Unless required by applicable law or agreed to in writing, software ** distributed under the License is distributed on an "AS IS" BASIS, ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ** See the License for the specific language governing permissions and ** limitations under the License. * **********************************************************************/ #include #include "beam_search.h" #include "tesseractclass.h" namespace tesseract { BeamSearch::BeamSearch(CubeRecoContext *cntxt, bool word_mode) { cntxt_ = cntxt; seg_pt_cnt_ = 0; col_cnt_ = 1; col_ = NULL; word_mode_ = word_mode; } // Cleanup the lattice corresponding to the last search void BeamSearch::Cleanup() { if (col_ != NULL) { for (int col = 0; col < col_cnt_; col++) { if (col_[col]) delete col_[col]; } delete []col_; } col_ = NULL; } BeamSearch::~BeamSearch() { Cleanup(); } // Creates a set of children nodes emerging from a parent node based on // the character alternate list and the language model. void BeamSearch::CreateChildren(SearchColumn *out_col, LangModel *lang_mod, SearchNode *parent_node, LangModEdge *lm_parent_edge, CharAltList *char_alt_list, int extra_cost) { // get all the edges from this parent int edge_cnt; LangModEdge **lm_edges = lang_mod->GetEdges(char_alt_list, lm_parent_edge, &edge_cnt); if (lm_edges) { // add them to the ending column with the appropriate parent for (int edge = 0; edge < edge_cnt; edge++) { // add a node to the column if the current column is not the // last one, or if the lang model edge indicates it is valid EOW if (!cntxt_->NoisyInput() && out_col->ColIdx() >= seg_pt_cnt_ && !lm_edges[edge]->IsEOW()) { // free edge since no object is going to own it delete lm_edges[edge]; continue; } // compute the recognition cost of this node int recognition_cost = MIN_PROB_COST; if (char_alt_list && char_alt_list->AltCount() > 0) { recognition_cost = MAX(0, char_alt_list->ClassCost( lm_edges[edge]->ClassID())); // Add the no space cost. This should zero in word mode recognition_cost += extra_cost; } // Note that the edge will be freed inside the column if // AddNode is called if (recognition_cost >= 0) { out_col->AddNode(lm_edges[edge], recognition_cost, parent_node, cntxt_); } else { delete lm_edges[edge]; } } // edge // free edge array delete []lm_edges; } // lm_edges } // Performs a beam seach in the specified search using the specified // language model; returns an alternate list of possible words as a result. WordAltList * BeamSearch::Search(SearchObject *srch_obj, LangModel *lang_mod) { // verifications if (!lang_mod) lang_mod = cntxt_->LangMod(); if (!lang_mod) { fprintf(stderr, "Cube ERROR (BeamSearch::Search): could not construct " "LangModel\n"); return NULL; } // free existing state Cleanup(); // get seg pt count seg_pt_cnt_ = srch_obj->SegPtCnt(); if (seg_pt_cnt_ < 0) { return NULL; } col_cnt_ = seg_pt_cnt_ + 1; // disregard suspicious cases if (seg_pt_cnt_ > 128) { fprintf(stderr, "Cube ERROR (BeamSearch::Search): segment point count is " "suspiciously high; bailing out\n"); return NULL; } // alloc memory for columns col_ = new SearchColumn *[col_cnt_]; if (!col_) { fprintf(stderr, "Cube ERROR (BeamSearch::Search): could not construct " "SearchColumn array\n"); return NULL; } memset(col_, 0, col_cnt_ * sizeof(*col_)); // for all possible segments for (int end_seg = 1; end_seg <= (seg_pt_cnt_ + 1); end_seg++) { // create a search column col_[end_seg - 1] = new SearchColumn(end_seg - 1, cntxt_->Params()->BeamWidth()); if (!col_[end_seg - 1]) { fprintf(stderr, "Cube ERROR (BeamSearch::Search): could not construct " "SearchColumn for column %d\n", end_seg - 1); return NULL; } // for all possible start segments int init_seg = MAX(0, end_seg - cntxt_->Params()->MaxSegPerChar()); for (int strt_seg = init_seg; strt_seg < end_seg; strt_seg++) { int parent_nodes_cnt; SearchNode **parent_nodes; // for the root segment, we do not have a parent if (strt_seg == 0) { parent_nodes_cnt = 1; parent_nodes = NULL; } else { // for all the existing nodes in the starting column parent_nodes_cnt = col_[strt_seg - 1]->NodeCount(); parent_nodes = col_[strt_seg - 1]->Nodes(); } // run the shape recognizer CharAltList *char_alt_list = srch_obj->RecognizeSegment(strt_seg - 1, end_seg - 1); // for all the possible parents for (int parent_idx = 0; parent_idx < parent_nodes_cnt; parent_idx++) { // point to the parent node SearchNode *parent_node = !parent_nodes ? NULL : parent_nodes[parent_idx]; LangModEdge *lm_parent_edge = !parent_node ? lang_mod->Root() : parent_node->LangModelEdge(); // compute the cost of not having spaces within the segment range int contig_cost = srch_obj->NoSpaceCost(strt_seg - 1, end_seg - 1); // In phrase mode, compute the cost of not having a space before // this character int no_space_cost = 0; if (!word_mode_ && strt_seg > 0) { no_space_cost = srch_obj->NoSpaceCost(strt_seg - 1); } // if the no space cost is low enough if ((contig_cost + no_space_cost) < MIN_PROB_COST) { // Add the children nodes CreateChildren(col_[end_seg - 1], lang_mod, parent_node, lm_parent_edge, char_alt_list, contig_cost + no_space_cost); } // In phrase mode and if not starting at the root if (!word_mode_ && strt_seg > 0) { // parent_node must be non-NULL // consider starting a new word for nodes that are valid EOW if (parent_node->LangModelEdge()->IsEOW()) { // get the space cost int space_cost = srch_obj->SpaceCost(strt_seg - 1); // if the space cost is low enough if ((contig_cost + space_cost) < MIN_PROB_COST) { // Restart the language model and add nodes as children to the // space node. CreateChildren(col_[end_seg - 1], lang_mod, parent_node, NULL, char_alt_list, contig_cost + space_cost); } } } } // parent } // strt_seg // prune the column nodes col_[end_seg - 1]->Prune(); // Free the column hash table. No longer needed col_[end_seg - 1]->FreeHashTable(); } // end_seg WordAltList *alt_list = CreateWordAltList(srch_obj); return alt_list; } // Creates a Word alternate list from the results in the lattice. WordAltList *BeamSearch::CreateWordAltList(SearchObject *srch_obj) { // create an alternate list of all the nodes in the last column int node_cnt = col_[col_cnt_ - 1]->NodeCount(); SearchNode **srch_nodes = col_[col_cnt_ - 1]->Nodes(); CharBigrams *bigrams = cntxt_->Bigrams(); WordUnigrams *word_unigrams = cntxt_->WordUnigramsObj(); // Save the index of the best-cost node before the alt list is // sorted, so that we can retrieve it from the node list when backtracking. best_presorted_node_idx_ = 0; int best_cost = -1; if (node_cnt <= 0) return NULL; // start creating the word alternate list WordAltList *alt_list = new WordAltList(node_cnt + 1); for (int node_idx = 0; node_idx < node_cnt; node_idx++) { // recognition cost int recognition_cost = srch_nodes[node_idx]->BestCost(); // compute the size cost of the alternate char_32 *ch_buff = NULL; int size_cost = SizeCost(srch_obj, srch_nodes[node_idx], &ch_buff); // accumulate other costs if (ch_buff) { int cost = 0; // char bigram cost int bigram_cost = !bigrams ? 0 : bigrams->Cost(ch_buff, cntxt_->CharacterSet()); // word unigram cost int unigram_cost = !word_unigrams ? 0 : word_unigrams->Cost(ch_buff, cntxt_->LangMod(), cntxt_->CharacterSet()); // overall cost cost = static_cast( (size_cost * cntxt_->Params()->SizeWgt()) + (bigram_cost * cntxt_->Params()->CharBigramWgt()) + (unigram_cost * cntxt_->Params()->WordUnigramWgt()) + (recognition_cost * cntxt_->Params()->RecoWgt())); // insert into word alt list alt_list->Insert(ch_buff, cost, static_cast(srch_nodes[node_idx])); // Note that strict < is necessary because WordAltList::Sort() // uses it in a bubble sort to swap entries. if (best_cost < 0 || cost < best_cost) { best_presorted_node_idx_ = node_idx; best_cost = cost; } delete []ch_buff; } } // sort the alternates based on cost alt_list->Sort(); return alt_list; } // Returns the lattice column corresponding to the specified column index. SearchColumn *BeamSearch::Column(int col) const { if (col < 0 || col >= col_cnt_ || !col_) return NULL; return col_[col]; } // Returns the best node in the last column of last performed search. SearchNode *BeamSearch::BestNode() const { if (col_cnt_ < 1 || !col_ || !col_[col_cnt_ - 1]) return NULL; int node_cnt = col_[col_cnt_ - 1]->NodeCount(); SearchNode **srch_nodes = col_[col_cnt_ - 1]->Nodes(); if (node_cnt < 1 || !srch_nodes || !srch_nodes[0]) return NULL; return srch_nodes[0]; } // Returns the string corresponding to the specified alt. char_32 *BeamSearch::Alt(int alt) const { // get the last column of the lattice if (col_cnt_ <= 0) return NULL; SearchColumn *srch_col = col_[col_cnt_ - 1]; if (!srch_col) return NULL; // point to the last node in the selected path if (alt >= srch_col->NodeCount() || srch_col->Nodes() == NULL) { return NULL; } SearchNode *srch_node = srch_col->Nodes()[alt]; if (!srch_node) return NULL; // get string char_32 *str32 = srch_node->PathString(); if (!str32) return NULL; return str32; } // Backtracks from the specified node index and returns the corresponding // character mapped segments and character count. Optional return // arguments are the char_32 result string and character bounding // boxes, if non-NULL values are passed in. CharSamp **BeamSearch::BackTrack(SearchObject *srch_obj, int node_index, int *char_cnt, char_32 **str32, Boxa **char_boxes) const { // get the last column of the lattice if (col_cnt_ <= 0) return NULL; SearchColumn *srch_col = col_[col_cnt_ - 1]; if (!srch_col) return NULL; // point to the last node in the selected path if (node_index >= srch_col->NodeCount() || !srch_col->Nodes()) return NULL; SearchNode *srch_node = srch_col->Nodes()[node_index]; if (!srch_node) return NULL; return BackTrack(srch_obj, srch_node, char_cnt, str32, char_boxes); } // Backtracks from the specified node index and returns the corresponding // character mapped segments and character count. Optional return // arguments are the char_32 result string and character bounding // boxes, if non-NULL values are passed in. CharSamp **BeamSearch::BackTrack(SearchObject *srch_obj, SearchNode *srch_node, int *char_cnt, char_32 **str32, Boxa **char_boxes) const { if (!srch_node) return NULL; if (str32) { if (*str32) delete [](*str32); // clear existing value *str32 = srch_node->PathString(); if (!*str32) return NULL; } if (char_boxes && *char_boxes) { boxaDestroy(char_boxes); // clear existing value } CharSamp **chars; chars = SplitByNode(srch_obj, srch_node, char_cnt, char_boxes); if (!chars && str32) delete []*str32; return chars; } // Backtracks from the given lattice node and return the corresponding // char mapped segments and character count. The character bounding // boxes are optional return arguments, if non-NULL values are passed in. CharSamp **BeamSearch::SplitByNode(SearchObject *srch_obj, SearchNode *srch_node, int *char_cnt, Boxa **char_boxes) const { // Count the characters (could be less than the path length when in // phrase mode) *char_cnt = 0; SearchNode *node = srch_node; while (node) { node = node->ParentNode(); (*char_cnt)++; } if (*char_cnt == 0) return NULL; // Allocate box array if (char_boxes) { if (*char_boxes) boxaDestroy(char_boxes); // clear existing value *char_boxes = boxaCreate(*char_cnt); if (*char_boxes == NULL) return NULL; } // Allocate memory for CharSamp array. CharSamp **chars = new CharSamp *[*char_cnt]; if (!chars) { if (char_boxes) boxaDestroy(char_boxes); return NULL; } int ch_idx = *char_cnt - 1; int seg_pt_cnt = srch_obj->SegPtCnt(); bool success=true; while (srch_node && ch_idx >= 0) { // Parent node (could be null) SearchNode *parent_node = srch_node->ParentNode(); // Get the seg pts corresponding to the search node int st_col = !parent_node ? 0 : parent_node->ColIdx() + 1; int st_seg_pt = st_col <= 0 ? -1 : st_col - 1; int end_col = srch_node->ColIdx(); int end_seg_pt = end_col >= seg_pt_cnt ? seg_pt_cnt : end_col; // Get a char sample corresponding to the segmentation points CharSamp *samp = srch_obj->CharSample(st_seg_pt, end_seg_pt); if (!samp) { success = false; break; } samp->SetLabel(srch_node->NodeString()); chars[ch_idx] = samp; if (char_boxes) { // Create the corresponding character bounding box Box *char_box = boxCreate(samp->Left(), samp->Top(), samp->Width(), samp->Height()); if (!char_box) { success = false; break; } boxaAddBox(*char_boxes, char_box, L_INSERT); } srch_node = parent_node; ch_idx--; } if (!success) { delete []chars; if (char_boxes) boxaDestroy(char_boxes); return NULL; } // Reverse the order of boxes. if (char_boxes) { int char_boxa_size = boxaGetCount(*char_boxes); int limit = char_boxa_size / 2; for (int i = 0; i < limit; ++i) { int box1_idx = i; int box2_idx = char_boxa_size - 1 - i; Box *box1 = boxaGetBox(*char_boxes, box1_idx, L_CLONE); Box *box2 = boxaGetBox(*char_boxes, box2_idx, L_CLONE); boxaReplaceBox(*char_boxes, box2_idx, box1); boxaReplaceBox(*char_boxes, box1_idx, box2); } } return chars; } // Returns the size cost of a string for a lattice path that // ends at the specified lattice node. int BeamSearch::SizeCost(SearchObject *srch_obj, SearchNode *node, char_32 **str32) const { CharSamp **chars = NULL; int char_cnt = 0; if (!node) return 0; // Backtrack to get string and character segmentation chars = BackTrack(srch_obj, node, &char_cnt, str32, NULL); if (!chars) return WORST_COST; int size_cost = (cntxt_->SizeModel() == NULL) ? 0 : cntxt_->SizeModel()->Cost(chars, char_cnt); delete []chars; return size_cost; } } // namespace tesesract