From 00b19d6fbac7080df7b08f1fd4bbd601694e1cc6 Mon Sep 17 00:00:00 2001 From: Aleksandr Pertovskiy Date: Wed, 6 May 2020 15:23:55 +0300 Subject: [PATCH] Add text recognition example --- samples/dnn/text_detection.cpp | 141 ++++++++++++++++++++++++++++----- 1 file changed, 119 insertions(+), 22 deletions(-) diff --git a/samples/dnn/text_detection.cpp b/samples/dnn/text_detection.cpp index e7b0f237d3..706e2fe58b 100644 --- a/samples/dnn/text_detection.cpp +++ b/samples/dnn/text_detection.cpp @@ -1,3 +1,20 @@ +/* + Text detection model: https://github.com/argman/EAST + Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1 + + Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch + How to convert from pb to onnx: + Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py + + import torch + import models.crnn as crnn + + model = CRNN(32, 1, 37, 256) + model.load_state_dict(torch.load('crnn.pth')) + dummy_input = torch.randn(1, 1, 32, 100) + torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True) +*/ + #include #include #include @@ -8,21 +25,26 @@ using namespace cv::dnn; const char* keys = "{ help h | | Print help message. }" "{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}" - "{ model m | | Path to a binary .pb file contains trained network.}" + "{ model m | | Path to a binary .pb file contains trained detector network.}" + "{ ocr | | Path to a binary .pb or .onnx file contains trained recognition network.}" "{ width | 320 | Preprocess input image by resizing to a specific width. It should be multiple by 32. }" "{ height | 320 | Preprocess input image by resizing to a specific height. It should be multiple by 32. }" "{ thr | 0.5 | Confidence threshold. }" "{ nms | 0.4 | Non-maximum suppression threshold. }"; -void decode(const Mat& scores, const Mat& geometry, float scoreThresh, - std::vector& detections, std::vector& confidences); +void decodeBoundingBoxes(const Mat& scores, const Mat& geometry, float scoreThresh, + std::vector& detections, std::vector& confidences); + +void fourPointsTransform(const Mat& frame, Point2f vertices[4], Mat& result); + +void decodeText(const Mat& scores, std::string& text); int main(int argc, char** argv) { // Parse command line arguments. CommandLineParser parser(argc, argv, keys); parser.about("Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of " - "EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)"); + "EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)"); if (argc == 1 || parser.has("help")) { parser.printMessage(); @@ -33,7 +55,8 @@ int main(int argc, char** argv) float nmsThreshold = parser.get("nms"); int inpWidth = parser.get("width"); int inpHeight = parser.get("height"); - String model = parser.get("model"); + String modelDecoder = parser.get("model"); + String modelRecognition = parser.get("ocr"); if (!parser.check()) { @@ -41,17 +64,19 @@ int main(int argc, char** argv) return 1; } - CV_Assert(!model.empty()); + CV_Assert(!modelDecoder.empty()); - // Load network. - Net net = readNet(model); + // Load networks. + Net detector = readNet(modelDecoder); + Net recognizer; + + if (!modelRecognition.empty()) + recognizer = readNet(modelRecognition); // Open a video file or an image file or a camera stream. VideoCapture cap; - if (parser.has("input")) - cap.open(parser.get("input")); - else - cap.open(0); + bool openSuccess = parser.has("input") ? cap.open(parser.get("input")) : cap.open(0); + CV_Assert(openSuccess); static const std::string kWinName = "EAST: An Efficient and Accurate Scene Text Detector"; namedWindow(kWinName, WINDOW_NORMAL); @@ -62,6 +87,7 @@ int main(int argc, char** argv) outNames[1] = "feature_fusion/concat_3"; Mat frame, blob; + TickMeter tickMeter; while (waitKey(1) < 0) { cap >> frame; @@ -72,8 +98,10 @@ int main(int argc, char** argv) } blobFromImage(frame, blob, 1.0, Size(inpWidth, inpHeight), Scalar(123.68, 116.78, 103.94), true, false); - net.setInput(blob); - net.forward(outs, outNames); + detector.setInput(blob); + tickMeter.start(); + detector.forward(outs, outNames); + tickMeter.stop(); Mat scores = outs[0]; Mat geometry = outs[1]; @@ -81,43 +109,64 @@ int main(int argc, char** argv) // Decode predicted bounding boxes. std::vector boxes; std::vector confidences; - decode(scores, geometry, confThreshold, boxes, confidences); + decodeBoundingBoxes(scores, geometry, confThreshold, boxes, confidences); // Apply non-maximum suppression procedure. std::vector indices; NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices); - // Render detections. Point2f ratio((float)frame.cols / inpWidth, (float)frame.rows / inpHeight); + + // Render text. for (size_t i = 0; i < indices.size(); ++i) { RotatedRect& box = boxes[indices[i]]; Point2f vertices[4]; box.points(vertices); + for (int j = 0; j < 4; ++j) { vertices[j].x *= ratio.x; vertices[j].y *= ratio.y; } + + if (!modelRecognition.empty()) + { + Mat cropped; + fourPointsTransform(frame, vertices, cropped); + + cvtColor(cropped, cropped, cv::COLOR_BGR2GRAY); + + Mat blobCrop = blobFromImage(cropped, 1.0/127.5, Size(), Scalar::all(127.5)); + recognizer.setInput(blobCrop); + + tickMeter.start(); + Mat result = recognizer.forward(); + tickMeter.stop(); + + std::string wordRecognized = ""; + decodeText(result, wordRecognized); + putText(frame, wordRecognized, vertices[1], FONT_HERSHEY_SIMPLEX, 1.5, Scalar(0, 0, 255)); + } + for (int j = 0; j < 4; ++j) line(frame, vertices[j], vertices[(j + 1) % 4], Scalar(0, 255, 0), 1); } // Put efficiency information. - std::vector layersTimes; - double freq = getTickFrequency() / 1000; - double t = net.getPerfProfile(layersTimes) / freq; - std::string label = format("Inference time: %.2f ms", t); + std::string label = format("Inference time: %.2f ms", tickMeter.getTimeMilli()); putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0)); imshow(kWinName, frame); + + tickMeter.reset(); } return 0; } -void decode(const Mat& scores, const Mat& geometry, float scoreThresh, - std::vector& detections, std::vector& confidences) +void decodeBoundingBoxes(const Mat& scores, const Mat& geometry, float scoreThresh, + std::vector& detections, std::vector& confidences) { detections.clear(); CV_Assert(scores.dims == 4); CV_Assert(geometry.dims == 4); CV_Assert(scores.size[0] == 1); @@ -159,3 +208,51 @@ void decode(const Mat& scores, const Mat& geometry, float scoreThresh, } } } + +void fourPointsTransform(const Mat& frame, Point2f vertices[4], Mat& result) +{ + const Size outputSize = Size(100, 32); + + Point2f targetVertices[4] = {Point(0, outputSize.height - 1), + Point(0, 0), Point(outputSize.width - 1, 0), + Point(outputSize.width - 1, outputSize.height - 1), + }; + Mat rotationMatrix = getPerspectiveTransform(vertices, targetVertices); + + warpPerspective(frame, result, rotationMatrix, outputSize); +} + +void decodeText(const Mat& scores, std::string& text) +{ + static const std::string alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"; + Mat scoresMat = scores.reshape(1, scores.size[0]); + + std::vector elements; + elements.reserve(scores.size[0]); + + for (int rowIndex = 0; rowIndex < scoresMat.rows; ++rowIndex) + { + Point p; + minMaxLoc(scoresMat.row(rowIndex), 0, 0, 0, &p); + if (p.x > 0 && static_cast(p.x) <= alphabet.size()) + { + elements.push_back(alphabet[p.x - 1]); + } + else + { + elements.push_back('-'); + } + } + + if (elements.size() > 0 && elements[0] != '-') + text += elements[0]; + + for (size_t elementIndex = 1; elementIndex < elements.size(); ++elementIndex) + { + if (elementIndex > 0 && elements[elementIndex] != '-' && + elements[elementIndex - 1] != elements[elementIndex]) + { + text += elements[elementIndex]; + } + } +} \ No newline at end of file