Merge pull request #2271 from stweil/refactor

Refactor class Network
This commit is contained in:
zdenop 2019-02-27 07:43:13 +01:00 committed by GitHub
commit 12c1225a5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 59 deletions

View File

@ -4,7 +4,6 @@
// and pulls in random data to fill out-of-input inputs.
// Output is therefore same size as its input, but deeper.
// Author: Ray Smith
// Created: Tue Mar 18 16:45:34 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
@ -61,6 +60,11 @@ class Convolve : public Network {
NetworkScratch* scratch,
NetworkIO* back_deltas) override;
private:
void DebugWeights() override {
tprintf("Must override Network::DebugWeights for type %d\n", type_);
}
protected:
// Serialized data.
int32_t half_x_;
@ -69,5 +73,4 @@ class Convolve : public Network {
} // namespace tesseract.
#endif // TESSERACT_LSTM_SUBSAMPLE_H_

View File

@ -2,7 +2,6 @@
// File: input.h
// Description: Input layer class for neural network implementations.
// Author: Ray Smith
// Created: Thu Mar 13 08:56:26 PDT 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
@ -93,6 +92,10 @@ class Input : public Network {
TRand* randomizer, NetworkIO* input);
private:
void DebugWeights() override {
tprintf("Must override Network::DebugWeights for type %d\n", type_);
}
// Input shape determines how images are dealt with.
StaticShape shape_;
// Cached total network x scale factor for scaling bounding boxes.

View File

@ -2,7 +2,6 @@
// File: network.cpp
// Description: Base class for neural network implementations.
// Author: Ray Smith
// Created: Wed May 01 17:25:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
@ -53,10 +52,11 @@ const int kMaxWinSize = 2000;
const int kXWinFrameSize = 30;
const int kYWinFrameSize = 80;
// String names corresponding to the NetworkType enum. Keep in sync.
// String names corresponding to the NetworkType enum.
// Keep in sync with NetworkType.
// Names used in Serialization to allow re-ordering/addition/deletion of
// layer types in NetworkType without invalidating existing network files.
char const* const Network::kTypeNames[NT_COUNT] = {
static char const* const kTypeNames[NT_COUNT] = {
"Invalid", "Input",
"Convolve", "Maxpool",
"Parallel", "Replicated",
@ -165,57 +165,63 @@ bool Network::Serialize(TFile* fp) const {
return true;
}
// Reads from the given file. Returns false in case of error.
// Should be overridden by subclasses, but NOT called by their DeSerialize.
bool Network::DeSerialize(TFile* fp) {
static NetworkType getNetworkType(TFile* fp) {
int8_t data;
if (!fp->DeSerialize(&data)) return false;
if (!fp->DeSerialize(&data)) return NT_NONE;
if (data == NT_NONE) {
STRING type_name;
if (!type_name.DeSerialize(fp)) return false;
if (!type_name.DeSerialize(fp)) return NT_NONE;
for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
}
if (data == NT_COUNT) {
tprintf("Invalid network layer type:%s\n", type_name.string());
return false;
return NT_NONE;
}
}
type_ = static_cast<NetworkType>(data);
if (!fp->DeSerialize(&data)) return false;
training_ = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
if (!fp->DeSerialize(&data)) return false;
needs_to_backprop_ = data != 0;
if (!fp->DeSerialize(&network_flags_)) return false;
if (!fp->DeSerialize(&ni_)) return false;
if (!fp->DeSerialize(&no_)) return false;
if (!fp->DeSerialize(&num_weights_)) return false;
if (!name_.DeSerialize(fp)) return false;
return true;
return static_cast<NetworkType>(data);
}
// Reads from the given file. Returns nullptr in case of error.
// Determines the type of the serialized class and calls its DeSerialize
// on a new object of the appropriate type, which is returned.
Network* Network::CreateFromFile(TFile* fp) {
Network stub;
if (!stub.DeSerialize(fp)) return nullptr;
NetworkType type; // Type of the derived network class.
TrainingState training; // Are we currently training?
bool needs_to_backprop; // This network needs to output back_deltas.
int32_t network_flags; // Behavior control flags in NetworkFlags.
int32_t ni; // Number of input values.
int32_t no; // Number of output values.
int32_t num_weights; // Number of weights in this and sub-network.
STRING name; // A unique name for this layer.
int8_t data;
Network* network = nullptr;
switch (stub.type_) {
type = getNetworkType(fp);
if (!fp->DeSerialize(&data)) return nullptr;
training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
if (!fp->DeSerialize(&data)) return nullptr;
needs_to_backprop = data != 0;
if (!fp->DeSerialize(&network_flags)) return nullptr;
if (!fp->DeSerialize(&ni)) return nullptr;
if (!fp->DeSerialize(&no)) return nullptr;
if (!fp->DeSerialize(&num_weights)) return nullptr;
if (!name.DeSerialize(fp)) return nullptr;
switch (type) {
case NT_CONVOLVE:
network = new Convolve(stub.name_, stub.ni_, 0, 0);
network = new Convolve(name, ni, 0, 0);
break;
case NT_INPUT:
network = new Input(stub.name_, stub.ni_, stub.no_);
network = new Input(name, ni, no);
break;
case NT_LSTM:
case NT_LSTM_SOFTMAX:
case NT_LSTM_SOFTMAX_ENCODED:
case NT_LSTM_SUMMARY:
network =
new LSTM(stub.name_, stub.ni_, stub.no_, stub.no_, false, stub.type_);
new LSTM(name, ni, no, no, false, type);
break;
case NT_MAXPOOL:
network = new Maxpool(stub.name_, stub.ni_, 0, 0);
network = new Maxpool(name, ni, 0, 0);
break;
// All variants of Parallel.
case NT_PARALLEL:
@ -223,23 +229,23 @@ Network* Network::CreateFromFile(TFile* fp) {
case NT_PAR_RL_LSTM:
case NT_PAR_UD_LSTM:
case NT_PAR_2D_LSTM:
network = new Parallel(stub.name_, stub.type_);
network = new Parallel(name, type);
break;
case NT_RECONFIG:
network = new Reconfig(stub.name_, stub.ni_, 0, 0);
network = new Reconfig(name, ni, 0, 0);
break;
// All variants of reversed.
case NT_XREVERSED:
case NT_YREVERSED:
case NT_XYTRANSPOSE:
network = new Reversed(stub.name_, stub.type_);
network = new Reversed(name, type);
break;
case NT_SERIES:
network = new Series(stub.name_);
network = new Series(name);
break;
case NT_TENSORFLOW:
#ifdef INCLUDE_TENSORFLOW
network = new TFNetwork(stub.name_);
network = new TFNetwork(name);
#else
tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
#endif
@ -253,16 +259,16 @@ Network* Network::CreateFromFile(TFile* fp) {
case NT_LOGISTIC:
case NT_POSCLIP:
case NT_SYMCLIP:
network = new FullyConnected(stub.name_, stub.ni_, stub.no_, stub.type_);
network = new FullyConnected(name, ni, no, type);
break;
default:
break;
}
if (network) {
network->training_ = stub.training_;
network->needs_to_backprop_ = stub.needs_to_backprop_;
network->network_flags_ = stub.network_flags_;
network->num_weights_ = stub.num_weights_;
network->training_ = training;
network->needs_to_backprop_ = needs_to_backprop;
network->network_flags_ = network_flags;
network->num_weights_ = num_weights;
if (!network->DeSerialize(fp)) {
delete network;
network = nullptr;

View File

@ -2,7 +2,6 @@
// File: network.h
// Description: Base class for neural network implementations.
// Author: Ray Smith
// Created: Wed May 01 16:38:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
@ -215,17 +214,16 @@ class Network {
virtual void CacheXScaleFactor(int factor) {}
// Provides debug output on the weights.
virtual void DebugWeights() {
tprintf("Must override Network::DebugWeights for type %d\n", type_);
}
virtual void DebugWeights() = 0;
// Writes to the given file. Returns false in case of error.
// Should be overridden by subclasses, but called by their Serialize.
virtual bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// Should be overridden by subclasses, but NOT called by their DeSerialize.
virtual bool DeSerialize(TFile* fp);
virtual bool DeSerialize(TFile* fp) = 0;
public:
// Updates the weights using the given learning rate, momentum and adam_beta.
// num_samples is used in the adam computation iff use_adam_ is true.
virtual void Update(float learning_rate, float momentum, float adam_beta,
@ -261,9 +259,7 @@ class Network {
// instead of all the replicated networks having to do it.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
tprintf("Must override Network::Forward for type %d\n", type_);
}
NetworkScratch* scratch, NetworkIO* output) = 0;
// Runs backward propagation of errors on fwdX_deltas.
// Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward.
@ -272,10 +268,7 @@ class Network {
// return false from Backward!
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
tprintf("Must override Network::Backward for type %d\n", type_);
return false;
}
NetworkIO* back_deltas) = 0;
// === Debug image display methods. ===
// Displays the image of the matrix to the forward window.
@ -309,12 +302,8 @@ class Network {
ScrollView* forward_win_; // Recognition debug display window.
ScrollView* backward_win_; // Training debug display window.
TRand* randomizer_; // Random number generator.
// Static serialized name/type_ mapping. Keep in sync with NetworkType.
static char const* const kTypeNames[NT_COUNT];
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_NETWORK_H_

View File

@ -3,7 +3,6 @@
// Description: Network layer that reconfigures the scaling vs feature
// depth.
// Author: Ray Smith
// Created: Wed Feb 26 15:37:42 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
@ -16,10 +15,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_RECONFIG_H_
#define TESSERACT_LSTM_RECONFIG_H_
#include "genericvector.h"
#include "matrix.h"
#include "network.h"
@ -71,6 +70,11 @@ class Reconfig : public Network {
NetworkScratch* scratch,
NetworkIO* back_deltas) override;
private:
void DebugWeights() override {
tprintf("Must override Network::DebugWeights for type %d\n", type_);
}
protected:
// Non-serialized data used to store parameters between forward and back.
StrideMap back_map_;
@ -81,5 +85,4 @@ class Reconfig : public Network {
} // namespace tesseract.
#endif // TESSERACT_LSTM_SUBSAMPLE_H_