mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2025-01-18 06:30:14 +08:00
Fixed int types for imported tf networks
This commit is contained in:
parent
4fa463cd71
commit
d18931e86e
@ -91,33 +91,36 @@ void TFNetwork::Forward(bool debug, const NetworkIO& input,
|
|||||||
// objects.
|
// objects.
|
||||||
if (!model_proto_.image_widths().empty()) {
|
if (!model_proto_.image_widths().empty()) {
|
||||||
TensorShape size_shape{1};
|
TensorShape size_shape{1};
|
||||||
Tensor width_tensor(tensorflow::DT_INT32, size_shape);
|
Tensor width_tensor(tensorflow::DT_INT64, size_shape);
|
||||||
auto eigen_wtensor = width_tensor.flat<int32>();
|
auto eigen_wtensor = width_tensor.flat<int64>();
|
||||||
*eigen_wtensor.data() = stride_map.Size(FD_WIDTH);
|
*eigen_wtensor.data() = stride_map.Size(FD_WIDTH);
|
||||||
tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor);
|
tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor);
|
||||||
}
|
}
|
||||||
if (!model_proto_.image_heights().empty()) {
|
if (!model_proto_.image_heights().empty()) {
|
||||||
TensorShape size_shape{1};
|
TensorShape size_shape{1};
|
||||||
Tensor height_tensor(tensorflow::DT_INT32, size_shape);
|
Tensor height_tensor(tensorflow::DT_INT64, size_shape);
|
||||||
auto eigen_htensor = height_tensor.flat<int32>();
|
auto eigen_htensor = height_tensor.flat<int64>();
|
||||||
*eigen_htensor.data() = stride_map.Size(FD_HEIGHT);
|
*eigen_htensor.data() = stride_map.Size(FD_HEIGHT);
|
||||||
tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor);
|
tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor);
|
||||||
}
|
}
|
||||||
std::vector<string> target_layers = {model_proto_.output_layer()};
|
std::vector<string> target_layers = {model_proto_.output_layer()};
|
||||||
std::vector<Tensor> outputs;
|
std::vector<Tensor> outputs;
|
||||||
Status s = session_->Run(tf_inputs, target_layers, {}, &outputs);
|
Status s = session_->Run(tf_inputs, target_layers, {}, &outputs);
|
||||||
|
if (!s.ok()) tprintf("session->Run failed:%s\n", s.error_message().c_str());
|
||||||
ASSERT_HOST(s.ok());
|
ASSERT_HOST(s.ok());
|
||||||
ASSERT_HOST(outputs.size() == 1);
|
ASSERT_HOST(outputs.size() == 1);
|
||||||
const Tensor& output_tensor = outputs[0];
|
const Tensor& output_tensor = outputs[0];
|
||||||
// Check the dimensions of the output.
|
// Check the dimensions of the output.
|
||||||
ASSERT_HOST(output_tensor.shape().dims() == 2);
|
ASSERT_HOST(output_tensor.shape().dims() == 3);
|
||||||
int output_dim0 = output_tensor.shape().dim_size(0);
|
int output_batch = output_tensor.shape().dim_size(0);
|
||||||
int output_dim1 = output_tensor.shape().dim_size(1);
|
int output_steps = output_tensor.shape().dim_size(1);
|
||||||
ASSERT_HOST(output_dim1 == output_shape_.depth());
|
int output_depth = output_tensor.shape().dim_size(2);
|
||||||
output->Resize2d(false, output_dim0, output_dim1);
|
ASSERT_HOST(output_batch == 1);
|
||||||
|
ASSERT_HOST(output_depth == output_shape_.depth());
|
||||||
|
output->Resize2d(false, output_steps, output_depth);
|
||||||
auto eigen_output = output_tensor.flat<float>();
|
auto eigen_output = output_tensor.flat<float>();
|
||||||
memcpy(output->f(0), eigen_output.data(),
|
memcpy(output->f(0), eigen_output.data(),
|
||||||
output_dim0 * output_dim1 * sizeof(output->f(0)[0]));
|
output_steps * output_depth * sizeof(output->f(0)[0]));
|
||||||
}
|
}
|
||||||
|
|
||||||
int TFNetwork::InitFromProto() {
|
int TFNetwork::InitFromProto() {
|
||||||
|
Loading…
Reference in New Issue
Block a user