opencv/samples/dnn/text_detection.py

198 lines
8.3 KiB
Python
Raw Normal View History

'''
Text detection model (EAST): https://github.com/argman/EAST
Download link for EAST model: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
DB detector model:
https://drive.google.com/uc?export=download&id=17_ABp79PlFt9yPCxSaarVc_DKTmrSGGf
CRNN Text recognition model sourced from: https://github.com/meijieru/crnn.pytorch
How to convert from .pb to .onnx:
Using classes from: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
Additional converted ONNX text recognition models available for direct download:
Download link: https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing
These models are taken from: https://github.com/clovaai/deep-text-recognition-benchmark
Importing and using the CRNN model in PyTorch:
import torch
from models.crnn import CRNN
model = CRNN(32, 1, 37, 256)
model.load_state_dict(torch.load('crnn.pth'))
dummy_input = torch.randn(1, 1, 32, 100)
torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
Usage: python text_detection.py DB --ocr_model=<path to recognition model>
'''
import os
import cv2
import argparse
import numpy as np
from common import *
def help():
print(
'''
Use this script for Text Detection and Recognition using OpenCV.
Firstly, download required models using `download_models.py` (if not already done). Set environment variable OPENCV_DOWNLOAD_CACHE_DIR to specify where models should be downloaded. Also, point OPENCV_SAMPLES_DATA_PATH to opencv/samples/data.
Example: python download_models.py East
python download_models.py OCR
To run:
Example: python text_detection.py modelName(i.e. DB or East)
Detection model path can also be specified using --model argument and ocr model can be specified using --ocr_model.
'''
)
############ Add argument parser for command line arguments ############
def get_args_parser():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--input', default='right.jpg',
help='Path to input image or video file. Skip this argument to capture frames from a camera.')
parser.add_argument('--zoo', default=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models.yml'),
help='An optional path to file with preprocessing parameters.')
parser.add_argument('--thr', type=float, default=0.5,
help='Confidence threshold.')
parser.add_argument('--nms', type=float, default=0.4,
help='Non-maximum suppression threshold.')
parser.add_argument('--binary_threshold', type=float, default=0.3,
help='Confidence threshold for the binary map in DB detector. ')
parser.add_argument('--polygon_threshold', type=float, default=0.5,
help='Confidence threshold for polygons in DB detector.')
parser.add_argument('--max_candidate', type=int, default=200,
help='Max candidates for polygons in DB detector.')
parser.add_argument('--unclip_ratio', type=float, default=2.0,
help='Unclip ratio for DB detector.')
parser.add_argument('--vocabulary_path', default='alphabet_36.txt',
help='Path to vocabulary file.')
args, _ = parser.parse_known_args()
add_preproc_args(args.zoo, parser, 'text_detection', prefix="")
add_preproc_args(args.zoo, parser, 'text_recognition', prefix="ocr_")
parser = argparse.ArgumentParser(parents=[parser],
description='Text Detection and Recognition using OpenCV.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
return parser.parse_args()
def fourPointsTransform(frame, vertices):
vertices = np.asarray(vertices)
outputSize = (100, 32)
targetVertices = np.array([
[0, outputSize[1] - 1],
[0, 0],
[outputSize[0] - 1, 0],
[outputSize[0] - 1, outputSize[1] - 1]], dtype="float32")
rotationMatrix = cv2.getPerspectiveTransform(vertices, targetVertices)
result = cv2.warpPerspective(frame, rotationMatrix, outputSize)
return result
def main():
args = get_args_parser()
if args.alias is None or hasattr(args, 'help'):
help()
exit(1)
if args.download_sha is not None:
args.model = findModel(args.model, args.download_sha)
else:
args.model = findModel(args.model, args.sha1)
args.ocr_model = findModel(args.ocr_model, args.ocr_sha1)
args.input = findFile(args.input)
args.vocabulary_path = findFile(args.vocabulary_path)
frame = cv2.imread(args.input)
board = np.ones_like(frame)*255
stdSize = 0.8
stdWeight = 2
stdImgSize = 512
imgWidth = min(frame.shape[:2])
fontSize = (stdSize*imgWidth)/stdImgSize
fontThickness = max(1,(stdWeight*imgWidth)//stdImgSize)
if(args.alias == "DB"):
# DB Detector initialization
detector = cv2.dnn_TextDetectionModel_DB(args.model)
detector.setBinaryThreshold(args.binary_threshold)
detector.setPolygonThreshold(args.polygon_threshold)
detector.setUnclipRatio(args.unclip_ratio)
detector.setMaxCandidates(args.max_candidate)
# Setting input parameters specific to the DB model
detector.setInputParams(scale=args.scale, size=(args.width, args.height), mean=args.mean)
# Performing text detection
detResults = detector.detect(frame)
elif(args.alias == "East"):
# EAST Detector initialization
detector = cv2.dnn_TextDetectionModel_EAST(args.model)
detector.setConfidenceThreshold(args.thr)
detector.setNMSThreshold(args.nms)
# Setting input parameters specific to EAST model
detector.setInputParams(scale=args.scale, size=(args.width, args.height), mean=args.mean, swapRB=True)
# Perfroming text detection
detResults = detector.detect(frame)
# Open the vocabulary file and read lines into a list
with open(args.vocabulary_path, 'r') as voc_file:
vocabulary = [line.strip() for line in voc_file]
if args.ocr_model is None:
print("[ERROR] Please pass the path to the ocr model using --ocr_model to run the sample")
exit(1)
# Initialize the text recognition model with the specified model path
recognizer = cv2.dnn_TextRecognitionModel(args.ocr_model)
# Set the vocabulary for the model
recognizer.setVocabulary(vocabulary)
# Set the decoding method to 'CTC-greedy'
recognizer.setDecodeType("CTC-greedy")
recScale = 1.0 / 127.5
recMean = (127.5, 127.5, 127.5)
recInputSize = (100, 32)
recognizer.setInputParams(scale=recScale, size=recInputSize, mean=recMean)
if len(detResults) > 0:
recInput = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) if not args.rgb else frame.copy()
contours = []
for i, (quadrangle, _) in enumerate(zip(detResults[0], detResults[1])):
if isinstance(quadrangle, np.ndarray):
quadrangle = np.array(quadrangle).astype(np.float32)
if quadrangle is None or len(quadrangle) != 4:
print("Skipping a quadrangle with incorrect points or transformation failed.")
continue
contours.append(np.array(quadrangle, dtype=np.int32))
cropped = fourPointsTransform(recInput, quadrangle)
recognitionResult = recognizer.recognize(cropped)
print(f"{i}: '{recognitionResult}'")
try:
text_origin = (int(quadrangle[1][0]), int(quadrangle[0][1]))
cv2.putText(board, recognitionResult, text_origin, cv2.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
except Exception as e:
print("Failed to write text on the frame:", e)
else:
print("Skipping a detection with invalid format:", quadrangle)
cv2.polylines(frame, contours, True, (0, 255, 0), 1)
cv2.polylines(board, contours, True, (200, 255, 200), 1)
else:
print("No Text Detected.")
stacked = cv2.hconcat([frame, board])
cv2.imshow("Text Detection and Recognition", stacked)
cv2.waitKey(0)
if __name__ == "__main__":
main()