diff --git a/src/lstm/convolve.h b/src/lstm/convolve.h index fcf5ccf0..be8beb1d 100644 --- a/src/lstm/convolve.h +++ b/src/lstm/convolve.h @@ -4,7 +4,6 @@ // and pulls in random data to fill out-of-input inputs. // Output is therefore same size as its input, but deeper. // Author: Ray Smith -// Created: Tue Mar 18 16:45:34 PST 2014 // // (C) Copyright 2014, Google Inc. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -61,6 +60,11 @@ class Convolve : public Network { NetworkScratch* scratch, NetworkIO* back_deltas) override; + private: + void DebugWeights() override { + tprintf("Must override Network::DebugWeights for type %d\n", type_); + } + protected: // Serialized data. int32_t half_x_; @@ -69,5 +73,4 @@ class Convolve : public Network { } // namespace tesseract. - #endif // TESSERACT_LSTM_SUBSAMPLE_H_ diff --git a/src/lstm/input.h b/src/lstm/input.h index cec22414..8054208b 100644 --- a/src/lstm/input.h +++ b/src/lstm/input.h @@ -2,7 +2,6 @@ // File: input.h // Description: Input layer class for neural network implementations. // Author: Ray Smith -// Created: Thu Mar 13 08:56:26 PDT 2014 // // (C) Copyright 2014, Google Inc. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -93,6 +92,10 @@ class Input : public Network { TRand* randomizer, NetworkIO* input); private: + void DebugWeights() override { + tprintf("Must override Network::DebugWeights for type %d\n", type_); + } + // Input shape determines how images are dealt with. StaticShape shape_; // Cached total network x scale factor for scaling bounding boxes. diff --git a/src/lstm/network.cpp b/src/lstm/network.cpp index eaa9baff..9500f4b3 100644 --- a/src/lstm/network.cpp +++ b/src/lstm/network.cpp @@ -2,7 +2,6 @@ // File: network.cpp // Description: Base class for neural network implementations. // Author: Ray Smith -// Created: Wed May 01 17:25:06 PST 2013 // // (C) Copyright 2013, Google Inc. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -53,10 +52,11 @@ const int kMaxWinSize = 2000; const int kXWinFrameSize = 30; const int kYWinFrameSize = 80; -// String names corresponding to the NetworkType enum. Keep in sync. +// String names corresponding to the NetworkType enum. +// Keep in sync with NetworkType. // Names used in Serialization to allow re-ordering/addition/deletion of // layer types in NetworkType without invalidating existing network files. -char const* const Network::kTypeNames[NT_COUNT] = { +static char const* const kTypeNames[NT_COUNT] = { "Invalid", "Input", "Convolve", "Maxpool", "Parallel", "Replicated", @@ -165,57 +165,63 @@ bool Network::Serialize(TFile* fp) const { return true; } -// Reads from the given file. Returns false in case of error. -// Should be overridden by subclasses, but NOT called by their DeSerialize. -bool Network::DeSerialize(TFile* fp) { +static NetworkType getNetworkType(TFile* fp) { int8_t data; - if (!fp->DeSerialize(&data)) return false; + if (!fp->DeSerialize(&data)) return NT_NONE; if (data == NT_NONE) { STRING type_name; - if (!type_name.DeSerialize(fp)) return false; + if (!type_name.DeSerialize(fp)) return NT_NONE; for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) { } if (data == NT_COUNT) { tprintf("Invalid network layer type:%s\n", type_name.string()); - return false; + return NT_NONE; } } - type_ = static_cast(data); - if (!fp->DeSerialize(&data)) return false; - training_ = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED; - if (!fp->DeSerialize(&data)) return false; - needs_to_backprop_ = data != 0; - if (!fp->DeSerialize(&network_flags_)) return false; - if (!fp->DeSerialize(&ni_)) return false; - if (!fp->DeSerialize(&no_)) return false; - if (!fp->DeSerialize(&num_weights_)) return false; - if (!name_.DeSerialize(fp)) return false; - return true; + return static_cast(data); } // Reads from the given file. Returns nullptr in case of error. // Determines the type of the serialized class and calls its DeSerialize // on a new object of the appropriate type, which is returned. Network* Network::CreateFromFile(TFile* fp) { - Network stub; - if (!stub.DeSerialize(fp)) return nullptr; + NetworkType type; // Type of the derived network class. + TrainingState training; // Are we currently training? + bool needs_to_backprop; // This network needs to output back_deltas. + int32_t network_flags; // Behavior control flags in NetworkFlags. + int32_t ni; // Number of input values. + int32_t no; // Number of output values. + int32_t num_weights; // Number of weights in this and sub-network. + STRING name; // A unique name for this layer. + int8_t data; Network* network = nullptr; - switch (stub.type_) { + type = getNetworkType(fp); + if (!fp->DeSerialize(&data)) return nullptr; + training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED; + if (!fp->DeSerialize(&data)) return nullptr; + needs_to_backprop = data != 0; + if (!fp->DeSerialize(&network_flags)) return nullptr; + if (!fp->DeSerialize(&ni)) return nullptr; + if (!fp->DeSerialize(&no)) return nullptr; + if (!fp->DeSerialize(&num_weights)) return nullptr; + if (!name.DeSerialize(fp)) return nullptr; + + switch (type) { case NT_CONVOLVE: - network = new Convolve(stub.name_, stub.ni_, 0, 0); + network = new Convolve(name, ni, 0, 0); break; case NT_INPUT: - network = new Input(stub.name_, stub.ni_, stub.no_); + network = new Input(name, ni, no); break; case NT_LSTM: case NT_LSTM_SOFTMAX: case NT_LSTM_SOFTMAX_ENCODED: case NT_LSTM_SUMMARY: network = - new LSTM(stub.name_, stub.ni_, stub.no_, stub.no_, false, stub.type_); + new LSTM(name, ni, no, no, false, type); break; case NT_MAXPOOL: - network = new Maxpool(stub.name_, stub.ni_, 0, 0); + network = new Maxpool(name, ni, 0, 0); break; // All variants of Parallel. case NT_PARALLEL: @@ -223,23 +229,23 @@ Network* Network::CreateFromFile(TFile* fp) { case NT_PAR_RL_LSTM: case NT_PAR_UD_LSTM: case NT_PAR_2D_LSTM: - network = new Parallel(stub.name_, stub.type_); + network = new Parallel(name, type); break; case NT_RECONFIG: - network = new Reconfig(stub.name_, stub.ni_, 0, 0); + network = new Reconfig(name, ni, 0, 0); break; // All variants of reversed. case NT_XREVERSED: case NT_YREVERSED: case NT_XYTRANSPOSE: - network = new Reversed(stub.name_, stub.type_); + network = new Reversed(name, type); break; case NT_SERIES: - network = new Series(stub.name_); + network = new Series(name); break; case NT_TENSORFLOW: #ifdef INCLUDE_TENSORFLOW - network = new TFNetwork(stub.name_); + network = new TFNetwork(name); #else tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n"); #endif @@ -253,16 +259,16 @@ Network* Network::CreateFromFile(TFile* fp) { case NT_LOGISTIC: case NT_POSCLIP: case NT_SYMCLIP: - network = new FullyConnected(stub.name_, stub.ni_, stub.no_, stub.type_); + network = new FullyConnected(name, ni, no, type); break; default: break; } if (network) { - network->training_ = stub.training_; - network->needs_to_backprop_ = stub.needs_to_backprop_; - network->network_flags_ = stub.network_flags_; - network->num_weights_ = stub.num_weights_; + network->training_ = training; + network->needs_to_backprop_ = needs_to_backprop; + network->network_flags_ = network_flags; + network->num_weights_ = num_weights; if (!network->DeSerialize(fp)) { delete network; network = nullptr; diff --git a/src/lstm/network.h b/src/lstm/network.h index ba528f11..24e047d6 100644 --- a/src/lstm/network.h +++ b/src/lstm/network.h @@ -2,7 +2,6 @@ // File: network.h // Description: Base class for neural network implementations. // Author: Ray Smith -// Created: Wed May 01 16:38:06 PST 2013 // // (C) Copyright 2013, Google Inc. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -215,17 +214,16 @@ class Network { virtual void CacheXScaleFactor(int factor) {} // Provides debug output on the weights. - virtual void DebugWeights() { - tprintf("Must override Network::DebugWeights for type %d\n", type_); - } + virtual void DebugWeights() = 0; // Writes to the given file. Returns false in case of error. // Should be overridden by subclasses, but called by their Serialize. virtual bool Serialize(TFile* fp) const; // Reads from the given file. Returns false in case of error. // Should be overridden by subclasses, but NOT called by their DeSerialize. - virtual bool DeSerialize(TFile* fp); + virtual bool DeSerialize(TFile* fp) = 0; + public: // Updates the weights using the given learning rate, momentum and adam_beta. // num_samples is used in the adam computation iff use_adam_ is true. virtual void Update(float learning_rate, float momentum, float adam_beta, @@ -261,9 +259,7 @@ class Network { // instead of all the replicated networks having to do it. virtual void Forward(bool debug, const NetworkIO& input, const TransposedArray* input_transpose, - NetworkScratch* scratch, NetworkIO* output) { - tprintf("Must override Network::Forward for type %d\n", type_); - } + NetworkScratch* scratch, NetworkIO* output) = 0; // Runs backward propagation of errors on fwdX_deltas. // Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward. @@ -272,10 +268,7 @@ class Network { // return false from Backward! virtual bool Backward(bool debug, const NetworkIO& fwd_deltas, NetworkScratch* scratch, - NetworkIO* back_deltas) { - tprintf("Must override Network::Backward for type %d\n", type_); - return false; - } + NetworkIO* back_deltas) = 0; // === Debug image display methods. === // Displays the image of the matrix to the forward window. @@ -309,12 +302,8 @@ class Network { ScrollView* forward_win_; // Recognition debug display window. ScrollView* backward_win_; // Training debug display window. TRand* randomizer_; // Random number generator. - - // Static serialized name/type_ mapping. Keep in sync with NetworkType. - static char const* const kTypeNames[NT_COUNT]; }; - } // namespace tesseract. #endif // TESSERACT_LSTM_NETWORK_H_ diff --git a/src/lstm/reconfig.h b/src/lstm/reconfig.h index 6e26399d..86e07252 100644 --- a/src/lstm/reconfig.h +++ b/src/lstm/reconfig.h @@ -3,7 +3,6 @@ // Description: Network layer that reconfigures the scaling vs feature // depth. // Author: Ray Smith -// Created: Wed Feb 26 15:37:42 PST 2014 // // (C) Copyright 2014, Google Inc. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,10 +15,10 @@ // See the License for the specific language governing permissions and // limitations under the License. /////////////////////////////////////////////////////////////////////// + #ifndef TESSERACT_LSTM_RECONFIG_H_ #define TESSERACT_LSTM_RECONFIG_H_ - #include "genericvector.h" #include "matrix.h" #include "network.h" @@ -71,6 +70,11 @@ class Reconfig : public Network { NetworkScratch* scratch, NetworkIO* back_deltas) override; + private: + void DebugWeights() override { + tprintf("Must override Network::DebugWeights for type %d\n", type_); + } + protected: // Non-serialized data used to store parameters between forward and back. StrideMap back_map_; @@ -81,5 +85,4 @@ class Reconfig : public Network { } // namespace tesseract. - #endif // TESSERACT_LSTM_SUBSAMPLE_H_