mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2024-11-24 02:59:07 +08:00
Refactor class Network
That class is an abstract class with several pure virtual functions. Signed-off-by: Stefan Weil <sw@weilnetz.de>
This commit is contained in:
parent
cf85054453
commit
98dd3b6351
@ -4,7 +4,6 @@
|
||||
// and pulls in random data to fill out-of-input inputs.
|
||||
// Output is therefore same size as its input, but deeper.
|
||||
// Author: Ray Smith
|
||||
// Created: Tue Mar 18 16:45:34 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -61,6 +60,11 @@ class Convolve : public Network {
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas) override;
|
||||
|
||||
private:
|
||||
void DebugWeights() override {
|
||||
tprintf("Must override Network::DebugWeights for type %d\n", type_);
|
||||
}
|
||||
|
||||
protected:
|
||||
// Serialized data.
|
||||
int32_t half_x_;
|
||||
@ -69,5 +73,4 @@ class Convolve : public Network {
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
|
||||
#endif // TESSERACT_LSTM_SUBSAMPLE_H_
|
||||
|
@ -2,7 +2,6 @@
|
||||
// File: input.h
|
||||
// Description: Input layer class for neural network implementations.
|
||||
// Author: Ray Smith
|
||||
// Created: Thu Mar 13 08:56:26 PDT 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -93,6 +92,10 @@ class Input : public Network {
|
||||
TRand* randomizer, NetworkIO* input);
|
||||
|
||||
private:
|
||||
void DebugWeights() override {
|
||||
tprintf("Must override Network::DebugWeights for type %d\n", type_);
|
||||
}
|
||||
|
||||
// Input shape determines how images are dealt with.
|
||||
StaticShape shape_;
|
||||
// Cached total network x scale factor for scaling bounding boxes.
|
||||
|
@ -2,7 +2,6 @@
|
||||
// File: network.cpp
|
||||
// Description: Base class for neural network implementations.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed May 01 17:25:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -53,10 +52,11 @@ const int kMaxWinSize = 2000;
|
||||
const int kXWinFrameSize = 30;
|
||||
const int kYWinFrameSize = 80;
|
||||
|
||||
// String names corresponding to the NetworkType enum. Keep in sync.
|
||||
// String names corresponding to the NetworkType enum.
|
||||
// Keep in sync with NetworkType.
|
||||
// Names used in Serialization to allow re-ordering/addition/deletion of
|
||||
// layer types in NetworkType without invalidating existing network files.
|
||||
char const* const Network::kTypeNames[NT_COUNT] = {
|
||||
static char const* const kTypeNames[NT_COUNT] = {
|
||||
"Invalid", "Input",
|
||||
"Convolve", "Maxpool",
|
||||
"Parallel", "Replicated",
|
||||
@ -165,57 +165,63 @@ bool Network::Serialize(TFile* fp) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// Should be overridden by subclasses, but NOT called by their DeSerialize.
|
||||
bool Network::DeSerialize(TFile* fp) {
|
||||
static NetworkType getNetworkType(TFile* fp) {
|
||||
int8_t data;
|
||||
if (!fp->DeSerialize(&data)) return false;
|
||||
if (!fp->DeSerialize(&data)) return NT_NONE;
|
||||
if (data == NT_NONE) {
|
||||
STRING type_name;
|
||||
if (!type_name.DeSerialize(fp)) return false;
|
||||
if (!type_name.DeSerialize(fp)) return NT_NONE;
|
||||
for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
|
||||
}
|
||||
if (data == NT_COUNT) {
|
||||
tprintf("Invalid network layer type:%s\n", type_name.string());
|
||||
return false;
|
||||
return NT_NONE;
|
||||
}
|
||||
}
|
||||
type_ = static_cast<NetworkType>(data);
|
||||
if (!fp->DeSerialize(&data)) return false;
|
||||
training_ = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
|
||||
if (!fp->DeSerialize(&data)) return false;
|
||||
needs_to_backprop_ = data != 0;
|
||||
if (!fp->DeSerialize(&network_flags_)) return false;
|
||||
if (!fp->DeSerialize(&ni_)) return false;
|
||||
if (!fp->DeSerialize(&no_)) return false;
|
||||
if (!fp->DeSerialize(&num_weights_)) return false;
|
||||
if (!name_.DeSerialize(fp)) return false;
|
||||
return true;
|
||||
return static_cast<NetworkType>(data);
|
||||
}
|
||||
|
||||
// Reads from the given file. Returns nullptr in case of error.
|
||||
// Determines the type of the serialized class and calls its DeSerialize
|
||||
// on a new object of the appropriate type, which is returned.
|
||||
Network* Network::CreateFromFile(TFile* fp) {
|
||||
Network stub;
|
||||
if (!stub.DeSerialize(fp)) return nullptr;
|
||||
NetworkType type; // Type of the derived network class.
|
||||
TrainingState training; // Are we currently training?
|
||||
bool needs_to_backprop; // This network needs to output back_deltas.
|
||||
int32_t network_flags; // Behavior control flags in NetworkFlags.
|
||||
int32_t ni; // Number of input values.
|
||||
int32_t no; // Number of output values.
|
||||
int32_t num_weights; // Number of weights in this and sub-network.
|
||||
STRING name; // A unique name for this layer.
|
||||
int8_t data;
|
||||
Network* network = nullptr;
|
||||
switch (stub.type_) {
|
||||
type = getNetworkType(fp);
|
||||
if (!fp->DeSerialize(&data)) return nullptr;
|
||||
training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
|
||||
if (!fp->DeSerialize(&data)) return nullptr;
|
||||
needs_to_backprop = data != 0;
|
||||
if (!fp->DeSerialize(&network_flags)) return nullptr;
|
||||
if (!fp->DeSerialize(&ni)) return nullptr;
|
||||
if (!fp->DeSerialize(&no)) return nullptr;
|
||||
if (!fp->DeSerialize(&num_weights)) return nullptr;
|
||||
if (!name.DeSerialize(fp)) return nullptr;
|
||||
|
||||
switch (type) {
|
||||
case NT_CONVOLVE:
|
||||
network = new Convolve(stub.name_, stub.ni_, 0, 0);
|
||||
network = new Convolve(name, ni, 0, 0);
|
||||
break;
|
||||
case NT_INPUT:
|
||||
network = new Input(stub.name_, stub.ni_, stub.no_);
|
||||
network = new Input(name, ni, no);
|
||||
break;
|
||||
case NT_LSTM:
|
||||
case NT_LSTM_SOFTMAX:
|
||||
case NT_LSTM_SOFTMAX_ENCODED:
|
||||
case NT_LSTM_SUMMARY:
|
||||
network =
|
||||
new LSTM(stub.name_, stub.ni_, stub.no_, stub.no_, false, stub.type_);
|
||||
new LSTM(name, ni, no, no, false, type);
|
||||
break;
|
||||
case NT_MAXPOOL:
|
||||
network = new Maxpool(stub.name_, stub.ni_, 0, 0);
|
||||
network = new Maxpool(name, ni, 0, 0);
|
||||
break;
|
||||
// All variants of Parallel.
|
||||
case NT_PARALLEL:
|
||||
@ -223,23 +229,23 @@ Network* Network::CreateFromFile(TFile* fp) {
|
||||
case NT_PAR_RL_LSTM:
|
||||
case NT_PAR_UD_LSTM:
|
||||
case NT_PAR_2D_LSTM:
|
||||
network = new Parallel(stub.name_, stub.type_);
|
||||
network = new Parallel(name, type);
|
||||
break;
|
||||
case NT_RECONFIG:
|
||||
network = new Reconfig(stub.name_, stub.ni_, 0, 0);
|
||||
network = new Reconfig(name, ni, 0, 0);
|
||||
break;
|
||||
// All variants of reversed.
|
||||
case NT_XREVERSED:
|
||||
case NT_YREVERSED:
|
||||
case NT_XYTRANSPOSE:
|
||||
network = new Reversed(stub.name_, stub.type_);
|
||||
network = new Reversed(name, type);
|
||||
break;
|
||||
case NT_SERIES:
|
||||
network = new Series(stub.name_);
|
||||
network = new Series(name);
|
||||
break;
|
||||
case NT_TENSORFLOW:
|
||||
#ifdef INCLUDE_TENSORFLOW
|
||||
network = new TFNetwork(stub.name_);
|
||||
network = new TFNetwork(name);
|
||||
#else
|
||||
tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
|
||||
#endif
|
||||
@ -253,16 +259,16 @@ Network* Network::CreateFromFile(TFile* fp) {
|
||||
case NT_LOGISTIC:
|
||||
case NT_POSCLIP:
|
||||
case NT_SYMCLIP:
|
||||
network = new FullyConnected(stub.name_, stub.ni_, stub.no_, stub.type_);
|
||||
network = new FullyConnected(name, ni, no, type);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
if (network) {
|
||||
network->training_ = stub.training_;
|
||||
network->needs_to_backprop_ = stub.needs_to_backprop_;
|
||||
network->network_flags_ = stub.network_flags_;
|
||||
network->num_weights_ = stub.num_weights_;
|
||||
network->training_ = training;
|
||||
network->needs_to_backprop_ = needs_to_backprop;
|
||||
network->network_flags_ = network_flags;
|
||||
network->num_weights_ = num_weights;
|
||||
if (!network->DeSerialize(fp)) {
|
||||
delete network;
|
||||
network = nullptr;
|
||||
|
@ -2,7 +2,6 @@
|
||||
// File: network.h
|
||||
// Description: Base class for neural network implementations.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed May 01 16:38:06 PST 2013
|
||||
//
|
||||
// (C) Copyright 2013, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -215,17 +214,16 @@ class Network {
|
||||
virtual void CacheXScaleFactor(int factor) {}
|
||||
|
||||
// Provides debug output on the weights.
|
||||
virtual void DebugWeights() {
|
||||
tprintf("Must override Network::DebugWeights for type %d\n", type_);
|
||||
}
|
||||
virtual void DebugWeights() = 0;
|
||||
|
||||
// Writes to the given file. Returns false in case of error.
|
||||
// Should be overridden by subclasses, but called by their Serialize.
|
||||
virtual bool Serialize(TFile* fp) const;
|
||||
// Reads from the given file. Returns false in case of error.
|
||||
// Should be overridden by subclasses, but NOT called by their DeSerialize.
|
||||
virtual bool DeSerialize(TFile* fp);
|
||||
virtual bool DeSerialize(TFile* fp) = 0;
|
||||
|
||||
public:
|
||||
// Updates the weights using the given learning rate, momentum and adam_beta.
|
||||
// num_samples is used in the adam computation iff use_adam_ is true.
|
||||
virtual void Update(float learning_rate, float momentum, float adam_beta,
|
||||
@ -261,9 +259,7 @@ class Network {
|
||||
// instead of all the replicated networks having to do it.
|
||||
virtual void Forward(bool debug, const NetworkIO& input,
|
||||
const TransposedArray* input_transpose,
|
||||
NetworkScratch* scratch, NetworkIO* output) {
|
||||
tprintf("Must override Network::Forward for type %d\n", type_);
|
||||
}
|
||||
NetworkScratch* scratch, NetworkIO* output) = 0;
|
||||
|
||||
// Runs backward propagation of errors on fwdX_deltas.
|
||||
// Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward.
|
||||
@ -272,10 +268,7 @@ class Network {
|
||||
// return false from Backward!
|
||||
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas) {
|
||||
tprintf("Must override Network::Backward for type %d\n", type_);
|
||||
return false;
|
||||
}
|
||||
NetworkIO* back_deltas) = 0;
|
||||
|
||||
// === Debug image display methods. ===
|
||||
// Displays the image of the matrix to the forward window.
|
||||
@ -309,12 +302,8 @@ class Network {
|
||||
ScrollView* forward_win_; // Recognition debug display window.
|
||||
ScrollView* backward_win_; // Training debug display window.
|
||||
TRand* randomizer_; // Random number generator.
|
||||
|
||||
// Static serialized name/type_ mapping. Keep in sync with NetworkType.
|
||||
static char const* const kTypeNames[NT_COUNT];
|
||||
};
|
||||
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
#endif // TESSERACT_LSTM_NETWORK_H_
|
||||
|
@ -3,7 +3,6 @@
|
||||
// Description: Network layer that reconfigures the scaling vs feature
|
||||
// depth.
|
||||
// Author: Ray Smith
|
||||
// Created: Wed Feb 26 15:37:42 PST 2014
|
||||
//
|
||||
// (C) Copyright 2014, Google Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -16,10 +15,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef TESSERACT_LSTM_RECONFIG_H_
|
||||
#define TESSERACT_LSTM_RECONFIG_H_
|
||||
|
||||
|
||||
#include "genericvector.h"
|
||||
#include "matrix.h"
|
||||
#include "network.h"
|
||||
@ -71,6 +70,11 @@ class Reconfig : public Network {
|
||||
NetworkScratch* scratch,
|
||||
NetworkIO* back_deltas) override;
|
||||
|
||||
private:
|
||||
void DebugWeights() override {
|
||||
tprintf("Must override Network::DebugWeights for type %d\n", type_);
|
||||
}
|
||||
|
||||
protected:
|
||||
// Non-serialized data used to store parameters between forward and back.
|
||||
StrideMap back_map_;
|
||||
@ -81,5 +85,4 @@ class Reconfig : public Network {
|
||||
|
||||
} // namespace tesseract.
|
||||
|
||||
|
||||
#endif // TESSERACT_LSTM_SUBSAMPLE_H_
|
||||
|
Loading…
Reference in New Issue
Block a user