tesseract/unittest/networkio_test.cc

216 lines
6.8 KiB
C++
Raw Normal View History

// (C) Copyright 2017, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "networkio.h"
2021-03-13 05:06:34 +08:00
#include "include_gunit.h"
#include "stridemap.h"
#ifdef INCLUDE_TENSORFLOW
2021-03-13 05:06:34 +08:00
# include <tensorflow/compiler/xla/array2d.h> // for xla::Array2D
#endif
namespace tesseract {
class NetworkioTest : public ::testing::Test {
2021-03-13 05:06:34 +08:00
protected:
void SetUp() override {
std::locale::global(std::locale(""));
}
#ifdef INCLUDE_TENSORFLOW
// Sets up an Array2d object of the given size, initialized to increasing
// values starting with start.
std::unique_ptr<xla::Array2D<int>> SetupArray(int ysize, int xsize, int start) {
std::unique_ptr<xla::Array2D<int>> a(new xla::Array2D<int>(ysize, xsize));
int value = start;
for (int y = 0; y < ysize; ++y) {
for (int x = 0; x < xsize; ++x) {
(*a)(y, x) = value++;
}
}
return a;
}
// Sets up a NetworkIO with a batch of 2 "images" of known values.
2021-03-13 05:06:34 +08:00
void SetupNetworkIO(NetworkIO *nio) {
std::vector<std::unique_ptr<xla::Array2D<int>>> arrays;
arrays.push_back(SetupArray(3, 4, 0));
arrays.push_back(SetupArray(4, 5, 12));
std::vector<std::pair<int, int>> h_w_sizes;
for (size_t i = 0; i < arrays.size(); ++i) {
2021-03-13 05:06:34 +08:00
h_w_sizes.emplace_back(arrays[i].get()->height(), arrays[i].get()->width());
}
StrideMap stride_map;
stride_map.SetStride(h_w_sizes);
nio->ResizeToMap(true, stride_map, 2);
// Iterate over the map, setting nio's contents from the arrays.
StrideMap::Index index(stride_map);
do {
2021-03-13 05:06:34 +08:00
int value = (*arrays[index.index(FD_BATCH)])(index.index(FD_HEIGHT), index.index(FD_WIDTH));
nio->SetPixel(index.t(), 0, 128 + value, 0.0f, 128.0f);
nio->SetPixel(index.t(), 1, 128 - value, 0.0f, 128.0f);
} while (index.Increment());
}
#endif
};
// Tests that the initialization via SetPixel works and the resize correctly
// fills with zero where image sizes don't match.
TEST_F(NetworkioTest, InitWithZeroFill) {
#ifdef INCLUDE_TENSORFLOW
NetworkIO nio;
nio.Resize2d(true, 32, 2);
int width = nio.Width();
for (int t = 0; t < width; ++t) {
nio.SetPixel(t, 0, 0, 0.0f, 128.0f);
nio.SetPixel(t, 1, 0, 0.0f, 128.0f);
}
// The initialization will wipe out all previously set values.
SetupNetworkIO(&nio);
nio.ZeroInvalidElements();
StrideMap::Index index(nio.stride_map());
int next_t = 0;
int pos = 0;
do {
int t = index.t();
// The indexed values just increase monotonically.
int value = nio.i(t)[0];
EXPECT_EQ(value, pos);
value = nio.i(t)[1];
EXPECT_EQ(value, -pos);
// When we skip t values, the data is always 0.
while (next_t < t) {
EXPECT_EQ(nio.i(next_t)[0], 0);
EXPECT_EQ(nio.i(next_t)[1], 0);
++next_t;
}
++pos;
++next_t;
} while (index.Increment());
EXPECT_EQ(pos, 32);
EXPECT_EQ(next_t, 40);
#else
LOG(INFO) << "Skip test because of missing xla::Array2D";
GTEST_SKIP();
#endif
}
// Tests that CopyWithYReversal works.
TEST_F(NetworkioTest, CopyWithYReversal) {
#ifdef INCLUDE_TENSORFLOW
NetworkIO nio;
SetupNetworkIO(&nio);
NetworkIO copy;
copy.CopyWithYReversal(nio);
StrideMap::Index index(copy.stride_map());
int next_t = 0;
int pos = 0;
2021-03-13 05:06:34 +08:00
std::vector<int> expected_values = {8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2,
3, 27, 28, 29, 30, 31, 22, 23, 24, 25, 26,
17, 18, 19, 20, 21, 12, 13, 14, 15, 16};
do {
int t = index.t();
// The indexed values match the expected values.
int value = copy.i(t)[0];
EXPECT_EQ(value, expected_values[pos]);
value = copy.i(t)[1];
EXPECT_EQ(value, -expected_values[pos]);
// When we skip t values, the data is always 0.
while (next_t < t) {
EXPECT_EQ(copy.i(next_t)[0], 0) << "Failure t = " << next_t;
EXPECT_EQ(copy.i(next_t)[1], 0) << "Failure t = " << next_t;
++next_t;
}
++pos;
++next_t;
} while (index.Increment());
EXPECT_EQ(pos, 32);
EXPECT_EQ(next_t, 40);
#else
LOG(INFO) << "Skip test because of missing xla::Array2D";
GTEST_SKIP();
#endif
}
// Tests that CopyWithXReversal works.
TEST_F(NetworkioTest, CopyWithXReversal) {
#ifdef INCLUDE_TENSORFLOW
NetworkIO nio;
SetupNetworkIO(&nio);
NetworkIO copy;
copy.CopyWithXReversal(nio);
StrideMap::Index index(copy.stride_map());
int next_t = 0;
int pos = 0;
2021-03-13 05:06:34 +08:00
std::vector<int> expected_values = {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9,
8, 16, 15, 14, 13, 12, 21, 20, 19, 18, 17,
26, 25, 24, 23, 22, 31, 30, 29, 28, 27};
do {
int t = index.t();
// The indexed values match the expected values.
int value = copy.i(t)[0];
EXPECT_EQ(value, expected_values[pos]);
value = copy.i(t)[1];
EXPECT_EQ(value, -expected_values[pos]);
// When we skip t values, the data is always 0.
while (next_t < t) {
EXPECT_EQ(copy.i(next_t)[0], 0) << "Failure t = " << next_t;
EXPECT_EQ(copy.i(next_t)[1], 0) << "Failure t = " << next_t;
++next_t;
}
++pos;
++next_t;
} while (index.Increment());
EXPECT_EQ(pos, 32);
EXPECT_EQ(next_t, 40);
#else
LOG(INFO) << "Skip test because of missing xla::Array2D";
GTEST_SKIP();
#endif
}
// Tests that CopyWithXYTranspose works.
TEST_F(NetworkioTest, CopyWithXYTranspose) {
#ifdef INCLUDE_TENSORFLOW
NetworkIO nio;
SetupNetworkIO(&nio);
NetworkIO copy;
copy.CopyWithXYTranspose(nio);
StrideMap::Index index(copy.stride_map());
int next_t = 0;
int pos = 0;
2021-03-13 05:06:34 +08:00
std::vector<int> expected_values = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7,
11, 12, 17, 22, 27, 13, 18, 23, 28, 14, 19,
24, 29, 15, 20, 25, 30, 16, 21, 26, 31};
do {
int t = index.t();
// The indexed values match the expected values.
int value = copy.i(t)[0];
EXPECT_EQ(value, expected_values[pos]);
value = copy.i(t)[1];
EXPECT_EQ(value, -expected_values[pos]);
// When we skip t values, the data is always 0.
while (next_t < t) {
EXPECT_EQ(copy.i(next_t)[0], 0);
EXPECT_EQ(copy.i(next_t)[1], 0);
++next_t;
}
++pos;
++next_t;
} while (index.Increment());
EXPECT_EQ(pos, 32);
EXPECT_EQ(next_t, 40);
#else
LOG(INFO) << "Skip test because of missing xla::Array2D";
GTEST_SKIP();
#endif
}
2021-03-13 05:06:34 +08:00
} // namespace tesseract