Better fix for re-enabling training

This commit is contained in:
Ray Smith 2017-05-08 14:26:09 -07:00
parent 0afd5939b1
commit b86b4fa06b
2 changed files with 2 additions and 2 deletions

View File

@ -65,7 +65,7 @@ void FullyConnected::SetEnableTraining(TrainingState state) {
// Temp disable only from enabled.
if (training_ == TS_ENABLED) training_ = state;
} else {
if (state == TS_ENABLED && training_ == TS_DISABLED)
if (state == TS_ENABLED && training_ != TS_ENABLED)
weights_.InitBackward();
training_ = state;
}

View File

@ -113,7 +113,7 @@ void LSTM::SetEnableTraining(TrainingState state) {
// Temp disable only from enabled.
if (training_ == TS_ENABLED) training_ = state;
} else {
if (state == TS_ENABLED && training_ == TS_DISABLED) {
if (state == TS_ENABLED && training_ != TS_ENABLED) {
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].InitBackward();