diff --git a/lstm/lstmtrainer.cpp b/lstm/lstmtrainer.cpp index 036199694..5678c5b87 100644 --- a/lstm/lstmtrainer.cpp +++ b/lstm/lstmtrainer.cpp @@ -918,6 +918,10 @@ bool LSTMTrainer::SaveTrainingDump(SerializeAmount serialize_amount, // Reads previously saved trainer from memory. bool LSTMTrainer::ReadTrainingDump(const GenericVector& data, LSTMTrainer* trainer) { + if (data.size() == 0) { + tprintf("Warning: data size is zero in LSTMTrainer::ReadTrainingDump\n"); + return false; + } return trainer->ReadSizedTrainingDump(&data[0], data.size()); } @@ -1298,8 +1302,9 @@ STRING LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate, if (error_rate < best_error_rate_) { // This is a new (global) minimum. if (tester != NULL) { - result = tester->Run(worst_iteration_, worst_error_rates_, - worst_model_data_, CurrentTrainingStage()); + if (worst_model_data_.size() != 0) + result = tester->Run(worst_iteration_, worst_error_rates_, + worst_model_data_, CurrentTrainingStage()); worst_model_data_.truncate(0); best_model_data_ = model_data; }