Simplify tanh and logistic functions and precompute function tables

Both functions are called very often, so computing the table values
at program start should be faster than computing them on demand.

Signed-off-by: Stefan Weil <sw@weilnetz.de>
This commit is contained in:
Stefan Weil 2019-02-11 18:33:14 +01:00
parent 7ca27bb14a
commit f491eb6188
2 changed files with 25 additions and 29 deletions

View File

@ -2,7 +2,6 @@
// File: functions.cpp
// Description: Static initialize-on-first-use non-linearity functions.
// Author: Ray Smith
// Created: Tue Jul 17 14:02:59 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
@ -16,6 +15,7 @@
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include <cmath> // for exp, tanh
#include "functions.h"
namespace tesseract {
@ -23,4 +23,16 @@ namespace tesseract {
double TanhTable[kTableSize];
double LogisticTable[kTableSize];
class TableInit {
TableInit() {
for (int i = 0; i < kTableSize; i++) {
TanhTable[i] = tanh(i / kScaleFactor);
LogisticTable[i] = 1 / (1 + exp(-i / kScaleFactor));
}
}
static TableInit tableInit;
};
TableInit TableInit::tableInit;
} // namespace tesseract.

View File

@ -18,7 +18,6 @@
#ifndef TESSERACT_LSTM_FUNCTIONS_H_
#define TESSERACT_LSTM_FUNCTIONS_H_
#include <cmath>
#include "helpers.h"
// Setting this to 1 or more causes massive dumps of debug data: weights,
@ -42,39 +41,24 @@ extern double LogisticTable[];
// Non-linearity (sigmoid) functions with cache tables and clipping.
inline double Tanh(double x) {
if (x < 0.0) return -Tanh(-x);
if (x >= (kTableSize - 1) / kScaleFactor) return 1.0;
x *= kScaleFactor;
int index = static_cast<int>(floor(x));
if (TanhTable[index] == 0.0 && index > 0) {
// Generate the entry.
TanhTable[index] = tanh(index / kScaleFactor);
}
if (index == kTableSize - 1) return TanhTable[kTableSize - 1];
if (TanhTable[index + 1] == 0.0) {
// Generate the entry.
TanhTable[index + 1] = tanh((index + 1) / kScaleFactor);
}
double offset = x - index;
return TanhTable[index] * (1.0 - offset) + TanhTable[index + 1] * offset;
int index = static_cast<int>(x);
if (index >= (kTableSize - 1)) return 1.0;
double tanh_i0 = TanhTable[index];
double tanh_i1 = TanhTable[index + 1];
// Linear interpolation.
return tanh_i0 + (tanh_i1 - tanh_i0) * (x - index);
}
inline double Logistic(double x) {
if (x < 0.0) return 1.0 - Logistic(-x);
if (x >= (kTableSize - 1) / kScaleFactor) return 1.0;
x *= kScaleFactor;
int index = static_cast<int>(floor(x));
if (LogisticTable[index] == 0.0) {
// Generate the entry.
LogisticTable[index] = 1.0 / (1.0 + exp(-index / kScaleFactor));
}
if (index == kTableSize - 1) return LogisticTable[kTableSize - 1];
if (LogisticTable[index + 1] == 0.0) {
// Generate the entry.
LogisticTable[index + 1] = 1.0 / (1.0 + exp(-(index + 1) / kScaleFactor));
}
double offset = x - index;
return LogisticTable[index] * (1.0 - offset) +
LogisticTable[index + 1] * offset;
int index = static_cast<int>(x);
if (index >= (kTableSize - 1)) return 1.0;
double l0 = LogisticTable[index];
double l1 = LogisticTable[index + 1];
// Linear interpolation.
return l0 + (l1 - l0) * (x - index);
}
// Non-linearity (sigmoid) functions and their derivatives.