mirror of
https://github.com/opencv/opencv.git
synced 2025-06-28 07:23:30 +08:00
Add text recognition example
This commit is contained in:
parent
f6b2b49e4a
commit
00b19d6fba
@ -1,3 +1,20 @@
|
|||||||
|
/*
|
||||||
|
Text detection model: https://github.com/argman/EAST
|
||||||
|
Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
|
||||||
|
|
||||||
|
Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
|
||||||
|
How to convert from pb to onnx:
|
||||||
|
Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import models.crnn as crnn
|
||||||
|
|
||||||
|
model = CRNN(32, 1, 37, 256)
|
||||||
|
model.load_state_dict(torch.load('crnn.pth'))
|
||||||
|
dummy_input = torch.randn(1, 1, 32, 100)
|
||||||
|
torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
|
||||||
|
*/
|
||||||
|
|
||||||
#include <opencv2/imgproc.hpp>
|
#include <opencv2/imgproc.hpp>
|
||||||
#include <opencv2/highgui.hpp>
|
#include <opencv2/highgui.hpp>
|
||||||
#include <opencv2/dnn.hpp>
|
#include <opencv2/dnn.hpp>
|
||||||
@ -8,21 +25,26 @@ using namespace cv::dnn;
|
|||||||
const char* keys =
|
const char* keys =
|
||||||
"{ help h | | Print help message. }"
|
"{ help h | | Print help message. }"
|
||||||
"{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
|
"{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
|
||||||
"{ model m | | Path to a binary .pb file contains trained network.}"
|
"{ model m | | Path to a binary .pb file contains trained detector network.}"
|
||||||
|
"{ ocr | | Path to a binary .pb or .onnx file contains trained recognition network.}"
|
||||||
"{ width | 320 | Preprocess input image by resizing to a specific width. It should be multiple by 32. }"
|
"{ width | 320 | Preprocess input image by resizing to a specific width. It should be multiple by 32. }"
|
||||||
"{ height | 320 | Preprocess input image by resizing to a specific height. It should be multiple by 32. }"
|
"{ height | 320 | Preprocess input image by resizing to a specific height. It should be multiple by 32. }"
|
||||||
"{ thr | 0.5 | Confidence threshold. }"
|
"{ thr | 0.5 | Confidence threshold. }"
|
||||||
"{ nms | 0.4 | Non-maximum suppression threshold. }";
|
"{ nms | 0.4 | Non-maximum suppression threshold. }";
|
||||||
|
|
||||||
void decode(const Mat& scores, const Mat& geometry, float scoreThresh,
|
void decodeBoundingBoxes(const Mat& scores, const Mat& geometry, float scoreThresh,
|
||||||
std::vector<RotatedRect>& detections, std::vector<float>& confidences);
|
std::vector<RotatedRect>& detections, std::vector<float>& confidences);
|
||||||
|
|
||||||
|
void fourPointsTransform(const Mat& frame, Point2f vertices[4], Mat& result);
|
||||||
|
|
||||||
|
void decodeText(const Mat& scores, std::string& text);
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
// Parse command line arguments.
|
// Parse command line arguments.
|
||||||
CommandLineParser parser(argc, argv, keys);
|
CommandLineParser parser(argc, argv, keys);
|
||||||
parser.about("Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
|
parser.about("Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
|
||||||
"EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)");
|
"EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)");
|
||||||
if (argc == 1 || parser.has("help"))
|
if (argc == 1 || parser.has("help"))
|
||||||
{
|
{
|
||||||
parser.printMessage();
|
parser.printMessage();
|
||||||
@ -33,7 +55,8 @@ int main(int argc, char** argv)
|
|||||||
float nmsThreshold = parser.get<float>("nms");
|
float nmsThreshold = parser.get<float>("nms");
|
||||||
int inpWidth = parser.get<int>("width");
|
int inpWidth = parser.get<int>("width");
|
||||||
int inpHeight = parser.get<int>("height");
|
int inpHeight = parser.get<int>("height");
|
||||||
String model = parser.get<String>("model");
|
String modelDecoder = parser.get<String>("model");
|
||||||
|
String modelRecognition = parser.get<String>("ocr");
|
||||||
|
|
||||||
if (!parser.check())
|
if (!parser.check())
|
||||||
{
|
{
|
||||||
@ -41,17 +64,19 @@ int main(int argc, char** argv)
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
CV_Assert(!model.empty());
|
CV_Assert(!modelDecoder.empty());
|
||||||
|
|
||||||
// Load network.
|
// Load networks.
|
||||||
Net net = readNet(model);
|
Net detector = readNet(modelDecoder);
|
||||||
|
Net recognizer;
|
||||||
|
|
||||||
|
if (!modelRecognition.empty())
|
||||||
|
recognizer = readNet(modelRecognition);
|
||||||
|
|
||||||
// Open a video file or an image file or a camera stream.
|
// Open a video file or an image file or a camera stream.
|
||||||
VideoCapture cap;
|
VideoCapture cap;
|
||||||
if (parser.has("input"))
|
bool openSuccess = parser.has("input") ? cap.open(parser.get<String>("input")) : cap.open(0);
|
||||||
cap.open(parser.get<String>("input"));
|
CV_Assert(openSuccess);
|
||||||
else
|
|
||||||
cap.open(0);
|
|
||||||
|
|
||||||
static const std::string kWinName = "EAST: An Efficient and Accurate Scene Text Detector";
|
static const std::string kWinName = "EAST: An Efficient and Accurate Scene Text Detector";
|
||||||
namedWindow(kWinName, WINDOW_NORMAL);
|
namedWindow(kWinName, WINDOW_NORMAL);
|
||||||
@ -62,6 +87,7 @@ int main(int argc, char** argv)
|
|||||||
outNames[1] = "feature_fusion/concat_3";
|
outNames[1] = "feature_fusion/concat_3";
|
||||||
|
|
||||||
Mat frame, blob;
|
Mat frame, blob;
|
||||||
|
TickMeter tickMeter;
|
||||||
while (waitKey(1) < 0)
|
while (waitKey(1) < 0)
|
||||||
{
|
{
|
||||||
cap >> frame;
|
cap >> frame;
|
||||||
@ -72,8 +98,10 @@ int main(int argc, char** argv)
|
|||||||
}
|
}
|
||||||
|
|
||||||
blobFromImage(frame, blob, 1.0, Size(inpWidth, inpHeight), Scalar(123.68, 116.78, 103.94), true, false);
|
blobFromImage(frame, blob, 1.0, Size(inpWidth, inpHeight), Scalar(123.68, 116.78, 103.94), true, false);
|
||||||
net.setInput(blob);
|
detector.setInput(blob);
|
||||||
net.forward(outs, outNames);
|
tickMeter.start();
|
||||||
|
detector.forward(outs, outNames);
|
||||||
|
tickMeter.stop();
|
||||||
|
|
||||||
Mat scores = outs[0];
|
Mat scores = outs[0];
|
||||||
Mat geometry = outs[1];
|
Mat geometry = outs[1];
|
||||||
@ -81,43 +109,64 @@ int main(int argc, char** argv)
|
|||||||
// Decode predicted bounding boxes.
|
// Decode predicted bounding boxes.
|
||||||
std::vector<RotatedRect> boxes;
|
std::vector<RotatedRect> boxes;
|
||||||
std::vector<float> confidences;
|
std::vector<float> confidences;
|
||||||
decode(scores, geometry, confThreshold, boxes, confidences);
|
decodeBoundingBoxes(scores, geometry, confThreshold, boxes, confidences);
|
||||||
|
|
||||||
// Apply non-maximum suppression procedure.
|
// Apply non-maximum suppression procedure.
|
||||||
std::vector<int> indices;
|
std::vector<int> indices;
|
||||||
NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices);
|
NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices);
|
||||||
|
|
||||||
// Render detections.
|
|
||||||
Point2f ratio((float)frame.cols / inpWidth, (float)frame.rows / inpHeight);
|
Point2f ratio((float)frame.cols / inpWidth, (float)frame.rows / inpHeight);
|
||||||
|
|
||||||
|
// Render text.
|
||||||
for (size_t i = 0; i < indices.size(); ++i)
|
for (size_t i = 0; i < indices.size(); ++i)
|
||||||
{
|
{
|
||||||
RotatedRect& box = boxes[indices[i]];
|
RotatedRect& box = boxes[indices[i]];
|
||||||
|
|
||||||
Point2f vertices[4];
|
Point2f vertices[4];
|
||||||
box.points(vertices);
|
box.points(vertices);
|
||||||
|
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
{
|
{
|
||||||
vertices[j].x *= ratio.x;
|
vertices[j].x *= ratio.x;
|
||||||
vertices[j].y *= ratio.y;
|
vertices[j].y *= ratio.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!modelRecognition.empty())
|
||||||
|
{
|
||||||
|
Mat cropped;
|
||||||
|
fourPointsTransform(frame, vertices, cropped);
|
||||||
|
|
||||||
|
cvtColor(cropped, cropped, cv::COLOR_BGR2GRAY);
|
||||||
|
|
||||||
|
Mat blobCrop = blobFromImage(cropped, 1.0/127.5, Size(), Scalar::all(127.5));
|
||||||
|
recognizer.setInput(blobCrop);
|
||||||
|
|
||||||
|
tickMeter.start();
|
||||||
|
Mat result = recognizer.forward();
|
||||||
|
tickMeter.stop();
|
||||||
|
|
||||||
|
std::string wordRecognized = "";
|
||||||
|
decodeText(result, wordRecognized);
|
||||||
|
putText(frame, wordRecognized, vertices[1], FONT_HERSHEY_SIMPLEX, 1.5, Scalar(0, 0, 255));
|
||||||
|
}
|
||||||
|
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
line(frame, vertices[j], vertices[(j + 1) % 4], Scalar(0, 255, 0), 1);
|
line(frame, vertices[j], vertices[(j + 1) % 4], Scalar(0, 255, 0), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put efficiency information.
|
// Put efficiency information.
|
||||||
std::vector<double> layersTimes;
|
std::string label = format("Inference time: %.2f ms", tickMeter.getTimeMilli());
|
||||||
double freq = getTickFrequency() / 1000;
|
|
||||||
double t = net.getPerfProfile(layersTimes) / freq;
|
|
||||||
std::string label = format("Inference time: %.2f ms", t);
|
|
||||||
putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
|
putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
|
||||||
|
|
||||||
imshow(kWinName, frame);
|
imshow(kWinName, frame);
|
||||||
|
|
||||||
|
tickMeter.reset();
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void decode(const Mat& scores, const Mat& geometry, float scoreThresh,
|
void decodeBoundingBoxes(const Mat& scores, const Mat& geometry, float scoreThresh,
|
||||||
std::vector<RotatedRect>& detections, std::vector<float>& confidences)
|
std::vector<RotatedRect>& detections, std::vector<float>& confidences)
|
||||||
{
|
{
|
||||||
detections.clear();
|
detections.clear();
|
||||||
CV_Assert(scores.dims == 4); CV_Assert(geometry.dims == 4); CV_Assert(scores.size[0] == 1);
|
CV_Assert(scores.dims == 4); CV_Assert(geometry.dims == 4); CV_Assert(scores.size[0] == 1);
|
||||||
@ -159,3 +208,51 @@ void decode(const Mat& scores, const Mat& geometry, float scoreThresh,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void fourPointsTransform(const Mat& frame, Point2f vertices[4], Mat& result)
|
||||||
|
{
|
||||||
|
const Size outputSize = Size(100, 32);
|
||||||
|
|
||||||
|
Point2f targetVertices[4] = {Point(0, outputSize.height - 1),
|
||||||
|
Point(0, 0), Point(outputSize.width - 1, 0),
|
||||||
|
Point(outputSize.width - 1, outputSize.height - 1),
|
||||||
|
};
|
||||||
|
Mat rotationMatrix = getPerspectiveTransform(vertices, targetVertices);
|
||||||
|
|
||||||
|
warpPerspective(frame, result, rotationMatrix, outputSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
void decodeText(const Mat& scores, std::string& text)
|
||||||
|
{
|
||||||
|
static const std::string alphabet = "0123456789abcdefghijklmnopqrstuvwxyz";
|
||||||
|
Mat scoresMat = scores.reshape(1, scores.size[0]);
|
||||||
|
|
||||||
|
std::vector<char> elements;
|
||||||
|
elements.reserve(scores.size[0]);
|
||||||
|
|
||||||
|
for (int rowIndex = 0; rowIndex < scoresMat.rows; ++rowIndex)
|
||||||
|
{
|
||||||
|
Point p;
|
||||||
|
minMaxLoc(scoresMat.row(rowIndex), 0, 0, 0, &p);
|
||||||
|
if (p.x > 0 && static_cast<size_t>(p.x) <= alphabet.size())
|
||||||
|
{
|
||||||
|
elements.push_back(alphabet[p.x - 1]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
elements.push_back('-');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (elements.size() > 0 && elements[0] != '-')
|
||||||
|
text += elements[0];
|
||||||
|
|
||||||
|
for (size_t elementIndex = 1; elementIndex < elements.size(); ++elementIndex)
|
||||||
|
{
|
||||||
|
if (elementIndex > 0 && elements[elementIndex] != '-' &&
|
||||||
|
elements[elementIndex - 1] != elements[elementIndex])
|
||||||
|
{
|
||||||
|
text += elements[elementIndex];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user