Merge pull request #1455 from stweil/cov

Overload method ForwardTimeStep (CID 1385636 Explicit null dereferenced)
This commit is contained in:
zdenop 2018-04-09 11:53:55 +02:00 committed by GitHub
commit 4b50f3f46f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 18 deletions

View File

@ -147,15 +147,12 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input,
int thread_id = 0;
#endif
double* temp_line = temp_lines[thread_id];
const double* d_input = nullptr;
const int8_t* i_input = nullptr;
if (input.int_mode()) {
i_input = input.i(t);
ForwardTimeStep(input.i(t), t, temp_line);
} else {
input.ReadTimeStep(t, curr_input[thread_id]);
d_input = curr_input[thread_id];
ForwardTimeStep(curr_input[thread_id], t, temp_line);
}
ForwardTimeStep(d_input, i_input, t, temp_line);
output->WriteTimeStep(t, temp_line);
if (IsTraining() && type_ != NT_SOFTMAX) {
acts_.CopyTimeStepFrom(t, *output, t);
@ -188,15 +185,7 @@ void FullyConnected::SetupForward(const NetworkIO& input,
}
}
void FullyConnected::ForwardTimeStep(const double* d_input, const int8_t* i_input,
int t, double* output_line) {
// input is copied to source_ line-by-line for cache coherency.
if (IsTraining() && external_source_ == nullptr && d_input != nullptr)
source_t_.WriteStrided(t, d_input);
if (d_input != nullptr)
weights_.MatrixDotVector(d_input, output_line);
else
weights_.MatrixDotVector(i_input, output_line);
void FullyConnected::ForwardTimeStep(int t, double* output_line) {
if (type_ == NT_TANH) {
FuncInplace<GFunc>(no_, output_line);
} else if (type_ == NT_LOGISTIC) {
@ -214,6 +203,22 @@ void FullyConnected::ForwardTimeStep(const double* d_input, const int8_t* i_inpu
}
}
void FullyConnected::ForwardTimeStep(const double* d_input,
int t, double* output_line) {
// input is copied to source_ line-by-line for cache coherency.
if (IsTraining() && external_source_ == NULL)
source_t_.WriteStrided(t, d_input);
weights_.MatrixDotVector(d_input, output_line);
ForwardTimeStep(t, output_line);
}
void FullyConnected::ForwardTimeStep(const int8_t* i_input,
int t, double* output_line) {
// input is copied to source_ line-by-line for cache coherency.
weights_.MatrixDotVector(i_input, output_line);
ForwardTimeStep(t, output_line);
}
// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,

View File

@ -91,8 +91,9 @@ class FullyConnected : public Network {
// Components of Forward so FullyConnected can be reused inside LSTM.
void SetupForward(const NetworkIO& input,
const TransposedArray* input_transpose);
void ForwardTimeStep(const double* d_input, const int8_t* i_input, int t,
double* output_line);
void ForwardTimeStep(int t, double* output_line);
void ForwardTimeStep(const double* d_input, int t, double* output_line);
void ForwardTimeStep(const int8_t* i_input, int t, double* output_line);
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.

View File

@ -396,9 +396,9 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
if (softmax_ != nullptr) {
if (input.int_mode()) {
int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
softmax_->ForwardTimeStep(nullptr, int_output->i(0), t, softmax_output);
softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
} else {
softmax_->ForwardTimeStep(curr_output, nullptr, t, softmax_output);
softmax_->ForwardTimeStep(curr_output, t, softmax_output);
}
output->WriteTimeStep(t, softmax_output);
if (type_ == NT_LSTM_SOFTMAX_ENCODED) {