Merge pull request #3313 from stweil/learning_rate

Add new checks for floating point errors and fix a division by zero
This commit is contained in:
Egor Pugin 2021-02-27 23:20:09 +03:00 committed by GitHub
commit 838a754d24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 13 deletions

View File

@ -22,6 +22,9 @@
#endif
#include <cerrno> // for errno
#if defined(__USE_GNU)
#include <cfenv> // for feenableexcept
#endif
#include <iostream>
#include <allheaders.h>
@ -617,6 +620,10 @@ static void PreloadRenderers(
**********************************************************************/
int main(int argc, char** argv) {
#if defined(__USE_GNU)
// Raise SIGFPE.
feenableexcept(FE_DIVBYZERO | FE_OVERFLOW | FE_INVALID);
#endif
const char* lang = nullptr;
const char* image = nullptr;
const char* outputbase = nullptr;

View File

@ -332,24 +332,25 @@ void WeightMatrix::SumOuterTransposed(const TransposedArray& u,
// Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the adam computation iff
// use_adam_ is true.
void WeightMatrix::Update(double learning_rate, double momentum,
double adam_beta, int num_samples) {
void WeightMatrix::Update(float learning_rate, float momentum,
float adam_beta, int num_samples) {
assert(!int_mode_);
if (use_adam_ && num_samples > 0 && num_samples < kAdamCorrectionIterations) {
learning_rate *= sqrt(1.0 - pow(adam_beta, num_samples));
learning_rate /= 1.0 - pow(momentum, num_samples);
if (use_adam_ && momentum > 0.0f &&
num_samples > 0 && num_samples < kAdamCorrectionIterations) {
learning_rate *= sqrt(1.0f - pow(adam_beta, num_samples));
learning_rate /= 1.0f - pow(momentum, num_samples);
}
if (use_adam_ && num_samples > 0 && momentum > 0.0) {
if (use_adam_ && num_samples > 0 && momentum > 0.0f) {
dw_sq_sum_.SumSquares(dw_, adam_beta);
dw_ *= learning_rate * (1.0 - momentum);
dw_ *= learning_rate * (1.0f - momentum);
updates_ *= momentum;
updates_ += dw_;
wf_.AdamUpdate(updates_, dw_sq_sum_, learning_rate * kAdamEpsilon);
} else {
dw_ *= learning_rate;
updates_ += dw_;
if (momentum > 0.0) wf_ += updates_;
if (momentum >= 0.0) updates_ *= momentum;
if (momentum > 0.0f) wf_ += updates_;
if (momentum >= 0.0f) updates_ *= momentum;
}
wf_t_.Transpose(wf_);
}

View File

@ -139,7 +139,7 @@ class WeightMatrix {
bool parallel);
// Updates the weights using the given learning rate, momentum and adam_beta.
// num_samples is used in the Adam correction factor.
void Update(double learning_rate, double momentum, double adam_beta,
void Update(float learning_rate, float momentum, float adam_beta,
int num_samples);
// Adds the dw_ in other to the dw_ is *this.
void AddDeltas(const WeightMatrix& other);

View File

@ -16,6 +16,9 @@
///////////////////////////////////////////////////////////////////////
#include <cerrno>
#if defined(__USE_GNU)
#include <cfenv> // for feenableexcept
#endif
#include "commontraining.h"
#include "fileio.h" // for LoadFileLinesToStrings
#include "lstmtester.h"
@ -44,6 +47,10 @@ static STRING_PARAM_FLAG(train_listfile, "",
"File listing training files in lstmf training format.");
static STRING_PARAM_FLAG(eval_listfile, "",
"File listing eval files in lstmf training format.");
#if defined(__USE_GNU)
static BOOL_PARAM_FLAG(debug_float, false,
"Raise error on certain float errors.");
#endif
static BOOL_PARAM_FLAG(stop_training, false,
"Just convert the training model to a runtime model.");
static BOOL_PARAM_FLAG(convert_to_int, false,
@ -73,6 +80,12 @@ const int kNumPagesPerBatch = 100;
int main(int argc, char **argv) {
tesseract::CheckSharedLibraryVersion();
ParseArguments(&argc, &argv);
#if defined(__USE_GNU)
if (FLAGS_debug_float) {
// Raise SIGFPE for unwanted floating point calculations.
feenableexcept(FE_DIVBYZERO | FE_OVERFLOW | FE_INVALID);
}
#endif
if (FLAGS_model_output.empty()) {
tprintf("Must provide a --model_output!\n");
return EXIT_FAILURE;

View File

@ -71,7 +71,7 @@ class LSTMTrainerTest : public testing::Test {
}
void SetupTrainer(const std::string& network_spec, const std::string& model_name,
const std::string& unicharset_file, const std::string& lstmf_file,
bool recode, bool adam, double learning_rate,
bool recode, bool adam, float learning_rate,
bool layer_specific, const std::string& kLang) {
// constexpr char kLang[] = "eng"; // Exact value doesn't matter.
std::string unicharset_name = TestDataNameToPath(unicharset_file);
@ -92,7 +92,7 @@ class LSTMTrainerTest : public testing::Test {
int net_mode = adam ? NF_ADAM : 0;
// Adam needs a higher learning rate, due to not multiplying the effective
// rate by 1/(1-momentum).
if (adam) learning_rate *= 20.0;
if (adam) learning_rate *= 20.0f;
if (layer_specific) net_mode |= NF_LAYER_SPECIFIC_LR;
EXPECT_TRUE(trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1,
learning_rate, 0.9, 0.999));
@ -168,7 +168,7 @@ class LSTMTrainerTest : public testing::Test {
std::string unicharset_name = lang + "/" + lang + ".unicharset";
std::string lstmf_name = lang + ".Arial_Unicode_MS.exp0.lstmf";
SetupTrainer("[1,1,0,32 Lbx100 O1c1]", "bidi-lstm", unicharset_name,
lstmf_name, recode, true, 5e-4, true, lang);
lstmf_name, recode, true, 5e-4f, true, lang);
std::vector<int> labels;
EXPECT_TRUE(trainer_->EncodeString(str.c_str(), &labels));
STRING decoded = trainer_->DecodeLabels(labels);