lstmtraining: Check write permission for output model

This is done by creating a temporary file.
Report an error and terminate if that fails.

Use also EXIT_SUCCESS and EXIT_FAILURE for the return values of main().

Signed-off-by: Stefan Weil <sw@weilnetz.de>
This commit is contained in:
Stefan Weil 2018-10-05 20:38:02 +02:00
parent 660dbaa9d5
commit 7434590b9a

View File

@ -76,11 +76,11 @@ int main(int argc, char **argv) {
// Purify the model name in case it is based on the network string.
if (FLAGS_model_output.empty()) {
tprintf("Must provide a --model_output!\n");
return 1;
return EXIT_FAILURE;
}
if (FLAGS_traineddata.empty()) {
tprintf("Must provide a --traineddata see training wiki\n");
return 1;
return EXIT_FAILURE;
}
STRING model_output = FLAGS_model_output.c_str();
for (int i = 0; i < model_output.length(); ++i) {
@ -89,6 +89,19 @@ int main(int argc, char **argv) {
if (model_output[i] == '(' || model_output[i] == ')')
model_output[i] = '_';
}
// Check write permissions.
STRING test_file = FLAGS_model_output.c_str();
test_file += "_wtest";
FILE* f = fopen(test_file.c_str(), "wb");
if (f != nullptr) {
fclose(f);
remove(test_file.c_str());
} else {
tprintf("Error, model output cannot be written: %s\n", strerror(errno));
return EXIT_FAILURE;
}
// Setup the trainer.
STRING checkpoint_file = FLAGS_model_output.c_str();
checkpoint_file += "_checkpoint";
@ -105,7 +118,7 @@ int main(int argc, char **argv) {
if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(), nullptr)) {
tprintf("Failed to read continue from: %s\n",
FLAGS_continue_from.c_str());
return 1;
return EXIT_FAILURE;
}
if (FLAGS_debug_network) {
trainer.DebugNetwork();
@ -116,20 +129,20 @@ int main(int argc, char **argv) {
FLAGS_model_output.c_str());
}
}
return 0;
return EXIT_SUCCESS;
}
// Get the list of files to process.
if (FLAGS_train_listfile.empty()) {
tprintf("Must supply a list of training filenames! --train_listfile\n");
return 1;
return EXIT_FAILURE;
}
GenericVector<STRING> filenames;
if (!tesseract::LoadFileLinesToStrings(FLAGS_train_listfile.c_str(),
&filenames)) {
tprintf("Failed to load list of training filenames from %s\n",
FLAGS_train_listfile.c_str());
return 1;
return EXIT_FAILURE;
}
// Checkpoints always take priority if they are available.
@ -145,7 +158,7 @@ int main(int argc, char **argv) {
? FLAGS_continue_from.c_str()
: FLAGS_old_traineddata.c_str())) {
tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str());
return 1;
return EXIT_FAILURE;
}
tprintf("Continuing from %s\n", FLAGS_continue_from.c_str());
trainer.InitIterations();
@ -155,7 +168,7 @@ int main(int argc, char **argv) {
tprintf("Appending a new network to an old one!!");
if (FLAGS_continue_from.empty()) {
tprintf("Must set --continue_from for appending!\n");
return 1;
return EXIT_FAILURE;
}
}
// We are initializing from scratch.
@ -165,7 +178,7 @@ int main(int argc, char **argv) {
FLAGS_adam_beta)) {
tprintf("Failed to create network from spec: %s\n",
FLAGS_net_spec.c_str());
return 1;
return EXIT_FAILURE;
}
trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
}
@ -176,7 +189,7 @@ int main(int argc, char **argv) {
: tesseract::CS_ROUND_ROBIN,
FLAGS_randomly_rotate)) {
tprintf("Load of images failed!!\n");
return 1;
return EXIT_FAILURE;
}
tesseract::LSTMTester tester(static_cast<int64_t>(FLAGS_max_image_MB) *
@ -186,7 +199,7 @@ int main(int argc, char **argv) {
if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) {
tprintf("Failed to load eval data from: %s\n",
FLAGS_eval_listfile.c_str());
return 1;
return EXIT_FAILURE;
}
tester_callback =
NewPermanentTessCallback(&tester, &tesseract::LSTMTester::RunEvalAsync);
@ -208,5 +221,5 @@ int main(int argc, char **argv) {
FLAGS_max_iterations == 0));
delete tester_callback;
tprintf("Finished! Error rate = %g\n", trainer.best_error_rate());
return 0;
return EXIT_SUCCESS;
} /* main */