2012-02-02 10:57:42 +08:00
|
|
|
// Copyright 2011 Google Inc. All Rights Reserved.
|
|
|
|
// Author: rays@google.com (Ray Smith)
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
//
|
|
|
|
///////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
#include "sampleiterator.h"
|
|
|
|
|
|
|
|
#include "indexmapbidi.h"
|
|
|
|
#include "shapetable.h"
|
|
|
|
#include "trainingsample.h"
|
|
|
|
#include "trainingsampleset.h"
|
|
|
|
|
|
|
|
namespace tesseract {
|
|
|
|
|
|
|
|
// ================== SampleIterator Implementation =================
|
|
|
|
|
|
|
|
SampleIterator::SampleIterator()
|
2016-12-13 00:21:24 +08:00
|
|
|
: charset_map_(nullptr),
|
|
|
|
shape_table_(nullptr),
|
|
|
|
sample_set_(nullptr),
|
2012-02-02 10:57:42 +08:00
|
|
|
randomize_(false),
|
2016-12-13 00:21:24 +08:00
|
|
|
owned_shape_table_(nullptr) {
|
2012-02-02 10:57:42 +08:00
|
|
|
num_shapes_ = 0;
|
|
|
|
Begin();
|
|
|
|
}
|
|
|
|
|
|
|
|
SampleIterator::~SampleIterator() {
|
|
|
|
Clear();
|
|
|
|
}
|
|
|
|
|
|
|
|
void SampleIterator::Clear() {
|
|
|
|
delete owned_shape_table_;
|
2016-12-13 00:21:24 +08:00
|
|
|
owned_shape_table_ = nullptr;
|
2012-02-02 10:57:42 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// See class comment for arguments.
|
|
|
|
void SampleIterator::Init(const IndexMapBiDi* charset_map,
|
|
|
|
const ShapeTable* shape_table,
|
|
|
|
bool randomize,
|
|
|
|
TrainingSampleSet* sample_set) {
|
|
|
|
Clear();
|
|
|
|
charset_map_ = charset_map;
|
|
|
|
shape_table_ = shape_table;
|
|
|
|
sample_set_ = sample_set;
|
|
|
|
randomize_ = randomize;
|
2016-12-13 00:21:24 +08:00
|
|
|
if (shape_table_ == nullptr && charset_map_ != nullptr) {
|
2012-02-02 10:57:42 +08:00
|
|
|
// The caller wishes to iterate by class. The easiest way to do this
|
|
|
|
// is to create a dummy shape_table_ that we will own.
|
|
|
|
int num_fonts = sample_set_->NumFonts();
|
|
|
|
owned_shape_table_ = new ShapeTable(sample_set_->unicharset());
|
|
|
|
int charsetsize = sample_set_->unicharset().size();
|
|
|
|
for (int c = 0; c < charsetsize; ++c) {
|
|
|
|
// We always add a shape for each character to keep the index in sync
|
|
|
|
// with the unichar_id.
|
|
|
|
int shape_id = owned_shape_table_->AddShape(c, 0);
|
|
|
|
for (int f = 1; f < num_fonts; ++f) {
|
|
|
|
if (sample_set_->NumClassSamples(f, c, true) > 0) {
|
|
|
|
owned_shape_table_->AddToShape(shape_id, c, f);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
shape_table_ = owned_shape_table_;
|
|
|
|
}
|
2016-12-13 00:21:24 +08:00
|
|
|
if (shape_table_ != nullptr) {
|
2012-02-02 10:57:42 +08:00
|
|
|
num_shapes_ = shape_table_->NumShapes();
|
|
|
|
} else {
|
|
|
|
num_shapes_ = randomize ? sample_set_->num_samples()
|
|
|
|
: sample_set_->num_raw_samples();
|
|
|
|
}
|
|
|
|
Begin();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Iterator functions designed for use with a simple for loop:
|
|
|
|
// for (it.Begin(); !it.AtEnd(); it.Next()) {
|
|
|
|
// const TrainingSample& sample = it.GetSample();
|
|
|
|
// }
|
|
|
|
void SampleIterator::Begin() {
|
|
|
|
shape_index_ = -1;
|
|
|
|
shape_char_index_ = 0;
|
|
|
|
num_shape_chars_ = 0;
|
|
|
|
shape_font_index_ = 0;
|
|
|
|
num_shape_fonts_ = 0;
|
|
|
|
sample_index_ = 0;
|
|
|
|
num_samples_ = 0;
|
|
|
|
// Find the first indexable sample.
|
|
|
|
Next();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool SampleIterator::AtEnd() const {
|
|
|
|
return shape_index_ >= num_shapes_;
|
|
|
|
}
|
|
|
|
|
|
|
|
const TrainingSample& SampleIterator::GetSample() const {
|
2016-12-13 00:21:24 +08:00
|
|
|
if (shape_table_ != nullptr) {
|
2012-02-02 10:57:42 +08:00
|
|
|
const UnicharAndFonts* shape_entry = GetShapeEntry();
|
|
|
|
int char_id = shape_entry->unichar_id;
|
|
|
|
int font_id = shape_entry->font_ids[shape_font_index_];
|
|
|
|
return *sample_set_->GetSample(font_id, char_id, sample_index_);
|
|
|
|
} else {
|
|
|
|
return *sample_set_->GetSample(shape_index_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
TrainingSample* SampleIterator::MutableSample() const {
|
2016-12-13 00:21:24 +08:00
|
|
|
if (shape_table_ != nullptr) {
|
2012-02-02 10:57:42 +08:00
|
|
|
const UnicharAndFonts* shape_entry = GetShapeEntry();
|
|
|
|
int char_id = shape_entry->unichar_id;
|
|
|
|
int font_id = shape_entry->font_ids[shape_font_index_];
|
|
|
|
return sample_set_->MutableSample(font_id, char_id, sample_index_);
|
|
|
|
} else {
|
|
|
|
return sample_set_->mutable_sample(shape_index_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Returns the total index (from the original set of samples) of the current
|
|
|
|
// sample.
|
|
|
|
int SampleIterator::GlobalSampleIndex() const {
|
2016-12-13 00:21:24 +08:00
|
|
|
if (shape_table_ != nullptr) {
|
2012-02-02 10:57:42 +08:00
|
|
|
const UnicharAndFonts* shape_entry = GetShapeEntry();
|
|
|
|
int char_id = shape_entry->unichar_id;
|
|
|
|
int font_id = shape_entry->font_ids[shape_font_index_];
|
|
|
|
return sample_set_->GlobalSampleIndex(font_id, char_id, sample_index_);
|
|
|
|
} else {
|
|
|
|
return shape_index_;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Returns the index of the current sample in compact charset space, so
|
|
|
|
// in a 2-class problem between x and y, the returned indices will all be
|
|
|
|
// 0 or 1, and have nothing to do with the unichar_ids.
|
2016-12-13 00:21:24 +08:00
|
|
|
// If the charset_map_ is nullptr, then this is equal to GetSparseClassID().
|
2012-02-02 10:57:42 +08:00
|
|
|
int SampleIterator::GetCompactClassID() const {
|
2016-12-13 00:21:24 +08:00
|
|
|
return charset_map_ != nullptr ? charset_map_->SparseToCompact(shape_index_)
|
2012-02-02 10:57:42 +08:00
|
|
|
: GetSparseClassID();
|
|
|
|
}
|
|
|
|
// Returns the index of the current sample in sparse charset space, so
|
|
|
|
// in a 2-class problem between x and y, the returned indices will all be
|
|
|
|
// x or y, where x and y may be unichar_ids (no shape_table_) or shape_ids
|
|
|
|
// with a shape_table_.
|
|
|
|
int SampleIterator::GetSparseClassID() const {
|
2016-12-13 00:21:24 +08:00
|
|
|
return shape_table_ != nullptr ? shape_index_ : GetSample().class_id();
|
2012-02-02 10:57:42 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Moves on to the next indexable sample. If the end is reached, leaves
|
|
|
|
// the state such that AtEnd() is true.
|
|
|
|
void SampleIterator::Next() {
|
2016-12-13 00:21:24 +08:00
|
|
|
if (shape_table_ != nullptr) {
|
2012-02-02 10:57:42 +08:00
|
|
|
// Next sample in this class/font combination.
|
|
|
|
++sample_index_;
|
|
|
|
if (sample_index_ < num_samples_)
|
|
|
|
return;
|
|
|
|
// Next font in this class in this shape.
|
|
|
|
sample_index_ = 0;
|
|
|
|
do {
|
|
|
|
++shape_font_index_;
|
|
|
|
if (shape_font_index_ >= num_shape_fonts_) {
|
|
|
|
// Next unichar in this shape.
|
|
|
|
shape_font_index_ = 0;
|
|
|
|
++shape_char_index_;
|
|
|
|
if (shape_char_index_ >= num_shape_chars_) {
|
|
|
|
// Find the next shape that is mapped in the charset_map_.
|
|
|
|
shape_char_index_ = 0;
|
|
|
|
do {
|
|
|
|
++shape_index_;
|
|
|
|
} while (shape_index_ < num_shapes_ &&
|
2016-12-13 00:21:24 +08:00
|
|
|
charset_map_ != nullptr &&
|
2012-02-02 10:57:42 +08:00
|
|
|
charset_map_->SparseToCompact(shape_index_) < 0);
|
|
|
|
if (shape_index_ >= num_shapes_)
|
|
|
|
return; // The end.
|
|
|
|
num_shape_chars_ = shape_table_->GetShape(shape_index_).size();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
const UnicharAndFonts* shape_entry = GetShapeEntry();
|
|
|
|
num_shape_fonts_ = shape_entry->font_ids.size();
|
|
|
|
int char_id = shape_entry->unichar_id;
|
|
|
|
int font_id = shape_entry->font_ids[shape_font_index_];
|
|
|
|
num_samples_ = sample_set_->NumClassSamples(font_id, char_id, randomize_);
|
|
|
|
} while (num_samples_ == 0);
|
|
|
|
} else {
|
|
|
|
// We are just iterating over the samples.
|
|
|
|
++shape_index_;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Returns the size of the compact charset space.
|
|
|
|
int SampleIterator::CompactCharsetSize() const {
|
2016-12-13 00:21:24 +08:00
|
|
|
return charset_map_ != nullptr ? charset_map_->CompactSize()
|
2012-02-02 10:57:42 +08:00
|
|
|
: SparseCharsetSize();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Returns the size of the sparse charset space.
|
|
|
|
int SampleIterator::SparseCharsetSize() const {
|
2016-12-13 00:21:24 +08:00
|
|
|
return charset_map_ != nullptr
|
2012-02-02 10:57:42 +08:00
|
|
|
? charset_map_->SparseSize()
|
2016-12-13 00:21:24 +08:00
|
|
|
: (shape_table_ != nullptr ? shape_table_->NumShapes()
|
2012-02-02 10:57:42 +08:00
|
|
|
: sample_set_->charsetsize());
|
|
|
|
}
|
|
|
|
|
|
|
|
// Apply the supplied feature_space/feature_map transform to all samples
|
|
|
|
// accessed by this iterator.
|
|
|
|
void SampleIterator::MapSampleFeatures(const IntFeatureMap& feature_map) {
|
|
|
|
for (Begin(); !AtEnd(); Next()) {
|
|
|
|
TrainingSample* sample = MutableSample();
|
|
|
|
sample->MapFeatures(feature_map);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Adjust the weights of all the samples to be uniform in the given charset.
|
|
|
|
// Returns the number of samples in the iterator.
|
|
|
|
int SampleIterator::UniformSamples() {
|
|
|
|
int num_good_samples = 0;
|
|
|
|
for (Begin(); !AtEnd(); Next()) {
|
|
|
|
TrainingSample* sample = MutableSample();
|
|
|
|
sample->set_weight(1.0);
|
|
|
|
++num_good_samples;
|
|
|
|
}
|
|
|
|
NormalizeSamples();
|
|
|
|
return num_good_samples;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Normalize the weights of all the samples in the charset_map so they sum
|
|
|
|
// to 1. Returns the minimum assigned sample weight.
|
|
|
|
double SampleIterator::NormalizeSamples() {
|
|
|
|
double total_weight = 0.0;
|
|
|
|
int sample_count = 0;
|
|
|
|
for (Begin(); !AtEnd(); Next()) {
|
|
|
|
const TrainingSample& sample = GetSample();
|
|
|
|
total_weight += sample.weight();
|
|
|
|
++sample_count;
|
|
|
|
}
|
|
|
|
// Normalize samples.
|
|
|
|
double min_assigned_sample_weight = 1.0;
|
|
|
|
if (total_weight > 0.0) {
|
|
|
|
for (Begin(); !AtEnd(); Next()) {
|
|
|
|
TrainingSample* sample = MutableSample();
|
|
|
|
double weight = sample->weight() / total_weight;
|
|
|
|
if (weight < min_assigned_sample_weight)
|
|
|
|
min_assigned_sample_weight = weight;
|
|
|
|
sample->set_weight(weight);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return min_assigned_sample_weight;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Helper returns the current UnicharAndFont shape_entry.
|
|
|
|
const UnicharAndFonts* SampleIterator::GetShapeEntry() const {
|
|
|
|
const Shape& shape = shape_table_->GetShape(shape_index_);
|
|
|
|
return &shape[shape_char_index_];
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace tesseract.
|
|
|
|
|