mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2025-06-19 02:39:12 +08:00
Merge pull request #1455 from stweil/cov
Overload method ForwardTimeStep (CID 1385636 Explicit null dereferenced)
This commit is contained in:
commit
4b50f3f46f
@ -147,15 +147,12 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input,
|
|||||||
int thread_id = 0;
|
int thread_id = 0;
|
||||||
#endif
|
#endif
|
||||||
double* temp_line = temp_lines[thread_id];
|
double* temp_line = temp_lines[thread_id];
|
||||||
const double* d_input = nullptr;
|
|
||||||
const int8_t* i_input = nullptr;
|
|
||||||
if (input.int_mode()) {
|
if (input.int_mode()) {
|
||||||
i_input = input.i(t);
|
ForwardTimeStep(input.i(t), t, temp_line);
|
||||||
} else {
|
} else {
|
||||||
input.ReadTimeStep(t, curr_input[thread_id]);
|
input.ReadTimeStep(t, curr_input[thread_id]);
|
||||||
d_input = curr_input[thread_id];
|
ForwardTimeStep(curr_input[thread_id], t, temp_line);
|
||||||
}
|
}
|
||||||
ForwardTimeStep(d_input, i_input, t, temp_line);
|
|
||||||
output->WriteTimeStep(t, temp_line);
|
output->WriteTimeStep(t, temp_line);
|
||||||
if (IsTraining() && type_ != NT_SOFTMAX) {
|
if (IsTraining() && type_ != NT_SOFTMAX) {
|
||||||
acts_.CopyTimeStepFrom(t, *output, t);
|
acts_.CopyTimeStepFrom(t, *output, t);
|
||||||
@ -188,15 +185,7 @@ void FullyConnected::SetupForward(const NetworkIO& input,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FullyConnected::ForwardTimeStep(const double* d_input, const int8_t* i_input,
|
void FullyConnected::ForwardTimeStep(int t, double* output_line) {
|
||||||
int t, double* output_line) {
|
|
||||||
// input is copied to source_ line-by-line for cache coherency.
|
|
||||||
if (IsTraining() && external_source_ == nullptr && d_input != nullptr)
|
|
||||||
source_t_.WriteStrided(t, d_input);
|
|
||||||
if (d_input != nullptr)
|
|
||||||
weights_.MatrixDotVector(d_input, output_line);
|
|
||||||
else
|
|
||||||
weights_.MatrixDotVector(i_input, output_line);
|
|
||||||
if (type_ == NT_TANH) {
|
if (type_ == NT_TANH) {
|
||||||
FuncInplace<GFunc>(no_, output_line);
|
FuncInplace<GFunc>(no_, output_line);
|
||||||
} else if (type_ == NT_LOGISTIC) {
|
} else if (type_ == NT_LOGISTIC) {
|
||||||
@ -214,6 +203,22 @@ void FullyConnected::ForwardTimeStep(const double* d_input, const int8_t* i_inpu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FullyConnected::ForwardTimeStep(const double* d_input,
|
||||||
|
int t, double* output_line) {
|
||||||
|
// input is copied to source_ line-by-line for cache coherency.
|
||||||
|
if (IsTraining() && external_source_ == NULL)
|
||||||
|
source_t_.WriteStrided(t, d_input);
|
||||||
|
weights_.MatrixDotVector(d_input, output_line);
|
||||||
|
ForwardTimeStep(t, output_line);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FullyConnected::ForwardTimeStep(const int8_t* i_input,
|
||||||
|
int t, double* output_line) {
|
||||||
|
// input is copied to source_ line-by-line for cache coherency.
|
||||||
|
weights_.MatrixDotVector(i_input, output_line);
|
||||||
|
ForwardTimeStep(t, output_line);
|
||||||
|
}
|
||||||
|
|
||||||
// Runs backward propagation of errors on the deltas line.
|
// Runs backward propagation of errors on the deltas line.
|
||||||
// See NetworkCpp for a detailed discussion of the arguments.
|
// See NetworkCpp for a detailed discussion of the arguments.
|
||||||
bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,
|
bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,
|
||||||
|
@ -91,8 +91,9 @@ class FullyConnected : public Network {
|
|||||||
// Components of Forward so FullyConnected can be reused inside LSTM.
|
// Components of Forward so FullyConnected can be reused inside LSTM.
|
||||||
void SetupForward(const NetworkIO& input,
|
void SetupForward(const NetworkIO& input,
|
||||||
const TransposedArray* input_transpose);
|
const TransposedArray* input_transpose);
|
||||||
void ForwardTimeStep(const double* d_input, const int8_t* i_input, int t,
|
void ForwardTimeStep(int t, double* output_line);
|
||||||
double* output_line);
|
void ForwardTimeStep(const double* d_input, int t, double* output_line);
|
||||||
|
void ForwardTimeStep(const int8_t* i_input, int t, double* output_line);
|
||||||
|
|
||||||
// Runs backward propagation of errors on the deltas line.
|
// Runs backward propagation of errors on the deltas line.
|
||||||
// See Network for a detailed discussion of the arguments.
|
// See Network for a detailed discussion of the arguments.
|
||||||
|
@ -396,9 +396,9 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
|
|||||||
if (softmax_ != nullptr) {
|
if (softmax_ != nullptr) {
|
||||||
if (input.int_mode()) {
|
if (input.int_mode()) {
|
||||||
int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
|
int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
|
||||||
softmax_->ForwardTimeStep(nullptr, int_output->i(0), t, softmax_output);
|
softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
|
||||||
} else {
|
} else {
|
||||||
softmax_->ForwardTimeStep(curr_output, nullptr, t, softmax_output);
|
softmax_->ForwardTimeStep(curr_output, t, softmax_output);
|
||||||
}
|
}
|
||||||
output->WriteTimeStep(t, softmax_output);
|
output->WriteTimeStep(t, softmax_output);
|
||||||
if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
|
if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
|
||||||
|
Loading…
Reference in New Issue
Block a user