Corrected SetEnableTraining for recovery from a recognize-only model.

This commit is contained in:
Ray Smith 2017-05-05 16:39:43 -07:00
parent 006a56c55a
commit 4fa463cd71
6 changed files with 29 additions and 16 deletions

View File

@ -56,13 +56,17 @@ StaticShape FullyConnected::OutputShape(const StaticShape& input_shape) const {
return result;
}
// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
// Suspends/Enables training by setting the training_ flag.
void FullyConnected::SetEnableTraining(TrainingState state) {
if (state == TS_RE_ENABLE) {
if (training_ == TS_DISABLED) weights_.InitBackward(false);
training_ = TS_ENABLED;
// Enable only from temp disabled.
if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED;
} else if (state == TS_TEMP_DISABLE) {
// Temp disable only from enabled.
if (training_ == TS_ENABLED) training_ = state;
} else {
if (state == TS_ENABLED && training_ == TS_DISABLED)
weights_.InitBackward();
training_ = state;
}
}

View File

@ -107,14 +107,18 @@ StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
// DeSerialize only operate on the run-time data if state is false.
void LSTM::SetEnableTraining(TrainingState state) {
if (state == TS_RE_ENABLE) {
if (training_ == TS_DISABLED) {
// Enable only from temp disabled.
if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED;
} else if (state == TS_TEMP_DISABLE) {
// Temp disable only from enabled.
if (training_ == TS_ENABLED) training_ = state;
} else {
if (state == TS_ENABLED && training_ == TS_DISABLED) {
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].InitBackward(false);
gate_weights_[w].InitBackward();
}
}
training_ = TS_ENABLED;
} else {
training_ = state;
}
if (softmax_ != NULL) softmax_->SetEnableTraining(state);

View File

@ -111,7 +111,11 @@ Network::~Network() {
// recognizer can be converted back to a trainer.
void Network::SetEnableTraining(TrainingState state) {
if (state == TS_RE_ENABLE) {
training_ = TS_ENABLED;
// Enable only from temp disabled.
if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED;
} else if (state == TS_TEMP_DISABLE) {
// Temp disable only from enabled.
if (training_ == TS_ENABLED) training_ = state;
} else {
training_ = state;
}

View File

@ -93,9 +93,10 @@ enum TrainingState {
// Valid states of training_.
TS_DISABLED, // Disabled permanently.
TS_ENABLED, // Enabled for backprop and to write a training dump.
// Re-enable from ANY disabled state.
TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump.
// Valid only for SetEnableTraining.
TS_RE_ENABLE, // Re-Enable whatever the current state.
TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED.
};
// Base class for network types. Not quite an abstract base class, but almost.

View File

@ -47,7 +47,8 @@ int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad,
}
}
}
InitBackward(ada_grad);
use_ada_grad_ = ada_grad;
InitBackward();
return ni * no;
}
@ -83,10 +84,9 @@ void WeightMatrix::ConvertToInt() {
// Allocates any needed memory for running Backward, and zeroes the deltas,
// thus eliminating any existing momentum.
void WeightMatrix::InitBackward(bool ada_grad) {
void WeightMatrix::InitBackward() {
int no = int_mode_ ? wi_.dim1() : wf_.dim1();
int ni = int_mode_ ? wi_.dim2() : wf_.dim2();
use_ada_grad_ = ada_grad;
dw_.Resize(no, ni, 0.0);
updates_.Resize(no, ni, 0.0);
wf_t_.Transpose(wf_);
@ -134,7 +134,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
} else {
if (!wf_.DeSerialize(fp)) return false;
if (training) {
InitBackward(use_ada_grad_);
InitBackward();
if (!updates_.DeSerialize(fp)) return false;
if (use_ada_grad_ && !dw_sq_sum_.DeSerialize(fp)) return false;
}
@ -157,7 +157,7 @@ bool WeightMatrix::DeSerializeOld(bool training, TFile* fp) {
FloatToDouble(float_array, &wf_);
}
if (training) {
InitBackward(use_ada_grad_);
InitBackward();
if (!float_array.DeSerialize(fp)) return false;
FloatToDouble(float_array, &updates_);
// Errs was only used in int training, which is now dead.

View File

@ -92,7 +92,7 @@ class WeightMatrix {
// Allocates any needed memory for running Backward, and zeroes the deltas,
// thus eliminating any existing momentum.
void InitBackward(bool ada_grad);
void InitBackward();
// Writes to the given file. Returns false in case of error.
bool Serialize(bool training, TFile* fp) const;