// Copyright 2011 Google Inc. All Rights Reserved. // Author: rays@google.com (Ray Smith) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // /////////////////////////////////////////////////////////////////////// #include #include "errorcounter.h" #include "fontinfo.h" #include "ndminx.h" #include "sampleiterator.h" #include "shapeclassifier.h" #include "shapetable.h" #include "trainingsample.h" #include "trainingsampleset.h" #include "unicity_table.h" namespace tesseract { // Tests a classifier, computing its error rate. // See errorcounter.h for description of arguments. // Iterates over the samples, calling the classifier in normal/silent mode. // If the classifier makes a CT_UNICHAR_TOPN_ERR error, and the appropriate // report_level is set (4 or greater), it will then call the classifier again // with a debug flag and a keep_this argument to find out what is going on. double ErrorCounter::ComputeErrorRate(ShapeClassifier* classifier, int report_level, CountTypes boosting_mode, const UnicityTable& fontinfo_table, const GenericVector& page_images, SampleIterator* it, double* unichar_error, double* scaled_error, STRING* fonts_report) { int charsetsize = it->shape_table()->unicharset().size(); int shapesize = it->CompactCharsetSize(); int fontsize = it->sample_set()->NumFonts(); ErrorCounter counter(charsetsize, shapesize, fontsize); GenericVector results; clock_t start = clock(); int total_samples = 0; double unscaled_error = 0.0; // Set a number of samples on which to run the classify debug mode. int error_samples = report_level > 3 ? report_level * report_level : 0; // Iterate over all the samples, accumulating errors. for (it->Begin(); !it->AtEnd(); it->Next()) { TrainingSample* mutable_sample = it->MutableSample(); int page_index = mutable_sample->page_num(); Pix* page_pix = 0 <= page_index && page_index < page_images.size() ? page_images[page_index] : NULL; // No debug, no keep this. classifier->ClassifySample(*mutable_sample, page_pix, 0, INVALID_UNICHAR_ID, &results); if (mutable_sample->class_id() == 0) { // This is junk so use the special counter. counter.AccumulateJunk(*it->shape_table(), results, mutable_sample); } else if (counter.AccumulateErrors(report_level > 3, boosting_mode, fontinfo_table, *it->shape_table(), results, mutable_sample) && error_samples > 0) { // Running debug, keep the correct answer, and debug the classifier. tprintf("Error on sample %d: Classifier debug output:\n", it->GlobalSampleIndex()); int keep_this = it->GetSparseClassID(); classifier->ClassifySample(*mutable_sample, page_pix, 1, keep_this, &results); --error_samples; } ++total_samples; } double total_time = 1.0 * (clock() - start) / CLOCKS_PER_SEC; // Create the appropriate error report. unscaled_error = counter.ReportErrors(report_level, boosting_mode, fontinfo_table, *it, unichar_error, fonts_report); if (scaled_error != NULL) *scaled_error = counter.scaled_error_; if (report_level > 1) { // It is useful to know the time in microseconds/char. tprintf("Errors computed in %.2fs at %.1f μs/char\n", total_time, 1000000.0 * total_time / total_samples); } return unscaled_error; } // Constructor is private. Only anticipated use of ErrorCounter is via // the static ComputeErrorRate. ErrorCounter::ErrorCounter(int charsetsize, int shapesize, int fontsize) : scaled_error_(0.0), unichar_counts_(charsetsize, shapesize, 0) { Counts empty_counts; font_counts_.init_to_size(fontsize, empty_counts); } ErrorCounter::~ErrorCounter() { } // Accumulates the errors from the classifier results on a single sample. // Returns true if debug is true and a CT_UNICHAR_TOPN_ERR error occurred. // boosting_mode selects the type of error to be used for boosting and the // is_error_ member of sample is set according to whether the required type // of error occurred. The font_table provides access to font properties // for error counting and shape_table is used to understand the relationship // between unichar_ids and shape_ids in the results bool ErrorCounter::AccumulateErrors(bool debug, CountTypes boosting_mode, const UnicityTable& font_table, const ShapeTable& shape_table, const GenericVector& results, TrainingSample* sample) { int num_results = results.size(); int res_index = 0; bool debug_it = false; int font_id = sample->font_id(); int unichar_id = sample->class_id(); sample->set_is_error(false); if (num_results == 0) { // Reject. We count rejects as a separate category, but still mark the // sample as an error in case any training module wants to use that to // improve the classifier. sample->set_is_error(true); ++font_counts_[font_id].n[CT_REJECT]; } else if (shape_table.GetShape(results[0].shape_id). ContainsUnicharAndFont(unichar_id, font_id)) { ++font_counts_[font_id].n[CT_SHAPE_TOP_CORRECT]; // Unichar and font OK, but count if multiple unichars. if (shape_table.GetShape(results[0].shape_id).size() > 1) ++font_counts_[font_id].n[CT_OK_MULTI_UNICHAR]; } else { // This is a top shape error. ++font_counts_[font_id].n[CT_SHAPE_TOP_ERR]; // Check to see if any font in the top choice has attributes that match. bool attributes_match = false; uinT32 font_props = font_table.get(font_id).properties; const Shape& shape = shape_table.GetShape(results[0].shape_id); for (int c = 0; c < shape.size() && !attributes_match; ++c) { for (int f = 0; f < shape[c].font_ids.size(); ++f) { if (font_table.get(shape[c].font_ids[f]).properties == font_props) { attributes_match = true; break; } } } // TODO(rays) It is easy to add counters for individual font attributes // here if we want them. if (!attributes_match) ++font_counts_[font_id].n[CT_FONT_ATTR_ERR]; if (boosting_mode == CT_SHAPE_TOP_ERR) sample->set_is_error(true); // Find rank of correct unichar answer. (Ignoring the font.) while (res_index < num_results && !shape_table.GetShape(results[res_index].shape_id). ContainsUnichar(unichar_id)) { ++res_index; } if (res_index == 0) { // Unichar OK, but count if multiple unichars. if (shape_table.GetShape(results[res_index].shape_id).size() > 1) { ++font_counts_[font_id].n[CT_OK_MULTI_UNICHAR]; } } else { // Count maps from unichar id to shape id. if (num_results > 0) ++unichar_counts_(unichar_id, results[0].shape_id); // This is a unichar error. ++font_counts_[font_id].n[CT_UNICHAR_TOP1_ERR]; if (boosting_mode == CT_UNICHAR_TOP1_ERR) sample->set_is_error(true); if (res_index >= MIN(2, num_results)) { // It is also a 2nd choice unichar error. ++font_counts_[font_id].n[CT_UNICHAR_TOP2_ERR]; if (boosting_mode == CT_UNICHAR_TOP2_ERR) sample->set_is_error(true); } if (res_index >= num_results) { // It is also a top-n choice unichar error. ++font_counts_[font_id].n[CT_UNICHAR_TOPN_ERR]; if (boosting_mode == CT_UNICHAR_TOPN_ERR) sample->set_is_error(true); debug_it = debug; } } } // Compute mean number of return values and mean rank of correct answer. font_counts_[font_id].n[CT_NUM_RESULTS] += num_results; font_counts_[font_id].n[CT_RANK] += res_index; // If it was an error for boosting then sum the weight. if (sample->is_error()) { scaled_error_ += sample->weight(); } if (debug_it) { tprintf("%d results for char %s font %d :", num_results, shape_table.unicharset().id_to_unichar(unichar_id), font_id); for (int i = 0; i < num_results; ++i) { tprintf(" %.3f/%.3f:%s", results[i].rating, results[i].font, shape_table.DebugStr(results[i].shape_id).string()); } tprintf("\n"); return true; } return false; } // Accumulates counts for junk. Counts only whether the junk was correctly // rejected or not. void ErrorCounter::AccumulateJunk(const ShapeTable& shape_table, const GenericVector& results, TrainingSample* sample) { // For junk we accept no answer, or an explicit shape answer matching the // class id of the sample. int num_results = results.size(); int font_id = sample->font_id(); int unichar_id = sample->class_id(); if (num_results > 0 && !shape_table.GetShape(results[0].shape_id).ContainsUnichar(unichar_id)) { // This is a junk error. ++font_counts_[font_id].n[CT_ACCEPTED_JUNK]; sample->set_is_error(true); // It counts as an error for boosting too so sum the weight. scaled_error_ += sample->weight(); } else { // Correctly rejected. ++font_counts_[font_id].n[CT_REJECTED_JUNK]; sample->set_is_error(false); } } // Creates a report of the error rate. The report_level controls the detail // that is reported to stderr via tprintf: // 0 -> no output. // >=1 -> bottom-line error rate. // >=3 -> font-level error rate. // boosting_mode determines the return value. It selects which (un-weighted) // error rate to return. // The fontinfo_table from MasterTrainer provides the names of fonts. // The it determines the current subset of the training samples. // If not NULL, the top-choice unichar error rate is saved in unichar_error. // If not NULL, the report string is saved in fonts_report. // (Ignoring report_level). double ErrorCounter::ReportErrors(int report_level, CountTypes boosting_mode, const UnicityTable& fontinfo_table, const SampleIterator& it, double* unichar_error, STRING* fonts_report) { // Compute totals over all the fonts and report individual font results // when required. Counts totals; int fontsize = font_counts_.size(); for (int f = 0; f < fontsize; ++f) { // Accumulate counts over fonts. totals += font_counts_[f]; STRING font_report; if (ReportString(font_counts_[f], &font_report)) { if (fonts_report != NULL) { *fonts_report += fontinfo_table.get(f).name; *fonts_report += ": "; *fonts_report += font_report; *fonts_report += "\n"; } if (report_level > 2) { // Report individual font error rates. tprintf("%s: %s\n", fontinfo_table.get(f).name, font_report.string()); } } } if (report_level > 0) { // Report the totals. STRING total_report; if (ReportString(totals, &total_report)) { tprintf("TOTAL Scaled Err=%.4g%%, %s\n", scaled_error_ * 100.0, total_report.string()); } // Report the worst substitution error only for now. if (totals.n[CT_UNICHAR_TOP1_ERR] > 0) { const UNICHARSET& unicharset = it.shape_table()->unicharset(); int charsetsize = unicharset.size(); int shapesize = it.CompactCharsetSize(); int worst_uni_id = 0; int worst_shape_id = 0; int worst_err = 0; for (int u = 0; u < charsetsize; ++u) { for (int s = 0; s < shapesize; ++s) { if (unichar_counts_(u, s) > worst_err) { worst_err = unichar_counts_(u, s); worst_uni_id = u; worst_shape_id = s; } } } if (worst_err > 0) { tprintf("Worst error = %d:%s -> %s with %d/%d=%.2f%% errors\n", worst_uni_id, unicharset.id_to_unichar(worst_uni_id), it.shape_table()->DebugStr(worst_shape_id).string(), worst_err, totals.n[CT_UNICHAR_TOP1_ERR], 100.0 * worst_err / totals.n[CT_UNICHAR_TOP1_ERR]); } } } double rates[CT_SIZE]; if (!ComputeRates(totals, rates)) return 0.0; // Set output values if asked for. if (unichar_error != NULL) *unichar_error = rates[CT_UNICHAR_TOP1_ERR]; return rates[boosting_mode]; } // Sets the report string to a combined human and machine-readable report // string of the error rates. // Returns false if there is no data, leaving report unchanged. bool ErrorCounter::ReportString(const Counts& counts, STRING* report) { // Compute the error rates. double rates[CT_SIZE]; if (!ComputeRates(counts, rates)) return false; // Using %.4g%%, the length of the output string should exactly match the // length of the format string, but in case of overflow, allow for +eddd // on each number. const int kMaxExtraLength = 5; // Length of +eddd. // Keep this format string and the snprintf in sync with the CountTypes enum. const char* format_str = "ShapeErr=%.4g%%, FontAttr=%.4g%%, " "Unichar=%.4g%%[1], %.4g%%[2], %.4g%%[n], " "Multi=%.4g%%, Rej=%.4g%%, " "Answers=%.3g, Rank=%.3g, " "OKjunk=%.4g%%, Badjunk=%.4g%%"; int max_str_len = strlen(format_str) + kMaxExtraLength * (CT_SIZE - 1) + 1; char* formatted_str = new char[max_str_len]; snprintf(formatted_str, max_str_len, format_str, rates[CT_SHAPE_TOP_ERR] * 100.0, rates[CT_FONT_ATTR_ERR] * 100.0, rates[CT_UNICHAR_TOP1_ERR] * 100.0, rates[CT_UNICHAR_TOP2_ERR] * 100.0, rates[CT_UNICHAR_TOPN_ERR] * 100.0, rates[CT_OK_MULTI_UNICHAR] * 100.0, rates[CT_REJECT] * 100.0, rates[CT_NUM_RESULTS], rates[CT_RANK], 100.0 * rates[CT_REJECTED_JUNK], 100.0 * rates[CT_ACCEPTED_JUNK]); *report = formatted_str; delete [] formatted_str; // Now append each field of counts with a tab in front so the result can // be loaded into a spreadsheet. for (int ct = 0; ct < CT_SIZE; ++ct) report->add_str_int("\t", counts.n[ct]); return true; } // Computes the error rates and returns in rates which is an array of size // CT_SIZE. Returns false if there is no data, leaving rates unchanged. bool ErrorCounter::ComputeRates(const Counts& counts, double rates[CT_SIZE]) { int ok_samples = counts.n[CT_SHAPE_TOP_CORRECT] + counts.n[CT_SHAPE_TOP_ERR] + counts.n[CT_REJECT]; int junk_samples = counts.n[CT_REJECTED_JUNK] + counts.n[CT_ACCEPTED_JUNK]; if (ok_samples == 0 && junk_samples == 0) { // There is no data. return false; } // Compute rates for normal chars. double denominator = static_cast(MAX(ok_samples, 1)); for (int ct = 0; ct <= CT_RANK; ++ct) rates[ct] = counts.n[ct] / denominator; // Compute rates for junk. denominator = static_cast(MAX(junk_samples, 1)); for (int ct = CT_REJECTED_JUNK; ct <= CT_ACCEPTED_JUNK; ++ct) rates[ct] = counts.n[ct] / denominator; return true; } ErrorCounter::Counts::Counts() { memset(n, 0, sizeof(n[0]) * CT_SIZE); } // Adds other into this for computing totals. void ErrorCounter::Counts::operator+=(const Counts& other) { for (int ct = 0; ct < CT_SIZE; ++ct) n[ct] += other.n[ct]; } } // namespace tesseract.