From 7434590b9a35c2428fac36b574e780516d202eaf Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Fri, 5 Oct 2018 20:38:02 +0200 Subject: [PATCH 1/2] lstmtraining: Check write permission for output model This is done by creating a temporary file. Report an error and terminate if that fails. Use also EXIT_SUCCESS and EXIT_FAILURE for the return values of main(). Signed-off-by: Stefan Weil --- src/training/lstmtraining.cpp | 37 +++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/training/lstmtraining.cpp b/src/training/lstmtraining.cpp index ac7017c47..cd4e1de0c 100644 --- a/src/training/lstmtraining.cpp +++ b/src/training/lstmtraining.cpp @@ -76,11 +76,11 @@ int main(int argc, char **argv) { // Purify the model name in case it is based on the network string. if (FLAGS_model_output.empty()) { tprintf("Must provide a --model_output!\n"); - return 1; + return EXIT_FAILURE; } if (FLAGS_traineddata.empty()) { tprintf("Must provide a --traineddata see training wiki\n"); - return 1; + return EXIT_FAILURE; } STRING model_output = FLAGS_model_output.c_str(); for (int i = 0; i < model_output.length(); ++i) { @@ -89,6 +89,19 @@ int main(int argc, char **argv) { if (model_output[i] == '(' || model_output[i] == ')') model_output[i] = '_'; } + + // Check write permissions. + STRING test_file = FLAGS_model_output.c_str(); + test_file += "_wtest"; + FILE* f = fopen(test_file.c_str(), "wb"); + if (f != nullptr) { + fclose(f); + remove(test_file.c_str()); + } else { + tprintf("Error, model output cannot be written: %s\n", strerror(errno)); + return EXIT_FAILURE; + } + // Setup the trainer. STRING checkpoint_file = FLAGS_model_output.c_str(); checkpoint_file += "_checkpoint"; @@ -105,7 +118,7 @@ int main(int argc, char **argv) { if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(), nullptr)) { tprintf("Failed to read continue from: %s\n", FLAGS_continue_from.c_str()); - return 1; + return EXIT_FAILURE; } if (FLAGS_debug_network) { trainer.DebugNetwork(); @@ -116,20 +129,20 @@ int main(int argc, char **argv) { FLAGS_model_output.c_str()); } } - return 0; + return EXIT_SUCCESS; } // Get the list of files to process. if (FLAGS_train_listfile.empty()) { tprintf("Must supply a list of training filenames! --train_listfile\n"); - return 1; + return EXIT_FAILURE; } GenericVector filenames; if (!tesseract::LoadFileLinesToStrings(FLAGS_train_listfile.c_str(), &filenames)) { tprintf("Failed to load list of training filenames from %s\n", FLAGS_train_listfile.c_str()); - return 1; + return EXIT_FAILURE; } // Checkpoints always take priority if they are available. @@ -145,7 +158,7 @@ int main(int argc, char **argv) { ? FLAGS_continue_from.c_str() : FLAGS_old_traineddata.c_str())) { tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str()); - return 1; + return EXIT_FAILURE; } tprintf("Continuing from %s\n", FLAGS_continue_from.c_str()); trainer.InitIterations(); @@ -155,7 +168,7 @@ int main(int argc, char **argv) { tprintf("Appending a new network to an old one!!"); if (FLAGS_continue_from.empty()) { tprintf("Must set --continue_from for appending!\n"); - return 1; + return EXIT_FAILURE; } } // We are initializing from scratch. @@ -165,7 +178,7 @@ int main(int argc, char **argv) { FLAGS_adam_beta)) { tprintf("Failed to create network from spec: %s\n", FLAGS_net_spec.c_str()); - return 1; + return EXIT_FAILURE; } trainer.set_perfect_delay(FLAGS_perfect_sample_delay); } @@ -176,7 +189,7 @@ int main(int argc, char **argv) { : tesseract::CS_ROUND_ROBIN, FLAGS_randomly_rotate)) { tprintf("Load of images failed!!\n"); - return 1; + return EXIT_FAILURE; } tesseract::LSTMTester tester(static_cast(FLAGS_max_image_MB) * @@ -186,7 +199,7 @@ int main(int argc, char **argv) { if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) { tprintf("Failed to load eval data from: %s\n", FLAGS_eval_listfile.c_str()); - return 1; + return EXIT_FAILURE; } tester_callback = NewPermanentTessCallback(&tester, &tesseract::LSTMTester::RunEvalAsync); @@ -208,5 +221,5 @@ int main(int argc, char **argv) { FLAGS_max_iterations == 0)); delete tester_callback; tprintf("Finished! Error rate = %g\n", trainer.best_error_rate()); - return 0; + return EXIT_SUCCESS; } /* main */ From f4e982e041e9154f49cf0abefa9db850636cba18 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Fri, 5 Oct 2018 21:39:18 +0200 Subject: [PATCH 2/2] combine_tessdata: Handle failures when extracting Report an error and terminate if that fails. Use also EXIT_SUCCESS and EXIT_FAILURE for the return values of main() and add missing return at end of main(). Signed-off-by: Stefan Weil --- src/training/combine_tessdata.cpp | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/training/combine_tessdata.cpp b/src/training/combine_tessdata.cpp index 3eb8d8541..c0a2a7510 100644 --- a/src/training/combine_tessdata.cpp +++ b/src/training/combine_tessdata.cpp @@ -72,7 +72,7 @@ int main(int argc, char **argv) { tesseract::TessdataManager tm; if (argc > 1 && (!strcmp(argv[1], "-v") || !strcmp(argv[1], "--version"))) { printf("%s\n", tesseract::TessBaseAPI::Version()); - return 0; + return EXIT_SUCCESS; } else if (argc == 2) { printf("Combining tessdata files\n"); STRING lang = argv[1]; @@ -92,16 +92,22 @@ int main(int argc, char **argv) { // Initialize TessdataManager with the data in the given traineddata file. if (!tm.Init(argv[2])) { tprintf("Failed to read %s\n", argv[2]); - exit(1); + return EXIT_FAILURE; } printf("Extracting tessdata components from %s\n", argv[2]); if (strcmp(argv[1], "-e") == 0) { for (i = 3; i < argc; ++i) { + errno = 0; if (tm.ExtractToFile(argv[i])) { printf("Wrote %s\n", argv[i]); - } else { + } else if (errno == 0) { printf("Not extracting %s, since this component" " is not present\n", argv[i]); + return EXIT_FAILURE; + } else { + printf("Error, could not extract %s: %s\n", + argv[i], strerror(errno)); + return EXIT_FAILURE; } } } else { // extract all the components @@ -111,8 +117,13 @@ int main(int argc, char **argv) { if (*last != '.') filename += '.'; filename += tesseract::kTessdataFileSuffixes[i]; + errno = 0; if (tm.ExtractToFile(filename.string())) { printf("Wrote %s\n", filename.string()); + } else if (errno != 0) { + printf("Error, could not extract %s: %s\n", + filename.string(), strerror(errno)); + return EXIT_FAILURE; } } } @@ -124,7 +135,7 @@ int main(int argc, char **argv) { if (rename(new_traineddata_filename, traineddata_filename.string()) != 0) { tprintf("Failed to create a temporary file %s\n", traineddata_filename.string()); - exit(1); + return EXIT_FAILURE; } // Initialize TessdataManager with the data in the given traineddata file. @@ -135,17 +146,17 @@ int main(int argc, char **argv) { } else if (argc == 3 && strcmp(argv[1], "-c") == 0) { if (!tm.Init(argv[2])) { tprintf("Failed to read %s\n", argv[2]); - exit(1); + return EXIT_FAILURE; } tesseract::TFile fp; if (!tm.GetComponent(tesseract::TESSDATA_LSTM, &fp)) { tprintf("No LSTM Component found in %s!\n", argv[2]); - exit(1); + return EXIT_FAILURE; } tesseract::LSTMRecognizer recognizer; if (!recognizer.DeSerialize(&tm, &fp)) { tprintf("Failed to deserialize LSTM in %s!\n", argv[2]); - exit(1); + return EXIT_FAILURE; } recognizer.ConvertToInt(); GenericVector lstm_data; @@ -155,7 +166,7 @@ int main(int argc, char **argv) { lstm_data.size()); if (!tm.SaveFile(argv[2], nullptr)) { tprintf("Failed to write modified traineddata:%s!\n", argv[2]); - exit(1); + return EXIT_FAILURE; } } else if (argc == 3 && strcmp(argv[1], "-d") == 0) { // Initialize TessdataManager with the data in the given traineddata file. @@ -186,4 +197,5 @@ int main(int argc, char **argv) { return 1; } tm.Directory(); + return EXIT_SUCCESS; }