mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 06:26:29 +08:00
Merge pull request #16955 from themechanicalcoder:text_recognition
* add text recognition sample * fix pylint warning * made changes according to the c++ example * fix errors * add text recognition sample * update text detection sample
This commit is contained in:
parent
0fb3b8db72
commit
1b336bb602
@ -1,25 +1,81 @@
|
||||
'''
|
||||
Text detection model: https://github.com/argman/EAST
|
||||
Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
|
||||
Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
|
||||
How to convert from pb to onnx:
|
||||
Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
|
||||
import torch
|
||||
import models.crnn as CRNN
|
||||
model = CRNN(32, 1, 37, 256)
|
||||
model.load_state_dict(torch.load('crnn.pth'))
|
||||
dummy_input = torch.randn(1, 1, 32, 100)
|
||||
torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
|
||||
'''
|
||||
|
||||
|
||||
# Import required modules
|
||||
import numpy as np
|
||||
import cv2 as cv
|
||||
import math
|
||||
import argparse
|
||||
|
||||
############ Add argument parser for command line arguments ############
|
||||
parser = argparse.ArgumentParser(description='Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)')
|
||||
parser.add_argument('--input', help='Path to input image or video file. Skip this argument to capture frames from a camera.')
|
||||
parser.add_argument('--model', required=True,
|
||||
help='Path to a binary .pb file of model contains trained weights.')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
|
||||
"EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)"
|
||||
"The OCR model can be obtained from converting the pretrained CRNN model to .onnx format from the github repository https://github.com/meijieru/crnn.pytorch")
|
||||
parser.add_argument('--input',
|
||||
help='Path to input image or video file. Skip this argument to capture frames from a camera.')
|
||||
parser.add_argument('--model', '-m', required=True,
|
||||
help='Path to a binary .pb file contains trained detector network.')
|
||||
parser.add_argument('--ocr', default="crnn.onnx",
|
||||
help="Path to a binary .pb or .onnx file contains trained recognition network", )
|
||||
parser.add_argument('--width', type=int, default=320,
|
||||
help='Preprocess input image by resizing to a specific width. It should be multiple by 32.')
|
||||
parser.add_argument('--height',type=int, default=320,
|
||||
parser.add_argument('--height', type=int, default=320,
|
||||
help='Preprocess input image by resizing to a specific height. It should be multiple by 32.')
|
||||
parser.add_argument('--thr',type=float, default=0.5,
|
||||
parser.add_argument('--thr', type=float, default=0.5,
|
||||
help='Confidence threshold.')
|
||||
parser.add_argument('--nms',type=float, default=0.4,
|
||||
parser.add_argument('--nms', type=float, default=0.4,
|
||||
help='Non-maximum suppression threshold.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
############ Utility functions ############
|
||||
def decode(scores, geometry, scoreThresh):
|
||||
|
||||
def fourPointsTransform(frame, vertices):
|
||||
vertices = np.asarray(vertices)
|
||||
outputSize = (100, 32)
|
||||
targetVertices = np.array([
|
||||
[0, outputSize[1] - 1],
|
||||
[0, 0],
|
||||
[outputSize[0] - 1, 0],
|
||||
[outputSize[0] - 1, outputSize[1] - 1]], dtype="float32")
|
||||
|
||||
rotationMatrix = cv.getPerspectiveTransform(vertices, targetVertices)
|
||||
result = cv.warpPerspective(frame, rotationMatrix, outputSize)
|
||||
return result
|
||||
|
||||
|
||||
def decodeText(scores):
|
||||
text = ""
|
||||
alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
for i in range(scores.shape[0]):
|
||||
c = np.argmax(scores[i][0])
|
||||
if c != 0:
|
||||
text += alphabet[c - 1]
|
||||
else:
|
||||
text += '-'
|
||||
|
||||
# adjacent same letters as well as background text must be removed to get the final output
|
||||
char_list = []
|
||||
for i in range(len(text)):
|
||||
if text[i] != '-' and (not (i > 0 and text[i] == text[i - 1])):
|
||||
char_list.append(text[i])
|
||||
return ''.join(char_list)
|
||||
|
||||
|
||||
def decodeBoundingBoxes(scores, geometry, scoreThresh):
|
||||
detections = []
|
||||
confidences = []
|
||||
|
||||
@ -47,7 +103,7 @@ def decode(scores, geometry, scoreThresh):
|
||||
score = scoresData[x]
|
||||
|
||||
# If score is lower than threshold score, move to next x
|
||||
if(score < scoreThresh):
|
||||
if (score < scoreThresh):
|
||||
continue
|
||||
|
||||
# Calculate offset
|
||||
@ -66,24 +122,27 @@ def decode(scores, geometry, scoreThresh):
|
||||
|
||||
# Find points for rectangle
|
||||
p1 = (-sinA * h + offset[0], -cosA * h + offset[1])
|
||||
p3 = (-cosA * w + offset[0], sinA * w + offset[1])
|
||||
center = (0.5*(p1[0]+p3[0]), 0.5*(p1[1]+p3[1]))
|
||||
detections.append((center, (w,h), -1*angle * 180.0 / math.pi))
|
||||
p3 = (-cosA * w + offset[0], sinA * w + offset[1])
|
||||
center = (0.5 * (p1[0] + p3[0]), 0.5 * (p1[1] + p3[1]))
|
||||
detections.append((center, (w, h), -1 * angle * 180.0 / math.pi))
|
||||
confidences.append(float(score))
|
||||
|
||||
# Return detections and confidences
|
||||
return [detections, confidences]
|
||||
|
||||
|
||||
def main():
|
||||
# Read and store arguments
|
||||
confThreshold = args.thr
|
||||
nmsThreshold = args.nms
|
||||
inpWidth = args.width
|
||||
inpHeight = args.height
|
||||
model = args.model
|
||||
modelDetector = args.model
|
||||
modelRecognition = args.ocr
|
||||
|
||||
# Load network
|
||||
net = cv.dnn.readNet(model)
|
||||
detector = cv.dnn.readNet(modelDetector)
|
||||
recognizer = cv.dnn.readNet(modelRecognition)
|
||||
|
||||
# Create a new named window
|
||||
kWinName = "EAST: An Efficient and Accurate Scene Text Detector"
|
||||
@ -95,6 +154,7 @@ def main():
|
||||
# Open a video file or an image file or a camera stream
|
||||
cap = cv.VideoCapture(args.input if args.input else 0)
|
||||
|
||||
tickmeter = cv.TickMeter()
|
||||
while cv.waitKey(1) < 0:
|
||||
# Read frame
|
||||
hasFrame, frame = cap.read()
|
||||
@ -111,19 +171,20 @@ def main():
|
||||
# Create a 4D blob from frame.
|
||||
blob = cv.dnn.blobFromImage(frame, 1.0, (inpWidth, inpHeight), (123.68, 116.78, 103.94), True, False)
|
||||
|
||||
# Run the model
|
||||
net.setInput(blob)
|
||||
outs = net.forward(outNames)
|
||||
t, _ = net.getPerfProfile()
|
||||
label = 'Inference time: %.2f ms' % (t * 1000.0 / cv.getTickFrequency())
|
||||
# Run the detection model
|
||||
detector.setInput(blob)
|
||||
|
||||
tickmeter.start()
|
||||
outs = detector.forward(outNames)
|
||||
tickmeter.stop()
|
||||
|
||||
# Get scores and geometry
|
||||
scores = outs[0]
|
||||
geometry = outs[1]
|
||||
[boxes, confidences] = decode(scores, geometry, confThreshold)
|
||||
[boxes, confidences] = decodeBoundingBoxes(scores, geometry, confThreshold)
|
||||
|
||||
# Apply NMS
|
||||
indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold,nmsThreshold)
|
||||
indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold, nmsThreshold)
|
||||
for i in indices:
|
||||
# get 4 corners of the rotated rect
|
||||
vertices = cv.boxPoints(boxes[i[0]])
|
||||
@ -131,16 +192,40 @@ def main():
|
||||
for j in range(4):
|
||||
vertices[j][0] *= rW
|
||||
vertices[j][1] *= rH
|
||||
|
||||
|
||||
# get cropped image using perspective transform
|
||||
if modelRecognition:
|
||||
cropped = fourPointsTransform(frame, vertices)
|
||||
cropped = cv.cvtColor(cropped, cv.COLOR_BGR2GRAY)
|
||||
|
||||
# Create a 4D blob from cropped image
|
||||
blob = cv.dnn.blobFromImage(cropped, size=(100, 32), mean=127.5, scalefactor=1 / 127.5)
|
||||
recognizer.setInput(blob)
|
||||
|
||||
# Run the recognition model
|
||||
tickmeter.start()
|
||||
result = recognizer.forward()
|
||||
tickmeter.stop()
|
||||
|
||||
# decode the result into text
|
||||
wordRecognized = decodeText(result)
|
||||
cv.putText(frame, wordRecognized, (int(vertices[1][0]), int(vertices[1][1])), cv.FONT_HERSHEY_SIMPLEX,
|
||||
0.5, (255, 0, 0))
|
||||
|
||||
for j in range(4):
|
||||
p1 = (vertices[j][0], vertices[j][1])
|
||||
p2 = (vertices[(j + 1) % 4][0], vertices[(j + 1) % 4][1])
|
||||
cv.line(frame, p1, p2, (0, 255, 0), 1)
|
||||
|
||||
# Put efficiency information
|
||||
label = 'Inference time: %.2f ms' % (tickmeter.getTimeMilli())
|
||||
cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
|
||||
|
||||
# Display the frame
|
||||
cv.imshow(kWinName,frame)
|
||||
cv.imshow(kWinName, frame)
|
||||
tickmeter.reset()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user