Optimize LSTM code for builds without OpenMP

The constant value kNumThreads is not only used to configure the number
of threads but also to allocate vectors used in those threads.

There is only a single thread without OpenMP, so it is sufficient to
allocate vectors with only one element in that case.

Replace also the upper limit in the for loops by the known vector size.

Signed-off-by: Stefan Weil <sw@weilnetz.de>
This commit is contained in:
Stefan Weil 2017-05-22 10:13:53 +02:00
parent 5a06417eb2
commit 15b3596ec4

View File

@ -28,7 +28,11 @@
#include "networkscratch.h" #include "networkscratch.h"
// Number of threads to use for parallel calculation of Forward and Backward. // Number of threads to use for parallel calculation of Forward and Backward.
#ifdef _OPENMP
const int kNumThreads = 4; const int kNumThreads = 4;
#else
const int kNumThreads = 1;
#endif
namespace tesseract { namespace tesseract {
@ -117,7 +121,7 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input,
temp_lines.init_to_size(kNumThreads, NetworkScratch::FloatVec()); temp_lines.init_to_size(kNumThreads, NetworkScratch::FloatVec());
GenericVector<NetworkScratch::FloatVec> curr_input; GenericVector<NetworkScratch::FloatVec> curr_input;
curr_input.init_to_size(kNumThreads, NetworkScratch::FloatVec()); curr_input.init_to_size(kNumThreads, NetworkScratch::FloatVec());
for (int i = 0; i < temp_lines.size(); ++i) { for (int i = 0; i < kNumThreads; ++i) {
temp_lines[i].Init(no_, scratch); temp_lines[i].Init(no_, scratch);
curr_input[i].Init(ni_, scratch); curr_input[i].Init(ni_, scratch);
} }
@ -208,7 +212,7 @@ bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,
back_deltas->Resize(fwd_deltas, ni_); back_deltas->Resize(fwd_deltas, ni_);
GenericVector<NetworkScratch::FloatVec> errors; GenericVector<NetworkScratch::FloatVec> errors;
errors.init_to_size(kNumThreads, NetworkScratch::FloatVec()); errors.init_to_size(kNumThreads, NetworkScratch::FloatVec());
for (int i = 0; i < errors.size(); ++i) errors[i].Init(no_, scratch); for (int i = 0; i < kNumThreads; ++i) errors[i].Init(no_, scratch);
GenericVector<NetworkScratch::FloatVec> temp_backprops; GenericVector<NetworkScratch::FloatVec> temp_backprops;
if (needs_to_backprop_) { if (needs_to_backprop_) {
temp_backprops.init_to_size(kNumThreads, NetworkScratch::FloatVec()); temp_backprops.init_to_size(kNumThreads, NetworkScratch::FloatVec());