diff --git a/samples/dnn/text_detection.py b/samples/dnn/text_detection.py index 9ea4c10190..7014a80148 100644 --- a/samples/dnn/text_detection.py +++ b/samples/dnn/text_detection.py @@ -1,25 +1,81 @@ +''' + Text detection model: https://github.com/argman/EAST + Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1 + Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch + How to convert from pb to onnx: + Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py + import torch + import models.crnn as CRNN + model = CRNN(32, 1, 37, 256) + model.load_state_dict(torch.load('crnn.pth')) + dummy_input = torch.randn(1, 1, 32, 100) + torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True) +''' + + # Import required modules +import numpy as np import cv2 as cv import math import argparse ############ Add argument parser for command line arguments ############ -parser = argparse.ArgumentParser(description='Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)') -parser.add_argument('--input', help='Path to input image or video file. Skip this argument to capture frames from a camera.') -parser.add_argument('--model', required=True, - help='Path to a binary .pb file of model contains trained weights.') +parser = argparse.ArgumentParser( + description="Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of " + "EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)" + "The OCR model can be obtained from converting the pretrained CRNN model to .onnx format from the github repository https://github.com/meijieru/crnn.pytorch") +parser.add_argument('--input', + help='Path to input image or video file. Skip this argument to capture frames from a camera.') +parser.add_argument('--model', '-m', required=True, + help='Path to a binary .pb file contains trained detector network.') +parser.add_argument('--ocr', default="crnn.onnx", + help="Path to a binary .pb or .onnx file contains trained recognition network", ) parser.add_argument('--width', type=int, default=320, help='Preprocess input image by resizing to a specific width. It should be multiple by 32.') -parser.add_argument('--height',type=int, default=320, +parser.add_argument('--height', type=int, default=320, help='Preprocess input image by resizing to a specific height. It should be multiple by 32.') -parser.add_argument('--thr',type=float, default=0.5, +parser.add_argument('--thr', type=float, default=0.5, help='Confidence threshold.') -parser.add_argument('--nms',type=float, default=0.4, +parser.add_argument('--nms', type=float, default=0.4, help='Non-maximum suppression threshold.') args = parser.parse_args() + ############ Utility functions ############ -def decode(scores, geometry, scoreThresh): + +def fourPointsTransform(frame, vertices): + vertices = np.asarray(vertices) + outputSize = (100, 32) + targetVertices = np.array([ + [0, outputSize[1] - 1], + [0, 0], + [outputSize[0] - 1, 0], + [outputSize[0] - 1, outputSize[1] - 1]], dtype="float32") + + rotationMatrix = cv.getPerspectiveTransform(vertices, targetVertices) + result = cv.warpPerspective(frame, rotationMatrix, outputSize) + return result + + +def decodeText(scores): + text = "" + alphabet = "0123456789abcdefghijklmnopqrstuvwxyz" + for i in range(scores.shape[0]): + c = np.argmax(scores[i][0]) + if c != 0: + text += alphabet[c - 1] + else: + text += '-' + + # adjacent same letters as well as background text must be removed to get the final output + char_list = [] + for i in range(len(text)): + if text[i] != '-' and (not (i > 0 and text[i] == text[i - 1])): + char_list.append(text[i]) + return ''.join(char_list) + + +def decodeBoundingBoxes(scores, geometry, scoreThresh): detections = [] confidences = [] @@ -47,7 +103,7 @@ def decode(scores, geometry, scoreThresh): score = scoresData[x] # If score is lower than threshold score, move to next x - if(score < scoreThresh): + if (score < scoreThresh): continue # Calculate offset @@ -66,24 +122,27 @@ def decode(scores, geometry, scoreThresh): # Find points for rectangle p1 = (-sinA * h + offset[0], -cosA * h + offset[1]) - p3 = (-cosA * w + offset[0], sinA * w + offset[1]) - center = (0.5*(p1[0]+p3[0]), 0.5*(p1[1]+p3[1])) - detections.append((center, (w,h), -1*angle * 180.0 / math.pi)) + p3 = (-cosA * w + offset[0], sinA * w + offset[1]) + center = (0.5 * (p1[0] + p3[0]), 0.5 * (p1[1] + p3[1])) + detections.append((center, (w, h), -1 * angle * 180.0 / math.pi)) confidences.append(float(score)) # Return detections and confidences return [detections, confidences] + def main(): # Read and store arguments confThreshold = args.thr nmsThreshold = args.nms inpWidth = args.width inpHeight = args.height - model = args.model + modelDetector = args.model + modelRecognition = args.ocr # Load network - net = cv.dnn.readNet(model) + detector = cv.dnn.readNet(modelDetector) + recognizer = cv.dnn.readNet(modelRecognition) # Create a new named window kWinName = "EAST: An Efficient and Accurate Scene Text Detector" @@ -95,6 +154,7 @@ def main(): # Open a video file or an image file or a camera stream cap = cv.VideoCapture(args.input if args.input else 0) + tickmeter = cv.TickMeter() while cv.waitKey(1) < 0: # Read frame hasFrame, frame = cap.read() @@ -111,19 +171,20 @@ def main(): # Create a 4D blob from frame. blob = cv.dnn.blobFromImage(frame, 1.0, (inpWidth, inpHeight), (123.68, 116.78, 103.94), True, False) - # Run the model - net.setInput(blob) - outs = net.forward(outNames) - t, _ = net.getPerfProfile() - label = 'Inference time: %.2f ms' % (t * 1000.0 / cv.getTickFrequency()) + # Run the detection model + detector.setInput(blob) + + tickmeter.start() + outs = detector.forward(outNames) + tickmeter.stop() # Get scores and geometry scores = outs[0] geometry = outs[1] - [boxes, confidences] = decode(scores, geometry, confThreshold) + [boxes, confidences] = decodeBoundingBoxes(scores, geometry, confThreshold) # Apply NMS - indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold,nmsThreshold) + indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold, nmsThreshold) for i in indices: # get 4 corners of the rotated rect vertices = cv.boxPoints(boxes[i[0]]) @@ -131,16 +192,40 @@ def main(): for j in range(4): vertices[j][0] *= rW vertices[j][1] *= rH + + + # get cropped image using perspective transform + if modelRecognition: + cropped = fourPointsTransform(frame, vertices) + cropped = cv.cvtColor(cropped, cv.COLOR_BGR2GRAY) + + # Create a 4D blob from cropped image + blob = cv.dnn.blobFromImage(cropped, size=(100, 32), mean=127.5, scalefactor=1 / 127.5) + recognizer.setInput(blob) + + # Run the recognition model + tickmeter.start() + result = recognizer.forward() + tickmeter.stop() + + # decode the result into text + wordRecognized = decodeText(result) + cv.putText(frame, wordRecognized, (int(vertices[1][0]), int(vertices[1][1])), cv.FONT_HERSHEY_SIMPLEX, + 0.5, (255, 0, 0)) + for j in range(4): p1 = (vertices[j][0], vertices[j][1]) p2 = (vertices[(j + 1) % 4][0], vertices[(j + 1) % 4][1]) cv.line(frame, p1, p2, (0, 255, 0), 1) # Put efficiency information + label = 'Inference time: %.2f ms' % (tickmeter.getTimeMilli()) cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0)) # Display the frame - cv.imshow(kWinName,frame) + cv.imshow(kWinName, frame) + tickmeter.reset() + if __name__ == "__main__": main()