From fd2e37da56e945f741ee7296ef8745473a9f7b64 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Mon, 30 Oct 2017 15:33:12 +0300 Subject: [PATCH] text: improve DL-based samples --- .../include/opencv2/text/textDetector.hpp | 10 ++- modules/text/samples/text_recognition_cnn.cpp | 66 +++++++++++-------- modules/text/samples/textbox_demo.cpp | 39 ++++++----- modules/text/src/text_detectorCNN.cpp | 24 +++---- 4 files changed, 81 insertions(+), 58 deletions(-) diff --git a/modules/text/include/opencv2/text/textDetector.hpp b/modules/text/include/opencv2/text/textDetector.hpp index 9c780ae31..fdb92fdfb 100644 --- a/modules/text/include/opencv2/text/textDetector.hpp +++ b/modules/text/include/opencv2/text/textDetector.hpp @@ -54,9 +54,15 @@ public: @param modelArchFilename the relative or absolute path to the prototxt file describing the classifiers architecture. @param modelWeightsFilename the relative or absolute path to the file containing the pretrained weights of the model in caffe-binary form. - @param detectMultiscale if true, multiple scales of the input image will be used as network input + @param detectionSizes a list of sizes for multiscale detection. The values`[(300,300),(700,500),(700,300),(700,700),(1600,1600)]` are + recommended in @cite LiaoSBWL17 to achieve the best quality. */ - CV_WRAP static Ptr create(const String& modelArchFilename, const String& modelWeightsFilename, bool detectMultiscale = false); + static Ptr create(const String& modelArchFilename, const String& modelWeightsFilename, + std::vector detectionSizes); + /** + @overload + */ + CV_WRAP static Ptr create(const String& modelArchFilename, const String& modelWeightsFilename); }; //! @} diff --git a/modules/text/samples/text_recognition_cnn.cpp b/modules/text/samples/text_recognition_cnn.cpp index d7a95398b..84df57d29 100644 --- a/modules/text/samples/text_recognition_cnn.cpp +++ b/modules/text/samples/text_recognition_cnn.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -29,22 +30,27 @@ bool fileExists (const string& filename) return f.good(); } -void textbox_draw(Mat src, vector& groups, vector& probs, float thres) +void textbox_draw(Mat src, std::vector& groups, std::vector& probs, std::vector& indexes) { - for (size_t i = 0; i < groups.size(); i++) + for (size_t i = 0; i < indexes.size(); i++) { - if(probs[i] > thres) + if (src.type() == CV_8UC3) { - if (src.type() == CV_8UC3) - { - rectangle(src, groups[i], Scalar( 0, 255, 255 ), 2, LINE_AA); - String label = format("%.2f", probs[i]); - cout << "text box: " << groups[i] << " confidence: " << probs[i] << "\n"; - putText(src, label, groups.at(i).tl(), FONT_HERSHEY_PLAIN, 1, Scalar( 0,0,255 ), 1, LINE_AA); - } - else - rectangle(src, groups[i], Scalar( 255 ), 3, 8 ); + Rect currrentBox = groups[indexes[i]]; + rectangle(src, currrentBox, Scalar( 0, 255, 255 ), 2, LINE_AA); + String label = format("%.2f", probs[indexes[i]]); + std::cout << "text box: " << currrentBox << " confidence: " << probs[indexes[i]] << "\n"; + + int baseLine = 0; + Size labelSize = getTextSize(label, FONT_HERSHEY_PLAIN, 1, 1, &baseLine); + int yLeftBottom = std::max(currrentBox.y, labelSize.height); + rectangle(src, Point(currrentBox.x, yLeftBottom - labelSize.height), + Point(currrentBox.x + labelSize.width, yLeftBottom + baseLine), Scalar( 255, 255, 255 ), FILLED); + + putText(src, label, Point(currrentBox.x, yLeftBottom), FONT_HERSHEY_PLAIN, 1, Scalar( 0,0,0 ), 1, LINE_AA); } + else + rectangle(src, groups[i], Scalar( 255 ), 3, 8 ); } } @@ -73,33 +79,41 @@ int main(int argc, const char * argv[]) cout << "Starting Text Box Demo" << endl; Ptr textSpotter = - text::TextDetectorCNN::create(modelArch, moddelWeights, false); + text::TextDetectorCNN::create(modelArch, moddelWeights); vector bbox; vector outProbabillities; textSpotter->detect(image, bbox, outProbabillities); + std::vector indexes; + cv::dnn::NMSBoxes(bbox, outProbabillities, 0.4f, 0.5f, indexes); - float prob_threshold = 0.6f; Mat image_copy = image.clone(); - textbox_draw(image_copy, bbox, outProbabillities, prob_threshold); + textbox_draw(image_copy, bbox, outProbabillities, indexes); imshow("Text detection", image_copy); image_copy = image.clone(); Ptr wordSpotter = text::OCRHolisticWordRecognizer::create("dictnet_vgg_deploy.prototxt", "dictnet_vgg.caffemodel", "dictnet_vgg_labels.txt"); - for(size_t i = 0; i < bbox.size(); i++) + for(size_t i = 0; i < indexes.size(); i++) { - if(outProbabillities[i] > prob_threshold) - { - Mat wordImg; - cvtColor(image(bbox[i]), wordImg, COLOR_BGR2GRAY); - string word; - vector confs; - wordSpotter->run(wordImg, word, NULL, NULL, &confs); - rectangle(image_copy, bbox[i], Scalar(0, 255, 255), 1, LINE_AA); - putText(image_copy, word, bbox[i].tl(), FONT_HERSHEY_PLAIN, 1, Scalar(0, 0, 255), 1, LINE_AA); - } + Mat wordImg; + cvtColor(image(bbox[indexes[i]]), wordImg, COLOR_BGR2GRAY); + string word; + vector confs; + wordSpotter->run(wordImg, word, NULL, NULL, &confs); + + Rect currrentBox = bbox[indexes[i]]; + rectangle(image_copy, currrentBox, Scalar( 0, 255, 255 ), 2, LINE_AA); + + int baseLine = 0; + Size labelSize = getTextSize(word, FONT_HERSHEY_PLAIN, 1, 1, &baseLine); + int yLeftBottom = std::max(currrentBox.y, labelSize.height); + rectangle(image_copy, Point(currrentBox.x, yLeftBottom - labelSize.height), + Point(currrentBox.x + labelSize.width, yLeftBottom + baseLine), Scalar( 255, 255, 255 ), FILLED); + + putText(image_copy, word, Point(currrentBox.x, yLeftBottom), FONT_HERSHEY_PLAIN, 1, Scalar( 0,0,0 ), 1, LINE_AA); + } imshow("Text recognition", image_copy); cout << "Recognition finished. Press any key to exit.\n"; diff --git a/modules/text/samples/textbox_demo.cpp b/modules/text/samples/textbox_demo.cpp index e6412f9f5..1cf9a9aab 100644 --- a/modules/text/samples/textbox_demo.cpp +++ b/modules/text/samples/textbox_demo.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -27,22 +28,27 @@ bool fileExists (const std::string& filename) return f.good(); } -void textbox_draw(Mat src, std::vector& groups, std::vector& probs, float thres) +void textbox_draw(Mat src, std::vector& groups, std::vector& probs, std::vector& indexes) { - for (size_t i = 0; i < groups.size(); i++) + for (size_t i = 0; i < indexes.size(); i++) { - if(probs[i] > thres) + if (src.type() == CV_8UC3) { - if (src.type() == CV_8UC3) - { - rectangle(src, groups[i], Scalar( 0, 255, 255 ), 2, LINE_AA); - String label = format("%.2f", probs[i]); - std::cout << "text box: " << groups[i] << " confidence: " << probs[i] << "\n"; - putText(src, label, groups.at(i).tl(), FONT_HERSHEY_PLAIN, 1, Scalar( 0,0,255 ), 1, LINE_AA); - } - else - rectangle(src, groups[i], Scalar( 255 ), 3, 8 ); + Rect currrentBox = groups[indexes[i]]; + rectangle(src, currrentBox, Scalar( 0, 255, 255 ), 2, LINE_AA); + String label = format("%.2f", probs[indexes[i]]); + std::cout << "text box: " << currrentBox << " confidence: " << probs[indexes[i]] << "\n"; + + int baseLine = 0; + Size labelSize = getTextSize(label, FONT_HERSHEY_PLAIN, 1, 1, &baseLine); + int yLeftBottom = std::max(currrentBox.y, labelSize.height); + rectangle(src, Point(currrentBox.x, yLeftBottom - labelSize.height), + Point(currrentBox.x + labelSize.width, yLeftBottom + baseLine), Scalar( 255, 255, 255 ), FILLED); + + putText(src, label, Point(currrentBox.x, yLeftBottom), FONT_HERSHEY_PLAIN, 1, Scalar( 0,0,0 ), 1, LINE_AA); } + else + rectangle(src, groups[i], Scalar( 255 ), 3, 8 ); } } @@ -62,7 +68,7 @@ int main(int argc, const char * argv[]) if (!fileExists(modelArch) || !fileExists(moddelWeights)) { - std::cout< textSpotter = - text::TextDetectorCNN::create(modelArch, moddelWeights, false); + text::TextDetectorCNN::create(modelArch, moddelWeights); std::vector bbox; std::vector outProbabillities; textSpotter->detect(image, bbox, outProbabillities); - textbox_draw(image, bbox, outProbabillities, 0.5f); + std::vector indexes; + cv::dnn::NMSBoxes(bbox, outProbabillities, 0.3f, 0.4f, indexes); + + textbox_draw(image, bbox, outProbabillities, indexes); imshow("TextBox Demo",image); std::cout << "Done!" << std::endl << std::endl; diff --git a/modules/text/src/text_detectorCNN.cpp b/modules/text/src/text_detectorCNN.cpp index e74594bac..84f769b42 100644 --- a/modules/text/src/text_detectorCNN.cpp +++ b/modules/text/src/text_detectorCNN.cpp @@ -23,8 +23,6 @@ protected: Net net_; std::vector sizes_; int inputChannelCount_; - bool detectMultiscale_; - void getOutputs(const float* buffer,int nbrTextBoxes,int nCol, std::vector& Bbox, std::vector& confidence, Size inputShape) @@ -54,21 +52,12 @@ protected: } public: - TextDetectorCNNImpl(const String& modelArchFilename, const String& modelWeightsFilename, bool detectMultiscale) : - detectMultiscale_(detectMultiscale) + TextDetectorCNNImpl(const String& modelArchFilename, const String& modelWeightsFilename, std::vector detectionSizes) : + sizes_(detectionSizes) { net_ = readNetFromCaffe(modelArchFilename, modelWeightsFilename); CV_Assert(!net_.empty()); inputChannelCount_ = 3; - sizes_.push_back(Size(700, 700)); - - if(detectMultiscale_) - { - sizes_.push_back(Size(300, 300)); - sizes_.push_back(Size(700,500)); - sizes_.push_back(Size(700,300)); - sizes_.push_back(Size(1600,1600)); - } } void detect(InputArray inputImage_, std::vector& Bbox, std::vector& confidence) @@ -92,9 +81,14 @@ public: } }; -Ptr TextDetectorCNN::create(const String &modelArchFilename, const String &modelWeightsFilename, bool detectMultiscale) +Ptr TextDetectorCNN::create(const String &modelArchFilename, const String &modelWeightsFilename, std::vector detectionSizes) +{ + return makePtr(modelArchFilename, modelWeightsFilename, detectionSizes); +} + +Ptr TextDetectorCNN::create(const String &modelArchFilename, const String &modelWeightsFilename) { - return makePtr(modelArchFilename, modelWeightsFilename, detectMultiscale); + return create(modelArchFilename, modelWeightsFilename, std::vector(1, Size(300, 300))); } } //namespace text } //namespace cv