mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2025-01-18 06:30:14 +08:00
Corrected SetEnableTraining for recovery from a recognize-only model.
This commit is contained in:
parent
006a56c55a
commit
4fa463cd71
@ -56,13 +56,17 @@ StaticShape FullyConnected::OutputShape(const StaticShape& input_shape) const {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Suspends/Enables training by setting the training_ flag. Serialize and
|
||||
// DeSerialize only operate on the run-time data if state is false.
|
||||
// Suspends/Enables training by setting the training_ flag.
|
||||
void FullyConnected::SetEnableTraining(TrainingState state) {
|
||||
if (state == TS_RE_ENABLE) {
|
||||
if (training_ == TS_DISABLED) weights_.InitBackward(false);
|
||||
training_ = TS_ENABLED;
|
||||
// Enable only from temp disabled.
|
||||
if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED;
|
||||
} else if (state == TS_TEMP_DISABLE) {
|
||||
// Temp disable only from enabled.
|
||||
if (training_ == TS_ENABLED) training_ = state;
|
||||
} else {
|
||||
if (state == TS_ENABLED && training_ == TS_DISABLED)
|
||||
weights_.InitBackward();
|
||||
training_ = state;
|
||||
}
|
||||
}
|
||||
|
@ -107,14 +107,18 @@ StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
|
||||
// DeSerialize only operate on the run-time data if state is false.
|
||||
void LSTM::SetEnableTraining(TrainingState state) {
|
||||
if (state == TS_RE_ENABLE) {
|
||||
if (training_ == TS_DISABLED) {
|
||||
// Enable only from temp disabled.
|
||||
if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED;
|
||||
} else if (state == TS_TEMP_DISABLE) {
|
||||
// Temp disable only from enabled.
|
||||
if (training_ == TS_ENABLED) training_ = state;
|
||||
} else {
|
||||
if (state == TS_ENABLED && training_ == TS_DISABLED) {
|
||||
for (int w = 0; w < WT_COUNT; ++w) {
|
||||
if (w == GFS && !Is2D()) continue;
|
||||
gate_weights_[w].InitBackward(false);
|
||||
gate_weights_[w].InitBackward();
|
||||
}
|
||||
}
|
||||
training_ = TS_ENABLED;
|
||||
} else {
|
||||
training_ = state;
|
||||
}
|
||||
if (softmax_ != NULL) softmax_->SetEnableTraining(state);
|
||||
|
@ -111,7 +111,11 @@ Network::~Network() {
|
||||
// recognizer can be converted back to a trainer.
|
||||
void Network::SetEnableTraining(TrainingState state) {
|
||||
if (state == TS_RE_ENABLE) {
|
||||
training_ = TS_ENABLED;
|
||||
// Enable only from temp disabled.
|
||||
if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED;
|
||||
} else if (state == TS_TEMP_DISABLE) {
|
||||
// Temp disable only from enabled.
|
||||
if (training_ == TS_ENABLED) training_ = state;
|
||||
} else {
|
||||
training_ = state;
|
||||
}
|
||||
|
@ -93,9 +93,10 @@ enum TrainingState {
|
||||
// Valid states of training_.
|
||||
TS_DISABLED, // Disabled permanently.
|
||||
TS_ENABLED, // Enabled for backprop and to write a training dump.
|
||||
// Re-enable from ANY disabled state.
|
||||
TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump.
|
||||
// Valid only for SetEnableTraining.
|
||||
TS_RE_ENABLE, // Re-Enable whatever the current state.
|
||||
TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED.
|
||||
};
|
||||
|
||||
// Base class for network types. Not quite an abstract base class, but almost.
|
||||
|
@ -47,7 +47,8 @@ int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad,
|
||||
}
|
||||
}
|
||||
}
|
||||
InitBackward(ada_grad);
|
||||
use_ada_grad_ = ada_grad;
|
||||
InitBackward();
|
||||
return ni * no;
|
||||
}
|
||||
|
||||
@ -83,10 +84,9 @@ void WeightMatrix::ConvertToInt() {
|
||||
|
||||
// Allocates any needed memory for running Backward, and zeroes the deltas,
|
||||
// thus eliminating any existing momentum.
|
||||
void WeightMatrix::InitBackward(bool ada_grad) {
|
||||
void WeightMatrix::InitBackward() {
|
||||
int no = int_mode_ ? wi_.dim1() : wf_.dim1();
|
||||
int ni = int_mode_ ? wi_.dim2() : wf_.dim2();
|
||||
use_ada_grad_ = ada_grad;
|
||||
dw_.Resize(no, ni, 0.0);
|
||||
updates_.Resize(no, ni, 0.0);
|
||||
wf_t_.Transpose(wf_);
|
||||
@ -134,7 +134,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
|
||||
} else {
|
||||
if (!wf_.DeSerialize(fp)) return false;
|
||||
if (training) {
|
||||
InitBackward(use_ada_grad_);
|
||||
InitBackward();
|
||||
if (!updates_.DeSerialize(fp)) return false;
|
||||
if (use_ada_grad_ && !dw_sq_sum_.DeSerialize(fp)) return false;
|
||||
}
|
||||
@ -157,7 +157,7 @@ bool WeightMatrix::DeSerializeOld(bool training, TFile* fp) {
|
||||
FloatToDouble(float_array, &wf_);
|
||||
}
|
||||
if (training) {
|
||||
InitBackward(use_ada_grad_);
|
||||
InitBackward();
|
||||
if (!float_array.DeSerialize(fp)) return false;
|
||||
FloatToDouble(float_array, &updates_);
|
||||
// Errs was only used in int training, which is now dead.
|
||||
|
@ -92,7 +92,7 @@ class WeightMatrix {
|
||||
|
||||
// Allocates any needed memory for running Backward, and zeroes the deltas,
|
||||
// thus eliminating any existing momentum.
|
||||
void InitBackward(bool ada_grad);
|
||||
void InitBackward();
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
bool Serialize(bool training, TFile* fp) const;
|
||||
|
Loading…
Reference in New Issue
Block a user