Merge pull request #16941 from rngtna:examples_dnn_text_decoder

pull/17231/head^2
Alexander Alekhin 5 years ago committed by GitHub
commit 4f1ba5e69e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 135
      samples/dnn/text_detection.cpp

@ -1,3 +1,20 @@
/*
Text detection model: https://github.com/argman/EAST
Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
How to convert from pb to onnx:
Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
import torch
import models.crnn as crnn
model = CRNN(32, 1, 37, 256)
model.load_state_dict(torch.load('crnn.pth'))
dummy_input = torch.randn(1, 1, 32, 100)
torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
*/
#include <opencv2/imgproc.hpp> #include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp> #include <opencv2/highgui.hpp>
#include <opencv2/dnn.hpp> #include <opencv2/dnn.hpp>
@ -8,15 +25,20 @@ using namespace cv::dnn;
const char* keys = const char* keys =
"{ help h | | Print help message. }" "{ help h | | Print help message. }"
"{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}" "{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
"{ model m | | Path to a binary .pb file contains trained network.}" "{ model m | | Path to a binary .pb file contains trained detector network.}"
"{ ocr | | Path to a binary .pb or .onnx file contains trained recognition network.}"
"{ width | 320 | Preprocess input image by resizing to a specific width. It should be multiple by 32. }" "{ width | 320 | Preprocess input image by resizing to a specific width. It should be multiple by 32. }"
"{ height | 320 | Preprocess input image by resizing to a specific height. It should be multiple by 32. }" "{ height | 320 | Preprocess input image by resizing to a specific height. It should be multiple by 32. }"
"{ thr | 0.5 | Confidence threshold. }" "{ thr | 0.5 | Confidence threshold. }"
"{ nms | 0.4 | Non-maximum suppression threshold. }"; "{ nms | 0.4 | Non-maximum suppression threshold. }";
void decode(const Mat& scores, const Mat& geometry, float scoreThresh, void decodeBoundingBoxes(const Mat& scores, const Mat& geometry, float scoreThresh,
std::vector<RotatedRect>& detections, std::vector<float>& confidences); std::vector<RotatedRect>& detections, std::vector<float>& confidences);
void fourPointsTransform(const Mat& frame, Point2f vertices[4], Mat& result);
void decodeText(const Mat& scores, std::string& text);
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
// Parse command line arguments. // Parse command line arguments.
@ -33,7 +55,8 @@ int main(int argc, char** argv)
float nmsThreshold = parser.get<float>("nms"); float nmsThreshold = parser.get<float>("nms");
int inpWidth = parser.get<int>("width"); int inpWidth = parser.get<int>("width");
int inpHeight = parser.get<int>("height"); int inpHeight = parser.get<int>("height");
String model = parser.get<String>("model"); String modelDecoder = parser.get<String>("model");
String modelRecognition = parser.get<String>("ocr");
if (!parser.check()) if (!parser.check())
{ {
@ -41,17 +64,19 @@ int main(int argc, char** argv)
return 1; return 1;
} }
CV_Assert(!model.empty()); CV_Assert(!modelDecoder.empty());
// Load network. // Load networks.
Net net = readNet(model); Net detector = readNet(modelDecoder);
Net recognizer;
if (!modelRecognition.empty())
recognizer = readNet(modelRecognition);
// Open a video file or an image file or a camera stream. // Open a video file or an image file or a camera stream.
VideoCapture cap; VideoCapture cap;
if (parser.has("input")) bool openSuccess = parser.has("input") ? cap.open(parser.get<String>("input")) : cap.open(0);
cap.open(parser.get<String>("input")); CV_Assert(openSuccess);
else
cap.open(0);
static const std::string kWinName = "EAST: An Efficient and Accurate Scene Text Detector"; static const std::string kWinName = "EAST: An Efficient and Accurate Scene Text Detector";
namedWindow(kWinName, WINDOW_NORMAL); namedWindow(kWinName, WINDOW_NORMAL);
@ -62,6 +87,7 @@ int main(int argc, char** argv)
outNames[1] = "feature_fusion/concat_3"; outNames[1] = "feature_fusion/concat_3";
Mat frame, blob; Mat frame, blob;
TickMeter tickMeter;
while (waitKey(1) < 0) while (waitKey(1) < 0)
{ {
cap >> frame; cap >> frame;
@ -72,8 +98,10 @@ int main(int argc, char** argv)
} }
blobFromImage(frame, blob, 1.0, Size(inpWidth, inpHeight), Scalar(123.68, 116.78, 103.94), true, false); blobFromImage(frame, blob, 1.0, Size(inpWidth, inpHeight), Scalar(123.68, 116.78, 103.94), true, false);
net.setInput(blob); detector.setInput(blob);
net.forward(outs, outNames); tickMeter.start();
detector.forward(outs, outNames);
tickMeter.stop();
Mat scores = outs[0]; Mat scores = outs[0];
Mat geometry = outs[1]; Mat geometry = outs[1];
@ -81,42 +109,63 @@ int main(int argc, char** argv)
// Decode predicted bounding boxes. // Decode predicted bounding boxes.
std::vector<RotatedRect> boxes; std::vector<RotatedRect> boxes;
std::vector<float> confidences; std::vector<float> confidences;
decode(scores, geometry, confThreshold, boxes, confidences); decodeBoundingBoxes(scores, geometry, confThreshold, boxes, confidences);
// Apply non-maximum suppression procedure. // Apply non-maximum suppression procedure.
std::vector<int> indices; std::vector<int> indices;
NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices); NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices);
// Render detections.
Point2f ratio((float)frame.cols / inpWidth, (float)frame.rows / inpHeight); Point2f ratio((float)frame.cols / inpWidth, (float)frame.rows / inpHeight);
// Render text.
for (size_t i = 0; i < indices.size(); ++i) for (size_t i = 0; i < indices.size(); ++i)
{ {
RotatedRect& box = boxes[indices[i]]; RotatedRect& box = boxes[indices[i]];
Point2f vertices[4]; Point2f vertices[4];
box.points(vertices); box.points(vertices);
for (int j = 0; j < 4; ++j) for (int j = 0; j < 4; ++j)
{ {
vertices[j].x *= ratio.x; vertices[j].x *= ratio.x;
vertices[j].y *= ratio.y; vertices[j].y *= ratio.y;
} }
if (!modelRecognition.empty())
{
Mat cropped;
fourPointsTransform(frame, vertices, cropped);
cvtColor(cropped, cropped, cv::COLOR_BGR2GRAY);
Mat blobCrop = blobFromImage(cropped, 1.0/127.5, Size(), Scalar::all(127.5));
recognizer.setInput(blobCrop);
tickMeter.start();
Mat result = recognizer.forward();
tickMeter.stop();
std::string wordRecognized = "";
decodeText(result, wordRecognized);
putText(frame, wordRecognized, vertices[1], FONT_HERSHEY_SIMPLEX, 1.5, Scalar(0, 0, 255));
}
for (int j = 0; j < 4; ++j) for (int j = 0; j < 4; ++j)
line(frame, vertices[j], vertices[(j + 1) % 4], Scalar(0, 255, 0), 1); line(frame, vertices[j], vertices[(j + 1) % 4], Scalar(0, 255, 0), 1);
} }
// Put efficiency information. // Put efficiency information.
std::vector<double> layersTimes; std::string label = format("Inference time: %.2f ms", tickMeter.getTimeMilli());
double freq = getTickFrequency() / 1000;
double t = net.getPerfProfile(layersTimes) / freq;
std::string label = format("Inference time: %.2f ms", t);
putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0)); putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
imshow(kWinName, frame); imshow(kWinName, frame);
tickMeter.reset();
} }
return 0; return 0;
} }
void decode(const Mat& scores, const Mat& geometry, float scoreThresh, void decodeBoundingBoxes(const Mat& scores, const Mat& geometry, float scoreThresh,
std::vector<RotatedRect>& detections, std::vector<float>& confidences) std::vector<RotatedRect>& detections, std::vector<float>& confidences)
{ {
detections.clear(); detections.clear();
@ -159,3 +208,51 @@ void decode(const Mat& scores, const Mat& geometry, float scoreThresh,
} }
} }
} }
void fourPointsTransform(const Mat& frame, Point2f vertices[4], Mat& result)
{
const Size outputSize = Size(100, 32);
Point2f targetVertices[4] = {Point(0, outputSize.height - 1),
Point(0, 0), Point(outputSize.width - 1, 0),
Point(outputSize.width - 1, outputSize.height - 1),
};
Mat rotationMatrix = getPerspectiveTransform(vertices, targetVertices);
warpPerspective(frame, result, rotationMatrix, outputSize);
}
void decodeText(const Mat& scores, std::string& text)
{
static const std::string alphabet = "0123456789abcdefghijklmnopqrstuvwxyz";
Mat scoresMat = scores.reshape(1, scores.size[0]);
std::vector<char> elements;
elements.reserve(scores.size[0]);
for (int rowIndex = 0; rowIndex < scoresMat.rows; ++rowIndex)
{
Point p;
minMaxLoc(scoresMat.row(rowIndex), 0, 0, 0, &p);
if (p.x > 0 && static_cast<size_t>(p.x) <= alphabet.size())
{
elements.push_back(alphabet[p.x - 1]);
}
else
{
elements.push_back('-');
}
}
if (elements.size() > 0 && elements[0] != '-')
text += elements[0];
for (size_t elementIndex = 1; elementIndex < elements.size(); ++elementIndex)
{
if (elementIndex > 0 && elements[elementIndex] != '-' &&
elements[elementIndex - 1] != elements[elementIndex])
{
text += elements[elementIndex];
}
}
}
Loading…
Cancel
Save