Fixed int types for imported tf networks

This commit is contained in:
Ray Smith 2017-05-05 16:42:44 -07:00
parent 4fa463cd71
commit d18931e86e

View File

@ -91,33 +91,36 @@ void TFNetwork::Forward(bool debug, const NetworkIO& input,
// objects. // objects.
if (!model_proto_.image_widths().empty()) { if (!model_proto_.image_widths().empty()) {
TensorShape size_shape{1}; TensorShape size_shape{1};
Tensor width_tensor(tensorflow::DT_INT32, size_shape); Tensor width_tensor(tensorflow::DT_INT64, size_shape);
auto eigen_wtensor = width_tensor.flat<int32>(); auto eigen_wtensor = width_tensor.flat<int64>();
*eigen_wtensor.data() = stride_map.Size(FD_WIDTH); *eigen_wtensor.data() = stride_map.Size(FD_WIDTH);
tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor); tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor);
} }
if (!model_proto_.image_heights().empty()) { if (!model_proto_.image_heights().empty()) {
TensorShape size_shape{1}; TensorShape size_shape{1};
Tensor height_tensor(tensorflow::DT_INT32, size_shape); Tensor height_tensor(tensorflow::DT_INT64, size_shape);
auto eigen_htensor = height_tensor.flat<int32>(); auto eigen_htensor = height_tensor.flat<int64>();
*eigen_htensor.data() = stride_map.Size(FD_HEIGHT); *eigen_htensor.data() = stride_map.Size(FD_HEIGHT);
tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor); tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor);
} }
std::vector<string> target_layers = {model_proto_.output_layer()}; std::vector<string> target_layers = {model_proto_.output_layer()};
std::vector<Tensor> outputs; std::vector<Tensor> outputs;
Status s = session_->Run(tf_inputs, target_layers, {}, &outputs); Status s = session_->Run(tf_inputs, target_layers, {}, &outputs);
if (!s.ok()) tprintf("session->Run failed:%s\n", s.error_message().c_str());
ASSERT_HOST(s.ok()); ASSERT_HOST(s.ok());
ASSERT_HOST(outputs.size() == 1); ASSERT_HOST(outputs.size() == 1);
const Tensor& output_tensor = outputs[0]; const Tensor& output_tensor = outputs[0];
// Check the dimensions of the output. // Check the dimensions of the output.
ASSERT_HOST(output_tensor.shape().dims() == 2); ASSERT_HOST(output_tensor.shape().dims() == 3);
int output_dim0 = output_tensor.shape().dim_size(0); int output_batch = output_tensor.shape().dim_size(0);
int output_dim1 = output_tensor.shape().dim_size(1); int output_steps = output_tensor.shape().dim_size(1);
ASSERT_HOST(output_dim1 == output_shape_.depth()); int output_depth = output_tensor.shape().dim_size(2);
output->Resize2d(false, output_dim0, output_dim1); ASSERT_HOST(output_batch == 1);
ASSERT_HOST(output_depth == output_shape_.depth());
output->Resize2d(false, output_steps, output_depth);
auto eigen_output = output_tensor.flat<float>(); auto eigen_output = output_tensor.flat<float>();
memcpy(output->f(0), eigen_output.data(), memcpy(output->f(0), eigen_output.data(),
output_dim0 * output_dim1 * sizeof(output->f(0)[0])); output_steps * output_depth * sizeof(output->f(0)[0]));
} }
int TFNetwork::InitFromProto() { int TFNetwork::InitFromProto() {