lstmrecognizer.cpp: Call OutputStats() only when 'invert' is true (#3387)

Co-authored-by: Stefan Weil <sw@weilnetz.de>
This commit is contained in:
Amit D 2021-04-08 18:55:23 +03:00 committed by GitHub
parent e6ce048426
commit a4a84c4c92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -343,33 +343,36 @@ bool LSTMRecognizer::RecognizeLine(const ImageData &image_data, bool invert, boo
Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, inputs);
network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
// Check for auto inversion.
float pos_min, pos_mean, pos_sd;
OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd);
if (invert && pos_mean < 0.5) {
// Run again inverted and see if it is any better.
NetworkIO inv_inputs, inv_outputs;
inv_inputs.set_int_mode(IsIntMode());
SetRandomSeed();
pixInvert(pix, pix);
Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, &inv_inputs);
network_->Forward(debug, inv_inputs, nullptr, &scratch_space_, &inv_outputs);
float inv_min, inv_mean, inv_sd;
OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd);
if (inv_mean > pos_mean) {
// Inverted did better. Use inverted data.
if (debug) {
tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n", pos_min, pos_mean,
pos_sd, inv_min, inv_mean, inv_sd);
}
*outputs = inv_outputs;
*inputs = inv_inputs;
} else if (re_invert) {
// Inverting was not an improvement, so undo and run again, so the
// outputs match the best forward result.
if (invert) {
float pos_min, pos_mean, pos_sd;
OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd);
if (pos_mean < 0.5f) {
// Run again inverted and see if it is any better.
NetworkIO inv_inputs, inv_outputs;
inv_inputs.set_int_mode(IsIntMode());
SetRandomSeed();
network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
pixInvert(pix, pix);
Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, &inv_inputs);
network_->Forward(debug, inv_inputs, nullptr, &scratch_space_, &inv_outputs);
float inv_min, inv_mean, inv_sd;
OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd);
if (inv_mean > pos_mean) {
// Inverted did better. Use inverted data.
if (debug) {
tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n", pos_min, pos_mean,
pos_sd, inv_min, inv_mean, inv_sd);
}
*outputs = inv_outputs;
*inputs = inv_inputs;
} else if (re_invert) {
// Inverting was not an improvement, so undo and run again, so the
// outputs match the best forward result.
SetRandomSeed();
network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
}
}
}
pix.destroy();
if (debug) {
std::vector<int> labels, coords;