mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2024-11-23 18:49:08 +08:00
Improve format of logging from lstmtraining
- always use C ("classic") locale - limit output of floating point values to 3 digits - remove unneeded linefeed after log message "wrote checkpoint" Signed-off-by: Stefan Weil <sw@weilnetz.de>
This commit is contained in:
parent
ed69e574a9
commit
0f56340151
@ -16,6 +16,7 @@
|
|||||||
///////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#include <cerrno>
|
#include <cerrno>
|
||||||
|
#include <locale> // for std::locale::classic
|
||||||
#if defined(__USE_GNU)
|
#if defined(__USE_GNU)
|
||||||
# include <cfenv> // for feenableexcept
|
# include <cfenv> // for feenableexcept
|
||||||
#endif
|
#endif
|
||||||
@ -222,9 +223,10 @@ int main(int argc, char **argv) {
|
|||||||
iteration = trainer.training_iteration()) {
|
iteration = trainer.training_iteration()) {
|
||||||
trainer.TrainOnLine(&trainer, false);
|
trainer.TrainOnLine(&trainer, false);
|
||||||
}
|
}
|
||||||
std::string log_str;
|
std::stringstream log_str;
|
||||||
|
log_str.imbue(std::locale::classic());
|
||||||
trainer.MaintainCheckpoints(tester_callback, log_str);
|
trainer.MaintainCheckpoints(tester_callback, log_str);
|
||||||
tprintf("%s\n", log_str.c_str());
|
tprintf("%s\n", log_str.str().c_str());
|
||||||
} while (trainer.best_error_rate() > FLAGS_target_error_rate &&
|
} while (trainer.best_error_rate() > FLAGS_target_error_rate &&
|
||||||
(trainer.training_iteration() < max_iterations));
|
(trainer.training_iteration() < max_iterations));
|
||||||
tprintf("Finished! Selected model with minimal training error rate (BCER) = %g\n",
|
tprintf("Finished! Selected model with minimal training error rate (BCER) = %g\n",
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
///////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#include "lstmtester.h"
|
#include "lstmtester.h"
|
||||||
|
#include <iomanip> // for std::setprecision
|
||||||
#include <thread> // for std::thread
|
#include <thread> // for std::thread
|
||||||
#include "fileio.h" // for LoadFileLinesToStrings
|
#include "fileio.h" // for LoadFileLinesToStrings
|
||||||
|
|
||||||
@ -115,14 +116,15 @@ std::string LSTMTester::RunEvalSync(int iteration, const double *training_errors
|
|||||||
}
|
}
|
||||||
char_error *= 100.0 / total_pages_;
|
char_error *= 100.0 / total_pages_;
|
||||||
word_error *= 100.0 / total_pages_;
|
word_error *= 100.0 / total_pages_;
|
||||||
std::string result;
|
std::stringstream result;
|
||||||
|
result.imbue(std::locale::classic());
|
||||||
|
result << std::fixed << std::setprecision(3);
|
||||||
if (iteration != 0 || training_stage != 0) {
|
if (iteration != 0 || training_stage != 0) {
|
||||||
result += "At iteration " + std::to_string(iteration);
|
result << "At iteration " << iteration
|
||||||
result += ", stage " + std::to_string(training_stage) + ", ";
|
<< ", stage " << training_stage << ", ";
|
||||||
}
|
}
|
||||||
result += "BCER eval=" + std::to_string(char_error);
|
result << "BCER eval=" << char_error << ", BWER eval=" << word_error;
|
||||||
result += ", BWER eval=" + std::to_string(word_error);
|
return result.str();
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper thread function for RunEvalAsync.
|
// Helper thread function for RunEvalAsync.
|
||||||
|
@ -23,6 +23,8 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <iomanip> // for std::setprecision
|
||||||
|
#include <locale> // for std::locale::classic
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "lstmtrainer.h"
|
#include "lstmtrainer.h"
|
||||||
|
|
||||||
@ -305,7 +307,7 @@ bool LSTMTrainer::LoadAllTrainingData(const std::vector<std::string> &filenames,
|
|||||||
// Writes checkpoints at appropriate times and builds and returns a log message
|
// Writes checkpoints at appropriate times and builds and returns a log message
|
||||||
// to indicate progress. Returns false if nothing interesting happened.
|
// to indicate progress. Returns false if nothing interesting happened.
|
||||||
bool LSTMTrainer::MaintainCheckpoints(const TestCallback &tester,
|
bool LSTMTrainer::MaintainCheckpoints(const TestCallback &tester,
|
||||||
std::string &log_msg) {
|
std::stringstream &log_msg) {
|
||||||
PrepareLogMsg(log_msg);
|
PrepareLogMsg(log_msg);
|
||||||
double error_rate = CharError();
|
double error_rate = CharError();
|
||||||
int iteration = learning_iteration();
|
int iteration = learning_iteration();
|
||||||
@ -330,35 +332,34 @@ bool LSTMTrainer::MaintainCheckpoints(const TestCallback &tester,
|
|||||||
std::vector<char> rec_model_data;
|
std::vector<char> rec_model_data;
|
||||||
if (error_rate < best_error_rate_) {
|
if (error_rate < best_error_rate_) {
|
||||||
SaveRecognitionDump(&rec_model_data);
|
SaveRecognitionDump(&rec_model_data);
|
||||||
log_msg += " New best BCER = " + std::to_string(error_rate);
|
log_msg << " New best BCER = " << error_rate;
|
||||||
log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
|
log_msg << UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
|
||||||
// If sub_trainer_ is not nullptr, either *this beat it to a new best, or it
|
// If sub_trainer_ is not nullptr, either *this beat it to a new best, or it
|
||||||
// just overwrote *this. In either case, we have finished with it.
|
// just overwrote *this. In either case, we have finished with it.
|
||||||
sub_trainer_.reset();
|
sub_trainer_.reset();
|
||||||
stall_iteration_ = learning_iteration() + kMinStallIterations;
|
stall_iteration_ = learning_iteration() + kMinStallIterations;
|
||||||
if (TransitionTrainingStage(kStageTransitionThreshold)) {
|
if (TransitionTrainingStage(kStageTransitionThreshold)) {
|
||||||
log_msg +=
|
log_msg << " Transitioned to stage " << CurrentTrainingStage();
|
||||||
" Transitioned to stage " + std::to_string(CurrentTrainingStage());
|
|
||||||
}
|
}
|
||||||
SaveTrainingDump(NO_BEST_TRAINER, *this, &best_trainer_);
|
SaveTrainingDump(NO_BEST_TRAINER, *this, &best_trainer_);
|
||||||
if (error_rate < error_rate_of_last_saved_best_ * kBestCheckpointFraction) {
|
if (error_rate < error_rate_of_last_saved_best_ * kBestCheckpointFraction) {
|
||||||
std::string best_model_name = DumpFilename();
|
std::string best_model_name = DumpFilename();
|
||||||
if (!SaveDataToFile(best_trainer_, best_model_name.c_str())) {
|
if (!SaveDataToFile(best_trainer_, best_model_name.c_str())) {
|
||||||
log_msg += " failed to write best model:";
|
log_msg << " failed to write best model:";
|
||||||
} else {
|
} else {
|
||||||
log_msg += " wrote best model:";
|
log_msg << " wrote best model:";
|
||||||
error_rate_of_last_saved_best_ = best_error_rate_;
|
error_rate_of_last_saved_best_ = best_error_rate_;
|
||||||
}
|
}
|
||||||
log_msg += best_model_name;
|
log_msg << best_model_name;
|
||||||
}
|
}
|
||||||
} else if (error_rate > worst_error_rate_) {
|
} else if (error_rate > worst_error_rate_) {
|
||||||
SaveRecognitionDump(&rec_model_data);
|
SaveRecognitionDump(&rec_model_data);
|
||||||
log_msg += " New worst BCER = " + std::to_string(error_rate);
|
log_msg << " New worst BCER = " << error_rate;
|
||||||
log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
|
log_msg << UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
|
||||||
if (worst_error_rate_ > best_error_rate_ + kMinDivergenceRate &&
|
if (worst_error_rate_ > best_error_rate_ + kMinDivergenceRate &&
|
||||||
best_error_rate_ < kMinStartedErrorRate && !best_trainer_.empty()) {
|
best_error_rate_ < kMinStartedErrorRate && !best_trainer_.empty()) {
|
||||||
// Error rate has ballooned. Go back to the best model.
|
// Error rate has ballooned. Go back to the best model.
|
||||||
log_msg += "\nDivergence! ";
|
log_msg << "\nDivergence! ";
|
||||||
// Copy best_trainer_ before reading it, as it will get overwritten.
|
// Copy best_trainer_ before reading it, as it will get overwritten.
|
||||||
std::vector<char> revert_data(best_trainer_);
|
std::vector<char> revert_data(best_trainer_);
|
||||||
if (ReadTrainingDump(revert_data, *this)) {
|
if (ReadTrainingDump(revert_data, *this)) {
|
||||||
@ -382,34 +383,33 @@ bool LSTMTrainer::MaintainCheckpoints(const TestCallback &tester,
|
|||||||
std::vector<char> checkpoint;
|
std::vector<char> checkpoint;
|
||||||
if (!SaveTrainingDump(FULL, *this, &checkpoint) ||
|
if (!SaveTrainingDump(FULL, *this, &checkpoint) ||
|
||||||
!SaveDataToFile(checkpoint, checkpoint_name_.c_str())) {
|
!SaveDataToFile(checkpoint, checkpoint_name_.c_str())) {
|
||||||
log_msg += " failed to write checkpoint.";
|
log_msg << " failed to write checkpoint.";
|
||||||
} else {
|
} else {
|
||||||
log_msg += " wrote checkpoint.";
|
log_msg << " wrote checkpoint.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log_msg += "\n";
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Builds a string containing a progress message with current error rates.
|
// Builds a string containing a progress message with current error rates.
|
||||||
void LSTMTrainer::PrepareLogMsg(std::string &log_msg) const {
|
void LSTMTrainer::PrepareLogMsg(std::stringstream &log_msg) const {
|
||||||
LogIterations("At", log_msg);
|
LogIterations("At", log_msg);
|
||||||
log_msg += ", Mean rms=" + std::to_string(error_rates_[ET_RMS]);
|
log_msg << std::fixed << std::setprecision(3)
|
||||||
log_msg += "%, delta=" + std::to_string(error_rates_[ET_DELTA]);
|
<< ", mean rms=" << error_rates_[ET_RMS]
|
||||||
log_msg += "%, BCER train=" + std::to_string(error_rates_[ET_CHAR_ERROR]);
|
<< "%, delta=" << error_rates_[ET_DELTA]
|
||||||
log_msg += "%, BWER train=" + std::to_string(error_rates_[ET_WORD_RECERR]);
|
<< "%, BCER train=" << error_rates_[ET_CHAR_ERROR]
|
||||||
log_msg += "%, skip ratio=" + std::to_string(error_rates_[ET_SKIP_RATIO]);
|
<< "%, BWER train=" << error_rates_[ET_WORD_RECERR]
|
||||||
log_msg += "%, ";
|
<< "%, skip ratio=" << error_rates_[ET_SKIP_RATIO] << "%,";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Appends <intro_str> iteration learning_iteration()/training_iteration()/
|
// Appends <intro_str> iteration learning_iteration()/training_iteration()/
|
||||||
// sample_iteration() to the log_msg.
|
// sample_iteration() to the log_msg.
|
||||||
void LSTMTrainer::LogIterations(const char *intro_str,
|
void LSTMTrainer::LogIterations(const char *intro_str,
|
||||||
std::string &log_msg) const {
|
std::stringstream &log_msg) const {
|
||||||
log_msg += intro_str;
|
log_msg << intro_str
|
||||||
log_msg += " iteration " + std::to_string(learning_iteration());
|
<< " iteration " << learning_iteration()
|
||||||
log_msg += "/" + std::to_string(training_iteration());
|
<< "/" << training_iteration()
|
||||||
log_msg += "/" + std::to_string(sample_iteration());
|
<< "/" << sample_iteration();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true and increments the training_stage_ if the error rate has just
|
// Returns true and increments the training_stage_ if the error rate has just
|
||||||
@ -602,14 +602,14 @@ bool LSTMTrainer::DeSerialize(const TessdataManager *mgr, TFile *fp) {
|
|||||||
// De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
|
// De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
|
||||||
// learning rates (by scaling reduction, or layer specific, according to
|
// learning rates (by scaling reduction, or layer specific, according to
|
||||||
// NF_LAYER_SPECIFIC_LR).
|
// NF_LAYER_SPECIFIC_LR).
|
||||||
void LSTMTrainer::StartSubtrainer(std::string &log_msg) {
|
void LSTMTrainer::StartSubtrainer(std::stringstream &log_msg) {
|
||||||
sub_trainer_ = std::make_unique<LSTMTrainer>();
|
sub_trainer_ = std::make_unique<LSTMTrainer>();
|
||||||
if (!ReadTrainingDump(best_trainer_, *sub_trainer_)) {
|
if (!ReadTrainingDump(best_trainer_, *sub_trainer_)) {
|
||||||
log_msg += " Failed to revert to previous best for trial!";
|
log_msg << " Failed to revert to previous best for trial!";
|
||||||
sub_trainer_.reset();
|
sub_trainer_.reset();
|
||||||
} else {
|
} else {
|
||||||
log_msg += " Trial sub_trainer_ from iteration " +
|
log_msg << " Trial sub_trainer_ from iteration "
|
||||||
std::to_string(sub_trainer_->training_iteration());
|
<< sub_trainer_->training_iteration();
|
||||||
// Reduce learning rate so it doesn't diverge this time.
|
// Reduce learning rate so it doesn't diverge this time.
|
||||||
sub_trainer_->ReduceLearningRates(this, log_msg);
|
sub_trainer_->ReduceLearningRates(this, log_msg);
|
||||||
// If it fails again, we will wait twice as long before reverting again.
|
// If it fails again, we will wait twice as long before reverting again.
|
||||||
@ -630,14 +630,13 @@ void LSTMTrainer::StartSubtrainer(std::string &log_msg) {
|
|||||||
// trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
|
// trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
|
||||||
// returned. STR_NONE is returned if the subtrainer wasn't good enough to
|
// returned. STR_NONE is returned if the subtrainer wasn't good enough to
|
||||||
// receive any training iterations.
|
// receive any training iterations.
|
||||||
SubTrainerResult LSTMTrainer::UpdateSubtrainer(std::string &log_msg) {
|
SubTrainerResult LSTMTrainer::UpdateSubtrainer(std::stringstream &log_msg) {
|
||||||
double training_error = CharError();
|
double training_error = CharError();
|
||||||
double sub_error = sub_trainer_->CharError();
|
double sub_error = sub_trainer_->CharError();
|
||||||
double sub_margin = (training_error - sub_error) / sub_error;
|
double sub_margin = (training_error - sub_error) / sub_error;
|
||||||
if (sub_margin >= kSubTrainerMarginFraction) {
|
if (sub_margin >= kSubTrainerMarginFraction) {
|
||||||
log_msg += " sub_trainer=" + std::to_string(sub_error);
|
log_msg << " sub_trainer=" << sub_error
|
||||||
log_msg += " margin=" + std::to_string(100.0 * sub_margin);
|
<< " margin=" << 100.0 * sub_margin << "\n";
|
||||||
log_msg += "\n";
|
|
||||||
// Catch up to current iteration.
|
// Catch up to current iteration.
|
||||||
int end_iteration = training_iteration();
|
int end_iteration = training_iteration();
|
||||||
while (sub_trainer_->training_iteration() < end_iteration &&
|
while (sub_trainer_->training_iteration() < end_iteration &&
|
||||||
@ -647,11 +646,12 @@ SubTrainerResult LSTMTrainer::UpdateSubtrainer(std::string &log_msg) {
|
|||||||
while (sub_trainer_->training_iteration() < target_iteration) {
|
while (sub_trainer_->training_iteration() < target_iteration) {
|
||||||
sub_trainer_->TrainOnLine(this, false);
|
sub_trainer_->TrainOnLine(this, false);
|
||||||
}
|
}
|
||||||
std::string batch_log = "Sub:";
|
std::stringstream batch_log("Sub:");
|
||||||
|
batch_log.imbue(std::locale::classic());
|
||||||
sub_trainer_->PrepareLogMsg(batch_log);
|
sub_trainer_->PrepareLogMsg(batch_log);
|
||||||
batch_log += "\n";
|
batch_log << "\n";
|
||||||
tprintf("UpdateSubtrainer:%s", batch_log.c_str());
|
tprintf("UpdateSubtrainer:%s", batch_log.str().c_str());
|
||||||
log_msg += batch_log;
|
log_msg << batch_log.str();
|
||||||
sub_error = sub_trainer_->CharError();
|
sub_error = sub_trainer_->CharError();
|
||||||
sub_margin = (training_error - sub_error) / sub_error;
|
sub_margin = (training_error - sub_error) / sub_error;
|
||||||
}
|
}
|
||||||
@ -661,9 +661,8 @@ SubTrainerResult LSTMTrainer::UpdateSubtrainer(std::string &log_msg) {
|
|||||||
std::vector<char> updated_trainer;
|
std::vector<char> updated_trainer;
|
||||||
SaveTrainingDump(LIGHT, *sub_trainer_, &updated_trainer);
|
SaveTrainingDump(LIGHT, *sub_trainer_, &updated_trainer);
|
||||||
ReadTrainingDump(updated_trainer, *this);
|
ReadTrainingDump(updated_trainer, *this);
|
||||||
log_msg += " Sub trainer wins at iteration " +
|
log_msg << " Sub trainer wins at iteration "
|
||||||
std::to_string(training_iteration());
|
<< training_iteration() << "\n";
|
||||||
log_msg += "\n";
|
|
||||||
return STR_REPLACED;
|
return STR_REPLACED;
|
||||||
}
|
}
|
||||||
return STR_UPDATED;
|
return STR_UPDATED;
|
||||||
@ -674,17 +673,16 @@ SubTrainerResult LSTMTrainer::UpdateSubtrainer(std::string &log_msg) {
|
|||||||
// Reduces network learning rates, either for everything, or for layers
|
// Reduces network learning rates, either for everything, or for layers
|
||||||
// independently, according to NF_LAYER_SPECIFIC_LR.
|
// independently, according to NF_LAYER_SPECIFIC_LR.
|
||||||
void LSTMTrainer::ReduceLearningRates(LSTMTrainer *samples_trainer,
|
void LSTMTrainer::ReduceLearningRates(LSTMTrainer *samples_trainer,
|
||||||
std::string &log_msg) {
|
std::stringstream &log_msg) {
|
||||||
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
|
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
|
||||||
int num_reduced = ReduceLayerLearningRates(
|
int num_reduced = ReduceLayerLearningRates(
|
||||||
kLearningRateDecay, kNumAdjustmentIterations, samples_trainer);
|
kLearningRateDecay, kNumAdjustmentIterations, samples_trainer);
|
||||||
log_msg +=
|
log_msg << "\nReduced learning rate on layers: " << num_reduced;
|
||||||
"\nReduced learning rate on layers: " + std::to_string(num_reduced);
|
|
||||||
} else {
|
} else {
|
||||||
ScaleLearningRate(kLearningRateDecay);
|
ScaleLearningRate(kLearningRateDecay);
|
||||||
log_msg += "\nReduced learning rate to :" + std::to_string(learning_rate_);
|
log_msg << "\nReduced learning rate to :" << learning_rate_;
|
||||||
}
|
}
|
||||||
log_msg += "\n";
|
log_msg << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Considers reducing the learning rate independently for each layer down by
|
// Considers reducing the learning rate independently for each layer down by
|
||||||
|
@ -25,6 +25,7 @@
|
|||||||
#include "rect.h"
|
#include "rect.h"
|
||||||
|
|
||||||
#include <functional> // for std::function
|
#include <functional> // for std::function
|
||||||
|
#include <sstream> // for std::stringstream
|
||||||
|
|
||||||
namespace tesseract {
|
namespace tesseract {
|
||||||
|
|
||||||
@ -192,7 +193,7 @@ public:
|
|||||||
|
|
||||||
// Keeps track of best and locally worst error rate, using internally computed
|
// Keeps track of best and locally worst error rate, using internally computed
|
||||||
// values. See MaintainCheckpointsSpecific for more detail.
|
// values. See MaintainCheckpointsSpecific for more detail.
|
||||||
bool MaintainCheckpoints(const TestCallback &tester, std::string &log_msg);
|
bool MaintainCheckpoints(const TestCallback &tester, std::stringstream &log_msg);
|
||||||
// Keeps track of best and locally worst error_rate (whatever it is) and
|
// Keeps track of best and locally worst error_rate (whatever it is) and
|
||||||
// launches tests using rec_model, when a new min or max is reached.
|
// launches tests using rec_model, when a new min or max is reached.
|
||||||
// Writes checkpoints using train_model at appropriate times and builds and
|
// Writes checkpoints using train_model at appropriate times and builds and
|
||||||
@ -201,12 +202,12 @@ public:
|
|||||||
bool MaintainCheckpointsSpecific(int iteration,
|
bool MaintainCheckpointsSpecific(int iteration,
|
||||||
const std::vector<char> *train_model,
|
const std::vector<char> *train_model,
|
||||||
const std::vector<char> *rec_model,
|
const std::vector<char> *rec_model,
|
||||||
TestCallback tester, std::string &log_msg);
|
TestCallback tester, std::stringstream &log_msg);
|
||||||
// Builds a string containing a progress message with current error rates.
|
// Builds a progress message with current error rates.
|
||||||
void PrepareLogMsg(std::string &log_msg) const;
|
void PrepareLogMsg(std::stringstream &log_msg) const;
|
||||||
// Appends <intro_str> iteration learning_iteration()/training_iteration()/
|
// Appends <intro_str> iteration learning_iteration()/training_iteration()/
|
||||||
// sample_iteration() to the log_msg.
|
// sample_iteration() to the log_msg.
|
||||||
void LogIterations(const char *intro_str, std::string &log_msg) const;
|
void LogIterations(const char *intro_str, std::stringstream &log_msg) const;
|
||||||
|
|
||||||
// TODO(rays) Add curriculum learning.
|
// TODO(rays) Add curriculum learning.
|
||||||
// Returns true and increments the training_stage_ if the error rate has just
|
// Returns true and increments the training_stage_ if the error rate has just
|
||||||
@ -226,7 +227,7 @@ public:
|
|||||||
// De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
|
// De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
|
||||||
// learning rates (by scaling reduction, or layer specific, according to
|
// learning rates (by scaling reduction, or layer specific, according to
|
||||||
// NF_LAYER_SPECIFIC_LR).
|
// NF_LAYER_SPECIFIC_LR).
|
||||||
void StartSubtrainer(std::string &log_msg);
|
void StartSubtrainer(std::stringstream &log_msg);
|
||||||
// While the sub_trainer_ is behind the current training iteration and its
|
// While the sub_trainer_ is behind the current training iteration and its
|
||||||
// training error is at least kSubTrainerMarginFraction better than the
|
// training error is at least kSubTrainerMarginFraction better than the
|
||||||
// current training error, trains the sub_trainer_, and returns STR_UPDATED if
|
// current training error, trains the sub_trainer_, and returns STR_UPDATED if
|
||||||
@ -235,10 +236,10 @@ public:
|
|||||||
// trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
|
// trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
|
||||||
// returned. STR_NONE is returned if the subtrainer wasn't good enough to
|
// returned. STR_NONE is returned if the subtrainer wasn't good enough to
|
||||||
// receive any training iterations.
|
// receive any training iterations.
|
||||||
SubTrainerResult UpdateSubtrainer(std::string &log_msg);
|
SubTrainerResult UpdateSubtrainer(std::stringstream &log_msg);
|
||||||
// Reduces network learning rates, either for everything, or for layers
|
// Reduces network learning rates, either for everything, or for layers
|
||||||
// independently, according to NF_LAYER_SPECIFIC_LR.
|
// independently, according to NF_LAYER_SPECIFIC_LR.
|
||||||
void ReduceLearningRates(LSTMTrainer *samples_trainer, std::string &log_msg);
|
void ReduceLearningRates(LSTMTrainer *samples_trainer, std::stringstream &log_msg);
|
||||||
// Considers reducing the learning rate independently for each layer down by
|
// Considers reducing the learning rate independently for each layer down by
|
||||||
// factor(<1), or leaving it the same, by double-training the given number of
|
// factor(<1), or leaving it the same, by double-training the given number of
|
||||||
// samples and minimizing the amount of changing of sign of weight updates.
|
// samples and minimizing the amount of changing of sign of weight updates.
|
||||||
|
@ -103,7 +103,7 @@ protected:
|
|||||||
int iteration_limit = iteration + max_iterations;
|
int iteration_limit = iteration + max_iterations;
|
||||||
double best_error = 100.0;
|
double best_error = 100.0;
|
||||||
do {
|
do {
|
||||||
std::string log_str;
|
std::stringstream log_str;
|
||||||
int target_iteration = iteration + kBatchIterations;
|
int target_iteration = iteration + kBatchIterations;
|
||||||
// Train a few.
|
// Train a few.
|
||||||
double mean_error = 0.0;
|
double mean_error = 0.0;
|
||||||
|
Loading…
Reference in New Issue
Block a user