mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2025-01-18 22:43:45 +08:00
4523ce9f7d
git-svn-id: https://tesseract-ocr.googlecode.com/svn/trunk@526 d0cd1f9f-072b-0410-8dd7-cf729c803f20
148 lines
3.8 KiB
C++
148 lines
3.8 KiB
C++
// Copyright 2008 Google Inc.
|
|
// All Rights Reserved.
|
|
// Author: ahmadab@google.com (Ahmad Abdulkader)
|
|
//
|
|
// neuron.h: Declarations of a class for an object that
|
|
// represents a single neuron in a neural network
|
|
//
|
|
|
|
#ifndef NEURON_H
|
|
#define NEURON_H
|
|
|
|
#include <math.h>
|
|
#include <vector>
|
|
|
|
#ifdef USE_STD_NAMESPACE
|
|
using std::vector;
|
|
#endif
|
|
|
|
namespace tesseract {
|
|
|
|
// Input Node bias values
|
|
static const float kInputNodeBias = 0.0f;
|
|
|
|
class Neuron {
|
|
public:
|
|
// Types of nodes
|
|
enum NeuronTypes {
|
|
Unknown = 0,
|
|
Input,
|
|
Hidden,
|
|
Output
|
|
};
|
|
Neuron();
|
|
~Neuron();
|
|
// set the forward dirty flag indicating that the
|
|
// activation of the net is not fresh
|
|
void Clear() {
|
|
frwd_dirty_ = true;
|
|
}
|
|
// Read a binary representation of the neuron info from
|
|
// an input buffer.
|
|
template <class BuffType> bool ReadBinary(BuffType *input_buff) {
|
|
float val;
|
|
if (input_buff->Read(&val, sizeof(val)) != sizeof(val)) {
|
|
return false;
|
|
}
|
|
// input nodes should have no biases
|
|
if (node_type_ == Input) {
|
|
bias_ = kInputNodeBias;
|
|
} else {
|
|
bias_ = val;
|
|
}
|
|
// read fanin count
|
|
int fan_in_cnt;
|
|
if (input_buff->Read(&fan_in_cnt, sizeof(fan_in_cnt)) !=
|
|
sizeof(fan_in_cnt)) {
|
|
return false;
|
|
}
|
|
// validate fan-in cnt
|
|
if (fan_in_cnt != fan_in_.size()) {
|
|
return false;
|
|
}
|
|
// read the weights
|
|
for (int in = 0; in < fan_in_cnt; in++) {
|
|
if (input_buff->Read(&val, sizeof(val)) != sizeof(val)) {
|
|
return false;
|
|
}
|
|
*(fan_in_weights_[in]) = val;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Add a new connection from this neuron *From*
|
|
// a target neuron using specfied params
|
|
// Note that what is actually copied in this function are pointers to the
|
|
// specified Neurons and weights and not the actualt values. This is by
|
|
// design to centralize the alloction of neurons and weights and so
|
|
// increase the locality of reference and improve cache-hits resulting
|
|
// in a faster net. This technique resulted in a 2X-10X speedup
|
|
// (depending on network size and processor)
|
|
void AddFromConnection(Neuron *neuron_vec,
|
|
float *wts_offset,
|
|
int from_cnt);
|
|
// Set the type of a neuron
|
|
void set_node_type(NeuronTypes type);
|
|
// Computes the output of the node by
|
|
// "pulling" the output of the fan-in nodes
|
|
void FeedForward();
|
|
// fast computation of sigmoid function using a lookup table
|
|
// defined in sigmoid_table.cpp
|
|
static float Sigmoid(float activation);
|
|
// Accessor functions
|
|
float output() const {
|
|
return output_;
|
|
}
|
|
void set_output(float out_val) {
|
|
output_ = out_val;
|
|
}
|
|
int id() const {
|
|
return id_;
|
|
}
|
|
int fan_in_cnt() const {
|
|
return fan_in_.size();
|
|
}
|
|
Neuron * fan_in(int idx) const {
|
|
return fan_in_[idx];
|
|
}
|
|
float fan_in_wts(int idx) const {
|
|
return *(fan_in_weights_[idx]);
|
|
}
|
|
void set_id(int id) {
|
|
id_ = id;
|
|
}
|
|
float bias() const {
|
|
return bias_;
|
|
}
|
|
Neuron::NeuronTypes node_type() const {
|
|
return node_type_;
|
|
}
|
|
|
|
protected:
|
|
// Type of Neuron
|
|
NeuronTypes node_type_;
|
|
// unqique id of the neuron
|
|
int id_;
|
|
// node bias
|
|
float bias_;
|
|
// node net activation
|
|
float activation_;
|
|
// node output
|
|
float output_;
|
|
// pointers to fanin nodes
|
|
vector<Neuron *> fan_in_;
|
|
// pointers to fanin weights
|
|
vector<float *> fan_in_weights_;
|
|
// Sigmoid function lookup table used for fast computation
|
|
// of sigmoid function
|
|
static const float kSigmoidTable[];
|
|
// flag determining if the activation of the node
|
|
// is fresh or not (dirty)
|
|
bool frwd_dirty_;
|
|
// Initializer
|
|
void Init();
|
|
};
|
|
}
|
|
|
|
#endif // NEURON_H__
|