tesseract/lstm/lstm.cpp

729 lines
26 KiB
C++
Raw Normal View History

///////////////////////////////////////////////////////////////////////
// File: lstm.cpp
// Description: Long-term-short-term-memory Recurrent neural network.
// Author: Ray Smith
// Created: Wed May 01 17:43:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "lstm.h"
#ifndef ANDROID_BUILD
#include <omp.h>
#endif
#include <stdio.h>
#include <stdlib.h>
#include "fullyconnected.h"
#include "functions.h"
#include "networkscratch.h"
#include "tprintf.h"
// Macros for openmp code if it is available, otherwise empty macros.
#ifdef _OPENMP
#define PARALLEL_IF_OPENMP(__num_threads) \
PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \
PRAGMA(omp sections nowait) { \
PRAGMA(omp section) {
#define SECTION_IF_OPENMP \
} \
PRAGMA(omp section) \
{
#define END_PARALLEL_IF_OPENMP \
} \
} /* end of sections */ \
} /* end of parallel section */
// Define the portable PRAGMA macro.
#ifdef _MSC_VER // Different _Pragma
#define PRAGMA(x) __pragma(x)
#else
#define PRAGMA(x) _Pragma(#x)
#endif // _MSC_VER
#else // _OPENMP
#define PARALLEL_IF_OPENMP(__num_threads)
#define SECTION_IF_OPENMP
#define END_PARALLEL_IF_OPENMP
#endif // _OPENMP
namespace tesseract {
// Max absolute value of state_. It is reasonably high to enable the state
// to count things.
const double kStateClip = 100.0;
// Max absolute value of gate_errors (the gradients).
const double kErrClip = 1.0f;
LSTM::LSTM(const STRING& name, int ni, int ns, int no, bool two_dimensional,
NetworkType type)
: Network(type, name, ni, no),
na_(ni + ns),
ns_(ns),
nf_(0),
is_2d_(two_dimensional),
softmax_(NULL),
input_width_(0) {
if (two_dimensional) na_ += ns_;
if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
nf_ = 0;
// networkbuilder ensures this is always true.
ASSERT_HOST(no == ns);
} else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : IntCastRounded(ceil(log2(no_)));
softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
} else {
tprintf("%d is invalid type of LSTM!\n", type);
ASSERT_HOST(false);
}
na_ += nf_;
}
LSTM::~LSTM() { delete softmax_; }
// Returns the shape output from the network given an input shape (which may
// be partially unknown ie zero).
StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
StaticShape result = input_shape;
result.set_depth(no_);
if (type_ == NT_LSTM_SUMMARY) result.set_width(1);
if (softmax_ != NULL) return softmax_->OutputShape(result);
return result;
}
// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
void LSTM::SetEnableTraining(TrainingState state) {
if (state == TS_RE_ENABLE) {
if (training_ == TS_DISABLED) {
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].InitBackward(false);
}
}
training_ = TS_ENABLED;
} else {
training_ = state;
}
if (softmax_ != NULL) softmax_->SetEnableTraining(state);
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
int LSTM::InitWeights(float range, TRand* randomizer) {
Network::SetRandomizer(randomizer);
num_weights_ = 0;
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
num_weights_ += gate_weights_[w].InitWeightsFloat(
ns_, na_ + 1, TestFlag(NF_ADA_GRAD), range, randomizer);
}
if (softmax_ != NULL) {
num_weights_ += softmax_->InitWeights(range, randomizer);
}
return num_weights_;
}
// Converts a float network to an int network.
void LSTM::ConvertToInt() {
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].ConvertToInt();
}
if (softmax_ != NULL) {
softmax_->ConvertToInt();
}
}
// Sets up the network for training using the given weight_range.
void LSTM::DebugWeights() {
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
STRING msg = name_;
msg.add_str_int(" Gate weights ", w);
gate_weights_[w].Debug2D(msg.string());
}
if (softmax_ != NULL) {
softmax_->DebugWeights();
}
}
// Writes to the given file. Returns false in case of error.
bool LSTM::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
if (fp->FWrite(&na_, sizeof(na_), 1) != 1) return false;
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
}
if (softmax_ != NULL && !softmax_->Serialize(fp)) return false;
return true;
}
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool LSTM::DeSerialize(bool swap, TFile* fp) {
if (fp->FRead(&na_, sizeof(na_), 1) != 1) return false;
if (swap) ReverseN(&na_, sizeof(na_));
if (type_ == NT_LSTM_SOFTMAX) {
nf_ = no_;
} else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
nf_ = IntCastRounded(ceil(log2(no_)));
} else {
nf_ = 0;
}
is_2d_ = false;
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
if (!gate_weights_[w].DeSerialize(IsTraining(), swap, fp)) return false;
if (w == CI) {
ns_ = gate_weights_[CI].NumOutputs();
is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
}
}
delete softmax_;
if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
softmax_ =
reinterpret_cast<FullyConnected*>(Network::CreateFromFile(swap, fp));
if (softmax_ == NULL) return false;
} else {
softmax_ = NULL;
}
return true;
}
// Runs forward propagation of activations on the input line.
// See NetworkCpp for a detailed discussion of the arguments.
void LSTM::Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
input_map_ = input.stride_map();
input_width_ = input.Width();
if (softmax_ != NULL)
output->ResizeFloat(input, no_);
else if (type_ == NT_LSTM_SUMMARY)
output->ResizeXTo1(input, no_);
else
output->Resize(input, no_);
ResizeForward(input);
// Temporary storage of forward computation for each gate.
NetworkScratch::FloatVec temp_lines[WT_COUNT];
for (int i = 0; i < WT_COUNT; ++i) temp_lines[i].Init(ns_, scratch);
// Single timestep buffers for the current/recurrent output and state.
NetworkScratch::FloatVec curr_state, curr_output;
curr_state.Init(ns_, scratch);
ZeroVector<double>(ns_, curr_state);
curr_output.Init(ns_, scratch);
ZeroVector<double>(ns_, curr_output);
// Rotating buffers of width buf_width allow storage of the state and output
// for the other dimension, used only when working in true 2D mode. The width
// is enough to hold an entire strip of the major direction.
int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
GenericVector<NetworkScratch::FloatVec> states, outputs;
if (Is2D()) {
states.init_to_size(buf_width, NetworkScratch::FloatVec());
outputs.init_to_size(buf_width, NetworkScratch::FloatVec());
for (int i = 0; i < buf_width; ++i) {
states[i].Init(ns_, scratch);
ZeroVector<double>(ns_, states[i]);
outputs[i].Init(ns_, scratch);
ZeroVector<double>(ns_, outputs[i]);
}
}
// Used only if a softmax LSTM.
NetworkScratch::FloatVec softmax_output;
NetworkScratch::IO int_output;
if (softmax_ != NULL) {
softmax_output.Init(no_, scratch);
ZeroVector<double>(no_, softmax_output);
if (input.int_mode()) int_output.Resize2d(true, 1, ns_, scratch);
softmax_->SetupForward(input, NULL);
}
NetworkScratch::FloatVec curr_input;
curr_input.Init(na_, scratch);
StrideMap::Index src_index(input_map_);
// Used only by NT_LSTM_SUMMARY.
StrideMap::Index dest_index(output->stride_map());
do {
int t = src_index.t();
// True if there is a valid old state for the 2nd dimension.
bool valid_2d = Is2D();
if (valid_2d) {
StrideMap::Index dim_index(src_index);
if (!dim_index.AddOffset(-1, FD_HEIGHT)) valid_2d = false;
}
// Index of the 2-D revolving buffers (outputs, states).
int mod_t = Modulo(t, buf_width); // Current timestep.
// Setup the padded input in source.
source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
if (softmax_ != NULL) {
source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
}
source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
if (Is2D())
source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
if (!source_.int_mode()) source_.ReadTimeStep(t, curr_input);
// Matrix multiply the inputs with the source.
PARALLEL_IF_OPENMP(GFS)
// It looks inefficient to create the threads on each t iteration, but the
// alternative of putting the parallel outside the t loop, a single around
// the t-loop and then tasks in place of the sections is a *lot* slower.
// Cell inputs.
if (source_.int_mode())
gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
else
gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
FuncInplace<GFunc>(ns_, temp_lines[CI]);
SECTION_IF_OPENMP
// Input Gates.
if (source_.int_mode())
gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
else
gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
FuncInplace<FFunc>(ns_, temp_lines[GI]);
SECTION_IF_OPENMP
// 1-D forget gates.
if (source_.int_mode())
gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
else
gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
FuncInplace<FFunc>(ns_, temp_lines[GF1]);
// 2-D forget gates.
if (Is2D()) {
if (source_.int_mode())
gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
else
gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
FuncInplace<FFunc>(ns_, temp_lines[GFS]);
}
SECTION_IF_OPENMP
// Output gates.
if (source_.int_mode())
gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
else
gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
FuncInplace<FFunc>(ns_, temp_lines[GO]);
END_PARALLEL_IF_OPENMP
// Apply forget gate to state.
MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
if (Is2D()) {
// Max-pool the forget gates (in 2-d) instead of blindly adding.
inT8* which_fg_col = which_fg_[t];
memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
if (valid_2d) {
const double* stepped_state = states[mod_t];
for (int i = 0; i < ns_; ++i) {
if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
which_fg_col[i] = 2;
}
}
}
}
MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
// Clip curr_state to a sane range.
ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
if (IsTraining()) {
// Save the gate node values.
node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
}
FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
if (IsTraining()) state_.WriteTimeStep(t, curr_state);
if (softmax_ != NULL) {
if (input.int_mode()) {
int_output->WriteTimeStep(0, curr_output);
softmax_->ForwardTimeStep(NULL, int_output->i(0), t, softmax_output);
} else {
softmax_->ForwardTimeStep(curr_output, NULL, t, softmax_output);
}
output->WriteTimeStep(t, softmax_output);
if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
CodeInBinary(no_, nf_, softmax_output);
}
} else if (type_ == NT_LSTM_SUMMARY) {
// Output only at the end of a row.
if (src_index.IsLast(FD_WIDTH)) {
output->WriteTimeStep(dest_index.t(), curr_output);
dest_index.Increment();
}
} else {
output->WriteTimeStep(t, curr_output);
}
// Save states for use by the 2nd dimension only if needed.
if (Is2D()) {
CopyVector(ns_, curr_state, states[mod_t]);
CopyVector(ns_, curr_output, outputs[mod_t]);
}
// Always zero the states at the end of every row, but only for the major
// direction. The 2-D state remains intact.
if (src_index.IsLast(FD_WIDTH)) {
ZeroVector<double>(ns_, curr_state);
ZeroVector<double>(ns_, curr_output);
}
} while (src_index.Increment());
#if DEBUG_DETAIL > 0
tprintf("Source:%s\n", name_.string());
source_.Print(10);
tprintf("State:%s\n", name_.string());
state_.Print(10);
tprintf("Output:%s\n", name_.string());
output->Print(10);
#endif
if (debug) DisplayForward(*output);
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool LSTM::Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
if (debug) DisplayBackward(fwd_deltas);
back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
// ======Scratch space.======
// Output errors from deltas with recurrence from sourceerr.
NetworkScratch::FloatVec outputerr;
outputerr.Init(ns_, scratch);
// Recurrent error in the state/source.
NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
curr_stateerr.Init(ns_, scratch);
curr_sourceerr.Init(na_, scratch);
ZeroVector<double>(ns_, curr_stateerr);
ZeroVector<double>(na_, curr_sourceerr);
// Errors in the gates.
NetworkScratch::FloatVec gate_errors[WT_COUNT];
for (int g = 0; g < WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
// Rotating buffers of width buf_width allow storage of the recurrent time-
// steps used only for true 2-D. Stores one full strip of the major direction.
int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
GenericVector<NetworkScratch::FloatVec> stateerr, sourceerr;
if (Is2D()) {
stateerr.init_to_size(buf_width, NetworkScratch::FloatVec());
sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec());
for (int t = 0; t < buf_width; ++t) {
stateerr[t].Init(ns_, scratch);
sourceerr[t].Init(na_, scratch);
ZeroVector<double>(ns_, stateerr[t]);
ZeroVector<double>(na_, sourceerr[t]);
}
}
// Parallel-generated sourceerr from each of the gates.
NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
for (int w = 0; w < WT_COUNT; ++w)
sourceerr_temps[w].Init(na_, scratch);
int width = input_width_;
// Transposed gate errors stored over all timesteps for sum outer.
NetworkScratch::GradientStore gate_errors_t[WT_COUNT];
for (int w = 0; w < WT_COUNT; ++w) {
gate_errors_t[w].Init(ns_, width, scratch);
}
// Used only if softmax_ != NULL.
NetworkScratch::FloatVec softmax_errors;
NetworkScratch::GradientStore softmax_errors_t;
if (softmax_ != NULL) {
softmax_errors.Init(no_, scratch);
softmax_errors_t.Init(no_, width, scratch);
}
double state_clip = Is2D() ? 9.0 : 4.0;
#if DEBUG_DETAIL > 1
tprintf("fwd_deltas:%s\n", name_.string());
fwd_deltas.Print(10);
#endif
StrideMap::Index dest_index(input_map_);
dest_index.InitToLast();
// Used only by NT_LSTM_SUMMARY.
StrideMap::Index src_index(fwd_deltas.stride_map());
src_index.InitToLast();
do {
int t = dest_index.t();
bool at_last_x = dest_index.IsLast(FD_WIDTH);
// up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
// valid if >= 0, which is true if 2d and not on the top/bottom.
int up_pos = -1;
int down_pos = -1;
if (Is2D()) {
if (dest_index.index(FD_HEIGHT) > 0) {
StrideMap::Index up_index(dest_index);
if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t();
}
if (!dest_index.IsLast(FD_HEIGHT)) {
StrideMap::Index down_index(dest_index);
if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t();
}
}
// Index of the 2-D revolving buffers (sourceerr, stateerr).
int mod_t = Modulo(t, buf_width); // Current timestep.
// Zero the state in the major direction only at the end of every row.
if (at_last_x) {
ZeroVector<double>(na_, curr_sourceerr);
ZeroVector<double>(ns_, curr_stateerr);
}
// Setup the outputerr.
if (type_ == NT_LSTM_SUMMARY) {
if (dest_index.IsLast(FD_WIDTH)) {
fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
src_index.Decrement();
} else {
ZeroVector<double>(ns_, outputerr);
}
} else if (softmax_ == NULL) {
fwd_deltas.ReadTimeStep(t, outputerr);
} else {
softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors,
softmax_errors_t.get(), outputerr);
}
if (!at_last_x)
AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
if (down_pos >= 0)
AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
// Apply the 1-d forget gates.
if (!at_last_x) {
const float* next_node_gf1 = node_values_[GF1].f(t + 1);
for (int i = 0; i < ns_; ++i) {
curr_stateerr[i] *= next_node_gf1[i];
}
}
if (Is2D() && t + 1 < width) {
for (int i = 0; i < ns_; ++i) {
if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
}
if (down_pos >= 0) {
const float* right_node_gfs = node_values_[GFS].f(down_pos);
const double* right_stateerr = stateerr[mod_t];
for (int i = 0; i < ns_; ++i) {
if (which_fg_[down_pos][i] == 2) {
curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
}
}
}
}
state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr,
curr_stateerr);
// Clip stateerr_ to a sane range.
ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
#if DEBUG_DETAIL > 1
if (t + 10 > width) {
tprintf("t=%d, stateerr=", t);
for (int i = 0; i < ns_; ++i)
tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i],
curr_sourceerr[ni_ + nf_ + i]);
tprintf("\n");
}
#endif
// Matrix multiply to get the source errors.
PARALLEL_IF_OPENMP(GFS)
// Cell inputs.
node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t,
curr_stateerr, gate_errors[CI]);
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
SECTION_IF_OPENMP
// Input Gates.
node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t,
curr_stateerr, gate_errors[GI]);
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
SECTION_IF_OPENMP
// 1-D forget Gates.
if (t > 0) {
node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
gate_errors[GF1]);
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1],
sourceerr_temps[GF1]);
} else {
memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
}
gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
// 2-D forget Gates.
if (up_pos >= 0) {
node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
gate_errors[GFS]);
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS],
sourceerr_temps[GFS]);
} else {
memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
}
if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
SECTION_IF_OPENMP
// Output gates.
state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr,
gate_errors[GO]);
ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
END_PARALLEL_IF_OPENMP
SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
curr_sourceerr);
back_deltas->WriteTimeStep(t, curr_sourceerr);
// Save states for use by the 2nd dimension only if needed.
if (Is2D()) {
CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
}
} while (dest_index.Decrement());
#if DEBUG_DETAIL > 2
for (int w = 0; w < WT_COUNT; ++w) {
tprintf("%s gate errors[%d]\n", name_.string(), w);
gate_errors_t[w].get()->PrintUnTransposed(10);
}
#endif
// Transposed source_ used to speed-up SumOuter.
NetworkScratch::GradientStore source_t, state_t;
source_t.Init(na_, width, scratch);
source_.Transpose(source_t.get());
state_t.Init(ns_, width, scratch);
state_.Transpose(state_t.get());
#ifdef _OPENMP
#pragma omp parallel for num_threads(GFS) if (!Is2D())
#endif
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
}
if (softmax_ != NULL) {
softmax_->FinishBackward(*softmax_errors_t);
}
if (needs_to_backprop_) {
// Normalize the inputerr in back_deltas.
back_deltas->CopyWithNormalization(*back_deltas, fwd_deltas);
return true;
}
return false;
}
// Updates the weights using the given learning rate and momentum.
// num_samples is the quotient to be used in the adagrad computation iff
// use_ada_grad_ is true.
void LSTM::Update(float learning_rate, float momentum, int num_samples) {
#if DEBUG_DETAIL > 3
PrintW();
#endif
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].Update(learning_rate, momentum, num_samples);
}
if (softmax_ != NULL) {
softmax_->Update(learning_rate, momentum, num_samples);
}
#if DEBUG_DETAIL > 3
PrintDW();
#endif
}
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// *changed.
void LSTM::CountAlternators(const Network& other, double* same,
double* changed) const {
ASSERT_HOST(other.type() == type_);
const LSTM* lstm = reinterpret_cast<const LSTM*>(&other);
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
}
if (softmax_ != NULL) {
softmax_->CountAlternators(*lstm->softmax_, same, changed);
}
}
// Prints the weights for debug purposes.
void LSTM::PrintW() {
tprintf("Weight state:%s\n", name_.string());
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
tprintf("Gate %d, inputs\n", w);
for (int i = 0; i < ni_; ++i) {
tprintf("Row %d:", i);
for (int s = 0; s < ns_; ++s)
tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
tprintf("\n");
}
tprintf("Gate %d, outputs\n", w);
for (int i = ni_; i < ni_ + ns_; ++i) {
tprintf("Row %d:", i - ni_);
for (int s = 0; s < ns_; ++s)
tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
tprintf("\n");
}
tprintf("Gate %d, bias\n", w);
for (int s = 0; s < ns_; ++s)
tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
tprintf("\n");
}
}
// Prints the weight deltas for debug purposes.
void LSTM::PrintDW() {
tprintf("Delta state:%s\n", name_.string());
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
tprintf("Gate %d, inputs\n", w);
for (int i = 0; i < ni_; ++i) {
tprintf("Row %d:", i);
for (int s = 0; s < ns_; ++s)
tprintf(" %g", gate_weights_[w].GetDW(s, i));
tprintf("\n");
}
tprintf("Gate %d, outputs\n", w);
for (int i = ni_; i < ni_ + ns_; ++i) {
tprintf("Row %d:", i - ni_);
for (int s = 0; s < ns_; ++s)
tprintf(" %g", gate_weights_[w].GetDW(s, i));
tprintf("\n");
}
tprintf("Gate %d, bias\n", w);
for (int s = 0; s < ns_; ++s)
tprintf(" %g", gate_weights_[w].GetDW(s, na_));
tprintf("\n");
}
}
// Resizes forward data to cope with an input image of the given width.
void LSTM::ResizeForward(const NetworkIO& input) {
source_.Resize(input, na_);
which_fg_.ResizeNoInit(input.Width(), ns_);
if (IsTraining()) {
state_.ResizeFloat(input, ns_);
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
node_values_[w].ResizeFloat(input, ns_);
}
}
}
} // namespace tesseract.