2020-06-10 14:53:18 +08:00
'''
Text detection model : https : / / github . com / argman / EAST
Download link : https : / / www . dropbox . com / s / r2ingd0l3zt8hxs / frozen_east_text_detection . tar . gz ? dl = 1
2020-06-27 15:04:00 +08:00
CRNN Text recognition model taken from here : https : / / github . com / meijieru / crnn . pytorch
2020-06-10 14:53:18 +08:00
How to convert from pb to onnx :
Using classes from here : https : / / github . com / meijieru / crnn . pytorch / blob / master / models / crnn . py
2020-06-27 15:04:00 +08:00
More converted onnx text recognition models can be downloaded directly here :
Download link : https : / / drive . google . com / drive / folders / 1 cTbQ3nuZG - EKWak6emD_s8_hHXWz7lAr ? usp = sharing
And these models taken from here : https : / / github . com / clovaai / deep - text - recognition - benchmark
2020-06-10 14:53:18 +08:00
import torch
2020-06-27 15:04:00 +08:00
from models . crnn import CRNN
2020-06-10 14:53:18 +08:00
model = CRNN ( 32 , 1 , 37 , 256 )
model . load_state_dict ( torch . load ( ' crnn.pth ' ) )
dummy_input = torch . randn ( 1 , 1 , 32 , 100 )
torch . onnx . export ( model , dummy_input , " crnn.onnx " , verbose = True )
'''
2018-12-18 18:40:04 +08:00
# Import required modules
2020-06-10 14:53:18 +08:00
import numpy as np
2018-12-18 18:40:04 +08:00
import cv2 as cv
import math
import argparse
############ Add argument parser for command line arguments ############
2020-06-10 14:53:18 +08:00
parser = argparse . ArgumentParser (
description = " Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
" EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2) "
2020-06-27 15:04:00 +08:00
" The OCR model can be obtained from converting the pretrained CRNN model to .onnx format from the github repository https://github.com/meijieru/crnn.pytorch "
" Or you can download trained OCR model directly from https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing " )
2020-06-10 14:53:18 +08:00
parser . add_argument ( ' --input ' ,
help = ' Path to input image or video file. Skip this argument to capture frames from a camera. ' )
parser . add_argument ( ' --model ' , ' -m ' , required = True ,
help = ' Path to a binary .pb file contains trained detector network. ' )
parser . add_argument ( ' --ocr ' , default = " crnn.onnx " ,
help = " Path to a binary .pb or .onnx file contains trained recognition network " , )
2018-12-18 18:40:04 +08:00
parser . add_argument ( ' --width ' , type = int , default = 320 ,
help = ' Preprocess input image by resizing to a specific width. It should be multiple by 32. ' )
2020-06-10 14:53:18 +08:00
parser . add_argument ( ' --height ' , type = int , default = 320 ,
2018-12-18 18:40:04 +08:00
help = ' Preprocess input image by resizing to a specific height. It should be multiple by 32. ' )
2020-06-10 14:53:18 +08:00
parser . add_argument ( ' --thr ' , type = float , default = 0.5 ,
2018-12-18 18:40:04 +08:00
help = ' Confidence threshold. ' )
2020-06-10 14:53:18 +08:00
parser . add_argument ( ' --nms ' , type = float , default = 0.4 ,
2018-12-18 18:40:04 +08:00
help = ' Non-maximum suppression threshold. ' )
args = parser . parse_args ( )
2020-06-10 14:53:18 +08:00
2018-12-18 18:40:04 +08:00
############ Utility functions ############
2020-06-10 14:53:18 +08:00
def fourPointsTransform ( frame , vertices ) :
vertices = np . asarray ( vertices )
outputSize = ( 100 , 32 )
targetVertices = np . array ( [
[ 0 , outputSize [ 1 ] - 1 ] ,
[ 0 , 0 ] ,
[ outputSize [ 0 ] - 1 , 0 ] ,
[ outputSize [ 0 ] - 1 , outputSize [ 1 ] - 1 ] ] , dtype = " float32 " )
rotationMatrix = cv . getPerspectiveTransform ( vertices , targetVertices )
result = cv . warpPerspective ( frame , rotationMatrix , outputSize )
return result
def decodeText ( scores ) :
text = " "
alphabet = " 0123456789abcdefghijklmnopqrstuvwxyz "
for i in range ( scores . shape [ 0 ] ) :
c = np . argmax ( scores [ i ] [ 0 ] )
if c != 0 :
text + = alphabet [ c - 1 ]
else :
text + = ' - '
# adjacent same letters as well as background text must be removed to get the final output
char_list = [ ]
for i in range ( len ( text ) ) :
if text [ i ] != ' - ' and ( not ( i > 0 and text [ i ] == text [ i - 1 ] ) ) :
char_list . append ( text [ i ] )
return ' ' . join ( char_list )
def decodeBoundingBoxes ( scores , geometry , scoreThresh ) :
2018-12-18 18:40:04 +08:00
detections = [ ]
confidences = [ ]
############ CHECK DIMENSIONS AND SHAPES OF geometry AND scores ############
assert len ( scores . shape ) == 4 , " Incorrect dimensions of scores "
assert len ( geometry . shape ) == 4 , " Incorrect dimensions of geometry "
assert scores . shape [ 0 ] == 1 , " Invalid dimensions of scores "
assert geometry . shape [ 0 ] == 1 , " Invalid dimensions of geometry "
assert scores . shape [ 1 ] == 1 , " Invalid dimensions of scores "
assert geometry . shape [ 1 ] == 5 , " Invalid dimensions of geometry "
assert scores . shape [ 2 ] == geometry . shape [ 2 ] , " Invalid dimensions of scores and geometry "
assert scores . shape [ 3 ] == geometry . shape [ 3 ] , " Invalid dimensions of scores and geometry "
height = scores . shape [ 2 ]
width = scores . shape [ 3 ]
for y in range ( 0 , height ) :
# Extract data from scores
scoresData = scores [ 0 ] [ 0 ] [ y ]
x0_data = geometry [ 0 ] [ 0 ] [ y ]
x1_data = geometry [ 0 ] [ 1 ] [ y ]
x2_data = geometry [ 0 ] [ 2 ] [ y ]
x3_data = geometry [ 0 ] [ 3 ] [ y ]
anglesData = geometry [ 0 ] [ 4 ] [ y ]
for x in range ( 0 , width ) :
score = scoresData [ x ]
# If score is lower than threshold score, move to next x
2020-06-10 14:53:18 +08:00
if ( score < scoreThresh ) :
2018-12-18 18:40:04 +08:00
continue
# Calculate offset
offsetX = x * 4.0
offsetY = y * 4.0
angle = anglesData [ x ]
# Calculate cos and sin of angle
cosA = math . cos ( angle )
sinA = math . sin ( angle )
h = x0_data [ x ] + x2_data [ x ]
w = x1_data [ x ] + x3_data [ x ]
# Calculate offset
offset = ( [ offsetX + cosA * x1_data [ x ] + sinA * x2_data [ x ] , offsetY - sinA * x1_data [ x ] + cosA * x2_data [ x ] ] )
# Find points for rectangle
p1 = ( - sinA * h + offset [ 0 ] , - cosA * h + offset [ 1 ] )
2020-06-10 14:53:18 +08:00
p3 = ( - cosA * w + offset [ 0 ] , sinA * w + offset [ 1 ] )
center = ( 0.5 * ( p1 [ 0 ] + p3 [ 0 ] ) , 0.5 * ( p1 [ 1 ] + p3 [ 1 ] ) )
detections . append ( ( center , ( w , h ) , - 1 * angle * 180.0 / math . pi ) )
2018-12-18 18:40:04 +08:00
confidences . append ( float ( score ) )
# Return detections and confidences
return [ detections , confidences ]
2020-06-10 14:53:18 +08:00
2018-12-18 18:40:04 +08:00
def main ( ) :
# Read and store arguments
confThreshold = args . thr
nmsThreshold = args . nms
inpWidth = args . width
inpHeight = args . height
2020-06-10 14:53:18 +08:00
modelDetector = args . model
modelRecognition = args . ocr
2018-12-18 18:40:04 +08:00
# Load network
2020-06-10 14:53:18 +08:00
detector = cv . dnn . readNet ( modelDetector )
recognizer = cv . dnn . readNet ( modelRecognition )
2018-12-18 18:40:04 +08:00
# Create a new named window
kWinName = " EAST: An Efficient and Accurate Scene Text Detector "
cv . namedWindow ( kWinName , cv . WINDOW_NORMAL )
outNames = [ ]
outNames . append ( " feature_fusion/Conv_7/Sigmoid " )
outNames . append ( " feature_fusion/concat_3 " )
# Open a video file or an image file or a camera stream
cap = cv . VideoCapture ( args . input if args . input else 0 )
2020-06-10 14:53:18 +08:00
tickmeter = cv . TickMeter ( )
2018-12-18 18:40:04 +08:00
while cv . waitKey ( 1 ) < 0 :
# Read frame
hasFrame , frame = cap . read ( )
if not hasFrame :
cv . waitKey ( )
break
# Get frame height and width
height_ = frame . shape [ 0 ]
width_ = frame . shape [ 1 ]
rW = width_ / float ( inpWidth )
rH = height_ / float ( inpHeight )
# Create a 4D blob from frame.
blob = cv . dnn . blobFromImage ( frame , 1.0 , ( inpWidth , inpHeight ) , ( 123.68 , 116.78 , 103.94 ) , True , False )
2020-06-10 14:53:18 +08:00
# Run the detection model
detector . setInput ( blob )
tickmeter . start ( )
outs = detector . forward ( outNames )
tickmeter . stop ( )
2018-12-18 18:40:04 +08:00
# Get scores and geometry
scores = outs [ 0 ]
geometry = outs [ 1 ]
2020-06-10 14:53:18 +08:00
[ boxes , confidences ] = decodeBoundingBoxes ( scores , geometry , confThreshold )
2018-12-18 18:40:04 +08:00
# Apply NMS
2020-06-10 14:53:18 +08:00
indices = cv . dnn . NMSBoxesRotated ( boxes , confidences , confThreshold , nmsThreshold )
2018-12-18 18:40:04 +08:00
for i in indices :
# get 4 corners of the rotated rect
2022-02-10 00:14:05 +08:00
vertices = cv . boxPoints ( boxes [ i ] )
2018-12-18 18:40:04 +08:00
# scale the bounding box coordinates based on the respective ratios
for j in range ( 4 ) :
vertices [ j ] [ 0 ] * = rW
vertices [ j ] [ 1 ] * = rH
2020-06-10 14:53:18 +08:00
# get cropped image using perspective transform
if modelRecognition :
cropped = fourPointsTransform ( frame , vertices )
cropped = cv . cvtColor ( cropped , cv . COLOR_BGR2GRAY )
# Create a 4D blob from cropped image
blob = cv . dnn . blobFromImage ( cropped , size = ( 100 , 32 ) , mean = 127.5 , scalefactor = 1 / 127.5 )
recognizer . setInput ( blob )
# Run the recognition model
tickmeter . start ( )
result = recognizer . forward ( )
tickmeter . stop ( )
# decode the result into text
wordRecognized = decodeText ( result )
cv . putText ( frame , wordRecognized , ( int ( vertices [ 1 ] [ 0 ] ) , int ( vertices [ 1 ] [ 1 ] ) ) , cv . FONT_HERSHEY_SIMPLEX ,
0.5 , ( 255 , 0 , 0 ) )
2018-12-18 18:40:04 +08:00
for j in range ( 4 ) :
2021-09-24 14:35:42 +08:00
p1 = ( int ( vertices [ j ] [ 0 ] ) , int ( vertices [ j ] [ 1 ] ) )
p2 = ( int ( vertices [ ( j + 1 ) % 4 ] [ 0 ] ) , int ( vertices [ ( j + 1 ) % 4 ] [ 1 ] ) )
2019-10-16 23:49:33 +08:00
cv . line ( frame , p1 , p2 , ( 0 , 255 , 0 ) , 1 )
2018-12-18 18:40:04 +08:00
# Put efficiency information
2020-06-10 14:53:18 +08:00
label = ' Inference time: %.2f ms ' % ( tickmeter . getTimeMilli ( ) )
2018-12-18 18:40:04 +08:00
cv . putText ( frame , label , ( 0 , 15 ) , cv . FONT_HERSHEY_SIMPLEX , 0.5 , ( 0 , 255 , 0 ) )
# Display the frame
2020-06-10 14:53:18 +08:00
cv . imshow ( kWinName , frame )
tickmeter . reset ( )
2018-12-18 18:40:04 +08:00
if __name__ == " __main__ " :
main ( )