mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2025-01-22 09:53:03 +08:00
Merge pull request #3313 from stweil/learning_rate
Add new checks for floating point errors and fix a division by zero
This commit is contained in:
commit
838a754d24
@ -22,6 +22,9 @@
|
||||
#endif
|
||||
|
||||
#include <cerrno> // for errno
|
||||
#if defined(__USE_GNU)
|
||||
#include <cfenv> // for feenableexcept
|
||||
#endif
|
||||
#include <iostream>
|
||||
|
||||
#include <allheaders.h>
|
||||
@ -617,6 +620,10 @@ static void PreloadRenderers(
|
||||
**********************************************************************/
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
#if defined(__USE_GNU)
|
||||
// Raise SIGFPE.
|
||||
feenableexcept(FE_DIVBYZERO | FE_OVERFLOW | FE_INVALID);
|
||||
#endif
|
||||
const char* lang = nullptr;
|
||||
const char* image = nullptr;
|
||||
const char* outputbase = nullptr;
|
||||
|
@ -332,24 +332,25 @@ void WeightMatrix::SumOuterTransposed(const TransposedArray& u,
|
||||
// Updates the weights using the given learning rate and momentum.
|
||||
// num_samples is the quotient to be used in the adam computation iff
|
||||
// use_adam_ is true.
|
||||
void WeightMatrix::Update(double learning_rate, double momentum,
|
||||
double adam_beta, int num_samples) {
|
||||
void WeightMatrix::Update(float learning_rate, float momentum,
|
||||
float adam_beta, int num_samples) {
|
||||
assert(!int_mode_);
|
||||
if (use_adam_ && num_samples > 0 && num_samples < kAdamCorrectionIterations) {
|
||||
learning_rate *= sqrt(1.0 - pow(adam_beta, num_samples));
|
||||
learning_rate /= 1.0 - pow(momentum, num_samples);
|
||||
if (use_adam_ && momentum > 0.0f &&
|
||||
num_samples > 0 && num_samples < kAdamCorrectionIterations) {
|
||||
learning_rate *= sqrt(1.0f - pow(adam_beta, num_samples));
|
||||
learning_rate /= 1.0f - pow(momentum, num_samples);
|
||||
}
|
||||
if (use_adam_ && num_samples > 0 && momentum > 0.0) {
|
||||
if (use_adam_ && num_samples > 0 && momentum > 0.0f) {
|
||||
dw_sq_sum_.SumSquares(dw_, adam_beta);
|
||||
dw_ *= learning_rate * (1.0 - momentum);
|
||||
dw_ *= learning_rate * (1.0f - momentum);
|
||||
updates_ *= momentum;
|
||||
updates_ += dw_;
|
||||
wf_.AdamUpdate(updates_, dw_sq_sum_, learning_rate * kAdamEpsilon);
|
||||
} else {
|
||||
dw_ *= learning_rate;
|
||||
updates_ += dw_;
|
||||
if (momentum > 0.0) wf_ += updates_;
|
||||
if (momentum >= 0.0) updates_ *= momentum;
|
||||
if (momentum > 0.0f) wf_ += updates_;
|
||||
if (momentum >= 0.0f) updates_ *= momentum;
|
||||
}
|
||||
wf_t_.Transpose(wf_);
|
||||
}
|
||||
|
@ -139,7 +139,7 @@ class WeightMatrix {
|
||||
bool parallel);
|
||||
// Updates the weights using the given learning rate, momentum and adam_beta.
|
||||
// num_samples is used in the Adam correction factor.
|
||||
void Update(double learning_rate, double momentum, double adam_beta,
|
||||
void Update(float learning_rate, float momentum, float adam_beta,
|
||||
int num_samples);
|
||||
// Adds the dw_ in other to the dw_ is *this.
|
||||
void AddDeltas(const WeightMatrix& other);
|
||||
|
@ -16,6 +16,9 @@
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <cerrno>
|
||||
#if defined(__USE_GNU)
|
||||
#include <cfenv> // for feenableexcept
|
||||
#endif
|
||||
#include "commontraining.h"
|
||||
#include "fileio.h" // for LoadFileLinesToStrings
|
||||
#include "lstmtester.h"
|
||||
@ -44,6 +47,10 @@ static STRING_PARAM_FLAG(train_listfile, "",
|
||||
"File listing training files in lstmf training format.");
|
||||
static STRING_PARAM_FLAG(eval_listfile, "",
|
||||
"File listing eval files in lstmf training format.");
|
||||
#if defined(__USE_GNU)
|
||||
static BOOL_PARAM_FLAG(debug_float, false,
|
||||
"Raise error on certain float errors.");
|
||||
#endif
|
||||
static BOOL_PARAM_FLAG(stop_training, false,
|
||||
"Just convert the training model to a runtime model.");
|
||||
static BOOL_PARAM_FLAG(convert_to_int, false,
|
||||
@ -73,6 +80,12 @@ const int kNumPagesPerBatch = 100;
|
||||
int main(int argc, char **argv) {
|
||||
tesseract::CheckSharedLibraryVersion();
|
||||
ParseArguments(&argc, &argv);
|
||||
#if defined(__USE_GNU)
|
||||
if (FLAGS_debug_float) {
|
||||
// Raise SIGFPE for unwanted floating point calculations.
|
||||
feenableexcept(FE_DIVBYZERO | FE_OVERFLOW | FE_INVALID);
|
||||
}
|
||||
#endif
|
||||
if (FLAGS_model_output.empty()) {
|
||||
tprintf("Must provide a --model_output!\n");
|
||||
return EXIT_FAILURE;
|
||||
|
@ -71,7 +71,7 @@ class LSTMTrainerTest : public testing::Test {
|
||||
}
|
||||
void SetupTrainer(const std::string& network_spec, const std::string& model_name,
|
||||
const std::string& unicharset_file, const std::string& lstmf_file,
|
||||
bool recode, bool adam, double learning_rate,
|
||||
bool recode, bool adam, float learning_rate,
|
||||
bool layer_specific, const std::string& kLang) {
|
||||
// constexpr char kLang[] = "eng"; // Exact value doesn't matter.
|
||||
std::string unicharset_name = TestDataNameToPath(unicharset_file);
|
||||
@ -92,7 +92,7 @@ class LSTMTrainerTest : public testing::Test {
|
||||
int net_mode = adam ? NF_ADAM : 0;
|
||||
// Adam needs a higher learning rate, due to not multiplying the effective
|
||||
// rate by 1/(1-momentum).
|
||||
if (adam) learning_rate *= 20.0;
|
||||
if (adam) learning_rate *= 20.0f;
|
||||
if (layer_specific) net_mode |= NF_LAYER_SPECIFIC_LR;
|
||||
EXPECT_TRUE(trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1,
|
||||
learning_rate, 0.9, 0.999));
|
||||
@ -168,7 +168,7 @@ class LSTMTrainerTest : public testing::Test {
|
||||
std::string unicharset_name = lang + "/" + lang + ".unicharset";
|
||||
std::string lstmf_name = lang + ".Arial_Unicode_MS.exp0.lstmf";
|
||||
SetupTrainer("[1,1,0,32 Lbx100 O1c1]", "bidi-lstm", unicharset_name,
|
||||
lstmf_name, recode, true, 5e-4, true, lang);
|
||||
lstmf_name, recode, true, 5e-4f, true, lang);
|
||||
std::vector<int> labels;
|
||||
EXPECT_TRUE(trainer_->EncodeString(str.c_str(), &labels));
|
||||
STRING decoded = trainer_->DecodeLabels(labels);
|
||||
|
Loading…
Reference in New Issue
Block a user