/////////////////////////////////////////////////////////////////////// // File: networkbuilder.cpp // Description: Class to parse the network description language and // build a corresponding network. // Author: Ray Smith // Created: Wed Jul 16 18:35:38 PST 2014 // // (C) Copyright 2014, Google Inc. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /////////////////////////////////////////////////////////////////////// #include "networkbuilder.h" #include "convolve.h" #include "fullyconnected.h" #include "input.h" #include "lstm.h" #include "maxpool.h" #include "network.h" #include "parallel.h" #include "reconfig.h" #include "reversed.h" #include "series.h" #include "unicharset.h" namespace tesseract { // Builds a network with a network_spec in the network description // language, to recognize a character set of num_outputs size. // If append_index is non-negative, then *network must be non-null and the // given network_spec will be appended to *network AFTER append_index, with // the top of the input *network discarded. // Note that network_spec is call by value to allow a non-const char* pointer // into the string for BuildFromString. // net_flags control network behavior according to the NetworkFlags enum. // The resulting network is returned via **network. // Returns false if something failed. bool NetworkBuilder::InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand* randomizer, Network** network) { NetworkBuilder builder(num_outputs); Series* bottom_series = nullptr; StaticShape input_shape; if (append_index >= 0) { // Split the current network after the given append_index. ASSERT_HOST(*network != nullptr && (*network)->type() == NT_SERIES); Series* series = static_cast(*network); Series* top_series = nullptr; series->SplitAt(append_index, &bottom_series, &top_series); if (bottom_series == nullptr || top_series == nullptr) { tprintf("Yikes! Splitting current network failed!!\n"); return false; } input_shape = bottom_series->OutputShape(input_shape); delete top_series; } char* str_ptr = &network_spec[0]; *network = builder.BuildFromString(input_shape, &str_ptr); if (*network == nullptr) return false; (*network)->SetNetworkFlags(net_flags); (*network)->InitWeights(weight_range, randomizer); (*network)->SetupNeedsBackprop(false); if (bottom_series != nullptr) { bottom_series->AppendSeries(*network); *network = bottom_series; } (*network)->CacheXScaleFactor((*network)->XScaleFactor()); return true; } // Helper skips whitespace. static void SkipWhitespace(char** str) { while (**str == ' ' || **str == '\t' || **str == '\n') ++*str; } // Parses the given string and returns a network according to the network // description language in networkbuilder.h Network* NetworkBuilder::BuildFromString(const StaticShape& input_shape, char** str) { SkipWhitespace(str); char code_ch = **str; if (code_ch == '[') { return ParseSeries(input_shape, nullptr, str); } if (input_shape.depth() == 0) { // There must be an input at this point. return ParseInput(str); } switch (code_ch) { case '(': return ParseParallel(input_shape, str); case 'R': return ParseR(input_shape, str); case 'S': return ParseS(input_shape, str); case 'C': return ParseC(input_shape, str); case 'M': return ParseM(input_shape, str); case 'L': return ParseLSTM(input_shape, str); case 'F': return ParseFullyConnected(input_shape, str); case 'O': return ParseOutput(input_shape, str); default: tprintf("Invalid network spec:%s\n", *str); return nullptr; } return nullptr; } // Parses an input specification and returns the result, which may include a // series. Network* NetworkBuilder::ParseInput(char** str) { // There must be an input at this point. int length = 0; int batch, height, width, depth; int num_converted = sscanf(*str, "%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length); StaticShape shape; shape.SetShape(batch, height, width, depth); // num_converted may or may not include the length. if (num_converted != 4 && num_converted != 5) { tprintf("Must specify an input layer as the first layer, not %s!!\n", *str); return nullptr; } *str += length; Input* input = new Input("Input", shape); // We want to allow [rest of net... or [rest of net... so we // have to check explicitly for '[' here. SkipWhitespace(str); if (**str == '[') return ParseSeries(shape, input, str); return input; } // Parses a sequential series of networks, defined by [...]. Network* NetworkBuilder::ParseSeries(const StaticShape& input_shape, Input* input_layer, char** str) { StaticShape shape = input_shape; Series* series = new Series("Series"); ++*str; if (input_layer != nullptr) { series->AddToStack(input_layer); shape = input_layer->OutputShape(shape); } Network* network = nullptr; while (**str != '\0' && **str != ']' && (network = BuildFromString(shape, str)) != nullptr) { shape = network->OutputShape(shape); series->AddToStack(network); } if (**str != ']') { tprintf("Missing ] at end of [Series]!\n"); delete series; return nullptr; } ++*str; return series; } // Parses a parallel set of networks, defined by (...). Network* NetworkBuilder::ParseParallel(const StaticShape& input_shape, char** str) { Parallel* parallel = new Parallel("Parallel", NT_PARALLEL); ++*str; Network* network = nullptr; while (**str != '\0' && **str != ')' && (network = BuildFromString(input_shape, str)) != nullptr) { parallel->AddToStack(network); } if (**str != ')') { tprintf("Missing ) at end of (Parallel)!\n"); delete parallel; return nullptr; } ++*str; return parallel; } // Parses a network that begins with 'R'. Network* NetworkBuilder::ParseR(const StaticShape& input_shape, char** str) { char dir = (*str)[1]; if (dir == 'x' || dir == 'y') { STRING name = "Reverse"; name += dir; *str += 2; Network* network = BuildFromString(input_shape, str); if (network == nullptr) return nullptr; Reversed* rev = new Reversed(name, dir == 'y' ? NT_YREVERSED : NT_XREVERSED); rev->SetNetwork(network); return rev; } int replicas = strtol(*str + 1, str, 10); if (replicas <= 0) { tprintf("Invalid R spec!:%s\n", *str); return nullptr; } Parallel* parallel = new Parallel("Replicated", NT_REPLICATED); char* str_copy = *str; for (int i = 0; i < replicas; ++i) { str_copy = *str; Network* network = BuildFromString(input_shape, &str_copy); if (network == nullptr) { tprintf("Invalid replicated network!\n"); delete parallel; return nullptr; } parallel->AddToStack(network); } *str = str_copy; return parallel; } // Parses a network that begins with 'S'. Network* NetworkBuilder::ParseS(const StaticShape& input_shape, char** str) { int y = strtol(*str + 1, str, 10); if (**str == ',') { int x = strtol(*str + 1, str, 10); if (y <= 0 || x <= 0) { tprintf("Invalid S spec!:%s\n", *str); return nullptr; } return new Reconfig("Reconfig", input_shape.depth(), x, y); } else if (**str == '(') { // TODO(rays) Add Generic reshape. tprintf("Generic reshape not yet implemented!!\n"); return nullptr; } tprintf("Invalid S spec!:%s\n", *str); return nullptr; } // Helper returns the fully-connected type for the character code. static NetworkType NonLinearity(char func) { switch (func) { case 's': return NT_LOGISTIC; case 't': return NT_TANH; case 'r': return NT_RELU; case 'l': return NT_LINEAR; case 'm': return NT_SOFTMAX; case 'p': return NT_POSCLIP; case 'n': return NT_SYMCLIP; default: return NT_NONE; } } // Parses a network that begins with 'C'. Network* NetworkBuilder::ParseC(const StaticShape& input_shape, char** str) { NetworkType type = NonLinearity((*str)[1]); if (type == NT_NONE) { tprintf("Invalid nonlinearity on C-spec!: %s\n", *str); return nullptr; } int y = 0, x = 0, d = 0; if ((y = strtol(*str + 2, str, 10)) <= 0 || **str != ',' || (x = strtol(*str + 1, str, 10)) <= 0 || **str != ',' || (d = strtol(*str + 1, str, 10)) <= 0) { tprintf("Invalid C spec!:%s\n", *str); return nullptr; } if (x == 1 && y == 1) { // No actual convolution. Just a FullyConnected on the current depth, to // be slid over all batch,y,x. return new FullyConnected("Conv1x1", input_shape.depth(), d, type); } Series* series = new Series("ConvSeries"); Convolve* convolve = new Convolve("Convolve", input_shape.depth(), x / 2, y / 2); series->AddToStack(convolve); StaticShape fc_input = convolve->OutputShape(input_shape); series->AddToStack(new FullyConnected("ConvNL", fc_input.depth(), d, type)); return series; } // Parses a network that begins with 'M'. Network* NetworkBuilder::ParseM(const StaticShape& input_shape, char** str) { int y = 0, x = 0; if ((*str)[1] != 'p' || (y = strtol(*str + 2, str, 10)) <= 0 || **str != ',' || (x = strtol(*str + 1, str, 10)) <= 0) { tprintf("Invalid Mp spec!:%s\n", *str); return nullptr; } return new Maxpool("Maxpool", input_shape.depth(), x, y); } // Parses an LSTM network, either individual, bi- or quad-directional. Network* NetworkBuilder::ParseLSTM(const StaticShape& input_shape, char** str) { bool two_d = false; NetworkType type = NT_LSTM; char* spec_start = *str; int chars_consumed = 1; int num_outputs = 0; char key = (*str)[chars_consumed], dir = 'f', dim = 'x'; if (key == 'S') { type = NT_LSTM_SOFTMAX; num_outputs = num_softmax_outputs_; ++chars_consumed; } else if (key == 'E') { type = NT_LSTM_SOFTMAX_ENCODED; num_outputs = num_softmax_outputs_; ++chars_consumed; } else if (key == '2' && (((*str)[2] == 'x' && (*str)[3] == 'y') || ((*str)[2] == 'y' && (*str)[3] == 'x'))) { chars_consumed = 4; dim = (*str)[3]; two_d = true; } else if (key == 'f' || key == 'r' || key == 'b') { dir = key; dim = (*str)[2]; if (dim != 'x' && dim != 'y') { tprintf("Invalid dimension (x|y) in L Spec!:%s\n", *str); return nullptr; } chars_consumed = 3; if ((*str)[chars_consumed] == 's') { ++chars_consumed; type = NT_LSTM_SUMMARY; } } else { tprintf("Invalid direction (f|r|b) in L Spec!:%s\n", *str); return nullptr; } int num_states = strtol(*str + chars_consumed, str, 10); if (num_states <= 0) { tprintf("Invalid number of states in L Spec!:%s\n", *str); return nullptr; } Network* lstm = nullptr; if (two_d) { lstm = BuildLSTMXYQuad(input_shape.depth(), num_states); } else { if (num_outputs == 0) num_outputs = num_states; STRING name(spec_start, *str - spec_start); lstm = new LSTM(name, input_shape.depth(), num_states, num_outputs, false, type); if (dir != 'f') { Reversed* rev = new Reversed("RevLSTM", NT_XREVERSED); rev->SetNetwork(lstm); lstm = rev; } if (dir == 'b') { name += "LTR"; Parallel* parallel = new Parallel("BidiLSTM", NT_PAR_RL_LSTM); parallel->AddToStack(new LSTM(name, input_shape.depth(), num_states, num_outputs, false, type)); parallel->AddToStack(lstm); lstm = parallel; } } if (dim == 'y') { Reversed* rev = new Reversed("XYTransLSTM", NT_XYTRANSPOSE); rev->SetNetwork(lstm); lstm = rev; } return lstm; } // Builds a set of 4 lstms with x and y reversal, running in true parallel. Network* NetworkBuilder::BuildLSTMXYQuad(int num_inputs, int num_states) { Parallel* parallel = new Parallel("2DLSTMQuad", NT_PAR_2D_LSTM); parallel->AddToStack(new LSTM("L2DLTRDown", num_inputs, num_states, num_states, true, NT_LSTM)); Reversed* rev = new Reversed("L2DLTRXRev", NT_XREVERSED); rev->SetNetwork(new LSTM("L2DRTLDown", num_inputs, num_states, num_states, true, NT_LSTM)); parallel->AddToStack(rev); rev = new Reversed("L2DRTLYRev", NT_YREVERSED); rev->SetNetwork( new LSTM("L2DRTLUp", num_inputs, num_states, num_states, true, NT_LSTM)); Reversed* rev2 = new Reversed("L2DXRevU", NT_XREVERSED); rev2->SetNetwork(rev); parallel->AddToStack(rev2); rev = new Reversed("L2DXRevY", NT_YREVERSED); rev->SetNetwork(new LSTM("L2DLTRDown", num_inputs, num_states, num_states, true, NT_LSTM)); parallel->AddToStack(rev); return parallel; } // Helper builds a truly (0-d) fully connected layer of the given type. static Network* BuildFullyConnected(const StaticShape& input_shape, NetworkType type, const STRING& name, int depth) { if (input_shape.height() == 0 || input_shape.width() == 0) { tprintf("Fully connected requires positive height and width, had %d,%d\n", input_shape.height(), input_shape.width()); return nullptr; } int input_size = input_shape.height() * input_shape.width(); int input_depth = input_size * input_shape.depth(); Network* fc = new FullyConnected(name, input_depth, depth, type); if (input_size > 1) { Series* series = new Series("FCSeries"); series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(), input_shape.width(), input_shape.height())); series->AddToStack(fc); fc = series; } return fc; } // Parses a Fully connected network. Network* NetworkBuilder::ParseFullyConnected(const StaticShape& input_shape, char** str) { char* spec_start = *str; NetworkType type = NonLinearity((*str)[1]); if (type == NT_NONE) { tprintf("Invalid nonlinearity on F-spec!: %s\n", *str); return nullptr; } int depth = strtol(*str + 1, str, 10); if (depth <= 0) { tprintf("Invalid F spec!:%s\n", *str); return nullptr; } STRING name(spec_start, *str - spec_start); return BuildFullyConnected(input_shape, type, name, depth); } // Parses an Output spec. Network* NetworkBuilder::ParseOutput(const StaticShape& input_shape, char** str) { char dims_ch = (*str)[1]; if (dims_ch != '0' && dims_ch != '1' && dims_ch != '2') { tprintf("Invalid dims (2|1|0) in output spec!:%s\n", *str); return nullptr; } char type_ch = (*str)[2]; if (type_ch != 'l' && type_ch != 's' && type_ch != 'c') { tprintf("Invalid output type (l|s|c) in output spec!:%s\n", *str); return nullptr; } int depth = strtol(*str + 3, str, 10); if (depth != num_softmax_outputs_) { tprintf("Warning: given outputs %d not equal to unicharset of %d.\n", depth, num_softmax_outputs_); depth = num_softmax_outputs_; } NetworkType type = NT_SOFTMAX; if (type_ch == 'l') type = NT_LOGISTIC; else if (type_ch == 's') type = NT_SOFTMAX_NO_CTC; if (dims_ch == '0') { // Same as standard fully connected. return BuildFullyConnected(input_shape, type, "Output", depth); } else if (dims_ch == '2') { // We don't care if x and/or y are variable. return new FullyConnected("Output2d", input_shape.depth(), depth, type); } // For 1-d y has to be fixed, and if not 1, moved to depth. if (input_shape.height() == 0) { tprintf("Fully connected requires fixed height!\n"); return nullptr; } int input_size = input_shape.height(); int input_depth = input_size * input_shape.depth(); Network* fc = new FullyConnected("Output", input_depth, depth, type); if (input_size > 1) { Series* series = new Series("FCSeries"); series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(), 1, input_shape.height())); series->AddToStack(fc); fc = series; } return fc; } } // namespace tesseract.