mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2024-12-12 07:29:07 +08:00
256 lines
8.7 KiB
C++
256 lines
8.7 KiB
C++
// Copyright 2008 Google Inc.
|
|
// All Rights Reserved.
|
|
// Author: ahmadab@google.com (Ahmad Abdulkader)
|
|
//
|
|
// neural_net.h: Declarations of a class for an object that
|
|
// represents an arbitrary network of neurons
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#ifndef NEURAL_NET_H
|
|
#define NEURAL_NET_H
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
#include "neuron.h"
|
|
#include "input_file_buffer.h"
|
|
|
|
namespace tesseract {
|
|
|
|
// Minimum input range below which we set the input weight to zero
|
|
static const float kMinInputRange = 1e-6f;
|
|
|
|
class NeuralNet {
|
|
public:
|
|
NeuralNet();
|
|
virtual ~NeuralNet();
|
|
// create a net object from a file. Uses stdio
|
|
static NeuralNet *FromFile(const string file_name);
|
|
// create a net object from an input buffer
|
|
static NeuralNet *FromInputBuffer(InputFileBuffer *ib);
|
|
// Different flavors of feed forward function
|
|
template <typename Type> bool FeedForward(const Type *inputs,
|
|
Type *outputs);
|
|
// Compute the output of a specific output node.
|
|
// This function is useful for application that are interested in a single
|
|
// output of the net and do not want to waste time on the rest
|
|
template <typename Type> bool GetNetOutput(const Type *inputs,
|
|
int output_id,
|
|
Type *output);
|
|
// Accessor functions
|
|
int in_cnt() const { return in_cnt_; }
|
|
int out_cnt() const { return out_cnt_; }
|
|
|
|
protected:
|
|
struct Node;
|
|
// A node-weight pair
|
|
struct WeightedNode {
|
|
Node *input_node;
|
|
float input_weight;
|
|
};
|
|
// node struct used for fast feedforward in
|
|
// Read only nets
|
|
struct Node {
|
|
float out;
|
|
float bias;
|
|
int fan_in_cnt;
|
|
WeightedNode *inputs;
|
|
};
|
|
// Read-Only flag (no training: On by default)
|
|
// will presumeably be set to false by
|
|
// the inherting TrainableNeuralNet class
|
|
bool read_only_;
|
|
// input count
|
|
int in_cnt_;
|
|
// output count
|
|
int out_cnt_;
|
|
// Total neuron count (including inputs)
|
|
int neuron_cnt_;
|
|
// count of unique weights
|
|
int wts_cnt_;
|
|
// Neuron vector
|
|
Neuron *neurons_;
|
|
// size of allocated weight chunk (in weights)
|
|
// This is basically the size of the biggest network
|
|
// that I have trained. However, the class will allow
|
|
// a bigger sized net if desired
|
|
static const int kWgtChunkSize = 0x10000;
|
|
// Magic number expected at the beginning of the NN
|
|
// binary file
|
|
static const unsigned int kNetSignature = 0xFEFEABD0;
|
|
// count of allocated wgts in the last chunk
|
|
int alloc_wgt_cnt_;
|
|
// vector of weights buffers
|
|
vector<vector<float> *>wts_vec_;
|
|
// Is the net an auto-encoder type
|
|
bool auto_encoder_;
|
|
// vector of input max values
|
|
vector<float> inputs_max_;
|
|
// vector of input min values
|
|
vector<float> inputs_min_;
|
|
// vector of input mean values
|
|
vector<float> inputs_mean_;
|
|
// vector of input standard deviation values
|
|
vector<float> inputs_std_dev_;
|
|
// vector of input offsets used by fast read-only
|
|
// feedforward function
|
|
vector<Node> fast_nodes_;
|
|
// Network Initialization function
|
|
void Init();
|
|
// Clears all neurons
|
|
void Clear() {
|
|
for (int node = 0; node < neuron_cnt_; node++) {
|
|
neurons_[node].Clear();
|
|
}
|
|
}
|
|
// Reads the net from an input buffer
|
|
template<class ReadBuffType> bool ReadBinary(ReadBuffType *input_buff) {
|
|
// Init vars
|
|
Init();
|
|
// is this an autoencoder
|
|
unsigned int read_val;
|
|
unsigned int auto_encode;
|
|
// read and verify signature
|
|
if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
|
|
return false;
|
|
}
|
|
if (read_val != kNetSignature) {
|
|
return false;
|
|
}
|
|
if (input_buff->Read(&auto_encode, sizeof(auto_encode)) !=
|
|
sizeof(auto_encode)) {
|
|
return false;
|
|
}
|
|
auto_encoder_ = auto_encode;
|
|
// read and validate total # of nodes
|
|
if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
|
|
return false;
|
|
}
|
|
neuron_cnt_ = read_val;
|
|
if (neuron_cnt_ <= 0) {
|
|
return false;
|
|
}
|
|
// set the size of the neurons vector
|
|
neurons_ = new Neuron[neuron_cnt_];
|
|
if (neurons_ == NULL) {
|
|
return false;
|
|
}
|
|
// read & validate inputs
|
|
if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
|
|
return false;
|
|
}
|
|
in_cnt_ = read_val;
|
|
if (in_cnt_ <= 0) {
|
|
return false;
|
|
}
|
|
// read outputs
|
|
if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
|
|
return false;
|
|
}
|
|
out_cnt_ = read_val;
|
|
if (out_cnt_ <= 0) {
|
|
return false;
|
|
}
|
|
// set neuron ids and types
|
|
for (int idx = 0; idx < neuron_cnt_; idx++) {
|
|
neurons_[idx].set_id(idx);
|
|
// input type
|
|
if (idx < in_cnt_) {
|
|
neurons_[idx].set_node_type(Neuron::Input);
|
|
} else if (idx >= (neuron_cnt_ - out_cnt_)) {
|
|
neurons_[idx].set_node_type(Neuron::Output);
|
|
} else {
|
|
neurons_[idx].set_node_type(Neuron::Hidden);
|
|
}
|
|
}
|
|
// read the connections
|
|
for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) {
|
|
// read fanout
|
|
if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
|
|
return false;
|
|
}
|
|
// read the neuron's info
|
|
int fan_out_cnt = read_val;
|
|
for (int fan_out_idx = 0; fan_out_idx < fan_out_cnt; fan_out_idx++) {
|
|
// read the neuron id
|
|
if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
|
|
return false;
|
|
}
|
|
// create the connection
|
|
if (!SetConnection(node_idx, read_val)) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
// read all the neurons' fan-in connections
|
|
for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) {
|
|
// read
|
|
if (!neurons_[node_idx].ReadBinary(input_buff)) {
|
|
return false;
|
|
}
|
|
}
|
|
// size input stats vector to expected input size
|
|
inputs_mean_.resize(in_cnt_);
|
|
inputs_std_dev_.resize(in_cnt_);
|
|
inputs_min_.resize(in_cnt_);
|
|
inputs_max_.resize(in_cnt_);
|
|
// read stats
|
|
if (input_buff->Read(&(inputs_mean_.front()),
|
|
sizeof(inputs_mean_[0]) * in_cnt_) !=
|
|
sizeof(inputs_mean_[0]) * in_cnt_) {
|
|
return false;
|
|
}
|
|
if (input_buff->Read(&(inputs_std_dev_.front()),
|
|
sizeof(inputs_std_dev_[0]) * in_cnt_) !=
|
|
sizeof(inputs_std_dev_[0]) * in_cnt_) {
|
|
return false;
|
|
}
|
|
if (input_buff->Read(&(inputs_min_.front()),
|
|
sizeof(inputs_min_[0]) * in_cnt_) !=
|
|
sizeof(inputs_min_[0]) * in_cnt_) {
|
|
return false;
|
|
}
|
|
if (input_buff->Read(&(inputs_max_.front()),
|
|
sizeof(inputs_max_[0]) * in_cnt_) !=
|
|
sizeof(inputs_max_[0]) * in_cnt_) {
|
|
return false;
|
|
}
|
|
// create a readonly version for fast feedforward
|
|
if (read_only_) {
|
|
return CreateFastNet();
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// creates a connection between two nodes
|
|
bool SetConnection(int from, int to);
|
|
// Create a read only version of the net that
|
|
// has faster feedforward performance
|
|
bool CreateFastNet();
|
|
// internal function to allocate a new set of weights
|
|
// Centralized weight allocation attempts to increase
|
|
// weights locality of reference making it more cache friendly
|
|
float *AllocWgt(int wgt_cnt);
|
|
// different flavors read-only feedforward function
|
|
template <typename Type> bool FastFeedForward(const Type *inputs,
|
|
Type *outputs);
|
|
// Compute the output of a specific output node.
|
|
// This function is useful for application that are interested in a single
|
|
// output of the net and do not want to waste time on the rest
|
|
// This is the fast-read-only version of this function
|
|
template <typename Type> bool FastGetNetOutput(const Type *inputs,
|
|
int output_id,
|
|
Type *output);
|
|
};
|
|
}
|
|
|
|
#endif // NEURAL_NET_H__
|