diff --git a/lstm/tfnetwork.cpp b/lstm/tfnetwork.cpp index abc8ba4a1..68a7ca930 100644 --- a/lstm/tfnetwork.cpp +++ b/lstm/tfnetwork.cpp @@ -91,33 +91,36 @@ void TFNetwork::Forward(bool debug, const NetworkIO& input, // objects. if (!model_proto_.image_widths().empty()) { TensorShape size_shape{1}; - Tensor width_tensor(tensorflow::DT_INT32, size_shape); - auto eigen_wtensor = width_tensor.flat(); + Tensor width_tensor(tensorflow::DT_INT64, size_shape); + auto eigen_wtensor = width_tensor.flat(); *eigen_wtensor.data() = stride_map.Size(FD_WIDTH); tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor); } if (!model_proto_.image_heights().empty()) { TensorShape size_shape{1}; - Tensor height_tensor(tensorflow::DT_INT32, size_shape); - auto eigen_htensor = height_tensor.flat(); + Tensor height_tensor(tensorflow::DT_INT64, size_shape); + auto eigen_htensor = height_tensor.flat(); *eigen_htensor.data() = stride_map.Size(FD_HEIGHT); tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor); } std::vector target_layers = {model_proto_.output_layer()}; std::vector outputs; Status s = session_->Run(tf_inputs, target_layers, {}, &outputs); + if (!s.ok()) tprintf("session->Run failed:%s\n", s.error_message().c_str()); ASSERT_HOST(s.ok()); ASSERT_HOST(outputs.size() == 1); const Tensor& output_tensor = outputs[0]; // Check the dimensions of the output. - ASSERT_HOST(output_tensor.shape().dims() == 2); - int output_dim0 = output_tensor.shape().dim_size(0); - int output_dim1 = output_tensor.shape().dim_size(1); - ASSERT_HOST(output_dim1 == output_shape_.depth()); - output->Resize2d(false, output_dim0, output_dim1); + ASSERT_HOST(output_tensor.shape().dims() == 3); + int output_batch = output_tensor.shape().dim_size(0); + int output_steps = output_tensor.shape().dim_size(1); + int output_depth = output_tensor.shape().dim_size(2); + ASSERT_HOST(output_batch == 1); + ASSERT_HOST(output_depth == output_shape_.depth()); + output->Resize2d(false, output_steps, output_depth); auto eigen_output = output_tensor.flat(); memcpy(output->f(0), eigen_output.data(), - output_dim0 * output_dim1 * sizeof(output->f(0)[0])); + output_steps * output_depth * sizeof(output->f(0)[0])); } int TFNetwork::InitFromProto() {