mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2025-06-11 20:53:24 +08:00
Merge pull request #1455 from stweil/cov
Overload method ForwardTimeStep (CID 1385636 Explicit null dereferenced)
This commit is contained in:
commit
4b50f3f46f
@ -147,15 +147,12 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input,
|
||||
int thread_id = 0;
|
||||
#endif
|
||||
double* temp_line = temp_lines[thread_id];
|
||||
const double* d_input = nullptr;
|
||||
const int8_t* i_input = nullptr;
|
||||
if (input.int_mode()) {
|
||||
i_input = input.i(t);
|
||||
ForwardTimeStep(input.i(t), t, temp_line);
|
||||
} else {
|
||||
input.ReadTimeStep(t, curr_input[thread_id]);
|
||||
d_input = curr_input[thread_id];
|
||||
ForwardTimeStep(curr_input[thread_id], t, temp_line);
|
||||
}
|
||||
ForwardTimeStep(d_input, i_input, t, temp_line);
|
||||
output->WriteTimeStep(t, temp_line);
|
||||
if (IsTraining() && type_ != NT_SOFTMAX) {
|
||||
acts_.CopyTimeStepFrom(t, *output, t);
|
||||
@ -188,15 +185,7 @@ void FullyConnected::SetupForward(const NetworkIO& input,
|
||||
}
|
||||
}
|
||||
|
||||
void FullyConnected::ForwardTimeStep(const double* d_input, const int8_t* i_input,
|
||||
int t, double* output_line) {
|
||||
// input is copied to source_ line-by-line for cache coherency.
|
||||
if (IsTraining() && external_source_ == nullptr && d_input != nullptr)
|
||||
source_t_.WriteStrided(t, d_input);
|
||||
if (d_input != nullptr)
|
||||
weights_.MatrixDotVector(d_input, output_line);
|
||||
else
|
||||
weights_.MatrixDotVector(i_input, output_line);
|
||||
void FullyConnected::ForwardTimeStep(int t, double* output_line) {
|
||||
if (type_ == NT_TANH) {
|
||||
FuncInplace<GFunc>(no_, output_line);
|
||||
} else if (type_ == NT_LOGISTIC) {
|
||||
@ -214,6 +203,22 @@ void FullyConnected::ForwardTimeStep(const double* d_input, const int8_t* i_inpu
|
||||
}
|
||||
}
|
||||
|
||||
void FullyConnected::ForwardTimeStep(const double* d_input,
|
||||
int t, double* output_line) {
|
||||
// input is copied to source_ line-by-line for cache coherency.
|
||||
if (IsTraining() && external_source_ == NULL)
|
||||
source_t_.WriteStrided(t, d_input);
|
||||
weights_.MatrixDotVector(d_input, output_line);
|
||||
ForwardTimeStep(t, output_line);
|
||||
}
|
||||
|
||||
void FullyConnected::ForwardTimeStep(const int8_t* i_input,
|
||||
int t, double* output_line) {
|
||||
// input is copied to source_ line-by-line for cache coherency.
|
||||
weights_.MatrixDotVector(i_input, output_line);
|
||||
ForwardTimeStep(t, output_line);
|
||||
}
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See NetworkCpp for a detailed discussion of the arguments.
|
||||
bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||
|
@ -91,8 +91,9 @@ class FullyConnected : public Network {
|
||||
// Components of Forward so FullyConnected can be reused inside LSTM.
|
||||
void SetupForward(const NetworkIO& input,
|
||||
const TransposedArray* input_transpose);
|
||||
void ForwardTimeStep(const double* d_input, const int8_t* i_input, int t,
|
||||
double* output_line);
|
||||
void ForwardTimeStep(int t, double* output_line);
|
||||
void ForwardTimeStep(const double* d_input, int t, double* output_line);
|
||||
void ForwardTimeStep(const int8_t* i_input, int t, double* output_line);
|
||||
|
||||
// Runs backward propagation of errors on the deltas line.
|
||||
// See Network for a detailed discussion of the arguments.
|
||||
|
@ -396,9 +396,9 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
|
||||
if (softmax_ != nullptr) {
|
||||
if (input.int_mode()) {
|
||||
int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
|
||||
softmax_->ForwardTimeStep(nullptr, int_output->i(0), t, softmax_output);
|
||||
softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
|
||||
} else {
|
||||
softmax_->ForwardTimeStep(curr_output, nullptr, t, softmax_output);
|
||||
softmax_->ForwardTimeStep(curr_output, t, softmax_output);
|
||||
}
|
||||
output->WriteTimeStep(t, softmax_output);
|
||||
if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
|
||||
|
Loading…
Reference in New Issue
Block a user