@ -1,25 +1,81 @@
'''
Text detection model : https : / / github . com / argman / EAST
Download link : https : / / www . dropbox . com / s / r2ingd0l3zt8hxs / frozen_east_text_detection . tar . gz ? dl = 1
Text recognition model taken from here : https : / / github . com / meijieru / crnn . pytorch
How to convert from pb to onnx :
Using classes from here : https : / / github . com / meijieru / crnn . pytorch / blob / master / models / crnn . py
import torch
import models . crnn as CRNN
model = CRNN ( 32 , 1 , 37 , 256 )
model . load_state_dict ( torch . load ( ' crnn.pth ' ) )
dummy_input = torch . randn ( 1 , 1 , 32 , 100 )
torch . onnx . export ( model , dummy_input , " crnn.onnx " , verbose = True )
'''
# Import required modules
# Import required modules
import numpy as np
import cv2 as cv
import cv2 as cv
import math
import math
import argparse
import argparse
############ Add argument parser for command line arguments ############
############ Add argument parser for command line arguments ############
parser = argparse . ArgumentParser ( description = ' Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2) ' )
parser = argparse . ArgumentParser (
parser . add_argument ( ' --input ' , help = ' Path to input image or video file. Skip this argument to capture frames from a camera. ' )
description = " Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
parser . add_argument ( ' --model ' , required = True ,
" EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2) "
help = ' Path to a binary .pb file of model contains trained weights. ' )
" The OCR model can be obtained from converting the pretrained CRNN model to .onnx format from the github repository https://github.com/meijieru/crnn.pytorch " )
parser . add_argument ( ' --input ' ,
help = ' Path to input image or video file. Skip this argument to capture frames from a camera. ' )
parser . add_argument ( ' --model ' , ' -m ' , required = True ,
help = ' Path to a binary .pb file contains trained detector network. ' )
parser . add_argument ( ' --ocr ' , default = " crnn.onnx " ,
help = " Path to a binary .pb or .onnx file contains trained recognition network " , )
parser . add_argument ( ' --width ' , type = int , default = 320 ,
parser . add_argument ( ' --width ' , type = int , default = 320 ,
help = ' Preprocess input image by resizing to a specific width. It should be multiple by 32. ' )
help = ' Preprocess input image by resizing to a specific width. It should be multiple by 32. ' )
parser . add_argument ( ' --height ' , type = int , default = 320 ,
parser . add_argument ( ' --height ' , type = int , default = 320 ,
help = ' Preprocess input image by resizing to a specific height. It should be multiple by 32. ' )
help = ' Preprocess input image by resizing to a specific height. It should be multiple by 32. ' )
parser . add_argument ( ' --thr ' , type = float , default = 0.5 ,
parser . add_argument ( ' --thr ' , type = float , default = 0.5 ,
help = ' Confidence threshold. ' )
help = ' Confidence threshold. ' )
parser . add_argument ( ' --nms ' , type = float , default = 0.4 ,
parser . add_argument ( ' --nms ' , type = float , default = 0.4 ,
help = ' Non-maximum suppression threshold. ' )
help = ' Non-maximum suppression threshold. ' )
args = parser . parse_args ( )
args = parser . parse_args ( )
############ Utility functions ############
############ Utility functions ############
def decode ( scores , geometry , scoreThresh ) :
def fourPointsTransform ( frame , vertices ) :
vertices = np . asarray ( vertices )
outputSize = ( 100 , 32 )
targetVertices = np . array ( [
[ 0 , outputSize [ 1 ] - 1 ] ,
[ 0 , 0 ] ,
[ outputSize [ 0 ] - 1 , 0 ] ,
[ outputSize [ 0 ] - 1 , outputSize [ 1 ] - 1 ] ] , dtype = " float32 " )
rotationMatrix = cv . getPerspectiveTransform ( vertices , targetVertices )
result = cv . warpPerspective ( frame , rotationMatrix , outputSize )
return result
def decodeText ( scores ) :
text = " "
alphabet = " 0123456789abcdefghijklmnopqrstuvwxyz "
for i in range ( scores . shape [ 0 ] ) :
c = np . argmax ( scores [ i ] [ 0 ] )
if c != 0 :
text + = alphabet [ c - 1 ]
else :
text + = ' - '
# adjacent same letters as well as background text must be removed to get the final output
char_list = [ ]
for i in range ( len ( text ) ) :
if text [ i ] != ' - ' and ( not ( i > 0 and text [ i ] == text [ i - 1 ] ) ) :
char_list . append ( text [ i ] )
return ' ' . join ( char_list )
def decodeBoundingBoxes ( scores , geometry , scoreThresh ) :
detections = [ ]
detections = [ ]
confidences = [ ]
confidences = [ ]
@ -47,7 +103,7 @@ def decode(scores, geometry, scoreThresh):
score = scoresData [ x ]
score = scoresData [ x ]
# If score is lower than threshold score, move to next x
# If score is lower than threshold score, move to next x
if ( score < scoreThresh ) :
if ( score < scoreThresh ) :
continue
continue
# Calculate offset
# Calculate offset
@ -66,24 +122,27 @@ def decode(scores, geometry, scoreThresh):
# Find points for rectangle
# Find points for rectangle
p1 = ( - sinA * h + offset [ 0 ] , - cosA * h + offset [ 1 ] )
p1 = ( - sinA * h + offset [ 0 ] , - cosA * h + offset [ 1 ] )
p3 = ( - cosA * w + offset [ 0 ] , sinA * w + offset [ 1 ] )
p3 = ( - cosA * w + offset [ 0 ] , sinA * w + offset [ 1 ] )
center = ( 0.5 * ( p1 [ 0 ] + p3 [ 0 ] ) , 0.5 * ( p1 [ 1 ] + p3 [ 1 ] ) )
center = ( 0.5 * ( p1 [ 0 ] + p3 [ 0 ] ) , 0.5 * ( p1 [ 1 ] + p3 [ 1 ] ) )
detections . append ( ( center , ( w , h ) , - 1 * angle * 180.0 / math . pi ) )
detections . append ( ( center , ( w , h ) , - 1 * angle * 180.0 / math . pi ) )
confidences . append ( float ( score ) )
confidences . append ( float ( score ) )
# Return detections and confidences
# Return detections and confidences
return [ detections , confidences ]
return [ detections , confidences ]
def main ( ) :
def main ( ) :
# Read and store arguments
# Read and store arguments
confThreshold = args . thr
confThreshold = args . thr
nmsThreshold = args . nms
nmsThreshold = args . nms
inpWidth = args . width
inpWidth = args . width
inpHeight = args . height
inpHeight = args . height
model = args . model
modelDetector = args . model
modelRecognition = args . ocr
# Load network
# Load network
net = cv . dnn . readNet ( model )
detector = cv . dnn . readNet ( modelDetector )
recognizer = cv . dnn . readNet ( modelRecognition )
# Create a new named window
# Create a new named window
kWinName = " EAST: An Efficient and Accurate Scene Text Detector "
kWinName = " EAST: An Efficient and Accurate Scene Text Detector "
@ -95,6 +154,7 @@ def main():
# Open a video file or an image file or a camera stream
# Open a video file or an image file or a camera stream
cap = cv . VideoCapture ( args . input if args . input else 0 )
cap = cv . VideoCapture ( args . input if args . input else 0 )
tickmeter = cv . TickMeter ( )
while cv . waitKey ( 1 ) < 0 :
while cv . waitKey ( 1 ) < 0 :
# Read frame
# Read frame
hasFrame , frame = cap . read ( )
hasFrame , frame = cap . read ( )
@ -111,19 +171,20 @@ def main():
# Create a 4D blob from frame.
# Create a 4D blob from frame.
blob = cv . dnn . blobFromImage ( frame , 1.0 , ( inpWidth , inpHeight ) , ( 123.68 , 116.78 , 103.94 ) , True , False )
blob = cv . dnn . blobFromImage ( frame , 1.0 , ( inpWidth , inpHeight ) , ( 123.68 , 116.78 , 103.94 ) , True , False )
# Run the model
# Run the detection model
net . setInput ( blob )
detector . setInput ( blob )
outs = net . forward ( outNames )
t , _ = net . getPerfProfile ( )
tickmeter . start ( )
label = ' Inference time: %.2f ms ' % ( t * 1000.0 / cv . getTickFrequency ( ) )
outs = detector . forward ( outNames )
tickmeter . stop ( )
# Get scores and geometry
# Get scores and geometry
scores = outs [ 0 ]
scores = outs [ 0 ]
geometry = outs [ 1 ]
geometry = outs [ 1 ]
[ boxes , confidences ] = decode ( scores , geometry , confThreshold )
[ boxes , confidences ] = decodeBoundingBoxes ( scores , geometry , confThreshold )
# Apply NMS
# Apply NMS
indices = cv . dnn . NMSBoxesRotated ( boxes , confidences , confThreshold , nmsThreshold )
indices = cv . dnn . NMSBoxesRotated ( boxes , confidences , confThreshold , nmsThreshold )
for i in indices :
for i in indices :
# get 4 corners of the rotated rect
# get 4 corners of the rotated rect
vertices = cv . boxPoints ( boxes [ i [ 0 ] ] )
vertices = cv . boxPoints ( boxes [ i [ 0 ] ] )
@ -131,16 +192,40 @@ def main():
for j in range ( 4 ) :
for j in range ( 4 ) :
vertices [ j ] [ 0 ] * = rW
vertices [ j ] [ 0 ] * = rW
vertices [ j ] [ 1 ] * = rH
vertices [ j ] [ 1 ] * = rH
# get cropped image using perspective transform
if modelRecognition :
cropped = fourPointsTransform ( frame , vertices )
cropped = cv . cvtColor ( cropped , cv . COLOR_BGR2GRAY )
# Create a 4D blob from cropped image
blob = cv . dnn . blobFromImage ( cropped , size = ( 100 , 32 ) , mean = 127.5 , scalefactor = 1 / 127.5 )
recognizer . setInput ( blob )
# Run the recognition model
tickmeter . start ( )
result = recognizer . forward ( )
tickmeter . stop ( )
# decode the result into text
wordRecognized = decodeText ( result )
cv . putText ( frame , wordRecognized , ( int ( vertices [ 1 ] [ 0 ] ) , int ( vertices [ 1 ] [ 1 ] ) ) , cv . FONT_HERSHEY_SIMPLEX ,
0.5 , ( 255 , 0 , 0 ) )
for j in range ( 4 ) :
for j in range ( 4 ) :
p1 = ( vertices [ j ] [ 0 ] , vertices [ j ] [ 1 ] )
p1 = ( vertices [ j ] [ 0 ] , vertices [ j ] [ 1 ] )
p2 = ( vertices [ ( j + 1 ) % 4 ] [ 0 ] , vertices [ ( j + 1 ) % 4 ] [ 1 ] )
p2 = ( vertices [ ( j + 1 ) % 4 ] [ 0 ] , vertices [ ( j + 1 ) % 4 ] [ 1 ] )
cv . line ( frame , p1 , p2 , ( 0 , 255 , 0 ) , 1 )
cv . line ( frame , p1 , p2 , ( 0 , 255 , 0 ) , 1 )
# Put efficiency information
# Put efficiency information
label = ' Inference time: %.2f ms ' % ( tickmeter . getTimeMilli ( ) )
cv . putText ( frame , label , ( 0 , 15 ) , cv . FONT_HERSHEY_SIMPLEX , 0.5 , ( 0 , 255 , 0 ) )
cv . putText ( frame , label , ( 0 , 15 ) , cv . FONT_HERSHEY_SIMPLEX , 0.5 , ( 0 , 255 , 0 ) )
# Display the frame
# Display the frame
cv . imshow ( kWinName , frame )
cv . imshow ( kWinName , frame )
tickmeter . reset ( )
if __name__ == " __main__ " :
if __name__ == " __main__ " :
main ( )
main ( )