mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2025-01-07 10:17:50 +08:00
4523ce9f7d
git-svn-id: https://tesseract-ocr.googlecode.com/svn/trunk@526 d0cd1f9f-072b-0410-8dd7-cf729c803f20
1738 lines
74 KiB
C++
1738 lines
74 KiB
C++
///////////////////////////////////////////////////////////////////////
|
|
// File: language_model.cpp
|
|
// Description: Functions that utilize the knowledge about the properties,
|
|
// structure and statistics of the language to help recognition.
|
|
// Author: Daria Antonova
|
|
// Created: Mon Nov 11 11:26:43 PST 2009
|
|
//
|
|
// (C) Copyright 2009, Google Inc.
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
///////////////////////////////////////////////////////////////////////
|
|
|
|
#include <math.h>
|
|
|
|
#include "language_model.h"
|
|
|
|
#include "dawg.h"
|
|
#include "matrix.h"
|
|
#include "params.h"
|
|
|
|
namespace tesseract {
|
|
|
|
ELISTIZE(ViterbiStateEntry);
|
|
|
|
const float LanguageModel::kInitialPainPointPriorityAdjustment = 5.0f;
|
|
const float LanguageModel::kDefaultPainPointPriorityAdjustment = 2.0f;
|
|
const float LanguageModel::kBestChoicePainPointPriorityAdjustment = 0.5f;
|
|
const float LanguageModel::kCriticalPainPointPriorityAdjustment = 0.1f;
|
|
const float LanguageModel::kMaxAvgNgramCost = 25.0f;
|
|
const int LanguageModel::kMinFixedLengthDawgLength = 2;
|
|
const float LanguageModel::kLooseMaxCharWhRatio = 2.5f;
|
|
|
|
LanguageModel::LanguageModel(Dict *dict,
|
|
WERD_CHOICE **prev_word_best_choice_ptr)
|
|
: INT_MEMBER(language_model_debug_level, 0, "Language model debug level",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
BOOL_INIT_MEMBER(language_model_ngram_on, false,
|
|
"Turn on/off the use of character ngram model",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
INT_INIT_MEMBER(language_model_ngram_order, 8,
|
|
"Maximum order of the character ngram model",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
INT_INIT_MEMBER(language_model_max_viterbi_list_size, 10,
|
|
"Maximum size of viterbi lists recorded in BLOB_CHOICEs"
|
|
"(excluding entries that represent dictionary word paths)",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
double_INIT_MEMBER(language_model_ngram_small_prob, 0.000001,
|
|
"To avoid overly small denominators use this as the "
|
|
"floor of the probability returned by the ngram model.",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
double_INIT_MEMBER(language_model_ngram_nonmatch_score, -40.0,
|
|
"Average classifier score of a non-matching unichar.",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
BOOL_INIT_MEMBER(language_model_ngram_use_only_first_uft8_step, false,
|
|
"Use only the first UTF8 step of the given string"
|
|
" when computing log probabilities.",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
double_INIT_MEMBER(language_model_ngram_scale_factor, 0.03,
|
|
"Strength of the character ngram model relative to the"
|
|
" character classifier ",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
INT_INIT_MEMBER(language_model_min_compound_length, 3,
|
|
"Minimum length of compound words",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
INT_INIT_MEMBER(language_model_fixed_length_choices_depth, 3,
|
|
"Depth of blob choice lists to explore"
|
|
" when fixed length dawgs are on",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
double_INIT_MEMBER(language_model_penalty_non_freq_dict_word, 0.1,
|
|
"Penalty for words not in the frequent word dictionary",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
double_INIT_MEMBER(language_model_penalty_non_dict_word, 0.15,
|
|
"Penalty for non-dictionary words",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
double_INIT_MEMBER(language_model_penalty_punc, 0.2,
|
|
"Penalty for inconsistent punctuation",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
double_INIT_MEMBER(language_model_penalty_case, 0.1,
|
|
"Penalty for inconsistent case",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
double_INIT_MEMBER(language_model_penalty_script, 0.5,
|
|
"Penalty for inconsistent script",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
double_INIT_MEMBER(language_model_penalty_chartype, 0.3,
|
|
"Penalty for inconsistent character type",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
double_INIT_MEMBER(language_model_penalty_increment, 0.01,
|
|
"Penalty increment",
|
|
dict->getImage()->getCCUtil()->params()),
|
|
dict_(dict), denorm_(NULL), fixed_pitch_(false),
|
|
max_char_wh_ratio_(0.0), acceptable_choice_found_(false) {
|
|
ASSERT_HOST(dict_ != NULL);
|
|
dawg_args_ = new DawgArgs(NULL, NULL, new DawgInfoVector(),
|
|
new DawgInfoVector(),
|
|
0.0, NO_PERM, kAnyWordLength, -1);
|
|
beginning_active_dawgs_ = new DawgInfoVector();
|
|
beginning_constraints_ = new DawgInfoVector();
|
|
fixed_length_beginning_active_dawgs_ = new DawgInfoVector();
|
|
empty_dawg_info_vec_ = new DawgInfoVector();
|
|
}
|
|
|
|
LanguageModel::~LanguageModel() {
|
|
delete beginning_active_dawgs_;
|
|
delete beginning_constraints_;
|
|
delete fixed_length_beginning_active_dawgs_;
|
|
delete empty_dawg_info_vec_;
|
|
delete dawg_args_->updated_active_dawgs;
|
|
delete dawg_args_->updated_constraints;
|
|
delete dawg_args_;
|
|
}
|
|
|
|
void LanguageModel::InitForWord(
|
|
const WERD_CHOICE *prev_word, const DENORM *denorm,
|
|
bool fixed_pitch, float best_choice_cert, float max_char_wh_ratio,
|
|
HEAP *pain_points, CHUNKS_RECORD *chunks_record) {
|
|
denorm_ = denorm;
|
|
fixed_pitch_ = fixed_pitch;
|
|
max_char_wh_ratio_ = max_char_wh_ratio;
|
|
acceptable_choice_found_ = false;
|
|
|
|
// For each cell, generate a "pain point" if the cell is not classified
|
|
// and has a left or right neighbor that was classified.
|
|
MATRIX *ratings = chunks_record->ratings;
|
|
for (int col = 0; col < ratings->dimension(); ++col) {
|
|
for (int row = col+1; row < ratings->dimension(); ++row) {
|
|
if ((row > 0 && ratings->get(col, row-1) != NOT_CLASSIFIED) ||
|
|
(col+1 < ratings->dimension() &&
|
|
ratings->get(col+1, row) != NOT_CLASSIFIED)) {
|
|
float worst_piece_cert;
|
|
bool fragmented;
|
|
GetWorstPieceCertainty(col, row, chunks_record->ratings,
|
|
&worst_piece_cert, &fragmented);
|
|
GeneratePainPoint(col, row, true, kInitialPainPointPriorityAdjustment,
|
|
worst_piece_cert, fragmented, best_choice_cert,
|
|
max_char_wh_ratio_, NULL, NULL,
|
|
chunks_record, pain_points);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Initialize vectors with beginning DawgInfos.
|
|
beginning_active_dawgs_->clear();
|
|
dict_->init_active_dawgs(kAnyWordLength, beginning_active_dawgs_, false);
|
|
beginning_constraints_->clear();
|
|
dict_->init_constraints(beginning_constraints_);
|
|
if (dict_->GetMaxFixedLengthDawgIndex() >= 0) {
|
|
fixed_length_beginning_active_dawgs_->clear();
|
|
for (int i = 0; i < beginning_active_dawgs_->size(); ++i) {
|
|
int dawg_index = (*beginning_active_dawgs_)[i].dawg_index;
|
|
if (dawg_index <= dict_->GetMaxFixedLengthDawgIndex() &&
|
|
dawg_index >= kMinFixedLengthDawgLength) {
|
|
*fixed_length_beginning_active_dawgs_ += (*beginning_active_dawgs_)[i];
|
|
}
|
|
}
|
|
}
|
|
|
|
max_penalty_adjust_ = (dict_->segment_penalty_dict_nonword -
|
|
dict_->segment_penalty_dict_case_ok);
|
|
|
|
// The rest of the function contains ngram-model specific initialization.
|
|
if (!language_model_ngram_on) return;
|
|
|
|
// Fill prev_word_str_ with the last language_model_ngram_order
|
|
// unichars from prev_word.
|
|
// Assume that populate_unichars() has been called on a valid
|
|
// prev_word_best_choice_, which is the case, since it points
|
|
// to the final result of the classification of the previous word.
|
|
if (prev_word != NULL && prev_word->unichar_string() != NULL) {
|
|
prev_word_str_ = prev_word->unichar_string();
|
|
}
|
|
prev_word_str_ += ' ';
|
|
const char *str_ptr = prev_word_str_.string();
|
|
const char *str_end = str_ptr + prev_word_str_.length();
|
|
int step;
|
|
prev_word_unichar_step_len_ = 0;
|
|
while (str_ptr != str_end && (step = UNICHAR::utf8_step(str_ptr))) {
|
|
str_ptr += step;
|
|
++prev_word_unichar_step_len_;
|
|
}
|
|
ASSERT_HOST(str_ptr == str_end);
|
|
}
|
|
|
|
void LanguageModel::CleanUp() {
|
|
for (int i = 0; i < updated_flags_.size(); ++i) *(updated_flags_[i]) = false;
|
|
updated_flags_.clear();
|
|
}
|
|
|
|
void LanguageModel::DeleteState(BLOB_CHOICE_LIST *choices) {
|
|
BLOB_CHOICE_IT b_it(choices);
|
|
for (b_it.mark_cycle_pt(); !b_it.cycled_list(); b_it.forward()) {
|
|
if (b_it.data()->language_model_state() != NULL) {
|
|
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(
|
|
b_it.data()->language_model_state());
|
|
delete state;
|
|
b_it.data()->set_language_model_state(NULL);
|
|
}
|
|
}
|
|
}
|
|
|
|
LanguageModelFlagsType LanguageModel::UpdateState(
|
|
LanguageModelFlagsType changed,
|
|
int curr_col, int curr_row,
|
|
BLOB_CHOICE_LIST *curr_list,
|
|
BLOB_CHOICE_LIST *parent_list,
|
|
HEAP *pain_points,
|
|
BestPathByColumn *best_path_by_column[],
|
|
CHUNKS_RECORD *chunks_record,
|
|
BestChoiceBundle *best_choice_bundle) {
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("\nUpdateState: col=%d row=%d (changed=0x%x parent=%p)\n",
|
|
curr_col, curr_row, changed, parent_list);
|
|
}
|
|
// Initialize helper variables.
|
|
bool word_end = (curr_row+1 >= chunks_record->ratings->dimension());
|
|
bool just_classified = (changed & kJustClassifiedFlag);
|
|
LanguageModelFlagsType new_changed = 0x0;
|
|
float denom = (language_model_ngram_on) ? ComputeDenom(curr_list) : 1.0f;
|
|
|
|
// Call AddViterbiStateEntry() for each parent+child ViterbiStateEntry.
|
|
ViterbiStateEntry_IT vit;
|
|
BLOB_CHOICE_IT c_it(curr_list);
|
|
int c_it_counter = 0;
|
|
bool first_iteration = true;
|
|
BLOB_CHOICE *first_lower = NULL;
|
|
BLOB_CHOICE *first_upper = NULL;
|
|
GetTopChoiceLowerUpper(changed, curr_list, &first_lower, &first_upper);
|
|
for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) {
|
|
if (dict_->GetMaxFixedLengthDawgIndex() >= 0 &&
|
|
c_it_counter++ >= language_model_fixed_length_choices_depth) {
|
|
break;
|
|
}
|
|
// Skip NULL unichars unless it is the only choice.
|
|
if (!curr_list->singleton() && c_it.data()->unichar_id() == 0) continue;
|
|
if (dict_->getUnicharset().get_fragment(c_it.data()->unichar_id())) {
|
|
continue; // skip fragments
|
|
}
|
|
// Set top choice flags.
|
|
LanguageModelFlagsType top_choice_flags = 0x0;
|
|
if (first_iteration && (changed | kSmallestRatingFlag)) {
|
|
top_choice_flags |= kSmallestRatingFlag;
|
|
}
|
|
if (first_lower == c_it.data()) top_choice_flags |= kLowerCaseFlag;
|
|
if (first_upper == c_it.data()) top_choice_flags |= kUpperCaseFlag;
|
|
|
|
if (parent_list == NULL) { // process the beginning of a word
|
|
new_changed |= AddViterbiStateEntry(
|
|
top_choice_flags, denom, word_end, curr_col, curr_row, c_it.data(),
|
|
NULL, NULL, pain_points, best_path_by_column,
|
|
chunks_record, best_choice_bundle);
|
|
} else { // get viterbi entries from each of the parent BLOB_CHOICEs
|
|
BLOB_CHOICE_IT p_it(parent_list);
|
|
for (p_it.mark_cycle_pt(); !p_it.cycled_list(); p_it.forward()) {
|
|
LanguageModelState *parent_lms =
|
|
reinterpret_cast<LanguageModelState *>(
|
|
p_it.data()->language_model_state());
|
|
if (parent_lms == NULL || parent_lms->viterbi_state_entries.empty()) {
|
|
continue;
|
|
}
|
|
vit.set_to_list(&(parent_lms->viterbi_state_entries));
|
|
int vit_counter = 0;
|
|
for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) {
|
|
// Skip pruned entries and do not look at prunable entries if we have
|
|
// already examined language_model_max_viterbi_list_size of them.
|
|
if (PrunablePath(vit.data()->top_choice_flags,
|
|
vit.data()->dawg_info) &&
|
|
(++vit_counter > language_model_max_viterbi_list_size ||
|
|
(language_model_ngram_on && vit.data()->ngram_info->pruned))) {
|
|
continue;
|
|
}
|
|
// Only consider the parent if it has been updated or
|
|
// if the current ratings cell has just been classified.
|
|
if (!just_classified && !vit.data()->updated) continue;
|
|
// Create a new ViterbiStateEntry if BLOB_CHOICE in c_it.data()
|
|
// looks good according to the Dawgs or character ngram model.
|
|
new_changed |= AddViterbiStateEntry(
|
|
top_choice_flags, denom, word_end, curr_col, curr_row,
|
|
c_it.data(), p_it.data(), vit.data(), pain_points,
|
|
best_path_by_column, chunks_record, best_choice_bundle);
|
|
}
|
|
} // done looking at parents for this c_it.data()
|
|
}
|
|
first_iteration = false;
|
|
}
|
|
return new_changed;
|
|
}
|
|
|
|
bool LanguageModel::ProblematicPath(const ViterbiStateEntry &vse,
|
|
UNICHAR_ID unichar_id, bool word_end) {
|
|
// The path is problematic if it is inconsistent and has a parent that
|
|
// is consistent (or a NULL parent).
|
|
if (!vse.Consistent() && (vse.parent_vse == NULL ||
|
|
vse.parent_vse->Consistent())) {
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("ProblematicPath: inconsistent\n");
|
|
}
|
|
return true;
|
|
}
|
|
// The path is problematic if it does not represent a dicionary word,
|
|
// while its parent does.
|
|
if (vse.dawg_info == NULL &&
|
|
(vse.parent_vse == NULL || vse.parent_vse->dawg_info != NULL)) {
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("ProblematicPath: dict word terminated\n");
|
|
}
|
|
return true;
|
|
}
|
|
// The path is problematic if ngram info indicates that this path is
|
|
// an unlikely sequence of character, while its parent is does not have
|
|
// extreme dips in ngram probabilities.
|
|
if (vse.ngram_info != NULL && vse.ngram_info->pruned &&
|
|
(vse.parent_vse == NULL || !vse.parent_vse->ngram_info->pruned)) {
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("ProblematicPath: small ngram prob\n");
|
|
}
|
|
return true;
|
|
}
|
|
// The path is problematic if there is a non-alpha character in the
|
|
// middle of the path (unless it is a digit in the middle of a path
|
|
// that represents a number).
|
|
if ((vse.parent_vse != NULL) && !word_end && // is middle
|
|
!(dict_->getUnicharset().get_isalpha(unichar_id) || // alpha
|
|
(dict_->getUnicharset().get_isdigit(unichar_id) && // ok digit
|
|
vse.dawg_info != NULL && vse.dawg_info->permuter == NUMBER_PERM))) {
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("ProblematicPath: non-alpha middle\n");
|
|
}
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void LanguageModel::GetTopChoiceLowerUpper(LanguageModelFlagsType changed,
|
|
BLOB_CHOICE_LIST *curr_list,
|
|
BLOB_CHOICE **first_lower,
|
|
BLOB_CHOICE **first_upper) {
|
|
if (!(changed & kLowerCaseFlag || changed & kUpperCaseFlag)) return;
|
|
BLOB_CHOICE_IT c_it(curr_list);
|
|
const UNICHARSET &unicharset = dict_->getUnicharset();
|
|
BLOB_CHOICE *first_unichar = NULL;
|
|
for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) {
|
|
UNICHAR_ID unichar_id = c_it.data()->unichar_id();
|
|
if (unicharset.get_fragment(unichar_id)) continue; // skip fragments
|
|
if (first_unichar == NULL) first_unichar = c_it.data();
|
|
if (*first_lower == NULL && unicharset.get_islower(unichar_id)) {
|
|
*first_lower = c_it.data();
|
|
}
|
|
if (*first_upper == NULL && unicharset.get_isupper(unichar_id)) {
|
|
*first_upper = c_it.data();
|
|
}
|
|
}
|
|
ASSERT_HOST(first_unichar != NULL);
|
|
if (*first_lower == NULL) *first_lower = first_unichar;
|
|
if (*first_upper == NULL) *first_upper = first_unichar;
|
|
}
|
|
|
|
LanguageModelFlagsType LanguageModel::AddViterbiStateEntry(
|
|
LanguageModelFlagsType top_choice_flags,
|
|
float denom,
|
|
bool word_end,
|
|
int curr_col, int curr_row,
|
|
BLOB_CHOICE *b,
|
|
BLOB_CHOICE *parent_b,
|
|
ViterbiStateEntry *parent_vse,
|
|
HEAP *pain_points,
|
|
BestPathByColumn *best_path_by_column[],
|
|
CHUNKS_RECORD *chunks_record,
|
|
BestChoiceBundle *best_choice_bundle) {
|
|
ViterbiStateEntry_IT vit;
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("\nAddViterbiStateEntry for unichar %s rating=%.4f"
|
|
" certainty=%.4f top_choice_flags=0x%x parent_vse=%p\n",
|
|
dict_->getUnicharset().id_to_unichar(b->unichar_id()),
|
|
b->rating(), b->certainty(), top_choice_flags, parent_vse);
|
|
if (language_model_debug_level > 3 && b->language_model_state() != NULL) {
|
|
tprintf("Existing viterbi list:\n");
|
|
vit.set_to_list(&(reinterpret_cast<LanguageModelState *>(
|
|
b->language_model_state())->viterbi_state_entries));
|
|
for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) {
|
|
PrintViterbiStateEntry("", vit.data(), b, chunks_record);
|
|
}
|
|
}
|
|
}
|
|
LanguageModelFlagsType changed = 0x0;
|
|
float optimistic_cost = 0.0f;
|
|
if (!language_model_ngram_on) optimistic_cost += b->rating();
|
|
if (parent_vse != NULL) optimistic_cost += parent_vse->cost;
|
|
// Discard this entry if it will not beat best choice.
|
|
if (optimistic_cost >= best_choice_bundle->best_choice->rating()) {
|
|
if (language_model_debug_level > 1) {
|
|
tprintf("Discarded ViterbiEntry with high cost %.4f"
|
|
" best_choice->rating()=%.4f\n", optimistic_cost,
|
|
best_choice_bundle->best_choice->rating());
|
|
}
|
|
return 0x0;
|
|
}
|
|
|
|
// Check consistency of the path and set the relevant consistency_info.
|
|
LanguageModelConsistencyInfo consistency_info;
|
|
FillConsistencyInfo(word_end, b->unichar_id(), parent_vse, parent_b,
|
|
&consistency_info);
|
|
|
|
// Invoke Dawg language model component.
|
|
LanguageModelDawgInfo *dawg_info =
|
|
GenerateDawgInfo(word_end, consistency_info.script_id,
|
|
curr_col, curr_row, *b, parent_vse, &changed);
|
|
|
|
// Invoke TopChoice language model component
|
|
float ratings_sum = b->rating();
|
|
if (parent_vse != NULL) ratings_sum += parent_vse->ratings_sum;
|
|
GenerateTopChoiceInfo(ratings_sum, dawg_info, consistency_info,
|
|
parent_vse, b, &top_choice_flags, &changed);
|
|
|
|
// Invoke Ngram language model component.
|
|
LanguageModelNgramInfo *ngram_info = NULL;
|
|
if (language_model_ngram_on) {
|
|
ngram_info = GenerateNgramInfo(
|
|
dict_->getUnicharset().id_to_unichar(b->unichar_id()), b->certainty(),
|
|
denom, curr_col, curr_row, parent_vse, parent_b, &changed);
|
|
ASSERT_HOST(ngram_info != NULL);
|
|
}
|
|
|
|
// Prune non-top choice paths with incosistent scripts.
|
|
if (consistency_info.inconsistent_script) {
|
|
if (!(top_choice_flags & kSmallestRatingFlag)) changed = 0x0;
|
|
if (ngram_info != NULL) ngram_info->pruned = true;
|
|
}
|
|
|
|
// If language model components did not like this unichar - return
|
|
if (!changed) {
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("Language model components did not like this entry\n");
|
|
}
|
|
delete dawg_info;
|
|
delete ngram_info;
|
|
return 0x0;
|
|
}
|
|
|
|
// Compute cost of associating the blobs that represent the current unichar.
|
|
AssociateStats associate_stats;
|
|
ComputeAssociateStats(curr_col, curr_row, max_char_wh_ratio_,
|
|
parent_vse, chunks_record, &associate_stats);
|
|
if (parent_vse != NULL) {
|
|
associate_stats.shape_cost += parent_vse->associate_stats.shape_cost;
|
|
associate_stats.bad_shape |= parent_vse->associate_stats.bad_shape;
|
|
}
|
|
|
|
// Compute the aggregate cost (adjusted ratings sum).
|
|
float cost = ComputeAdjustedPathCost(
|
|
ratings_sum,
|
|
(parent_vse == NULL) ? 1 : (parent_vse->length+1),
|
|
(dawg_info == NULL) ? 0.0f : 1.0f,
|
|
dawg_info, ngram_info, consistency_info, associate_stats, parent_vse);
|
|
|
|
if (b->language_model_state() == NULL) {
|
|
b->set_language_model_state(new LanguageModelState(curr_col, curr_row));
|
|
}
|
|
LanguageModelState *lms =
|
|
reinterpret_cast<LanguageModelState *>(b->language_model_state());
|
|
|
|
// Discard this entry if it represents a prunable path and
|
|
// language_model_max_viterbi_list_size such entries with a lower
|
|
// cost have already been recorded.
|
|
if (PrunablePath(top_choice_flags, dawg_info) &&
|
|
(lms->viterbi_state_entries_prunable_length >=
|
|
language_model_max_viterbi_list_size) &&
|
|
cost >= lms->viterbi_state_entries_prunable_max_cost) {
|
|
if (language_model_debug_level > 1) {
|
|
tprintf("Discarded ViterbiEntry with high cost %g max cost %g\n",
|
|
cost, lms->viterbi_state_entries_prunable_max_cost);
|
|
}
|
|
delete dawg_info;
|
|
delete ngram_info;
|
|
return 0x0;
|
|
}
|
|
|
|
// Create the new ViterbiStateEntry and add it to lms->viterbi_state_entries
|
|
ViterbiStateEntry *new_vse = new ViterbiStateEntry(
|
|
parent_b, parent_vse, b, cost, consistency_info,
|
|
associate_stats, top_choice_flags, dawg_info, ngram_info);
|
|
updated_flags_.push_back(&(new_vse->updated));
|
|
lms->viterbi_state_entries.add_sorted(ViterbiStateEntry::Compare,
|
|
false, new_vse);
|
|
if (PrunablePath(top_choice_flags, dawg_info)) {
|
|
lms->viterbi_state_entries_prunable_length++;
|
|
}
|
|
|
|
// Update lms->viterbi_state_entries_prunable_max_cost and clear
|
|
// top_choice_flags of enties with ratings_sum than new_vse->ratings_sum.
|
|
if ((lms->viterbi_state_entries_prunable_length >=
|
|
language_model_max_viterbi_list_size) || top_choice_flags) {
|
|
ASSERT_HOST(!lms->viterbi_state_entries.empty());
|
|
int prunable_counter = language_model_max_viterbi_list_size;
|
|
vit.set_to_list(&(lms->viterbi_state_entries));
|
|
for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) {
|
|
ViterbiStateEntry *curr_vse = vit.data();
|
|
// Clear the appropriate top choice flags of the entries in the
|
|
// list that have ratings_sum higher thank new_entry->ratings_sum
|
|
// (since they will not be top choices any more).
|
|
if (curr_vse->top_choice_flags && curr_vse != new_vse &&
|
|
ComputeConsistencyAdjustedRatingsSum(
|
|
curr_vse->ratings_sum, curr_vse->dawg_info,
|
|
curr_vse->consistency_info) >
|
|
ComputeConsistencyAdjustedRatingsSum(
|
|
new_vse->ratings_sum, new_vse->dawg_info,
|
|
new_vse->consistency_info)) {
|
|
curr_vse->top_choice_flags &= ~(top_choice_flags);
|
|
}
|
|
if (prunable_counter > 0 &&
|
|
PrunablePath(curr_vse->top_choice_flags, curr_vse->dawg_info)) {
|
|
--prunable_counter;
|
|
}
|
|
// Update lms->viterbi_state_entries_prunable_max_cost.
|
|
if (prunable_counter == 0) {
|
|
lms->viterbi_state_entries_prunable_max_cost = vit.data()->cost;
|
|
if (language_model_debug_level > 1) {
|
|
tprintf("Set viterbi_state_entries_prunable_max_cost to %.4f\n",
|
|
lms->viterbi_state_entries_prunable_max_cost);
|
|
}
|
|
prunable_counter = -1; // stop counting
|
|
}
|
|
}
|
|
}
|
|
|
|
// Print the newly created ViterbiStateEntry.
|
|
if (language_model_debug_level > 2) {
|
|
PrintViterbiStateEntry("New", new_vse, b, chunks_record);
|
|
if (language_model_debug_level > 3) {
|
|
tprintf("Updated viterbi list (max cost %g):\n",
|
|
lms->viterbi_state_entries_prunable_max_cost);
|
|
vit.set_to_list(&(lms->viterbi_state_entries));
|
|
for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) {
|
|
PrintViterbiStateEntry("", vit.data(), b, chunks_record);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Update best choice is needed.
|
|
if (word_end) {
|
|
UpdateBestChoice(b, new_vse, pain_points,
|
|
chunks_record, best_choice_bundle);
|
|
}
|
|
|
|
// Update stats in best_path_by_column.
|
|
if (new_vse->Consistent() || new_vse->dawg_info != NULL ||
|
|
(new_vse->ngram_info != NULL && !new_vse->ngram_info->pruned)) {
|
|
float avg_cost = new_vse->cost / static_cast<float>(curr_row+1);
|
|
for (int c = curr_col; c <= curr_row; ++c) {
|
|
if (avg_cost < (*best_path_by_column)[c].avg_cost) {
|
|
(*best_path_by_column)[c].avg_cost = avg_cost;
|
|
(*best_path_by_column)[c].best_vse = new_vse;
|
|
(*best_path_by_column)[c].best_b = b;
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("Set best_path_by_column[%d]=(%g %p)\n",
|
|
c, avg_cost, new_vse);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
void LanguageModel::PrintViterbiStateEntry(
|
|
const char *msg, ViterbiStateEntry *vse,
|
|
BLOB_CHOICE *b, CHUNKS_RECORD *chunks_record) {
|
|
tprintf("%s ViterbiStateEntry %p with ratings_sum=%.4f length=%d cost=%.4f",
|
|
msg, vse, vse->ratings_sum, vse->length, vse->cost);
|
|
if (vse->top_choice_flags) {
|
|
tprintf(" top_choice_flags=0x%x", vse->top_choice_flags);
|
|
}
|
|
if (!vse->Consistent()) {
|
|
tprintf(" inconsistent=(punc %d case %d chartype %d script %d)\n",
|
|
vse->consistency_info.NumInconsistentPunc(),
|
|
vse->consistency_info.NumInconsistentCase(),
|
|
vse->consistency_info.NumInconsistentChartype(),
|
|
vse->consistency_info.inconsistent_script);
|
|
}
|
|
if (vse->dawg_info) tprintf(" permuter=%d", vse->dawg_info->permuter);
|
|
if (vse->ngram_info) {
|
|
tprintf(" ngram_cost=%g context=%s ngram pruned=%d",
|
|
vse->ngram_info->ngram_cost,
|
|
vse->ngram_info->context.string(),
|
|
vse->ngram_info->pruned);
|
|
}
|
|
if (vse->associate_stats.shape_cost > 0.0f) {
|
|
tprintf(" shape_cost=%g", vse->associate_stats.shape_cost);
|
|
}
|
|
if (language_model_debug_level > 3) {
|
|
STRING wd_str;
|
|
WERD_CHOICE *wd = ConstructWord(b, vse, chunks_record,
|
|
NULL, NULL, NULL, NULL);
|
|
wd->string_and_lengths(dict_->getUnicharset(), &wd_str, NULL);
|
|
delete wd;
|
|
tprintf(" str=%s", wd_str.string());
|
|
}
|
|
tprintf("\n");
|
|
}
|
|
|
|
void LanguageModel::GenerateTopChoiceInfo(
|
|
float ratings_sum,
|
|
const LanguageModelDawgInfo *dawg_info,
|
|
const LanguageModelConsistencyInfo &consistency_info,
|
|
const ViterbiStateEntry *parent_vse,
|
|
BLOB_CHOICE *b,
|
|
LanguageModelFlagsType *top_choice_flags,
|
|
LanguageModelFlagsType *changed) {
|
|
ratings_sum = ComputeConsistencyAdjustedRatingsSum(
|
|
ratings_sum, dawg_info, consistency_info);
|
|
// Clear flags that do not agree with parent_vse->top_choice_flags.
|
|
if (parent_vse != NULL) *top_choice_flags &= parent_vse->top_choice_flags;
|
|
if (consistency_info.Consistent()) *top_choice_flags |= kConsistentFlag;
|
|
if (*top_choice_flags == 0x0) return;
|
|
LanguageModelState *lms =
|
|
reinterpret_cast<LanguageModelState *>(b->language_model_state());
|
|
if (lms != NULL && !lms->viterbi_state_entries.empty()) {
|
|
ViterbiStateEntry_IT vit(&(lms->viterbi_state_entries));
|
|
for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) {
|
|
if (ratings_sum >= ComputeConsistencyAdjustedRatingsSum(
|
|
vit.data()->ratings_sum, vit.data()->dawg_info,
|
|
vit.data()->consistency_info)) {
|
|
// Clear the appropriate flags if the list already contains
|
|
// a top choice entry with a lower cost.
|
|
*top_choice_flags &= ~(vit.data()->top_choice_flags);
|
|
}
|
|
}
|
|
}
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("GenerateTopChoiceInfo: top_choice_flags=0x%x\n",
|
|
*top_choice_flags);
|
|
}
|
|
*changed |= *top_choice_flags;
|
|
}
|
|
|
|
LanguageModelDawgInfo *LanguageModel::GenerateDawgInfo(
|
|
bool word_end, int script_id,
|
|
int curr_col, int curr_row,
|
|
const BLOB_CHOICE &b,
|
|
const ViterbiStateEntry *parent_vse,
|
|
LanguageModelFlagsType *changed) {
|
|
// Initialize active_dawgs and constraints from parent_vse if it is not NULL,
|
|
// otherwise use beginning_active_dawgs_ and beginning_constraints_.
|
|
if (parent_vse == NULL) {
|
|
dawg_args_->active_dawgs = beginning_active_dawgs_;
|
|
dawg_args_->constraints = beginning_constraints_;
|
|
dawg_args_->permuter = NO_PERM;
|
|
} else {
|
|
if (parent_vse->dawg_info == NULL) return NULL; // not a dict word path
|
|
dawg_args_->active_dawgs = parent_vse->dawg_info->active_dawgs;
|
|
dawg_args_->constraints = parent_vse->dawg_info->constraints;
|
|
dawg_args_->permuter = parent_vse->dawg_info->permuter;
|
|
}
|
|
bool use_fixed_length_dawgs = UseFixedLengthDawgs(script_id);
|
|
|
|
// Deal with hyphenated words.
|
|
if (!use_fixed_length_dawgs && word_end &&
|
|
dict_->has_hyphen_end(b.unichar_id(), curr_col == 0)) {
|
|
if (language_model_debug_level > 0) tprintf("Hyphenated word found\n");
|
|
*changed |= kDawgFlag;
|
|
return new LanguageModelDawgInfo(dawg_args_->active_dawgs,
|
|
dawg_args_->constraints,
|
|
COMPOUND_PERM);
|
|
}
|
|
|
|
// Deal with compound words.
|
|
if (!use_fixed_length_dawgs && dict_->compound_marker(b.unichar_id()) &&
|
|
(parent_vse == NULL || parent_vse->dawg_info->permuter != NUMBER_PERM)) {
|
|
if (language_model_debug_level > 0) tprintf("Found compound marker");
|
|
// Do not allow compound operators at the beginning and end of the word.
|
|
// Do not allow more than one compound operator per word.
|
|
// Do not allow compounding of words with lengths shorter than
|
|
// language_model_min_compound_length
|
|
if (parent_vse == NULL || word_end ||
|
|
dawg_args_->permuter == COMPOUND_PERM ||
|
|
parent_vse->length < language_model_min_compound_length) return NULL;
|
|
|
|
int i;
|
|
// Check a that the path terminated before the current character is a word.
|
|
bool has_word_ending = false;
|
|
for (i = 0; i < parent_vse->dawg_info->active_dawgs->size(); ++i) {
|
|
const DawgInfo &info = (*parent_vse->dawg_info->active_dawgs)[i];
|
|
const Dawg *pdawg = dict_->GetDawg(info.dawg_index);
|
|
assert(pdawg != NULL);
|
|
if (pdawg->type() == DAWG_TYPE_WORD && info.ref != NO_EDGE &&
|
|
pdawg->end_of_word(info.ref)) {
|
|
has_word_ending = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!has_word_ending) return NULL;
|
|
|
|
// Return LanguageModelDawgInfo with active_dawgs set to
|
|
// the beginning dawgs of type DAWG_TYPE_WORD.
|
|
if (language_model_debug_level > 0) tprintf("Compound word found\n");
|
|
DawgInfoVector beginning_word_dawgs;
|
|
for (i = 0; i < beginning_active_dawgs_->size(); ++i) {
|
|
const Dawg *bdawg =
|
|
dict_->GetDawg((*beginning_active_dawgs_)[i].dawg_index);
|
|
if (bdawg->type() == DAWG_TYPE_WORD) {
|
|
beginning_word_dawgs += (*beginning_active_dawgs_)[i];
|
|
}
|
|
}
|
|
*changed |= kDawgFlag;
|
|
return new LanguageModelDawgInfo(&(beginning_word_dawgs),
|
|
dawg_args_->constraints,
|
|
COMPOUND_PERM);
|
|
} // done dealing with compound words
|
|
|
|
LanguageModelDawgInfo *dawg_info = NULL;
|
|
|
|
// Call LetterIsOkay().
|
|
dict_->LetterIsOkay(dawg_args_, b.unichar_id(), word_end);
|
|
if (dawg_args_->permuter != NO_PERM) {
|
|
*changed |= kDawgFlag;
|
|
dawg_info = new LanguageModelDawgInfo(dawg_args_->updated_active_dawgs,
|
|
dawg_args_->updated_constraints,
|
|
dawg_args_->permuter);
|
|
}
|
|
|
|
// For non-space delimited languages: since every letter could be
|
|
// a valid word, a new word could start at every unichar. Thus append
|
|
// fixed_length_beginning_active_dawgs_ to dawg_info->active_dawgs.
|
|
if (use_fixed_length_dawgs) {
|
|
if (dawg_info == NULL) {
|
|
*changed |= kDawgFlag;
|
|
dawg_info = new LanguageModelDawgInfo(
|
|
fixed_length_beginning_active_dawgs_,
|
|
empty_dawg_info_vec_, SYSTEM_DAWG_PERM);
|
|
} else {
|
|
*(dawg_info->active_dawgs) += *(fixed_length_beginning_active_dawgs_);
|
|
}
|
|
} // done dealing with fixed-length dawgs
|
|
|
|
return dawg_info;
|
|
}
|
|
|
|
LanguageModelNgramInfo *LanguageModel::GenerateNgramInfo(
|
|
const char *unichar, float certainty, float denom,
|
|
int curr_col, int curr_row,
|
|
const ViterbiStateEntry *parent_vse,
|
|
BLOB_CHOICE *parent_b,
|
|
LanguageModelFlagsType *changed) {
|
|
// Initialize parent context.
|
|
const char *pcontext_ptr = "";
|
|
int pcontext_unichar_step_len = 0;
|
|
if (parent_vse == NULL) {
|
|
pcontext_ptr = prev_word_str_.string();
|
|
pcontext_unichar_step_len = prev_word_unichar_step_len_;
|
|
} else {
|
|
pcontext_ptr = parent_vse->ngram_info->context.string();
|
|
pcontext_unichar_step_len =
|
|
parent_vse->ngram_info->context_unichar_step_len;
|
|
}
|
|
// Compute p(unichar | parent context).
|
|
int unichar_step_len = 0;
|
|
bool pruned = false;
|
|
float ngram_cost = ComputeNgramCost(unichar, certainty, denom,
|
|
pcontext_ptr, &unichar_step_len,
|
|
&pruned);
|
|
// First attempt to normalize ngram_cost for strings of different
|
|
// lengths - we multiply ngram_cost by P(char | context) as many times
|
|
// as the number of chunks occupied by char. This makes the ngram costs
|
|
// of all the paths ending at the current BLOB_CHOICE comparable.
|
|
// TODO(daria): it would be worth looking at different ways of normalization.
|
|
if (curr_row > curr_col) ngram_cost += (curr_row - curr_col) * ngram_cost;
|
|
// Add the ngram_cost of the parent.
|
|
if (parent_vse != NULL) ngram_cost += parent_vse->ngram_info->ngram_cost;
|
|
|
|
// Shorten parent context string by unichar_step_len unichars.
|
|
int num_remove = (unichar_step_len + pcontext_unichar_step_len -
|
|
language_model_ngram_order);
|
|
if (num_remove > 0) pcontext_unichar_step_len -= num_remove;
|
|
while (num_remove > 0 && *pcontext_ptr != '\0') {
|
|
pcontext_ptr += UNICHAR::utf8_step(pcontext_ptr);
|
|
--num_remove;
|
|
}
|
|
|
|
// Decide whether to prune this ngram path and update changed accordingly.
|
|
if (parent_vse != NULL && parent_vse->ngram_info->pruned) pruned = true;
|
|
if (!pruned) *changed |= kNgramFlag;
|
|
|
|
// Construct and return the new LanguageModelNgramInfo.
|
|
LanguageModelNgramInfo *ngram_info = new LanguageModelNgramInfo(
|
|
pcontext_ptr, pcontext_unichar_step_len, pruned, ngram_cost);
|
|
ngram_info->context += unichar;
|
|
ngram_info->context_unichar_step_len += unichar_step_len;
|
|
assert(ngram_info->context_unichar_step_len <= language_model_ngram_order);
|
|
return ngram_info;
|
|
}
|
|
|
|
float LanguageModel::ComputeNgramCost(const char *unichar,
|
|
float certainty,
|
|
float denom,
|
|
const char *context,
|
|
int *unichar_step_len,
|
|
bool *found_small_prob) {
|
|
const char *context_ptr = context;
|
|
char *modified_context = NULL;
|
|
char *modified_context_end = NULL;
|
|
const char *unichar_ptr = unichar;
|
|
const char *unichar_end = unichar_ptr + strlen(unichar_ptr);
|
|
float prob = 0.0f;
|
|
int step = 0;
|
|
while (unichar_ptr < unichar_end &&
|
|
(step = UNICHAR::utf8_step(unichar_ptr)) > 0) {
|
|
if (language_model_debug_level > 1) {
|
|
tprintf("prob(%s | %s)=%g\n", unichar_ptr, context_ptr,
|
|
dict_->ProbabilityInContext(context_ptr, -1, unichar_ptr, step));
|
|
}
|
|
prob += dict_->ProbabilityInContext(context_ptr, -1, unichar_ptr, step);
|
|
++(*unichar_step_len);
|
|
if (language_model_ngram_use_only_first_uft8_step) break;
|
|
unichar_ptr += step;
|
|
// If there are multiple UTF8 characters present in unichar, context is
|
|
// updated to include the previously examined characters from str,
|
|
// unless use_only_first_uft8_step is true.
|
|
if (unichar_ptr < unichar_end) {
|
|
if (modified_context == NULL) {
|
|
int context_len = strlen(context);
|
|
modified_context =
|
|
new char[context_len + strlen(unichar_ptr) + step + 1];
|
|
strncpy(modified_context, context, context_len);
|
|
modified_context_end = modified_context + context_len;
|
|
context_ptr = modified_context;
|
|
}
|
|
strncpy(modified_context_end, unichar_ptr - step, step);
|
|
modified_context_end += step;
|
|
*modified_context_end = '\0';
|
|
}
|
|
}
|
|
if (prob < language_model_ngram_small_prob) {
|
|
if (language_model_debug_level > 0) tprintf("Found small prob %g\n", prob);
|
|
*found_small_prob = true;
|
|
if (prob < MIN_FLOAT32) prob = MIN_FLOAT32;
|
|
}
|
|
float cost = -(log(CertaintyScore(certainty) / denom) +
|
|
language_model_ngram_scale_factor *
|
|
log(prob / static_cast<float>(*unichar_step_len)));
|
|
if (language_model_debug_level > 1) {
|
|
tprintf("-log [ p(%s) * p(%s | %s) ] = -log(%g*%g) = %g\n", unichar,
|
|
unichar, context_ptr, CertaintyScore(certainty)/denom,
|
|
prob/static_cast<float>(*unichar_step_len), cost);
|
|
}
|
|
if (modified_context != NULL) delete[] modified_context;
|
|
return cost;
|
|
}
|
|
|
|
float LanguageModel::ComputeDenom(BLOB_CHOICE_LIST *curr_list) {
|
|
if (curr_list->empty()) return 1.0f;
|
|
float denom = 0.0f;
|
|
int len = 0;
|
|
BLOB_CHOICE_IT c_it(curr_list);
|
|
for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) {
|
|
ASSERT_HOST(c_it.data() != NULL);
|
|
++len;
|
|
denom += CertaintyScore(c_it.data()->certainty());
|
|
}
|
|
assert(len != 0);
|
|
// The ideal situation would be to have the classifier scores for
|
|
// classifying each position as each of the characters in the unicharset.
|
|
// Since we can not do this because of speed, we add a very crude estimate
|
|
// of what these scores for the "missing" classifications would sum up to.
|
|
denom += (dict_->getUnicharset().size() - len) *
|
|
CertaintyScore(language_model_ngram_nonmatch_score);
|
|
|
|
return denom;
|
|
}
|
|
|
|
void LanguageModel::FillConsistencyInfo(
|
|
bool word_end,
|
|
UNICHAR_ID unichar_id,
|
|
ViterbiStateEntry *parent_vse,
|
|
BLOB_CHOICE *parent_b,
|
|
LanguageModelConsistencyInfo *consistency_info) {
|
|
const UNICHARSET &unicharset = dict_->getUnicharset();
|
|
if (parent_vse != NULL) *consistency_info = parent_vse->consistency_info;
|
|
|
|
// Check punctuation validity.
|
|
if (unicharset.get_ispunctuation(unichar_id)) consistency_info->num_punc++;
|
|
if (dict_->GetPuncDawg() != NULL && !consistency_info->invalid_punc) {
|
|
if (dict_->compound_marker(unichar_id) && parent_b != NULL &&
|
|
(unicharset.get_isalpha(parent_b->unichar_id()) ||
|
|
unicharset.get_isdigit(parent_b->unichar_id()))) {
|
|
// reset punc_ref for compound words
|
|
consistency_info->punc_ref = NO_EDGE;
|
|
} else {
|
|
UNICHAR_ID pattern_unichar_id =
|
|
(unicharset.get_isalpha(unichar_id) ||
|
|
unicharset.get_isdigit(unichar_id)) ?
|
|
Dawg::kPatternUnicharID : unichar_id;
|
|
if (consistency_info->punc_ref == NO_EDGE ||
|
|
pattern_unichar_id != Dawg::kPatternUnicharID ||
|
|
dict_->GetPuncDawg()->edge_letter(consistency_info->punc_ref) !=
|
|
Dawg::kPatternUnicharID) {
|
|
NODE_REF node = Dict::GetStartingNode(dict_->GetPuncDawg(),
|
|
consistency_info->punc_ref);
|
|
consistency_info->punc_ref =
|
|
(node != NO_EDGE) ? dict_->GetPuncDawg()->edge_char_of(
|
|
node, pattern_unichar_id, word_end) : NO_EDGE;
|
|
if (consistency_info->punc_ref == NO_EDGE) {
|
|
consistency_info->invalid_punc = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Update case related counters.
|
|
if (parent_vse != NULL && !word_end && dict_->compound_marker(unichar_id)) {
|
|
// Reset counters if we are dealing with a compound word.
|
|
consistency_info->num_lower = 0;
|
|
consistency_info->num_non_first_upper = 0;
|
|
}
|
|
else if (unicharset.get_islower(unichar_id)) {
|
|
consistency_info->num_lower++;
|
|
} else if ((parent_b != NULL) && unicharset.get_isupper(unichar_id)) {
|
|
if (unicharset.get_isupper(parent_b->unichar_id()) ||
|
|
consistency_info->num_lower > 0 ||
|
|
consistency_info->num_non_first_upper > 0) {
|
|
consistency_info->num_non_first_upper++;
|
|
}
|
|
}
|
|
|
|
// Initialize consistency_info->script_id (use script of unichar_id
|
|
// if it is not Common, use script id recorded by the parent otherwise).
|
|
// Set inconsistent_script to true if the script of the current unichar
|
|
// is not consistent with that of the parent.
|
|
consistency_info->script_id = unicharset.get_script(unichar_id);
|
|
// Hiragana and Katakana can mix with Han.
|
|
if (dict_->getUnicharset().han_sid() != dict_->getUnicharset().null_sid()) {
|
|
if ((unicharset.hiragana_sid() != unicharset.null_sid() &&
|
|
consistency_info->script_id == unicharset.hiragana_sid()) ||
|
|
(unicharset.katakana_sid() != unicharset.null_sid() &&
|
|
consistency_info->script_id == unicharset.katakana_sid())) {
|
|
consistency_info->script_id = dict_->getUnicharset().han_sid();
|
|
}
|
|
}
|
|
|
|
if (parent_vse != NULL &&
|
|
(parent_vse->consistency_info.script_id !=
|
|
dict_->getUnicharset().common_sid())) {
|
|
int parent_script_id = parent_vse->consistency_info.script_id;
|
|
// If we are dealing with Han script, unichars from Common
|
|
// script are only ok at the beginning and the end of words.
|
|
bool han_inconsistent =
|
|
(IsHan(parent_script_id) && !word_end &&
|
|
(consistency_info->script_id == dict_->getUnicharset().common_sid()));
|
|
// If script_id is Common, use script id of the parent instead.
|
|
if (consistency_info->script_id == dict_->getUnicharset().common_sid()) {
|
|
consistency_info->script_id = parent_script_id;
|
|
}
|
|
if (consistency_info->script_id != parent_script_id || han_inconsistent) {
|
|
consistency_info->inconsistent_script = true;
|
|
}
|
|
}
|
|
|
|
// Update chartype related counters.
|
|
if (unicharset.get_isalpha(unichar_id)) {
|
|
consistency_info->num_alphas++;
|
|
} else if (unicharset.get_isdigit(unichar_id)) {
|
|
consistency_info->num_digits++;
|
|
} else if (!unicharset.get_ispunctuation(unichar_id)) {
|
|
consistency_info->num_other++;
|
|
}
|
|
}
|
|
|
|
float LanguageModel::ComputeAdjustedPathCost(
|
|
float ratings_sum, int length, float dawg_score,
|
|
const LanguageModelDawgInfo *dawg_info,
|
|
const LanguageModelNgramInfo *ngram_info,
|
|
const LanguageModelConsistencyInfo &consistency_info,
|
|
const AssociateStats &associate_stats,
|
|
ViterbiStateEntry *parent_vse) {
|
|
float adjustment = 1.0f;
|
|
if (dawg_info == NULL || dawg_info->permuter != FREQ_DAWG_PERM) {
|
|
adjustment += language_model_penalty_non_freq_dict_word;
|
|
}
|
|
if (dawg_score == 0.0f) {
|
|
adjustment += language_model_penalty_non_dict_word;
|
|
if (length > language_model_min_compound_length) {
|
|
adjustment += ((length - language_model_min_compound_length) *
|
|
language_model_penalty_increment);
|
|
}
|
|
} else if (dawg_score < 1.0f) {
|
|
adjustment += (1.0f - dawg_score) * language_model_penalty_non_dict_word;
|
|
}
|
|
if (associate_stats.shape_cost > 0) {
|
|
adjustment += associate_stats.shape_cost / static_cast<float>(length);
|
|
}
|
|
if (language_model_ngram_on) {
|
|
ASSERT_HOST(ngram_info != NULL);
|
|
return ngram_info->ngram_cost * adjustment;
|
|
} else {
|
|
adjustment += ComputeConsistencyAdjustment(dawg_info, consistency_info);
|
|
return ratings_sum * adjustment;
|
|
}
|
|
}
|
|
|
|
void LanguageModel::UpdateBestChoice(
|
|
BLOB_CHOICE *b,
|
|
ViterbiStateEntry *vse,
|
|
HEAP *pain_points,
|
|
CHUNKS_RECORD *chunks_record,
|
|
BestChoiceBundle *best_choice_bundle) {
|
|
int i;
|
|
BLOB_CHOICE_LIST_VECTOR temp_best_char_choices(vse->length);
|
|
for (i = 0; i < vse->length; ++i) {
|
|
temp_best_char_choices.push_back(NULL);
|
|
}
|
|
float *certainties = new float[vse->length];
|
|
STATE temp_state;
|
|
// The fraction of letters in the path that are "covered" by dawgs.
|
|
// For space delimited languages this number will be 0.0 for nonwords
|
|
// and 1.0 for dictionary words. For non-space delimited languages
|
|
// this number will be in the [0.0, 1.0] range.
|
|
float dawg_score;
|
|
WERD_CHOICE *word = ConstructWord(b, vse, chunks_record,
|
|
&temp_best_char_choices, certainties,
|
|
&dawg_score, &temp_state);
|
|
ASSERT_HOST(word != NULL);
|
|
|
|
// Log new segmentation (for dict_->LogNewChoice()).
|
|
PIECES_STATE pieces_widths;
|
|
bin_to_pieces(&temp_state, chunks_record->ratings->dimension() - 1,
|
|
pieces_widths);
|
|
dict_->LogNewSegmentation(pieces_widths);
|
|
|
|
if (language_model_debug_level > 0) {
|
|
STRING word_str;
|
|
word->string_and_lengths(dict_->getUnicharset(), &word_str, NULL);
|
|
tprintf("UpdateBestChoice() constructed word %s\n", word_str.string());
|
|
if (language_model_debug_level > 2) word->print();
|
|
};
|
|
|
|
// Update raw_choice if needed.
|
|
if ((vse->top_choice_flags & kSmallestRatingFlag) &&
|
|
word->rating() < best_choice_bundle->raw_choice->rating()) {
|
|
dict_->LogNewChoice(1.0, certainties, true, word);
|
|
*(best_choice_bundle->raw_choice) = *word;
|
|
best_choice_bundle->raw_choice->set_permuter(TOP_CHOICE_PERM);
|
|
best_choice_bundle->raw_choice->populate_unichars(dict_->getUnicharset());
|
|
if (language_model_debug_level > 0) tprintf("Updated raw choice\n");
|
|
}
|
|
|
|
// When working with non-space delimited languages we re-adjust the cost
|
|
// taking into account the final dawg_score and a more precise shape cost.
|
|
// While constructing the paths we assume that they are all dictionary words
|
|
// (since a single character would be a valid dictionary word). At the end
|
|
// we compute dawg_score which reflects how many characters on the path are
|
|
// "covered" by dictionary words of length > 1.
|
|
if (vse->associate_stats.full_wh_ratio_var != 0.0f ||
|
|
(dict_->GetMaxFixedLengthDawgIndex() >= 0 && dawg_score < 1.0f)) {
|
|
vse->cost = ComputeAdjustedPathCost(
|
|
vse->ratings_sum, vse->length, dawg_score, vse->dawg_info,
|
|
vse->ngram_info, vse->consistency_info, vse->associate_stats,
|
|
vse->parent_vse);
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("Updated vse cost to %g (dawg_score %g full_wh_ratio_var %g)\n",
|
|
vse->cost, dawg_score, vse->associate_stats.full_wh_ratio_var);
|
|
}
|
|
}
|
|
|
|
// Update best choice and best char choices if needed.
|
|
// TODO(daria): re-write AcceptableChoice() and NoDangerousAmbig()
|
|
// to fit better into the new segmentation search.
|
|
word->set_rating(vse->cost);
|
|
if (word->rating() < best_choice_bundle->best_choice->rating()) {
|
|
dict_->LogNewChoice(vse->cost / (language_model_ngram_on ?
|
|
vse->ngram_info->ngram_cost :
|
|
vse->ratings_sum),
|
|
certainties, false, word);
|
|
// Since the rating of the word could have been modified by
|
|
// Dict::LogNewChoice() - check again.
|
|
if (word->rating() < best_choice_bundle->best_choice->rating()) {
|
|
bool modified_blobs; // not used
|
|
DANGERR fixpt;
|
|
if (dict_->AcceptableChoice(&temp_best_char_choices, word, &fixpt,
|
|
ASSOCIATOR_CALLER, &modified_blobs) &&
|
|
AcceptablePath(*vse)) {
|
|
acceptable_choice_found_ = true;
|
|
}
|
|
// Update best_choice_bundle.
|
|
*(best_choice_bundle->best_choice) = *word;
|
|
best_choice_bundle->best_choice->populate_unichars(
|
|
dict_->getUnicharset());
|
|
best_choice_bundle->updated = true;
|
|
best_choice_bundle->best_char_choices->delete_data_pointers();
|
|
best_choice_bundle->best_char_choices->clear();
|
|
for (i = 0; i < temp_best_char_choices.size(); ++i) {
|
|
BLOB_CHOICE_LIST *cc_list = new BLOB_CHOICE_LIST();
|
|
cc_list->deep_copy(temp_best_char_choices[i], &BLOB_CHOICE::deep_copy);
|
|
best_choice_bundle->best_char_choices->push_back(cc_list);
|
|
}
|
|
best_choice_bundle->best_state->part2 = temp_state.part2;
|
|
best_choice_bundle->best_state->part1 = temp_state.part1;
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("Updated best choice\n");
|
|
print_state("New state ", best_choice_bundle->best_state,
|
|
chunks_record->ratings->dimension()-1);
|
|
}
|
|
// Update hyphen state if we are dealing with a dictionary word.
|
|
if (vse->dawg_info != NULL && dict_->GetMaxFixedLengthDawgIndex() < 0) {
|
|
if (dict_->has_hyphen_end(*word)) {
|
|
dict_->set_hyphen_word(*word, *(dawg_args_->active_dawgs),
|
|
*(dawg_args_->constraints));
|
|
} else {
|
|
dict_->reset_hyphen_vars(true);
|
|
}
|
|
}
|
|
best_choice_bundle->best_vse = vse;
|
|
best_choice_bundle->best_b = b;
|
|
best_choice_bundle->fixpt = fixpt;
|
|
}
|
|
}
|
|
|
|
// Clean up.
|
|
delete[] certainties;
|
|
delete word;
|
|
}
|
|
|
|
WERD_CHOICE *LanguageModel::ConstructWord(
|
|
BLOB_CHOICE *b,
|
|
ViterbiStateEntry *vse,
|
|
CHUNKS_RECORD *chunks_record,
|
|
BLOB_CHOICE_LIST_VECTOR *best_char_choices,
|
|
float certainties[],
|
|
float *dawg_score,
|
|
STATE *state) {
|
|
uinT64 state_uint = 0x0;
|
|
// Construct a WERD_CHOICE by tracing parent pointers.
|
|
WERD_CHOICE *word = new WERD_CHOICE(vse->length);
|
|
word->set_length(vse->length);
|
|
const uinT64 kHighestBitOn = 0x8000000000000000LL;
|
|
BLOB_CHOICE *curr_b = b;
|
|
LanguageModelState *curr_lms =
|
|
reinterpret_cast<LanguageModelState *>(b->language_model_state());
|
|
ViterbiStateEntry *curr_vse = vse;
|
|
|
|
int i;
|
|
bool compound = dict_->hyphenated(); // treat hyphenated words as compound
|
|
bool dawg_score_done = true;
|
|
if (dawg_score != NULL) {
|
|
*dawg_score = 0.0f;
|
|
// For space-delimited languages the presence of dawg_info in the last
|
|
// ViterbyStateEntry on the path means that the whole path represents
|
|
// a valid dictionary word.
|
|
if (dict_->GetMaxFixedLengthDawgIndex() < 0) {
|
|
if (vse->dawg_info != NULL) *dawg_score = 1.0f;
|
|
} else if (vse->length == 1) {
|
|
*dawg_score = 1.0f; // each one-letter word is legal
|
|
dawg_score_done = true; // in non-space delimited languages
|
|
} else {
|
|
dawg_score_done = false; // do more work to compute dawg_score
|
|
}
|
|
}
|
|
// For non-space delimited languages we compute the fraction of letters
|
|
// "covered" by fixed length dawgs (i.e. words of length > 1 on the path).
|
|
int covered_by_fixed_length_dawgs = 0;
|
|
// The number of unichars remaining that should be skipped because
|
|
// they are covered by the previous word from fixed length dawgs.
|
|
int fixed_length_num_unichars_to_skip = 0;
|
|
|
|
// Re-compute the variance of the width-to-hight ratios (since we now
|
|
// can compute the mean over the whole word).
|
|
float full_wh_ratio_mean = 0.0f;
|
|
if (vse->associate_stats.full_wh_ratio_var != 0.0f) {
|
|
vse->associate_stats.shape_cost -= vse->associate_stats.full_wh_ratio_var;
|
|
full_wh_ratio_mean = (vse->associate_stats.full_wh_ratio_total /
|
|
static_cast<float>(vse->length));
|
|
vse->associate_stats.full_wh_ratio_var = 0.0f;
|
|
}
|
|
|
|
for (i = (vse->length-1); i >= 0; --i) {
|
|
word->set_unichar_id(curr_b->unichar_id(), i);
|
|
word->set_fragment_length(1, i);
|
|
if (certainties != NULL) certainties[i] = curr_b->certainty();
|
|
if (best_char_choices != NULL) {
|
|
best_char_choices->set(chunks_record->ratings->get(
|
|
curr_lms->contained_in_col, curr_lms->contained_in_row), i);
|
|
}
|
|
if (state != NULL) {
|
|
// Record row minus col zeroes in the reverse state to mark the number
|
|
// of joins done by using a blob from this cell in the ratings matrix.
|
|
state_uint >>= (curr_lms->contained_in_row - curr_lms->contained_in_col);
|
|
// Record a one in the reverse state to indicate the split before
|
|
// the blob from the next cell in the ratings matrix (unless we are
|
|
// at the first cell, in which case there is no next blob).
|
|
if (i != 0) {
|
|
state_uint >>= 1;
|
|
state_uint |= kHighestBitOn;
|
|
}
|
|
}
|
|
// For non-space delimited languages: find word endings recorded while
|
|
// trying to separate the current path into words (for words found in
|
|
// fixed length dawgs.
|
|
if (!dawg_score_done && curr_vse->dawg_info != NULL) {
|
|
UpdateCoveredByFixedLengthDawgs(*(curr_vse->dawg_info->active_dawgs),
|
|
i, vse->length,
|
|
&fixed_length_num_unichars_to_skip,
|
|
&covered_by_fixed_length_dawgs,
|
|
dawg_score, &dawg_score_done);
|
|
}
|
|
// Update the width-to-height ratio variance. Useful non-space delimited
|
|
// languages to ensure that the blobs are of uniform width.
|
|
// Skip leading and trailing punctuation when computing the variance.
|
|
if ((full_wh_ratio_mean != 0.0f &&
|
|
((curr_vse != vse && curr_vse->parent_vse != NULL) ||
|
|
!dict_->getUnicharset().get_ispunctuation(curr_b->unichar_id())))) {
|
|
vse->associate_stats.full_wh_ratio_var +=
|
|
pow(full_wh_ratio_mean - curr_vse->associate_stats.full_wh_ratio, 2);
|
|
if (language_model_debug_level > 2) {
|
|
tprintf("full_wh_ratio_var += (%g-%g)^2\n",
|
|
full_wh_ratio_mean, curr_vse->associate_stats.full_wh_ratio);
|
|
}
|
|
}
|
|
|
|
// Mark the word as compound if compound permuter was set for any of
|
|
// the unichars on the path (usually this will happen for unichars
|
|
// that are compounding operators, like "-" and "/").
|
|
if (!compound && curr_vse->dawg_info &&
|
|
curr_vse->dawg_info->permuter == COMPOUND_PERM) compound = true;
|
|
|
|
// Update curr_* pointers.
|
|
if (curr_vse->parent_b == NULL) break;
|
|
curr_b = curr_vse->parent_b;
|
|
curr_lms =
|
|
reinterpret_cast<LanguageModelState *>(curr_b->language_model_state());
|
|
curr_vse = curr_vse->parent_vse;
|
|
}
|
|
ASSERT_HOST(i == 0); // check that we recorded all the unichar ids
|
|
// Re-adjust shape cost to include the updated width-to-hight variance.
|
|
if (full_wh_ratio_mean != 0.0f) {
|
|
vse->associate_stats.shape_cost += vse->associate_stats.full_wh_ratio_var;
|
|
}
|
|
|
|
if (state != NULL) {
|
|
state_uint >>= (64 - (chunks_record->ratings->dimension()-1));
|
|
state->part2 = state_uint;
|
|
state_uint >>= 32;
|
|
state->part1 = state_uint;
|
|
}
|
|
word->set_rating(vse->ratings_sum);
|
|
word->set_certainty(vse->min_certainty);
|
|
if (vse->dawg_info != NULL && dict_->GetMaxFixedLengthDawgIndex() < 0) {
|
|
word->set_permuter(compound ? COMPOUND_PERM : vse->dawg_info->permuter);
|
|
} else if (language_model_ngram_on && !vse->ngram_info->pruned) {
|
|
word->set_permuter(NGRAM_PERM);
|
|
} else if (vse->top_choice_flags) {
|
|
word->set_permuter(TOP_CHOICE_PERM);
|
|
} else {
|
|
word->set_permuter(NO_PERM);
|
|
}
|
|
return word;
|
|
}
|
|
|
|
void LanguageModel::UpdateCoveredByFixedLengthDawgs(
|
|
const DawgInfoVector &active_dawgs, int word_index, int word_length,
|
|
int *skip, int *covered, float *dawg_score, bool *dawg_score_done) {
|
|
if (language_model_debug_level > 3) {
|
|
tprintf("UpdateCoveredByFixedLengthDawgs for index %d skip=%d\n",
|
|
word_index, *skip, word_length);
|
|
}
|
|
|
|
if (*skip > 0) {
|
|
--(*skip);
|
|
} else {
|
|
int best_index = -1;
|
|
for (int d = 0; d < active_dawgs.size(); ++d) {
|
|
int dawg_index = (active_dawgs[d]).dawg_index;
|
|
if (dawg_index > dict_->GetMaxFixedLengthDawgIndex()) {
|
|
// If active_dawgs of the last ViterbiStateEntry on the path
|
|
// contain a non-fixed length dawg, this means that the whole
|
|
// path represents a word from a non-fixed length word dawg.
|
|
if (word_index == (word_length - 1)) {
|
|
*dawg_score = 1.0f;
|
|
*dawg_score_done = true;
|
|
return;
|
|
}
|
|
} else if (dawg_index >= kMinFixedLengthDawgLength) {
|
|
const Dawg *curr_dawg = dict_->GetDawg(dawg_index);
|
|
ASSERT_HOST(curr_dawg != NULL);
|
|
if ((active_dawgs[d]).ref != NO_EDGE &&
|
|
curr_dawg->end_of_word((active_dawgs[d]).ref) &&
|
|
dawg_index > best_index) {
|
|
best_index = dawg_index;
|
|
}
|
|
|
|
if (language_model_debug_level > 3) {
|
|
tprintf("dawg_index %d, ref %d, eow %d\n", dawg_index,
|
|
(active_dawgs[d]).ref,
|
|
((active_dawgs[d]).ref != NO_EDGE &&
|
|
curr_dawg->end_of_word((active_dawgs[d]).ref)));
|
|
}
|
|
}
|
|
} // end for
|
|
if (best_index != -1) {
|
|
*skip = best_index - 1;
|
|
*covered += best_index;
|
|
}
|
|
} // end if/else skip
|
|
|
|
if (word_index == 0) {
|
|
ASSERT_HOST(*covered <= word_length);
|
|
*dawg_score = (static_cast<float>(*covered) /
|
|
static_cast<float>(word_length));
|
|
*dawg_score_done = true;
|
|
}
|
|
}
|
|
|
|
void LanguageModel::GeneratePainPointsFromColumn(
|
|
int col,
|
|
const GenericVector<int> &non_empty_rows,
|
|
float best_choice_cert,
|
|
HEAP *pain_points,
|
|
BestPathByColumn *best_path_by_column[],
|
|
CHUNKS_RECORD *chunks_record) {
|
|
for (int i = 0; i < non_empty_rows.length(); ++i) {
|
|
int row = non_empty_rows[i];
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("\nLooking for pain points in col=%d row=%d\n", col, row);
|
|
}
|
|
if (language_model_ngram_on) {
|
|
GenerateNgramModelPainPointsFromColumn(
|
|
col, row, pain_points, chunks_record);
|
|
} else {
|
|
GenerateProblematicPathPainPointsFromColumn(
|
|
col, row, best_choice_cert, pain_points,
|
|
best_path_by_column, chunks_record);
|
|
}
|
|
}
|
|
}
|
|
|
|
void LanguageModel::GenerateNgramModelPainPointsFromColumn(
|
|
int col, int row, HEAP *pain_points, CHUNKS_RECORD *chunks_record) {
|
|
// Find the first top choice path recorded for this cell.
|
|
// If this path is prunned - generate a pain point.
|
|
ASSERT_HOST(chunks_record->ratings->get(col, row) != NULL);
|
|
BLOB_CHOICE_IT bit(chunks_record->ratings->get(col, row));
|
|
bool fragmented = false;
|
|
for (bit.mark_cycle_pt(); !bit.cycled_list(); bit.forward()) {
|
|
if (dict_->getUnicharset().get_fragment(bit.data()->unichar_id())) {
|
|
fragmented = true;
|
|
continue;
|
|
}
|
|
LanguageModelState *lms = reinterpret_cast<LanguageModelState *>(
|
|
bit.data()->language_model_state());
|
|
if (lms == NULL) continue;
|
|
ViterbiStateEntry_IT vit(&(lms->viterbi_state_entries));
|
|
for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) {
|
|
const ViterbiStateEntry *vse = vit.data();
|
|
if (!vse->top_choice_flags) continue;
|
|
ASSERT_HOST(vse->ngram_info != NULL);
|
|
if (vse->ngram_info->pruned && (vse->parent_vse == NULL ||
|
|
!vse->parent_vse->ngram_info->pruned)) {
|
|
if (vse->parent_vse != NULL) {
|
|
LanguageModelState *pp_lms = reinterpret_cast<LanguageModelState *>(
|
|
vse->parent_b->language_model_state());
|
|
GeneratePainPoint(pp_lms->contained_in_col, row, false,
|
|
kDefaultPainPointPriorityAdjustment,
|
|
-1.0f, fragmented, -1.0f,
|
|
max_char_wh_ratio_,
|
|
vse->parent_vse->parent_b,
|
|
vse->parent_vse->parent_vse,
|
|
chunks_record, pain_points);
|
|
}
|
|
if (vse->parent_vse != NULL &&
|
|
vse->parent_vse->parent_vse != NULL &&
|
|
dict_->getUnicharset().get_ispunctuation(
|
|
vse->parent_b->unichar_id())) {
|
|
// If the dip in the ngram probability is due to punctuation in the
|
|
// middle of the word - go two unichars back to combine this
|
|
// puntuation mark with the previous character on the path.
|
|
LanguageModelState *pp_lms = reinterpret_cast<LanguageModelState *>(
|
|
vse->parent_vse->parent_b->language_model_state());
|
|
GeneratePainPoint(pp_lms->contained_in_col, col-1, false,
|
|
kDefaultPainPointPriorityAdjustment,
|
|
-1.0f, fragmented, -1.0f,
|
|
max_char_wh_ratio_,
|
|
vse->parent_vse->parent_vse->parent_b,
|
|
vse->parent_vse->parent_vse->parent_vse,
|
|
chunks_record, pain_points);
|
|
} else if (row+1 < chunks_record->ratings->dimension()) {
|
|
GeneratePainPoint(col, row+1, true,
|
|
kDefaultPainPointPriorityAdjustment,
|
|
-1.0f, fragmented, -1.0f,
|
|
max_char_wh_ratio_,
|
|
vse->parent_b,
|
|
vse->parent_vse,
|
|
chunks_record, pain_points);
|
|
}
|
|
}
|
|
return; // examined the lowest cost top choice path
|
|
}
|
|
}
|
|
}
|
|
|
|
void LanguageModel::GenerateProblematicPathPainPointsFromColumn(
|
|
int col, int row, float best_choice_cert,
|
|
HEAP *pain_points, BestPathByColumn *best_path_by_column[],
|
|
CHUNKS_RECORD *chunks_record) {
|
|
MATRIX *ratings = chunks_record->ratings;
|
|
|
|
// Get the best path from this matrix cell.
|
|
BLOB_CHOICE_LIST *blist = ratings->get(col, row);
|
|
ASSERT_HOST(blist != NULL);
|
|
if (blist->empty()) return;
|
|
BLOB_CHOICE_IT bit(blist);
|
|
bool fragment = false;
|
|
while (dict_->getUnicharset().get_fragment(bit.data()->unichar_id()) &&
|
|
!bit.at_last()) { // skip fragments
|
|
fragment = true;
|
|
bit.forward();
|
|
}
|
|
if (bit.data()->language_model_state() == NULL) return;
|
|
ViterbiStateEntry_IT vit(&(reinterpret_cast<LanguageModelState *>(
|
|
bit.data()->language_model_state())->viterbi_state_entries));
|
|
if (vit.empty()) return;
|
|
ViterbiStateEntry *vse = vit.data();
|
|
// Check whether this path is promising.
|
|
bool path_is_promising = true;
|
|
if (vse->parent_vse != NULL) {
|
|
float potential_avg_cost =
|
|
((vse->parent_vse->cost + bit.data()->rating()*0.5f) /
|
|
static_cast<float>(row+1));
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("potential_avg_cost %g best cost %g\n",
|
|
potential_avg_cost, (*best_path_by_column)[col].avg_cost);
|
|
}
|
|
if (potential_avg_cost >= (*best_path_by_column)[col].avg_cost) {
|
|
path_is_promising = false;
|
|
}
|
|
}
|
|
// Set best_parent_vse to the best parent recorded in best_path_by_column.
|
|
ViterbiStateEntry *best_parent_vse = vse->parent_vse;
|
|
BLOB_CHOICE *best_parent_b = vse->parent_b;
|
|
if (col > 0 && (*best_path_by_column)[col-1].best_vse != NULL) {
|
|
ASSERT_HOST((*best_path_by_column)[col-1].best_b != NULL);
|
|
LanguageModelState *best_lms = reinterpret_cast<LanguageModelState *>(
|
|
((*best_path_by_column)[col-1].best_b)->language_model_state());
|
|
if (best_lms->contained_in_row == col-1) {
|
|
best_parent_vse = (*best_path_by_column)[col-1].best_vse;
|
|
best_parent_b = (*best_path_by_column)[col-1].best_b;
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("Setting best_parent_vse to %p\n", best_parent_vse);
|
|
}
|
|
}
|
|
}
|
|
// Check whether this entry terminates the best parent path
|
|
// recorded in best_path_by_column.
|
|
bool best_not_prolonged = (best_parent_vse != vse->parent_vse);
|
|
|
|
// If this path is problematic because of the last unichar - generate
|
|
// a pain point to combine it with its left and right neighbor.
|
|
BLOB_CHOICE_IT tmp_bit;
|
|
if (best_not_prolonged ||
|
|
(path_is_promising &&
|
|
ProblematicPath(*vse, bit.data()->unichar_id(),
|
|
row+1 == ratings->dimension()))) {
|
|
float worst_piece_cert;
|
|
bool fragmented;
|
|
if (col-1 > 0) {
|
|
GetWorstPieceCertainty(col-1, row, chunks_record->ratings,
|
|
&worst_piece_cert, &fragmented);
|
|
GeneratePainPoint(col-1, row, false,
|
|
kDefaultPainPointPriorityAdjustment,
|
|
worst_piece_cert, fragmented, best_choice_cert,
|
|
max_char_wh_ratio_, best_parent_b, best_parent_vse,
|
|
chunks_record, pain_points);
|
|
}
|
|
if (row+1 < ratings->dimension()) {
|
|
GetWorstPieceCertainty(col, row+1, chunks_record->ratings,
|
|
&worst_piece_cert, &fragmented);
|
|
GeneratePainPoint(col, row+1, true, kDefaultPainPointPriorityAdjustment,
|
|
worst_piece_cert, fragmented, best_choice_cert,
|
|
max_char_wh_ratio_, best_parent_b, best_parent_vse,
|
|
chunks_record, pain_points);
|
|
}
|
|
} // for ProblematicPath
|
|
}
|
|
|
|
void LanguageModel::GeneratePainPointsFromBestChoice(
|
|
HEAP *pain_points,
|
|
CHUNKS_RECORD *chunks_record,
|
|
BestChoiceBundle *best_choice_bundle) {
|
|
// Variables to backtrack best_vse path;
|
|
ViterbiStateEntry *curr_vse = best_choice_bundle->best_vse;
|
|
BLOB_CHOICE *curr_b = best_choice_bundle->best_b;
|
|
|
|
// Begins and ends in DANGERR vector record the positions in the blob choice
|
|
// list of the best choice. We need to translate these endpoints into the
|
|
// beginning column and ending row for the pain points. We maintain
|
|
// danger_begin and danger_end arrays indexed by the position in
|
|
// best_choice_bundle->best_char_choices (which is equal to the position
|
|
// on the best_choice_bundle->best_vse path).
|
|
// danger_end[d] stores the DANGERR_INFO structs with end==d and is
|
|
// initialized at the beginning of this function.
|
|
// danger_begin[d] stores the DANGERR_INFO struct with begin==d and
|
|
// has end set to the row of the end of this ambiguity.
|
|
// The translation from end in terms of the best choice index to the end row
|
|
// is done while following the parents of best_choice_bundle->best_vse.
|
|
assert(best_choice_bundle->best_char_choices->length() ==
|
|
best_choice_bundle->best_vse->length);
|
|
DANGERR *danger_begin = NULL;
|
|
DANGERR *danger_end = NULL;
|
|
int d;
|
|
if (!best_choice_bundle->fixpt.empty()) {
|
|
danger_begin = new DANGERR[best_choice_bundle->best_vse->length];
|
|
danger_end = new DANGERR[best_choice_bundle->best_vse->length];
|
|
for (d = 0; d < best_choice_bundle->fixpt.size(); ++d) {
|
|
const DANGERR_INFO &danger = best_choice_bundle->fixpt[d];
|
|
// Only use n->1 ambiguities.
|
|
if (danger.end > danger.begin && !danger.correct_is_ngram &&
|
|
(!language_model_ngram_on || danger.dangerous)) {
|
|
danger_end[danger.end].push_back(danger);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Variables to keep track of punctuation/number streaks.
|
|
int punc_streak_end_row = -1;
|
|
int punc_streak_length = 0;
|
|
float punc_streak_min_cert = 0.0f;
|
|
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("\nGenerating pain points for best path=%p\n", curr_vse);
|
|
}
|
|
|
|
int word_index = best_choice_bundle->best_vse->length;
|
|
while (curr_vse != NULL) {
|
|
word_index--;
|
|
ASSERT_HOST(word_index >= 0);
|
|
ASSERT_HOST(curr_b != NULL);
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("Looking at unichar %s\n",
|
|
dict_->getUnicharset().id_to_unichar(curr_b->unichar_id()));
|
|
}
|
|
|
|
int pp_col = reinterpret_cast<LanguageModelState *>(
|
|
curr_b->language_model_state())->contained_in_col;
|
|
int pp_row = reinterpret_cast<LanguageModelState *>(
|
|
curr_b->language_model_state())->contained_in_row;
|
|
|
|
// Generate pain points for ambiguities found by NoDangerousAmbig().
|
|
if (danger_end != NULL) {
|
|
// Translate end index of an ambiguity to an end row.
|
|
for (d = 0; d < danger_end[word_index].size(); ++d) {
|
|
danger_end[word_index][d].end = pp_row;
|
|
danger_begin[danger_end[word_index][d].begin].push_back(
|
|
danger_end[word_index][d]);
|
|
}
|
|
// Generate a pain point to combine unchars in the "wrong" part
|
|
// of the ambiguity.
|
|
for (d = 0; d < danger_begin[word_index].size(); ++d) {
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("Generating pain point from %sambiguity\n",
|
|
danger_begin[word_index][d].dangerous ? "dangerous " : "");
|
|
}
|
|
GeneratePainPoint(pp_col, danger_begin[word_index][d].end, false,
|
|
danger_begin[word_index][d].dangerous ?
|
|
kCriticalPainPointPriorityAdjustment :
|
|
kBestChoicePainPointPriorityAdjustment,
|
|
best_choice_bundle->best_choice->certainty(), true,
|
|
best_choice_bundle->best_choice->certainty(),
|
|
kLooseMaxCharWhRatio,
|
|
curr_vse->parent_b, curr_vse->parent_vse,
|
|
chunks_record, pain_points);
|
|
}
|
|
}
|
|
|
|
if (!language_model_ngram_on) { // no need to use further heuristics if we
|
|
// are guided by the character ngram model
|
|
// Generate pain points for problematic sub-paths.
|
|
if (ProblematicPath(*curr_vse, curr_b->unichar_id(),
|
|
pp_row+1 == chunks_record->ratings->dimension())) {
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("Generating pain point from a problematic sub-path\n");
|
|
}
|
|
float worst_piece_cert;
|
|
bool fragment;
|
|
if (pp_col > 0) {
|
|
GetWorstPieceCertainty(pp_col-1, pp_row, chunks_record->ratings,
|
|
&worst_piece_cert, &fragment);
|
|
GeneratePainPoint(pp_col-1, pp_row, false,
|
|
kBestChoicePainPointPriorityAdjustment,
|
|
worst_piece_cert, true,
|
|
best_choice_bundle->best_choice->certainty(),
|
|
max_char_wh_ratio_, NULL, NULL,
|
|
chunks_record, pain_points);
|
|
}
|
|
if (pp_row+1 < chunks_record->ratings->dimension()) {
|
|
GetWorstPieceCertainty(pp_col, pp_row+1, chunks_record->ratings,
|
|
&worst_piece_cert, &fragment);
|
|
GeneratePainPoint(pp_col, pp_row+1, true,
|
|
kBestChoicePainPointPriorityAdjustment,
|
|
worst_piece_cert, true,
|
|
best_choice_bundle->best_choice->certainty(),
|
|
max_char_wh_ratio_, NULL, NULL,
|
|
chunks_record, pain_points);
|
|
}
|
|
}
|
|
|
|
// Generate a pain point if we encountered a punctuation/number streak to
|
|
// combine all punctuation marks into one blob and try to classify it.
|
|
bool is_alpha = dict_->getUnicharset().get_isalpha(curr_b->unichar_id());
|
|
if (!is_alpha) {
|
|
if (punc_streak_end_row == -1) punc_streak_end_row = pp_row;
|
|
punc_streak_length++;
|
|
if (curr_b->certainty() < punc_streak_min_cert)
|
|
punc_streak_min_cert = curr_b->certainty();
|
|
}
|
|
if (is_alpha || curr_vse->parent_vse == NULL) {
|
|
if (punc_streak_end_row != -1 && punc_streak_length > 1) {
|
|
if (language_model_debug_level > 0) {
|
|
tprintf("Generating pain point from a punctuation streak\n");
|
|
}
|
|
if (is_alpha ||
|
|
(curr_vse->parent_vse == NULL && punc_streak_length > 2)) {
|
|
GeneratePainPoint(pp_row+1, punc_streak_end_row, false,
|
|
kBestChoicePainPointPriorityAdjustment,
|
|
punc_streak_min_cert, true,
|
|
best_choice_bundle->best_choice->certainty(),
|
|
max_char_wh_ratio_, curr_b, curr_vse,
|
|
chunks_record, pain_points);
|
|
}
|
|
// Handle punctuation/number streaks starting from the first unichar.
|
|
if (curr_vse->parent_vse == NULL) {
|
|
GeneratePainPoint(0, punc_streak_end_row, false,
|
|
kBestChoicePainPointPriorityAdjustment,
|
|
punc_streak_min_cert, true,
|
|
best_choice_bundle->best_choice->certainty(),
|
|
max_char_wh_ratio_, NULL, NULL,
|
|
chunks_record, pain_points);
|
|
}
|
|
}
|
|
punc_streak_end_row = -1;
|
|
punc_streak_length = 0;
|
|
punc_streak_min_cert = 0.0f;
|
|
} // end handling punctuation streaks
|
|
}
|
|
|
|
curr_b = curr_vse->parent_b;
|
|
curr_vse = curr_vse->parent_vse;
|
|
} // end looking at best choice subpaths
|
|
|
|
// Clean up.
|
|
if (danger_end != NULL) {
|
|
delete[] danger_begin;
|
|
delete[] danger_end;
|
|
}
|
|
}
|
|
|
|
void LanguageModel::GeneratePainPoint(
|
|
int col, int row, bool ok_to_extend, float priority,
|
|
float worst_piece_cert, bool fragmented, float best_choice_cert,
|
|
float max_char_wh_ratio,
|
|
BLOB_CHOICE *parent_b, ViterbiStateEntry *parent_vse,
|
|
CHUNKS_RECORD *chunks_record, HEAP *pain_points) {
|
|
if (col < 0 || row >= chunks_record->ratings->dimension() ||
|
|
chunks_record->ratings->get(col, row) != NOT_CLASSIFIED) {
|
|
return;
|
|
}
|
|
if (language_model_debug_level > 3) {
|
|
tprintf("\nGenerating pain point for col=%d row=%d priority=%g parent=",
|
|
col, row, priority);
|
|
if (parent_vse != NULL) {
|
|
PrintViterbiStateEntry("", parent_vse, parent_b, chunks_record);
|
|
} else {
|
|
tprintf("NULL");
|
|
}
|
|
tprintf("\n");
|
|
}
|
|
|
|
AssociateStats associate_stats;
|
|
ComputeAssociateStats(col, row, max_char_wh_ratio, parent_vse,
|
|
chunks_record, &associate_stats);
|
|
// For fixed-pitch fonts/languages: if the current combined blob overlaps
|
|
// the next blob on the right and it is ok to extend the blob, try expending
|
|
// the blob untill there is no overlap with the next blob on the right or
|
|
// until the width-to-hight ratio becomes too large.
|
|
if (ok_to_extend) {
|
|
while (associate_stats.bad_fixed_pitch_right_gap &&
|
|
row+1 < chunks_record->ratings->dimension() &&
|
|
!associate_stats.bad_fixed_pitch_wh_ratio) {
|
|
ComputeAssociateStats(col, ++row, max_char_wh_ratio, parent_vse,
|
|
chunks_record, &associate_stats);
|
|
}
|
|
}
|
|
|
|
if (associate_stats.bad_shape) {
|
|
if (language_model_debug_level > 3) {
|
|
tprintf("Discarded pain point with a bad shape\n");
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Compute pain point priority.
|
|
if (associate_stats.shape_cost > 0) {
|
|
priority *= associate_stats.shape_cost;
|
|
}
|
|
if (worst_piece_cert < best_choice_cert) {
|
|
worst_piece_cert = best_choice_cert;
|
|
}
|
|
priority *= CertaintyScore(worst_piece_cert);
|
|
if (fragmented) priority /= kDefaultPainPointPriorityAdjustment;
|
|
|
|
if (language_model_debug_level > 3) {
|
|
tprintf("worst_piece_cert=%g fragmented=%d\n",
|
|
worst_piece_cert, fragmented);
|
|
}
|
|
|
|
if (parent_vse != NULL) {
|
|
priority *= sqrt(parent_vse->cost / static_cast<float>(col));
|
|
if (parent_vse->dawg_info != NULL) {
|
|
priority /= kDefaultPainPointPriorityAdjustment;
|
|
if (parent_vse->length > language_model_min_compound_length) {
|
|
priority /= sqrt(static_cast<double>(parent_vse->length));
|
|
}
|
|
}
|
|
}
|
|
|
|
MATRIX_COORD *pain_point = new MATRIX_COORD(col, row);
|
|
if (HeapPushCheckSize(pain_points, priority, pain_point)) {
|
|
if (language_model_debug_level) {
|
|
tprintf("Added pain point with priority %g\n", priority);
|
|
}
|
|
} else {
|
|
delete pain_point;
|
|
if (language_model_debug_level) tprintf("Pain points heap is full\n");
|
|
}
|
|
}
|
|
|
|
} // namespace tesseract
|