Merge pull request #16955 from themechanicalcoder:text_recognition

* add text recognition sample

* fix pylint warning

* made changes according to the c++ example

* fix errors

* add text recognition sample

* update text detection sample
pull/17511/head
Gourav Roy 5 years ago committed by GitHub
parent 0fb3b8db72
commit 1b336bb602
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 111
      samples/dnn/text_detection.py

@ -1,13 +1,35 @@
'''
Text detection model: https://github.com/argman/EAST
Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
How to convert from pb to onnx:
Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
import torch
import models.crnn as CRNN
model = CRNN(32, 1, 37, 256)
model.load_state_dict(torch.load('crnn.pth'))
dummy_input = torch.randn(1, 1, 32, 100)
torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
'''
# Import required modules # Import required modules
import numpy as np
import cv2 as cv import cv2 as cv
import math import math
import argparse import argparse
############ Add argument parser for command line arguments ############ ############ Add argument parser for command line arguments ############
parser = argparse.ArgumentParser(description='Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)') parser = argparse.ArgumentParser(
parser.add_argument('--input', help='Path to input image or video file. Skip this argument to capture frames from a camera.') description="Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
parser.add_argument('--model', required=True, "EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)"
help='Path to a binary .pb file of model contains trained weights.') "The OCR model can be obtained from converting the pretrained CRNN model to .onnx format from the github repository https://github.com/meijieru/crnn.pytorch")
parser.add_argument('--input',
help='Path to input image or video file. Skip this argument to capture frames from a camera.')
parser.add_argument('--model', '-m', required=True,
help='Path to a binary .pb file contains trained detector network.')
parser.add_argument('--ocr', default="crnn.onnx",
help="Path to a binary .pb or .onnx file contains trained recognition network", )
parser.add_argument('--width', type=int, default=320, parser.add_argument('--width', type=int, default=320,
help='Preprocess input image by resizing to a specific width. It should be multiple by 32.') help='Preprocess input image by resizing to a specific width. It should be multiple by 32.')
parser.add_argument('--height', type=int, default=320, parser.add_argument('--height', type=int, default=320,
@ -18,8 +40,42 @@ parser.add_argument('--nms',type=float, default=0.4,
help='Non-maximum suppression threshold.') help='Non-maximum suppression threshold.')
args = parser.parse_args() args = parser.parse_args()
############ Utility functions ############ ############ Utility functions ############
def decode(scores, geometry, scoreThresh):
def fourPointsTransform(frame, vertices):
vertices = np.asarray(vertices)
outputSize = (100, 32)
targetVertices = np.array([
[0, outputSize[1] - 1],
[0, 0],
[outputSize[0] - 1, 0],
[outputSize[0] - 1, outputSize[1] - 1]], dtype="float32")
rotationMatrix = cv.getPerspectiveTransform(vertices, targetVertices)
result = cv.warpPerspective(frame, rotationMatrix, outputSize)
return result
def decodeText(scores):
text = ""
alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
for i in range(scores.shape[0]):
c = np.argmax(scores[i][0])
if c != 0:
text += alphabet[c - 1]
else:
text += '-'
# adjacent same letters as well as background text must be removed to get the final output
char_list = []
for i in range(len(text)):
if text[i] != '-' and (not (i > 0 and text[i] == text[i - 1])):
char_list.append(text[i])
return ''.join(char_list)
def decodeBoundingBoxes(scores, geometry, scoreThresh):
detections = [] detections = []
confidences = [] confidences = []
@ -74,16 +130,19 @@ def decode(scores, geometry, scoreThresh):
# Return detections and confidences # Return detections and confidences
return [detections, confidences] return [detections, confidences]
def main(): def main():
# Read and store arguments # Read and store arguments
confThreshold = args.thr confThreshold = args.thr
nmsThreshold = args.nms nmsThreshold = args.nms
inpWidth = args.width inpWidth = args.width
inpHeight = args.height inpHeight = args.height
model = args.model modelDetector = args.model
modelRecognition = args.ocr
# Load network # Load network
net = cv.dnn.readNet(model) detector = cv.dnn.readNet(modelDetector)
recognizer = cv.dnn.readNet(modelRecognition)
# Create a new named window # Create a new named window
kWinName = "EAST: An Efficient and Accurate Scene Text Detector" kWinName = "EAST: An Efficient and Accurate Scene Text Detector"
@ -95,6 +154,7 @@ def main():
# Open a video file or an image file or a camera stream # Open a video file or an image file or a camera stream
cap = cv.VideoCapture(args.input if args.input else 0) cap = cv.VideoCapture(args.input if args.input else 0)
tickmeter = cv.TickMeter()
while cv.waitKey(1) < 0: while cv.waitKey(1) < 0:
# Read frame # Read frame
hasFrame, frame = cap.read() hasFrame, frame = cap.read()
@ -111,16 +171,17 @@ def main():
# Create a 4D blob from frame. # Create a 4D blob from frame.
blob = cv.dnn.blobFromImage(frame, 1.0, (inpWidth, inpHeight), (123.68, 116.78, 103.94), True, False) blob = cv.dnn.blobFromImage(frame, 1.0, (inpWidth, inpHeight), (123.68, 116.78, 103.94), True, False)
# Run the model # Run the detection model
net.setInput(blob) detector.setInput(blob)
outs = net.forward(outNames)
t, _ = net.getPerfProfile() tickmeter.start()
label = 'Inference time: %.2f ms' % (t * 1000.0 / cv.getTickFrequency()) outs = detector.forward(outNames)
tickmeter.stop()
# Get scores and geometry # Get scores and geometry
scores = outs[0] scores = outs[0]
geometry = outs[1] geometry = outs[1]
[boxes, confidences] = decode(scores, geometry, confThreshold) [boxes, confidences] = decodeBoundingBoxes(scores, geometry, confThreshold)
# Apply NMS # Apply NMS
indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold, nmsThreshold) indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold, nmsThreshold)
@ -131,16 +192,40 @@ def main():
for j in range(4): for j in range(4):
vertices[j][0] *= rW vertices[j][0] *= rW
vertices[j][1] *= rH vertices[j][1] *= rH
# get cropped image using perspective transform
if modelRecognition:
cropped = fourPointsTransform(frame, vertices)
cropped = cv.cvtColor(cropped, cv.COLOR_BGR2GRAY)
# Create a 4D blob from cropped image
blob = cv.dnn.blobFromImage(cropped, size=(100, 32), mean=127.5, scalefactor=1 / 127.5)
recognizer.setInput(blob)
# Run the recognition model
tickmeter.start()
result = recognizer.forward()
tickmeter.stop()
# decode the result into text
wordRecognized = decodeText(result)
cv.putText(frame, wordRecognized, (int(vertices[1][0]), int(vertices[1][1])), cv.FONT_HERSHEY_SIMPLEX,
0.5, (255, 0, 0))
for j in range(4): for j in range(4):
p1 = (vertices[j][0], vertices[j][1]) p1 = (vertices[j][0], vertices[j][1])
p2 = (vertices[(j + 1) % 4][0], vertices[(j + 1) % 4][1]) p2 = (vertices[(j + 1) % 4][0], vertices[(j + 1) % 4][1])
cv.line(frame, p1, p2, (0, 255, 0), 1) cv.line(frame, p1, p2, (0, 255, 0), 1)
# Put efficiency information # Put efficiency information
label = 'Inference time: %.2f ms' % (tickmeter.getTimeMilli())
cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0)) cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
# Display the frame # Display the frame
cv.imshow(kWinName, frame) cv.imshow(kWinName, frame)
tickmeter.reset()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

Loading…
Cancel
Save