mirror of https://github.com/opencv/opencv.git
Open Source Computer Vision Library
https://opencv.org/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
144 lines
4.9 KiB
144 lines
4.9 KiB
#include <iostream> |
|
#include <fstream> |
|
|
|
#include <opencv2/imgproc.hpp> |
|
#include <opencv2/highgui.hpp> |
|
#include <opencv2/dnn/dnn.hpp> |
|
|
|
using namespace cv; |
|
using namespace cv::dnn; |
|
|
|
String keys = |
|
"{ help h | | Print help message. }" |
|
"{ inputImage i | | Path to an input image. Skip this argument to capture frames from a camera. }" |
|
"{ modelPath mp | | Path to a binary .onnx file contains trained CRNN text recognition model. " |
|
"Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}" |
|
"{ RGBInput rgb |0| 0: imread with flags=IMREAD_GRAYSCALE; 1: imread with flags=IMREAD_COLOR. }" |
|
"{ evaluate e |false| false: predict with input images; true: evaluate on benchmarks. }" |
|
"{ evalDataPath edp | | Path to benchmarks for evaluation. " |
|
"Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}" |
|
"{ vocabularyPath vp | alphabet_36.txt | Path to recognition vocabulary. " |
|
"Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"; |
|
|
|
String convertForEval(String &input); |
|
|
|
int main(int argc, char** argv) |
|
{ |
|
// Parse arguments |
|
CommandLineParser parser(argc, argv, keys); |
|
parser.about("Use this script to run the PyTorch implementation of " |
|
"An End-to-End Trainable Neural Network for Image-based SequenceRecognition and Its Application to Scene Text Recognition " |
|
"(https://arxiv.org/abs/1507.05717)"); |
|
if (argc == 1 || parser.has("help")) |
|
{ |
|
parser.printMessage(); |
|
return 0; |
|
} |
|
|
|
String modelPath = parser.get<String>("modelPath"); |
|
String vocPath = parser.get<String>("vocabularyPath"); |
|
int imreadRGB = parser.get<int>("RGBInput"); |
|
|
|
if (!parser.check()) |
|
{ |
|
parser.printErrors(); |
|
return 1; |
|
} |
|
|
|
// Load the network |
|
CV_Assert(!modelPath.empty()); |
|
TextRecognitionModel recognizer(modelPath); |
|
|
|
// Load vocabulary |
|
CV_Assert(!vocPath.empty()); |
|
std::ifstream vocFile; |
|
vocFile.open(samples::findFile(vocPath)); |
|
CV_Assert(vocFile.is_open()); |
|
String vocLine; |
|
std::vector<String> vocabulary; |
|
while (std::getline(vocFile, vocLine)) { |
|
vocabulary.push_back(vocLine); |
|
} |
|
recognizer.setVocabulary(vocabulary); |
|
recognizer.setDecodeType("CTC-greedy"); |
|
|
|
// Set parameters |
|
double scale = 1.0 / 127.5; |
|
Scalar mean = Scalar(127.5, 127.5, 127.5); |
|
Size inputSize = Size(100, 32); |
|
recognizer.setInputParams(scale, inputSize, mean); |
|
|
|
if (parser.get<bool>("evaluate")) |
|
{ |
|
// For evaluation |
|
String evalDataPath = parser.get<String>("evalDataPath"); |
|
CV_Assert(!evalDataPath.empty()); |
|
String gtPath = evalDataPath + "/test_gts.txt"; |
|
std::ifstream evalGts; |
|
evalGts.open(gtPath); |
|
CV_Assert(evalGts.is_open()); |
|
|
|
String gtLine; |
|
int cntRight=0, cntAll=0; |
|
TickMeter timer; |
|
timer.reset(); |
|
|
|
while (std::getline(evalGts, gtLine)) { |
|
size_t splitLoc = gtLine.find_first_of(' '); |
|
String imgPath = evalDataPath + '/' + gtLine.substr(0, splitLoc); |
|
String gt = gtLine.substr(splitLoc+1); |
|
|
|
// Inference |
|
Mat frame = imread(samples::findFile(imgPath), imreadRGB); |
|
CV_Assert(!frame.empty()); |
|
timer.start(); |
|
std::string recognitionResult = recognizer.recognize(frame); |
|
timer.stop(); |
|
|
|
if (gt == convertForEval(recognitionResult)) |
|
cntRight++; |
|
|
|
cntAll++; |
|
} |
|
std::cout << "Accuracy(%): " << (double)(cntRight) / (double)(cntAll) << std::endl; |
|
std::cout << "Average Inference Time(ms): " << timer.getTimeMilli() / (double)(cntAll) << std::endl; |
|
} |
|
else |
|
{ |
|
// Create a window |
|
static const std::string winName = "Input Cropped Image"; |
|
|
|
// Open an image file |
|
CV_Assert(parser.has("inputImage")); |
|
Mat frame = imread(samples::findFile(parser.get<String>("inputImage")), imreadRGB); |
|
CV_Assert(!frame.empty()); |
|
|
|
// Recognition |
|
std::string recognitionResult = recognizer.recognize(frame); |
|
|
|
imshow(winName, frame); |
|
std::cout << "Predition: '" << recognitionResult << "'" << std::endl; |
|
waitKey(); |
|
} |
|
|
|
return 0; |
|
} |
|
|
|
// Convert the predictions to lower case, and remove other characters. |
|
// Only for Evaluation |
|
String convertForEval(String & input) |
|
{ |
|
String output; |
|
for (uint i = 0; i < input.length(); i++){ |
|
char ch = input[i]; |
|
if ((int)ch >= 97 && (int)ch <= 122) { |
|
output.push_back(ch); |
|
} else if ((int)ch >= 65 && (int)ch <= 90) { |
|
output.push_back((char)(ch + 32)); |
|
} else { |
|
continue; |
|
} |
|
} |
|
|
|
return output; |
|
}
|
|
|