diff --git a/doc/opencv.bib b/doc/opencv.bib index 54396d6a10..6212ea5a55 100644 --- a/doc/opencv.bib +++ b/doc/opencv.bib @@ -1261,3 +1261,26 @@ pages={281--305}, year={1987} } +@inproceedings{liao2020real, + author={Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang}, + title={Real-time Scene Text Detection with Differentiable Binarization}, + booktitle={Proc. AAAI}, + year={2020} +} +@article{shi2016end, + title={An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition}, + author={Shi, Baoguang and Bai, Xiang and Yao, Cong}, + journal={IEEE transactions on pattern analysis and machine intelligence}, + volume={39}, + number={11}, + pages={2298--2304}, + year={2016}, + publisher={IEEE} +} +@inproceedings{zhou2017east, + title={East: an efficient and accurate scene text detector}, + author={Zhou, Xinyu and Yao, Cong and Wen, He and Wang, Yuzhi and Zhou, Shuchang and He, Weiran and Liang, Jiajun}, + booktitle={Proceedings of the IEEE conference on Computer Vision and Pattern Recognition}, + pages={5551--5560}, + year={2017} +} diff --git a/doc/tutorials/dnn/dnn_OCR/dnn_OCR.markdown b/doc/tutorials/dnn/dnn_OCR/dnn_OCR.markdown index 43c86acaf0..ddf40c96a0 100644 --- a/doc/tutorials/dnn/dnn_OCR/dnn_OCR.markdown +++ b/doc/tutorials/dnn/dnn_OCR/dnn_OCR.markdown @@ -1,6 +1,7 @@ # How to run custom OCR model {#tutorial_dnn_OCR} @prev_tutorial{tutorial_dnn_custom_layers} +@next_tutorial{tutorial_dnn_text_spotting} ## Introduction @@ -43,4 +44,4 @@ The input of text recognition model is the output of the text detection model, w DenseNet_CTC has the smallest parameters and best FPS, and it is suitable for edge devices, which are very sensitive to the cost of calculation. If you have limited computing resources and want to achieve better accuracy, VGG_CTC is a good choice. -CRNN_VGG_BiLSTM_CTC is suitable for scenarios that require high recognition accuracy. \ No newline at end of file +CRNN_VGG_BiLSTM_CTC is suitable for scenarios that require high recognition accuracy. diff --git a/doc/tutorials/dnn/dnn_text_spotting/detect_test1.jpg b/doc/tutorials/dnn/dnn_text_spotting/detect_test1.jpg new file mode 100644 index 0000000000..b154dfc4ec Binary files /dev/null and b/doc/tutorials/dnn/dnn_text_spotting/detect_test1.jpg differ diff --git a/doc/tutorials/dnn/dnn_text_spotting/detect_test2.jpg b/doc/tutorials/dnn/dnn_text_spotting/detect_test2.jpg new file mode 100644 index 0000000000..a46dcc03a1 Binary files /dev/null and b/doc/tutorials/dnn/dnn_text_spotting/detect_test2.jpg differ diff --git a/doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown b/doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown new file mode 100644 index 0000000000..0aa66f9e61 --- /dev/null +++ b/doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown @@ -0,0 +1,316 @@ +# High Level API: TextDetectionModel and TextRecognitionModel {#tutorial_dnn_text_spotting} + +@prev_tutorial{tutorial_dnn_OCR} + +## Introduction +In this tutorial, we will introduce the APIs for TextRecognitionModel and TextDetectionModel in detail. + +--- +#### TextRecognitionModel: + +In the current version, @ref cv::dnn::TextRecognitionModel only supports CNN+RNN+CTC based algorithms, +and the greedy decoding method for CTC is provided. +For more information, please refer to the [original paper](https://arxiv.org/abs/1507.05717) + +Before recognition, you should `setVocabulary` and `setDecodeType`. +- "CTC-greedy", the output of the text recognition model should be a probability matrix. + The shape should be `(T, B, Dim)`, where + - `T` is the sequence length + - `B` is the batch size (only support `B=1` in inference) + - and `Dim` is the length of vocabulary +1('Blank' of CTC is at the index=0 of Dim). + +@ref cv::dnn::TextRecognitionModel::recognize() is the main function for text recognition. +- The input image should be a cropped text image or an image with `roiRects` +- Other decoding methods may supported in the future + +--- + +#### TextDetectionModel: + +@ref cv::dnn::TextDetectionModel API provides these methods for text detection: +- cv::dnn::TextDetectionModel::detect() returns the results in std::vector> (4-points quadrangles) +- cv::dnn::TextDetectionModel::detectTextRectangles() returns the results in std::vector (RBOX-like) + +In the current version, @ref cv::dnn::TextDetectionModel supports these algorithms: +- use @ref cv::dnn::TextDetectionModel_DB with "DB" models +- and use @ref cv::dnn::TextDetectionModel_EAST with "EAST" models + +The following provided pretrained models are variants of DB (w/o deformable convolution), +and the performance can be referred to the Table.1 in the [paper]((https://arxiv.org/abs/1911.08947)). +For more information, please refer to the [official code](https://github.com/MhLiao/DB) + +--- + +You can train your own model with more data, and convert it into ONNX format. +We encourage you to add new algorithms to these APIs. + + +## Pretrained Models + +#### TextRecognitionModel: + +``` +crnn.onnx: +url: https://drive.google.com/uc?export=dowload&id=1ooaLR-rkTl8jdpGy1DoQs0-X0lQsB6Fj +sha: 270d92c9ccb670ada2459a25977e8deeaf8380d3, +alphabet_36.txt: https://drive.google.com/uc?export=dowload&id=1oPOYx5rQRp8L6XQciUwmwhMCfX0KyO4b +parameter setting: -rgb=0; +description: The classification number of this model is 36 (0~9 + a~z). + The training dataset is MJSynth. + +crnn_cs.onnx: +url: https://drive.google.com/uc?export=dowload&id=12diBsVJrS9ZEl6BNUiRp9s0xPALBS7kt +sha: a641e9c57a5147546f7a2dbea4fd322b47197cd5 +alphabet_94.txt: https://drive.google.com/uc?export=dowload&id=1oKXxXKusquimp7XY1mFvj9nwLzldVgBR +parameter setting: -rgb=1; +description: The classification number of this model is 94 (0~9 + a~z + A~Z + punctuations). + The training datasets are MJsynth and SynthText. + +crnn_cs_CN.onnx: +url: https://drive.google.com/uc?export=dowload&id=1is4eYEUKH7HR7Gl37Sw4WPXx6Ir8oQEG +sha: 3940942b85761c7f240494cf662dcbf05dc00d14 +alphabet_3944.txt: https://drive.google.com/uc?export=dowload&id=18IZUUdNzJ44heWTndDO6NNfIpJMmN-ul +parameter setting: -rgb=1; +description: The classification number of this model is 3944 (0~9 + a~z + A~Z + Chinese characters + special characters). + The training dataset is ReCTS (https://rrc.cvc.uab.es/?ch=12). +``` + +More models can be found in [here](https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing), +which are taken from [clovaai](https://github.com/clovaai/deep-text-recognition-benchmark). +You can train more models by [CRNN](https://github.com/meijieru/crnn.pytorch), and convert models by `torch.onnx.export`. + +#### TextDetectionModel: + +``` +- DB_IC15_resnet50.onnx: +url: https://drive.google.com/uc?export=dowload&id=17_ABp79PlFt9yPCxSaarVc_DKTmrSGGf +sha: bef233c28947ef6ec8c663d20a2b326302421fa3 +recommended parameter setting: -inputHeight=736, -inputWidth=1280; +description: This model is trained on ICDAR2015, so it can only detect English text instances. + +- DB_IC15_resnet18.onnx: +url: https://drive.google.com/uc?export=dowload&id=1sZszH3pEt8hliyBlTmB-iulxHP1dCQWV +sha: 19543ce09b2efd35f49705c235cc46d0e22df30b +recommended parameter setting: -inputHeight=736, -inputWidth=1280; +description: This model is trained on ICDAR2015, so it can only detect English text instances. + +- DB_TD500_resnet50.onnx: +url: https://drive.google.com/uc?export=dowload&id=19YWhArrNccaoSza0CfkXlA8im4-lAGsR +sha: 1b4dd21a6baa5e3523156776970895bd3db6960a +recommended parameter setting: -inputHeight=736, -inputWidth=736; +description: This model is trained on MSRA-TD500, so it can detect both English and Chinese text instances. + +- DB_TD500_resnet18.onnx: +url: https://drive.google.com/uc?export=dowload&id=1vY_KsDZZZb_svd5RT6pjyI8BS1nPbBSX +sha: 8a3700bdc13e00336a815fc7afff5dcc1ce08546 +recommended parameter setting: -inputHeight=736, -inputWidth=736; +description: This model is trained on MSRA-TD500, so it can detect both English and Chinese text instances. + +``` + +We will release more models of DB [here](https://drive.google.com/drive/folders/1qzNCHfUJOS0NEUOIKn69eCtxdlNPpWbq?usp=sharing) in the future. + +``` +- EAST: +Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1 +This model is based on https://github.com/argman/EAST +``` + +## Images for Testing + +``` +Text Recognition: +url: https://drive.google.com/uc?export=dowload&id=1nMcEy68zDNpIlqAn6xCk_kYcUTIeSOtN +sha: 89205612ce8dd2251effa16609342b69bff67ca3 + +Text Detection: +url: https://drive.google.com/uc?export=dowload&id=149tAhIcvfCYeyufRoZ9tmc2mZDKE_XrF +sha: ced3c03fb7f8d9608169a913acf7e7b93e07109b +``` + +## Example for Text Recognition + +Step1. Loading images and models with a vocabulary + +```cpp + // Load a cropped text line image + // you can find cropped images for testing in "Images for Testing" + int rgb = IMREAD_COLOR; // This should be changed according to the model input requirement. + Mat image = imread("path/to/text_rec_test.png", rgb); + + // Load models weights + TextRecognitionModel model("path/to/crnn_cs.onnx"); + + // The decoding method + // more methods will be supported in future + model.setDecodeType("CTC-greedy"); + + // Load vocabulary + // vocabulary should be changed according to the text recognition model + std::ifstream vocFile; + vocFile.open("path/to/alphabet_94.txt"); + CV_Assert(vocFile.is_open()); + String vocLine; + std::vector vocabulary; + while (std::getline(vocFile, vocLine)) { + vocabulary.push_back(vocLine); + } + model.setVocabulary(vocabulary); +``` + +Step2. Setting Parameters + +```cpp + // Normalization parameters + double scale = 1.0 / 127.5; + Scalar mean = Scalar(127.5, 127.5, 127.5); + + // The input shape + Size inputSize = Size(100, 32); + + model.setInputParams(scale, inputSize, mean); +``` +Step3. Inference +```cpp + std::string recognitionResult = recognizer.recognize(image); + std::cout << "'" << recognitionResult << "'" << std::endl; +``` + +Input image: + +![Picture example](text_rec_test.png) + +Output: +``` +'welcome' +``` + + +## Example for Text Detection + +Step1. Loading images and models +```cpp + // Load an image + // you can find some images for testing in "Images for Testing" + Mat frame = imread("/path/to/text_det_test.png"); +``` + +Step2.a Setting Parameters (DB) +```cpp + // Load model weights + TextDetectionModel_DB model("/path/to/DB_TD500_resnet50.onnx"); + + // Post-processing parameters + float binThresh = 0.3; + float polyThresh = 0.5; + uint maxCandidates = 200; + double unclipRatio = 2.0; + model.setBinaryThreshold(binThresh) + .setPolygonThreshold(polyThresh) + .setMaxCandidates(maxCandidates) + .setUnclipRatio(unclipRatio) + ; + + // Normalization parameters + double scale = 1.0 / 255.0; + Scalar mean = Scalar(122.67891434, 116.66876762, 104.00698793); + + // The input shape + Size inputSize = Size(736, 736); + + model.setInputParams(scale, inputSize, mean); +``` + +Step2.b Setting Parameters (EAST) +```cpp + TextDetectionModel_EAST model("EAST.pb"); + + float confThreshold = 0.5; + float nmsThreshold = 0.4; + model.setConfidenceThreshold(confThresh) + .setNMSThreshold(nmsThresh) + ; + + double detScale = 1.0; + Size detInputSize = Size(320, 320); + Scalar detMean = Scalar(123.68, 116.78, 103.94); + bool swapRB = true; + model.setInputParams(detScale, detInputSize, detMean, swapRB); +``` + + +Step3. Inference +```cpp + std::vector> detResults; + model.detect(detResults); + + // Visualization + polylines(frame, results, true, Scalar(0, 255, 0), 2); + imshow("Text Detection", image); + waitKey(); +``` + +Output: + +![Picture example](text_det_test_results.jpg) + +## Example for Text Spotting + +After following the steps above, it is easy to get the detection results of an input image. +Then, you can do transformation and crop text images for recognition. +For more information, please refer to **Detailed Sample** +```cpp + // Transform and Crop + Mat cropped; + fourPointsTransform(recInput, vertices, cropped); + + String recResult = recognizer.recognize(cropped); +``` + +Output Examples: + +![Picture example](detect_test1.jpg) + +![Picture example](detect_test2.jpg) + +## Source Code +The [source code](https://github.com/opencv/opencv/blob/master/modules/dnn/src/model.cpp) +of these APIs can be found in the DNN module. + +## Detailed Sample +For more information, please refer to: +- [samples/dnn/scene_text_recognition.cpp](https://github.com/opencv/opencv/blob/master/samples/dnn/scene_text_recognition.cpp) +- [samples/dnn/scene_text_detection.cpp](https://github.com/opencv/opencv/blob/master/samples/dnn/scene_text_detection.cpp) +- [samples/dnn/text_detection.cpp](https://github.com/opencv/opencv/blob/master/samples/dnn/text_detection.cpp) +- [samples/dnn/scene_text_spotting.cpp](https://github.com/opencv/opencv/blob/master/samples/dnn/scene_text_spotting.cpp) + +#### Test with an image +Examples: +```bash +example_dnn_scene_text_recognition -mp=path/to/crnn_cs.onnx -i=path/to/an/image -rgb=1 -vp=/path/to/alphabet_94.txt +example_dnn_scene_text_detection -mp=path/to/DB_TD500_resnet50.onnx -i=path/to/an/image -ih=736 -iw=736 +example_dnn_scene_text_spotting -dmp=path/to/DB_IC15_resnet50.onnx -rmp=path/to/crnn_cs.onnx -i=path/to/an/image -iw=1280 -ih=736 -rgb=1 -vp=/path/to/alphabet_94.txt +example_dnn_text_detection -dmp=path/to/EAST.pb -rmp=path/to/crnn_cs.onnx -i=path/to/an/image -rgb=1 -vp=path/to/alphabet_94.txt +``` + +#### Test on public datasets +Text Recognition: + +The download link for testing images can be found in the **Images for Testing** + + +Examples: +```bash +example_dnn_scene_text_recognition -mp=path/to/crnn.onnx -e=true -edp=path/to/evaluation_data_rec -vp=/path/to/alphabet_36.txt -rgb=0 +example_dnn_scene_text_recognition -mp=path/to/crnn_cs.onnx -e=true -edp=path/to/evaluation_data_rec -vp=/path/to/alphabet_94.txt -rgb=1 +``` + +Text Detection: + +The download links for testing images can be found in the **Images for Testing** + +Examples: +```bash +example_dnn_scene_text_detection -mp=path/to/DB_TD500_resnet50.onnx -e=true -edp=path/to/evaluation_data_det/TD500 -ih=736 -iw=736 +example_dnn_scene_text_detection -mp=path/to/DB_IC15_resnet50.onnx -e=true -edp=path/to/evaluation_data_det/IC15 -ih=736 -iw=1280 +``` diff --git a/doc/tutorials/dnn/dnn_text_spotting/text_det_test_results.jpg b/doc/tutorials/dnn/dnn_text_spotting/text_det_test_results.jpg new file mode 100644 index 0000000000..173840f729 Binary files /dev/null and b/doc/tutorials/dnn/dnn_text_spotting/text_det_test_results.jpg differ diff --git a/doc/tutorials/dnn/dnn_text_spotting/text_rec_test.png b/doc/tutorials/dnn/dnn_text_spotting/text_rec_test.png new file mode 100644 index 0000000000..c3226376e4 Binary files /dev/null and b/doc/tutorials/dnn/dnn_text_spotting/text_rec_test.png differ diff --git a/doc/tutorials/dnn/table_of_content_dnn.markdown b/doc/tutorials/dnn/table_of_content_dnn.markdown index 0a66d04ee4..0ed97749fb 100644 --- a/doc/tutorials/dnn/table_of_content_dnn.markdown +++ b/doc/tutorials/dnn/table_of_content_dnn.markdown @@ -79,4 +79,14 @@ Deep Neural Networks (dnn module) {#tutorial_table_of_content_dnn} *Author:* Zihao Mu - In this tutorial you will learn how to use opencv_dnn module using custom OCR models. \ No newline at end of file + In this tutorial you will learn how to use opencv_dnn module using custom OCR models. + +- @subpage tutorial_dnn_text_spotting + + *Languages:* C++ + + *Compatibility:* \> OpenCV 4.5 + + *Author:* Wenqing Zhang + + In these tutorial, we'll introduce how to use the high-level APIs for text recognition and text detection diff --git a/modules/dnn/include/opencv2/dnn/dnn.hpp b/modules/dnn/include/opencv2/dnn/dnn.hpp index 5467c989ac..3ece129455 100644 --- a/modules/dnn/include/opencv2/dnn/dnn.hpp +++ b/modules/dnn/include/opencv2/dnn/dnn.hpp @@ -1326,6 +1326,255 @@ CV__DNN_INLINE_NS_BEGIN float confThreshold = 0.5f, float nmsThreshold = 0.0f); }; + +/** @brief This class represents high-level API for text recognition networks. + * + * TextRecognitionModel allows to set params for preprocessing input image. + * TextRecognitionModel creates net from file with trained weights and config, + * sets preprocessing input, runs forward pass and return recognition result. + * For TextRecognitionModel, CRNN-CTC is supported. + */ +class CV_EXPORTS_W_SIMPLE TextRecognitionModel : public Model +{ +public: + CV_DEPRECATED_EXTERNAL // avoid using in C++ code, will be moved to "protected" (need to fix bindings first) + TextRecognitionModel(); + + /** + * @brief Create Text Recognition model from deep learning network + * Call setDecodeType() and setVocabulary() after constructor to initialize the decoding method + * @param[in] network Net object + */ + CV_WRAP TextRecognitionModel(const Net& network); + + /** + * @brief Create text recognition model from network represented in one of the supported formats + * Call setDecodeType() and setVocabulary() after constructor to initialize the decoding method + * @param[in] model Binary file contains trained weights + * @param[in] config Text file contains network configuration + */ + CV_WRAP inline + TextRecognitionModel(const std::string& model, const std::string& config = "") + : TextRecognitionModel(readNet(model, config)) { /* nothing */ } + + /** + * @brief Set the decoding method of translating the network output into string + * @param[in] decodeType The decoding method of translating the network output into string: {'CTC-greedy': greedy decoding for the output of CTC-based methods} + */ + CV_WRAP + TextRecognitionModel& setDecodeType(const std::string& decodeType); + + /** + * @brief Get the decoding method + * @return the decoding method + */ + CV_WRAP + const std::string& getDecodeType() const; + + /** + * @brief Set the vocabulary for recognition. + * @param[in] vocabulary the associated vocabulary of the network. + */ + CV_WRAP + TextRecognitionModel& setVocabulary(const std::vector& vocabulary); + + /** + * @brief Get the vocabulary for recognition. + * @return vocabulary the associated vocabulary + */ + CV_WRAP + const std::vector& getVocabulary() const; + + /** + * @brief Given the @p input frame, create input blob, run net and return recognition result + * @param[in] frame The input image + * @return The text recognition result + */ + CV_WRAP + std::string recognize(InputArray frame) const; + + /** + * @brief Given the @p input frame, create input blob, run net and return recognition result + * @param[in] frame The input image + * @param[in] roiRects List of text detection regions of interest (cv::Rect, CV_32SC4). ROIs is be cropped as the network inputs + * @param[out] results A set of text recognition results. + */ + CV_WRAP + void recognize(InputArray frame, InputArrayOfArrays roiRects, CV_OUT std::vector& results) const; +}; + + +/** @brief Base class for text detection networks + */ +class CV_EXPORTS_W TextDetectionModel : public Model +{ +protected: + CV_DEPRECATED_EXTERNAL // avoid using in C++ code, will be moved to "protected" (need to fix bindings first) + TextDetectionModel(); + +public: + + /** @brief Performs detection + * + * Given the input @p frame, prepare network input, run network inference, post-process network output and return result detections. + * + * Each result is quadrangle's 4 points in this order: + * - bottom-left + * - top-left + * - top-right + * - bottom-right + * + * Use cv::getPerspectiveTransform function to retrive image region without perspective transformations. + * + * @note If DL model doesn't support that kind of output then result may be derived from detectTextRectangles() output. + * + * @param[in] frame The input image + * @param[out] detections array with detections' quadrangles (4 points per result) + * @param[out] confidences array with detection confidences + */ + CV_WRAP + void detect( + InputArray frame, + CV_OUT std::vector< std::vector >& detections, + CV_OUT std::vector& confidences + ) const; + + /** @overload */ + CV_WRAP + void detect( + InputArray frame, + CV_OUT std::vector< std::vector >& detections + ) const; + + /** @brief Performs detection + * + * Given the input @p frame, prepare network input, run network inference, post-process network output and return result detections. + * + * Each result is rotated rectangle. + * + * @note Result may be inaccurate in case of strong perspective transformations. + * + * @param[in] frame the input image + * @param[out] detections array with detections' RotationRect results + * @param[out] confidences array with detection confidences + */ + CV_WRAP + void detectTextRectangles( + InputArray frame, + CV_OUT std::vector& detections, + CV_OUT std::vector& confidences + ) const; + + /** @overload */ + CV_WRAP + void detectTextRectangles( + InputArray frame, + CV_OUT std::vector& detections + ) const; +}; + +/** @brief This class represents high-level API for text detection DL networks compatible with EAST model. + * + * Configurable parameters: + * - (float) confThreshold - used to filter boxes by confidences, default: 0.5f + * - (float) nmsThreshold - used in non maximum suppression, default: 0.0f + */ +class CV_EXPORTS_W_SIMPLE TextDetectionModel_EAST : public TextDetectionModel +{ +public: + CV_DEPRECATED_EXTERNAL // avoid using in C++ code, will be moved to "protected" (need to fix bindings first) + TextDetectionModel_EAST(); + + /** + * @brief Create text detection algorithm from deep learning network + * @param[in] network Net object + */ + CV_WRAP TextDetectionModel_EAST(const Net& network); + + /** + * @brief Create text detection model from network represented in one of the supported formats. + * An order of @p model and @p config arguments does not matter. + * @param[in] model Binary file contains trained weights. + * @param[in] config Text file contains network configuration. + */ + CV_WRAP inline + TextDetectionModel_EAST(const std::string& model, const std::string& config = "") + : TextDetectionModel_EAST(readNet(model, config)) { /* nothing */ } + + /** + * @brief Set the detection confidence threshold + * @param[in] confThreshold A threshold used to filter boxes by confidences + */ + CV_WRAP + TextDetectionModel_EAST& setConfidenceThreshold(float confThreshold); + + /** + * @brief Get the detection confidence threshold + */ + CV_WRAP + float getConfidenceThreshold() const; + + /** + * @brief Set the detection NMS filter threshold + * @param[in] nmsThreshold A threshold used in non maximum suppression + */ + CV_WRAP + TextDetectionModel_EAST& setNMSThreshold(float nmsThreshold); + + /** + * @brief Get the detection confidence threshold + */ + CV_WRAP + float getNMSThreshold() const; +}; + +/** @brief This class represents high-level API for text detection DL networks compatible with DB model. + * + * Related publications: @cite liao2020real + * Paper: https://arxiv.org/abs/1911.08947 + * For more information about the hyper-parameters setting, please refer to https://github.com/MhLiao/DB + * + * Configurable parameters: + * - (float) binaryThreshold - The threshold of the binary map. It is usually set to 0.3. + * - (float) polygonThreshold - The threshold of text polygons. It is usually set to 0.5, 0.6, and 0.7. Default is 0.5f + * - (double) unclipRatio - The unclip ratio of the detected text region, which determines the output size. It is usually set to 2.0. + * - (int) maxCandidates - The max number of the output results. + */ +class CV_EXPORTS_W_SIMPLE TextDetectionModel_DB : public TextDetectionModel +{ +public: + CV_DEPRECATED_EXTERNAL // avoid using in C++ code, will be moved to "protected" (need to fix bindings first) + TextDetectionModel_DB(); + + /** + * @brief Create text detection algorithm from deep learning network. + * @param[in] network Net object. + */ + CV_WRAP TextDetectionModel_DB(const Net& network); + + /** + * @brief Create text detection model from network represented in one of the supported formats. + * An order of @p model and @p config arguments does not matter. + * @param[in] model Binary file contains trained weights. + * @param[in] config Text file contains network configuration. + */ + CV_WRAP inline + TextDetectionModel_DB(const std::string& model, const std::string& config = "") + : TextDetectionModel_DB(readNet(model, config)) { /* nothing */ } + + CV_WRAP TextDetectionModel_DB& setBinaryThreshold(float binaryThreshold); + CV_WRAP float getBinaryThreshold() const; + + CV_WRAP TextDetectionModel_DB& setPolygonThreshold(float polygonThreshold); + CV_WRAP float getPolygonThreshold() const; + + CV_WRAP TextDetectionModel_DB& setUnclipRatio(double unclipRatio); + CV_WRAP double getUnclipRatio() const; + + CV_WRAP TextDetectionModel_DB& setMaxCandidates(int maxCandidates); + CV_WRAP int getMaxCandidates() const; +}; + //! @} CV__DNN_INLINE_NS_END } diff --git a/modules/dnn/src/model.cpp b/modules/dnn/src/model.cpp index 16f7d31a25..fae235a3b5 100644 --- a/modules/dnn/src/model.cpp +++ b/modules/dnn/src/model.cpp @@ -4,7 +4,6 @@ #include "precomp.hpp" #include -#include #include #include @@ -37,9 +36,10 @@ public: virtual void setPreferableBackend(Backend backendId) { net.setPreferableBackend(backendId); } virtual void setPreferableTarget(Target targetId) { net.setPreferableTarget(targetId); } - /*virtual*/ + virtual void initNet(const Net& network) { + CV_TRACE_FUNCTION(); net = network; outNames = net.getUnconnectedOutLayersNames(); @@ -91,6 +91,7 @@ public: /*virtual*/ void processFrame(InputArray frame, OutputArrayOfArrays outs) { + CV_TRACE_FUNCTION(); if (size.empty()) CV_Error(Error::StsBadSize, "Input size not specified"); @@ -103,6 +104,7 @@ public: Mat imInfo(Matx13f(size.height, size.width, 1.6f)); net.setInput(imInfo, "im_info"); } + net.forward(outs, outNames); } }; @@ -545,4 +547,778 @@ void DetectionModel::detect(InputArray frame, CV_OUT std::vector& classIds, CV_Error(Error::StsNotImplemented, "Unknown output layer type: \"" + lastLayer->type + "\""); } +struct TextRecognitionModel_Impl : public Model::Impl +{ + std::string decodeType; + std::vector vocabulary; + + TextRecognitionModel_Impl() + { + CV_TRACE_FUNCTION(); + } + + TextRecognitionModel_Impl(const Net& network) + { + CV_TRACE_FUNCTION(); + initNet(network); + } + + inline + void setVocabulary(const std::vector& inputVoc) + { + vocabulary = inputVoc; + } + + inline + void setDecodeType(const std::string& type) + { + decodeType = type; + } + + virtual + std::string decode(const Mat& prediction) + { + CV_TRACE_FUNCTION(); + CV_Assert(!prediction.empty()); + if (decodeType.empty()) + CV_Error(Error::StsBadArg, "TextRecognitionModel: decodeType is not specified"); + if (vocabulary.empty()) + CV_Error(Error::StsBadArg, "TextRecognitionModel: vocabulary is not specified"); + + std::string decodeSeq; + if (decodeType == "CTC-greedy") + { + CV_CheckEQ(prediction.dims, 3, ""); + CV_CheckType(prediction.type(), CV_32FC1, ""); + const int vocLength = (int)(vocabulary.size()); + CV_CheckLE(prediction.size[1], vocLength, ""); + bool ctcFlag = true; + int lastLoc = 0; + for (int i = 0; i < prediction.size[0]; i++) + { + const float* pred = prediction.ptr(i); + int maxLoc = 0; + float maxScore = pred[0]; + for (int j = 1; j < vocLength + 1; j++) + { + float score = pred[j]; + if (maxScore < score) + { + maxScore = score; + maxLoc = j; + } + } + + if (maxLoc > 0) + { + std::string currentChar = vocabulary.at(maxLoc - 1); + if (maxLoc != lastLoc || ctcFlag) + { + lastLoc = maxLoc; + decodeSeq += currentChar; + ctcFlag = false; + } + } + else + { + ctcFlag = true; + } + } + } else if (decodeType.length() == 0) { + CV_Error(Error::StsBadArg, "Please set decodeType"); + } else { + CV_Error_(Error::StsBadArg, ("Unsupported decodeType: %s", decodeType.c_str())); + } + + return decodeSeq; + } + + virtual + std::string recognize(InputArray frame) + { + CV_TRACE_FUNCTION(); + std::vector outs; + processFrame(frame, outs); + CV_CheckEQ(outs.size(), (size_t)1, ""); + return decode(outs[0]); + } + + virtual + void recognize(InputArray frame, InputArrayOfArrays roiRects, CV_OUT std::vector& results) + { + CV_TRACE_FUNCTION(); + results.clear(); + if (roiRects.empty()) + { + auto s = recognize(frame); + results.push_back(s); + return; + } + + std::vector rects; + roiRects.copyTo(rects); + + // Predict for each RoI + Mat input = frame.getMat(); + for (size_t i = 0; i < rects.size(); i++) + { + Rect roiRect = rects[i]; + Mat roi = input(roiRect); + auto s = recognize(roi); + results.push_back(s); + } + } + + static inline + TextRecognitionModel_Impl& from(const std::shared_ptr& ptr) + { + CV_Assert(ptr); + return *((TextRecognitionModel_Impl*)ptr.get()); + } +}; + +TextRecognitionModel::TextRecognitionModel() +{ + impl = std::static_pointer_cast(makePtr()); +} + +TextRecognitionModel::TextRecognitionModel(const Net& network) +{ + impl = std::static_pointer_cast(std::make_shared(network)); +} + +TextRecognitionModel& TextRecognitionModel::setDecodeType(const std::string& decodeType) +{ + TextRecognitionModel_Impl::from(impl).setDecodeType(decodeType); + return *this; +} + +const std::string& TextRecognitionModel::getDecodeType() const +{ + return TextRecognitionModel_Impl::from(impl).decodeType; +} + +TextRecognitionModel& TextRecognitionModel::setVocabulary(const std::vector& inputVoc) +{ + TextRecognitionModel_Impl::from(impl).setVocabulary(inputVoc); + return *this; +} + +const std::vector& TextRecognitionModel::getVocabulary() const +{ + return TextRecognitionModel_Impl::from(impl).vocabulary; +} + +std::string TextRecognitionModel::recognize(InputArray frame) const +{ + return TextRecognitionModel_Impl::from(impl).recognize(frame); +} + +void TextRecognitionModel::recognize(InputArray frame, InputArrayOfArrays roiRects, CV_OUT std::vector& results) const +{ + TextRecognitionModel_Impl::from(impl).recognize(frame, roiRects, results); +} + + +///////////////////////////////////////// Text Detection ///////////////////////////////////////// + +struct TextDetectionModel_Impl : public Model::Impl +{ + TextDetectionModel_Impl() {} + + TextDetectionModel_Impl(const Net& network) + { + CV_TRACE_FUNCTION(); + initNet(network); + } + + virtual + std::vector< std::vector > detect(InputArray frame, CV_OUT std::vector& confidences) + { + CV_TRACE_FUNCTION(); + std::vector rects = detectTextRectangles(frame, confidences); + std::vector< std::vector > results; + for (const RotatedRect& rect : rects) + { + Point2f vertices[4] = {}; + rect.points(vertices); + std::vector result = { vertices[0], vertices[1], vertices[2], vertices[3] }; + results.emplace_back(result); + } + return results; + } + + virtual + std::vector< std::vector > detect(InputArray frame) + { + CV_TRACE_FUNCTION(); + std::vector confidences; + return detect(frame, confidences); + } + + virtual + std::vector detectTextRectangles(InputArray frame, CV_OUT std::vector& confidences) + { + CV_Error(Error::StsNotImplemented, ""); + } + + virtual + std::vector detectTextRectangles(InputArray frame) + { + CV_TRACE_FUNCTION(); + std::vector confidences; + return detectTextRectangles(frame, confidences); + } + + static inline + TextDetectionModel_Impl& from(const std::shared_ptr& ptr) + { + CV_Assert(ptr); + return *((TextDetectionModel_Impl*)ptr.get()); + } +}; + + +TextDetectionModel::TextDetectionModel() + : Model() +{ + // nothing +} + +static +void to32s( + const std::vector< std::vector >& detections_f, + CV_OUT std::vector< std::vector >& detections +) +{ + detections.resize(detections_f.size()); + for (size_t i = 0; i < detections_f.size(); i++) + { + const auto& contour_f = detections_f[i]; + std::vector contour(contour_f.size()); + for (size_t j = 0; j < contour_f.size(); j++) + { + contour[j].x = cvRound(contour_f[j].x); + contour[j].y = cvRound(contour_f[j].y); + } + swap(detections[i], contour); + } +} + +void TextDetectionModel::detect( + InputArray frame, + CV_OUT std::vector< std::vector >& detections, + CV_OUT std::vector& confidences +) const +{ + std::vector< std::vector > detections_f = TextDetectionModel_Impl::from(impl).detect(frame, confidences); + to32s(detections_f, detections); + return; +} + +void TextDetectionModel::detect( + InputArray frame, + CV_OUT std::vector< std::vector >& detections +) const +{ + std::vector< std::vector > detections_f = TextDetectionModel_Impl::from(impl).detect(frame); + to32s(detections_f, detections); + return; +} + +void TextDetectionModel::detectTextRectangles( + InputArray frame, + CV_OUT std::vector& detections, + CV_OUT std::vector& confidences +) const +{ + detections = TextDetectionModel_Impl::from(impl).detectTextRectangles(frame, confidences); + return; +} + +void TextDetectionModel::detectTextRectangles( + InputArray frame, + CV_OUT std::vector& detections +) const +{ + detections = TextDetectionModel_Impl::from(impl).detectTextRectangles(frame); + return; +} + + +struct TextDetectionModel_EAST_Impl : public TextDetectionModel_Impl +{ + float confThreshold; + float nmsThreshold; + + TextDetectionModel_EAST_Impl() + : confThreshold(0.5f) + , nmsThreshold(0.0f) + { + CV_TRACE_FUNCTION(); + } + + TextDetectionModel_EAST_Impl(const Net& network) + : TextDetectionModel_EAST_Impl() + { + CV_TRACE_FUNCTION(); + initNet(network); + } + + void setConfidenceThreshold(float confThreshold_) { confThreshold = confThreshold_; } + float getConfidenceThreshold() const { return confThreshold; } + + void setNMSThreshold(float nmsThreshold_) { nmsThreshold = nmsThreshold_; } + float getNMSThreshold() const { return nmsThreshold; } + + // TODO: According to article EAST supports quadrangles output: https://arxiv.org/pdf/1704.03155.pdf +#if 0 + virtual + std::vector< std::vector > detect(InputArray frame, CV_OUT std::vector& confidences) CV_OVERRIDE +#endif + + virtual + std::vector detectTextRectangles(InputArray frame, CV_OUT std::vector& confidences) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + std::vector results; + + std::vector outs; + processFrame(frame, outs); + CV_CheckEQ(outs.size(), (size_t)2, ""); + Mat geometry = outs[0]; + Mat scoreMap = outs[1]; + + CV_CheckEQ(scoreMap.dims, 4, ""); + CV_CheckEQ(geometry.dims, 4, ""); + CV_CheckEQ(scoreMap.size[0], 1, ""); + CV_CheckEQ(geometry.size[0], 1, ""); + CV_CheckEQ(scoreMap.size[1], 1, ""); + CV_CheckEQ(geometry.size[1], 5, ""); + CV_CheckEQ(scoreMap.size[2], geometry.size[2], ""); + CV_CheckEQ(scoreMap.size[3], geometry.size[3], ""); + + CV_CheckType(scoreMap.type(), CV_32FC1, ""); + CV_CheckType(geometry.type(), CV_32FC1, ""); + + std::vector boxes; + std::vector scores; + const int height = scoreMap.size[2]; + const int width = scoreMap.size[3]; + for (int y = 0; y < height; ++y) + { + const float* scoresData = scoreMap.ptr(0, 0, y); + const float* x0_data = geometry.ptr(0, 0, y); + const float* x1_data = geometry.ptr(0, 1, y); + const float* x2_data = geometry.ptr(0, 2, y); + const float* x3_data = geometry.ptr(0, 3, y); + const float* anglesData = geometry.ptr(0, 4, y); + for (int x = 0; x < width; ++x) + { + float score = scoresData[x]; + if (score < confThreshold) + continue; + + float offsetX = x * 4.0f, offsetY = y * 4.0f; + float angle = anglesData[x]; + float cosA = std::cos(angle); + float sinA = std::sin(angle); + float h = x0_data[x] + x2_data[x]; + float w = x1_data[x] + x3_data[x]; + + Point2f offset(offsetX + cosA * x1_data[x] + sinA * x2_data[x], + offsetY - sinA * x1_data[x] + cosA * x2_data[x]); + Point2f p1 = Point2f(-sinA * h, -cosA * h) + offset; + Point2f p3 = Point2f(-cosA * w, sinA * w) + offset; + boxes.push_back(RotatedRect(0.5f * (p1 + p3), Size2f(w, h), -angle * 180.0f / (float)CV_PI)); + scores.push_back(score); + } + } + + // Apply non-maximum suppression procedure. + std::vector indices; + NMSBoxes(boxes, scores, confThreshold, nmsThreshold, indices); + + confidences.clear(); + confidences.reserve(indices.size()); + + // Re-scale + Point2f ratio((float)frame.cols() / size.width, (float)frame.rows() / size.height); + bool isUniformRatio = std::fabs(ratio.x - ratio.y) <= 0.01f; + for (uint i = 0; i < indices.size(); i++) + { + auto idx = indices[i]; + + auto conf = scores[idx]; + confidences.push_back(conf); + + RotatedRect& box0 = boxes[idx]; + + if (isUniformRatio) + { + RotatedRect box = box0; + box.center.x *= ratio.x; + box.center.y *= ratio.y; + box.size.width *= ratio.x; + box.size.height *= ratio.y; + results.emplace_back(box); + } + else + { + Point2f vertices[4] = {}; + box0.points(vertices); + for (int j = 0; j < 4; j++) + { + vertices[j].x *= ratio.x; + vertices[j].y *= ratio.y; + } + RotatedRect box = minAreaRect(Mat(4, 1, CV_32FC2, (void*)vertices)); + + // minArea() rect is not normalized, it may return rectangles rotated by +90/-90 + float angle_diff = std::fabs(box.angle - box0.angle); + while (angle_diff >= (90 + 45)) + { + box.angle += (box.angle < box0.angle) ? 180 : -180; + angle_diff = std::fabs(box.angle - box0.angle); + } + if (angle_diff > 45) // avoid ~90 degree turns + { + std::swap(box.size.width, box.size.height); + if (box.angle < box0.angle) + box.angle += 90; + else if (box.angle > box0.angle) + box.angle -= 90; + } + // CV_DbgAssert(std::fabs(box.angle - box0.angle) <= 45); + + results.emplace_back(box); + } + } + + return results; + } + + static inline + TextDetectionModel_EAST_Impl& from(const std::shared_ptr& ptr) + { + CV_Assert(ptr); + return *((TextDetectionModel_EAST_Impl*)ptr.get()); + } +}; + + +TextDetectionModel_EAST::TextDetectionModel_EAST() + : TextDetectionModel() +{ + impl = std::static_pointer_cast(makePtr()); +} + +TextDetectionModel_EAST::TextDetectionModel_EAST(const Net& network) + : TextDetectionModel() +{ + impl = std::static_pointer_cast(makePtr(network)); +} + +TextDetectionModel_EAST& TextDetectionModel_EAST::setConfidenceThreshold(float confThreshold) +{ + TextDetectionModel_EAST_Impl::from(impl).setConfidenceThreshold(confThreshold); + return *this; +} +float TextDetectionModel_EAST::getConfidenceThreshold() const +{ + return TextDetectionModel_EAST_Impl::from(impl).getConfidenceThreshold(); +} + +TextDetectionModel_EAST& TextDetectionModel_EAST::setNMSThreshold(float nmsThreshold) +{ + TextDetectionModel_EAST_Impl::from(impl).setNMSThreshold(nmsThreshold); + return *this; +} +float TextDetectionModel_EAST::getNMSThreshold() const +{ + return TextDetectionModel_EAST_Impl::from(impl).getNMSThreshold(); +} + + + +struct TextDetectionModel_DB_Impl : public TextDetectionModel_Impl +{ + float binaryThreshold; + float polygonThreshold; + double unclipRatio; + int maxCandidates; + + TextDetectionModel_DB_Impl() + : binaryThreshold(0.3f) + , polygonThreshold(0.5f) + , unclipRatio(2.0f) + , maxCandidates(0) + { + CV_TRACE_FUNCTION(); + } + + TextDetectionModel_DB_Impl(const Net& network) + : TextDetectionModel_DB_Impl() + { + CV_TRACE_FUNCTION(); + initNet(network); + } + + void setBinaryThreshold(float binaryThreshold_) { binaryThreshold = binaryThreshold_; } + float getBinaryThreshold() const { return binaryThreshold; } + + void setPolygonThreshold(float polygonThreshold_) { polygonThreshold = polygonThreshold_; } + float getPolygonThreshold() const { return polygonThreshold; } + + void setUnclipRatio(double unclipRatio_) { unclipRatio = unclipRatio_; } + double getUnclipRatio() const { return unclipRatio; } + + void setMaxCandidates(int maxCandidates_) { maxCandidates = maxCandidates_; } + int getMaxCandidates() const { return maxCandidates; } + + + virtual + std::vector detectTextRectangles(InputArray frame, CV_OUT std::vector& confidences) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + std::vector< std::vector > contours = detect(frame, confidences); + std::vector results; results.reserve(contours.size()); + for (size_t i = 0; i < contours.size(); i++) + { + auto& contour = contours[i]; + RotatedRect box = minAreaRect(contour); + + // minArea() rect is not normalized, it may return rectangles with angle=-90 or height < width + const float angle_threshold = 60; // do not expect vertical text, TODO detection algo property + bool swap_size = false; + if (box.size.width < box.size.height) // horizontal-wide text area is expected + swap_size = true; + else if (std::fabs(box.angle) >= angle_threshold) // don't work with vertical rectangles + swap_size = true; + if (swap_size) + { + std::swap(box.size.width, box.size.height); + if (box.angle < 0) + box.angle += 90; + else if (box.angle > 0) + box.angle -= 90; + } + + results.push_back(box); + } + return results; + } + + std::vector< std::vector > detect(InputArray frame, CV_OUT std::vector& confidences) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + std::vector< std::vector > results; + + std::vector outs; + processFrame(frame, outs); + CV_Assert(outs.size() == 1); + Mat binary = outs[0]; + + // Threshold + Mat bitmap; + threshold(binary, bitmap, binaryThreshold, 255, THRESH_BINARY); + + // Scale ratio + float scaleHeight = (float)(frame.rows()) / (float)(binary.size[0]); + float scaleWidth = (float)(frame.cols()) / (float)(binary.size[1]); + + // Find contours + std::vector< std::vector > contours; + bitmap.convertTo(bitmap, CV_8UC1); + findContours(bitmap, contours, RETR_LIST, CHAIN_APPROX_SIMPLE); + + // Candidate number limitation + size_t numCandidate = std::min(contours.size(), (size_t)(maxCandidates > 0 ? maxCandidates : INT_MAX)); + + for (size_t i = 0; i < numCandidate; i++) + { + std::vector& contour = contours[i]; + + // Calculate text contour score + if (contourScore(binary, contour) < polygonThreshold) + continue; + + // Rescale + std::vector contourScaled; contourScaled.reserve(contour.size()); + for (size_t j = 0; j < contour.size(); j++) + { + contourScaled.push_back(Point(int(contour[j].x * scaleWidth), + int(contour[j].y * scaleHeight))); + } + + // Unclip + RotatedRect box = minAreaRect(contourScaled); + + // minArea() rect is not normalized, it may return rectangles with angle=-90 or height < width + const float angle_threshold = 60; // do not expect vertical text, TODO detection algo property + bool swap_size = false; + if (box.size.width < box.size.height) // horizontal-wide text area is expected + swap_size = true; + else if (std::fabs(box.angle) >= angle_threshold) // don't work with vertical rectangles + swap_size = true; + if (swap_size) + { + std::swap(box.size.width, box.size.height); + if (box.angle < 0) + box.angle += 90; + else if (box.angle > 0) + box.angle -= 90; + } + + Point2f vertex[4]; + box.points(vertex); // order: bl, tl, tr, br + std::vector approx; + for (int j = 0; j < 4; j++) + approx.emplace_back(vertex[j]); + std::vector polygon; + unclip(approx, polygon, unclipRatio); + results.push_back(polygon); + } + + confidences = std::vector(contours.size(), 1.0f); + return results; + } + + // According to https://github.com/MhLiao/DB/blob/master/structure/representers/seg_detector_representer.py (2020-10) + static double contourScore(const Mat& binary, const std::vector& contour) + { + Rect rect = boundingRect(contour); + int xmin = std::max(rect.x, 0); + int xmax = std::min(rect.x + rect.width, binary.cols - 1); + int ymin = std::max(rect.y, 0); + int ymax = std::min(rect.y + rect.height, binary.rows - 1); + + Mat binROI = binary(Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1)); + + Mat mask = Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8U); + std::vector roiContour; + for (size_t i = 0; i < contour.size(); i++) { + Point pt = Point(contour[i].x - xmin, contour[i].y - ymin); + roiContour.push_back(pt); + } + std::vector> roiContours = {roiContour}; + fillPoly(mask, roiContours, Scalar(1)); + double score = cv::mean(binROI, mask).val[0]; + + return score; + } + + // According to https://github.com/MhLiao/DB/blob/master/structure/representers/seg_detector_representer.py (2020-10) + static void unclip(const std::vector& inPoly, std::vector &outPoly, const double unclipRatio) + { + double area = contourArea(inPoly); + double length = arcLength(inPoly, true); + double distance = area * unclipRatio / length; + + size_t numPoints = inPoly.size(); + std::vector> newLines; + for (size_t i = 0; i < numPoints; i++) { + std::vector newLine; + Point pt1 = inPoly[i]; + Point pt2 = inPoly[(i - 1) % numPoints]; + Point vec = pt1 - pt2; + float unclipDis = (float)(distance / norm(vec)); + Point2f rotateVec = Point2f(vec.y * unclipDis, -vec.x * unclipDis); + newLine.push_back(Point2f(pt1.x + rotateVec.x, pt1.y + rotateVec.y)); + newLine.push_back(Point2f(pt2.x + rotateVec.x, pt2.y + rotateVec.y)); + newLines.push_back(newLine); + } + + size_t numLines = newLines.size(); + for (size_t i = 0; i < numLines; i++) { + Point2f a = newLines[i][0]; + Point2f b = newLines[i][1]; + Point2f c = newLines[(i + 1) % numLines][0]; + Point2f d = newLines[(i + 1) % numLines][1]; + Point2f pt; + Point2f v1 = b - a; + Point2f v2 = d - c; + double cosAngle = (v1.x * v2.x + v1.y * v2.y) / (norm(v1) * norm(v2)); + + if( fabs(cosAngle) > 0.7 ) { + pt.x = (b.x + c.x) * 0.5; + pt.y = (b.y + c.y) * 0.5; + } else { + double denom = a.x * (double)(d.y - c.y) + b.x * (double)(c.y - d.y) + + d.x * (double)(b.y - a.y) + c.x * (double)(a.y - b.y); + double num = a.x * (double)(d.y - c.y) + c.x * (double)(a.y - d.y) + d.x * (double)(c.y - a.y); + double s = num / denom; + + pt.x = a.x + s*(b.x - a.x); + pt.y = a.y + s*(b.y - a.y); + } + + + outPoly.push_back(pt); + } + } + + + static inline + TextDetectionModel_DB_Impl& from(const std::shared_ptr& ptr) + { + CV_Assert(ptr); + return *((TextDetectionModel_DB_Impl*)ptr.get()); + } +}; + + +TextDetectionModel_DB::TextDetectionModel_DB() + : TextDetectionModel() +{ + impl = std::static_pointer_cast(makePtr()); +} + +TextDetectionModel_DB::TextDetectionModel_DB(const Net& network) + : TextDetectionModel() +{ + impl = std::static_pointer_cast(makePtr(network)); +} + +TextDetectionModel_DB& TextDetectionModel_DB::setBinaryThreshold(float binaryThreshold) +{ + TextDetectionModel_DB_Impl::from(impl).setBinaryThreshold(binaryThreshold); + return *this; +} +float TextDetectionModel_DB::getBinaryThreshold() const +{ + return TextDetectionModel_DB_Impl::from(impl).getBinaryThreshold(); +} + +TextDetectionModel_DB& TextDetectionModel_DB::setPolygonThreshold(float polygonThreshold) +{ + TextDetectionModel_DB_Impl::from(impl).setPolygonThreshold(polygonThreshold); + return *this; +} +float TextDetectionModel_DB::getPolygonThreshold() const +{ + return TextDetectionModel_DB_Impl::from(impl).getPolygonThreshold(); +} + +TextDetectionModel_DB& TextDetectionModel_DB::setUnclipRatio(double unclipRatio) +{ + TextDetectionModel_DB_Impl::from(impl).setUnclipRatio(unclipRatio); + return *this; +} +double TextDetectionModel_DB::getUnclipRatio() const +{ + return TextDetectionModel_DB_Impl::from(impl).getUnclipRatio(); +} + +TextDetectionModel_DB& TextDetectionModel_DB::setMaxCandidates(int maxCandidates) +{ + TextDetectionModel_DB_Impl::from(impl).setMaxCandidates(maxCandidates); + return *this; +} +int TextDetectionModel_DB::getMaxCandidates() const +{ + return TextDetectionModel_DB_Impl::from(impl).getMaxCandidates(); +} + + }} // namespace diff --git a/modules/dnn/test/test_common.hpp b/modules/dnn/test/test_common.hpp index 3bc8fc3a89..ea6b3bde92 100644 --- a/modules/dnn/test/test_common.hpp +++ b/modules/dnn/test/test_common.hpp @@ -113,6 +113,14 @@ void normAssertDetections( double confThreshold = 0.0, double scores_diff = 1e-5, double boxes_iou_diff = 1e-4); +// For text detection networks +// Curved text polygon is not supported in the current version. +// (concave polygon is invalid input to intersectConvexConvex) +void normAssertTextDetections( + const std::vector>& gtPolys, + const std::vector>& testPolys, + const char *comment = "", double boxes_iou_diff = 1e-4); + void readFileContent(const std::string& filename, CV_OUT std::vector& content); #ifdef HAVE_INF_ENGINE diff --git a/modules/dnn/test/test_common.impl.hpp b/modules/dnn/test/test_common.impl.hpp index cf1b558391..4627e94e6e 100644 --- a/modules/dnn/test/test_common.impl.hpp +++ b/modules/dnn/test/test_common.impl.hpp @@ -177,6 +177,52 @@ void normAssertDetections( testBoxes, comment, confThreshold, scores_diff, boxes_iou_diff); } +// For text detection networks +// Curved text polygon is not supported in the current version. +// (concave polygon is invalid input to intersectConvexConvex) +void normAssertTextDetections( + const std::vector>& gtPolys, + const std::vector>& testPolys, + const char *comment /*= ""*/, double boxes_iou_diff /*= 1e-4*/) +{ + std::vector matchedRefBoxes(gtPolys.size(), false); + for (uint i = 0; i < testPolys.size(); ++i) + { + const std::vector& testPoly = testPolys[i]; + bool matched = false; + double topIoU = 0; + for (uint j = 0; j < gtPolys.size() && !matched; ++j) + { + if (!matchedRefBoxes[j]) + { + std::vector intersectionPolygon; + float intersectArea = intersectConvexConvex(testPoly, gtPolys[j], intersectionPolygon, true); + double iou = intersectArea / (contourArea(testPoly) + contourArea(gtPolys[j]) - intersectArea); + topIoU = std::max(topIoU, iou); + if (1.0 - iou < boxes_iou_diff) + { + matched = true; + matchedRefBoxes[j] = true; + } + } + } + if (!matched) { + std::cout << cv::format("Unmatched-det:") << testPoly << std::endl; + std::cout << "Highest IoU: " << topIoU << std::endl; + } + EXPECT_TRUE(matched) << comment; + } + + // Check unmatched groundtruth. + for (uint i = 0; i < gtPolys.size(); ++i) + { + if (!matchedRefBoxes[i]) { + std::cout << cv::format("Unmatched-gt:") << gtPolys[i] << std::endl; + } + EXPECT_TRUE(matchedRefBoxes[i]); + } +} + void readFileContent(const std::string& filename, CV_OUT std::vector& content) { const std::ios::openmode mode = std::ios::in | std::ios::binary; diff --git a/modules/dnn/test/test_model.cpp b/modules/dnn/test/test_model.cpp index 58a881488a..852ae0040e 100644 --- a/modules/dnn/test/test_model.cpp +++ b/modules/dnn/test/test_model.cpp @@ -113,6 +113,155 @@ public: model.segment(frame, mask); normAssert(mask, exp, "", norm, norm); } + + void testTextRecognitionModel(const std::string& weights, const std::string& cfg, + const std::string& imgPath, const std::string& seq, + const std::string& decodeType, const std::vector& vocabulary, + const Size& size = {-1, -1}, Scalar mean = Scalar(), + double scale = 1.0, bool swapRB = false, bool crop = false) + { + checkBackend(); + + Mat frame = imread(imgPath, IMREAD_GRAYSCALE); + + TextRecognitionModel model(weights, cfg); + model.setDecodeType(decodeType) + .setVocabulary(vocabulary) + .setInputSize(size).setInputMean(mean).setInputScale(scale) + .setInputSwapRB(swapRB).setInputCrop(crop); + + model.setPreferableBackend(backend); + model.setPreferableTarget(target); + + std::string result = model.recognize(frame); + EXPECT_EQ(result, seq) << "Full frame: " << imgPath; + + std::vector rois; + rois.push_back(Rect(0, 0, frame.cols, frame.rows)); + rois.push_back(Rect(0, 0, frame.cols, frame.rows)); // twice + std::vector results; + model.recognize(frame, rois, results); + EXPECT_EQ((size_t)2u, results.size()) << "ROI: " << imgPath; + EXPECT_EQ(results[0], seq) << "ROI[0]: " << imgPath; + EXPECT_EQ(results[1], seq) << "ROI[1]: " << imgPath; + } + + void testTextDetectionModelByDB(const std::string& weights, const std::string& cfg, + const std::string& imgPath, const std::vector>& gt, + float binThresh, float polyThresh, + uint maxCandidates, double unclipRatio, + const Size& size = {-1, -1}, Scalar mean = Scalar(), + double scale = 1.0, bool swapRB = false, bool crop = false) + { + checkBackend(); + + Mat frame = imread(imgPath); + + TextDetectionModel_DB model(weights, cfg); + model.setBinaryThreshold(binThresh) + .setPolygonThreshold(polyThresh) + .setUnclipRatio(unclipRatio) + .setMaxCandidates(maxCandidates) + .setInputSize(size).setInputMean(mean).setInputScale(scale) + .setInputSwapRB(swapRB).setInputCrop(crop); + + model.setPreferableBackend(backend); + model.setPreferableTarget(target); + + // 1. Check common TextDetectionModel API through RotatedRect + std::vector results; + model.detectTextRectangles(frame, results); + + EXPECT_GT(results.size(), (size_t)0); + + std::vector< std::vector > contours; + for (size_t i = 0; i < results.size(); i++) + { + const RotatedRect& box = results[i]; + Mat contour; + boxPoints(box, contour); + std::vector contour2i(4); + for (int i = 0; i < 4; i++) + { + contour2i[i].x = cvRound(contour.at(i, 0)); + contour2i[i].y = cvRound(contour.at(i, 1)); + } + contours.push_back(contour2i); + } +#if 0 // test debug + Mat result = frame.clone(); + drawContours(result, contours, -1, Scalar(0, 0, 255), 1); + imshow("result", result); // imwrite("result.png", result); + waitKey(0); +#endif + normAssertTextDetections(gt, contours, "", 0.05f); + + // 2. Check quadrangle-based API + // std::vector< std::vector > contours; + model.detect(frame, contours); + +#if 0 // test debug + Mat result = frame.clone(); + drawContours(result, contours, -1, Scalar(0, 0, 255), 1); + imshow("result_contours", result); // imwrite("result_contours.png", result); + waitKey(0); +#endif + normAssertTextDetections(gt, contours, "", 0.05f); + } + + void testTextDetectionModelByEAST(const std::string& weights, const std::string& cfg, + const std::string& imgPath, const std::vector& gt, + float confThresh, float nmsThresh, + const Size& size = {-1, -1}, Scalar mean = Scalar(), + double scale = 1.0, bool swapRB = false, bool crop = false) + { + const double EPS_PIXELS = 3; + + checkBackend(); + + Mat frame = imread(imgPath); + + TextDetectionModel_EAST model(weights, cfg); + model.setConfidenceThreshold(confThresh) + .setNMSThreshold(nmsThresh) + .setInputSize(size).setInputMean(mean).setInputScale(scale) + .setInputSwapRB(swapRB).setInputCrop(crop); + + model.setPreferableBackend(backend); + model.setPreferableTarget(target); + + std::vector results; + model.detectTextRectangles(frame, results); + + EXPECT_EQ(results.size(), (size_t)1); + for (size_t i = 0; i < results.size(); i++) + { + const RotatedRect& box = results[i]; +#if 0 // test debug + Mat contour; + boxPoints(box, contour); + std::vector contour2i(4); + for (int i = 0; i < 4; i++) + { + contour2i[i].x = cvRound(contour.at(i, 0)); + contour2i[i].y = cvRound(contour.at(i, 1)); + } + std::vector< std::vector > contours; + contours.push_back(contour2i); + + Mat result = frame.clone(); + drawContours(result, contours, -1, Scalar(0, 0, 255), 1); + imshow("result", result); //imwrite("result.png", result); + waitKey(0); +#endif + const RotatedRect& gtBox = gt[i]; + EXPECT_NEAR(box.center.x, gtBox.center.x, EPS_PIXELS); + EXPECT_NEAR(box.center.y, gtBox.center.y, EPS_PIXELS); + EXPECT_NEAR(box.size.width, gtBox.size.width, EPS_PIXELS); + EXPECT_NEAR(box.size.height, gtBox.size.height, EPS_PIXELS); + EXPECT_NEAR(box.angle, gtBox.angle, 1); + } + } }; TEST_P(Test_Model, Classify) @@ -446,6 +595,77 @@ TEST_P(Test_Model, Segmentation) testSegmentationModel(weights_file, config_file, inp, exp, norm, size, mean, scale, swapRB); } +TEST_P(Test_Model, TextRecognition) +{ + if (target == DNN_TARGET_OPENCL_FP16) + applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16); + + std::string imgPath = _tf("text_rec_test.png"); + std::string weightPath = _tf("onnx/models/crnn.onnx", false); + std::string seq = "welcome"; + + Size size{100, 32}; + double scale = 1.0 / 127.5; + Scalar mean = Scalar(127.5); + std::string decodeType = "CTC-greedy"; + std::vector vocabulary = {"0","1","2","3","4","5","6","7","8","9", + "a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z"}; + + testTextRecognitionModel(weightPath, "", imgPath, seq, decodeType, vocabulary, size, mean, scale); +} + +TEST_P(Test_Model, TextDetectionByDB) +{ + if (target == DNN_TARGET_OPENCL_FP16) + applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16); + + std::string imgPath = _tf("text_det_test1.png"); + std::string weightPath = _tf("onnx/models/DB_TD500_resnet50.onnx", false); + + // GroundTruth + std::vector> gt = { + { Point(142, 193), Point(136, 164), Point(213, 150), Point(219, 178) }, + { Point(136, 165), Point(122, 114), Point(319, 71), Point(330, 122) } + }; + + Size size{736, 736}; + double scale = 1.0 / 255.0; + Scalar mean = Scalar(122.67891434, 116.66876762, 104.00698793); + + float binThresh = 0.3; + float polyThresh = 0.5; + uint maxCandidates = 200; + double unclipRatio = 2.0; + + testTextDetectionModelByDB(weightPath, "", imgPath, gt, binThresh, polyThresh, maxCandidates, unclipRatio, size, mean, scale); +} + +TEST_P(Test_Model, TextDetectionByEAST) +{ + if (target == DNN_TARGET_OPENCL_FP16) + applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16); + + std::string imgPath = _tf("text_det_test2.jpg"); + std::string weightPath = _tf("frozen_east_text_detection.pb", false); + + // GroundTruth + std::vector gt = { + RotatedRect(Point2f(657.55f, 409.5f), Size2f(316.84f, 62.45f), -4.79) + }; + + // Model parameters + Size size{320, 320}; + double scale = 1.0; + Scalar mean = Scalar(123.68, 116.78, 103.94); + bool swapRB = true; + + // Detection algorithm parameters + float confThresh = 0.5; + float nmsThresh = 0.4; + + testTextDetectionModelByEAST(weightPath, "", imgPath, gt, confThresh, nmsThresh, size, mean, scale, swapRB); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_Model, dnnBackendsAndTargets()); }} // namespace diff --git a/samples/data/alphabet_36.txt b/samples/data/alphabet_36.txt new file mode 100644 index 0000000000..7104368905 --- /dev/null +++ b/samples/data/alphabet_36.txt @@ -0,0 +1,36 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z diff --git a/samples/data/alphabet_94.txt b/samples/data/alphabet_94.txt new file mode 100644 index 0000000000..87c6d67850 --- /dev/null +++ b/samples/data/alphabet_94.txt @@ -0,0 +1,94 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ +: +; +< += +> +? +@ +[ +\ +] +^ +_ +` +{ +| +} +~ diff --git a/samples/dnn/scene_text_detection.cpp b/samples/dnn/scene_text_detection.cpp new file mode 100644 index 0000000000..5b8626caad --- /dev/null +++ b/samples/dnn/scene_text_detection.cpp @@ -0,0 +1,151 @@ +#include +#include +#include + +#include +#include +#include + +using namespace cv; +using namespace cv::dnn; + +std::string keys = + "{ help h | | Print help message. }" + "{ inputImage i | | Path to an input image. Skip this argument to capture frames from a camera. }" + "{ modelPath mp | | Path to a binary .onnx file contains trained DB detector model. " + "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}" + "{ inputHeight ih |736| image height of the model input. It should be multiple by 32.}" + "{ inputWidth iw |736| image width of the model input. It should be multiple by 32.}" + "{ binaryThreshold bt |0.3| Confidence threshold of the binary map. }" + "{ polygonThreshold pt |0.5| Confidence threshold of polygons. }" + "{ maxCandidate max |200| Max candidates of polygons. }" + "{ unclipRatio ratio |2.0| unclip ratio. }" + "{ evaluate e |false| false: predict with input images; true: evaluate on benchmarks. }" + "{ evalDataPath edp | | Path to benchmarks for evaluation. " + "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"; + +int main(int argc, char** argv) +{ + // Parse arguments + CommandLineParser parser(argc, argv, keys); + parser.about("Use this script to run the official PyTorch implementation (https://github.com/MhLiao/DB) of " + "Real-time Scene Text Detection with Differentiable Binarization (https://arxiv.org/abs/1911.08947)\n" + "The current version of this script is a variant of the original network without deformable convolution"); + if (argc == 1 || parser.has("help")) + { + parser.printMessage(); + return 0; + } + + float binThresh = parser.get("binaryThreshold"); + float polyThresh = parser.get("polygonThreshold"); + uint maxCandidates = parser.get("maxCandidate"); + String modelPath = parser.get("modelPath"); + double unclipRatio = parser.get("unclipRatio"); + int height = parser.get("inputHeight"); + int width = parser.get("inputWidth"); + + if (!parser.check()) + { + parser.printErrors(); + return 1; + } + + // Load the network + CV_Assert(!modelPath.empty()); + TextDetectionModel_DB detector(modelPath); + detector.setBinaryThreshold(binThresh) + .setPolygonThreshold(polyThresh) + .setUnclipRatio(unclipRatio) + .setMaxCandidates(maxCandidates); + + double scale = 1.0 / 255.0; + Size inputSize = Size(width, height); + Scalar mean = Scalar(122.67891434, 116.66876762, 104.00698793); + detector.setInputParams(scale, inputSize, mean); + + // Create a window + static const std::string winName = "TextDetectionModel"; + + if (parser.get("evaluate")) { + // for evaluation + String evalDataPath = parser.get("evalDataPath"); + CV_Assert(!evalDataPath.empty()); + String testListPath = evalDataPath + "/test_list.txt"; + std::ifstream testList; + testList.open(testListPath); + CV_Assert(testList.is_open()); + + // Create a window for showing groundtruth + static const std::string winNameGT = "GT"; + + String testImgPath; + while (std::getline(testList, testImgPath)) { + String imgPath = evalDataPath + "/test_images/" + testImgPath; + std::cout << "Image Path: " << imgPath << std::endl; + + Mat frame = imread(samples::findFile(imgPath), IMREAD_COLOR); + CV_Assert(!frame.empty()); + Mat src = frame.clone(); + + // Inference + std::vector> results; + detector.detect(frame, results); + + polylines(frame, results, true, Scalar(0, 255, 0), 2); + imshow(winName, frame); + + // load groundtruth + String imgName = testImgPath.substr(0, testImgPath.length() - 4); + String gtPath = evalDataPath + "/test_gts/" + imgName + ".txt"; + // std::cout << gtPath << std::endl; + std::ifstream gtFile; + gtFile.open(gtPath); + CV_Assert(gtFile.is_open()); + + std::vector> gts; + String gtLine; + while (std::getline(gtFile, gtLine)) { + size_t splitLoc = gtLine.find_last_of(','); + String text = gtLine.substr(splitLoc+1); + if ( text == "###\r" || text == "1") { + // ignore difficult instances + continue; + } + gtLine = gtLine.substr(0, splitLoc); + + std::regex delimiter(","); + std::vector v(std::sregex_token_iterator(gtLine.begin(), gtLine.end(), delimiter, -1), + std::sregex_token_iterator()); + std::vector loc; + std::vector pts; + for (auto && s : v) { + loc.push_back(atoi(s.c_str())); + } + for (size_t i = 0; i < loc.size() / 2; i++) { + pts.push_back(Point(loc[2 * i], loc[2 * i + 1])); + } + gts.push_back(pts); + } + polylines(src, gts, true, Scalar(0, 255, 0), 2); + imshow(winNameGT, src); + + waitKey(); + } + } else { + // Open an image file + CV_Assert(parser.has("inputImage")); + Mat frame = imread(samples::findFile(parser.get("inputImage"))); + CV_Assert(!frame.empty()); + + // Detect + std::vector> results; + detector.detect(frame, results); + + polylines(frame, results, true, Scalar(0, 255, 0), 2); + imshow(winName, frame); + waitKey(); + } + + return 0; +} diff --git a/samples/dnn/scene_text_recognition.cpp b/samples/dnn/scene_text_recognition.cpp new file mode 100644 index 0000000000..29b14441dd --- /dev/null +++ b/samples/dnn/scene_text_recognition.cpp @@ -0,0 +1,144 @@ +#include +#include + +#include +#include +#include + +using namespace cv; +using namespace cv::dnn; + +String keys = + "{ help h | | Print help message. }" + "{ inputImage i | | Path to an input image. Skip this argument to capture frames from a camera. }" + "{ modelPath mp | | Path to a binary .onnx file contains trained CRNN text recognition model. " + "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}" + "{ RGBInput rgb |0| 0: imread with flags=IMREAD_GRAYSCALE; 1: imread with flags=IMREAD_COLOR. }" + "{ evaluate e |false| false: predict with input images; true: evaluate on benchmarks. }" + "{ evalDataPath edp | | Path to benchmarks for evaluation. " + "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}" + "{ vocabularyPath vp | alphabet_36.txt | Path to recognition vocabulary. " + "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"; + +String convertForEval(String &input); + +int main(int argc, char** argv) +{ + // Parse arguments + CommandLineParser parser(argc, argv, keys); + parser.about("Use this script to run the PyTorch implementation of " + "An End-to-End Trainable Neural Network for Image-based SequenceRecognition and Its Application to Scene Text Recognition " + "(https://arxiv.org/abs/1507.05717)"); + if (argc == 1 || parser.has("help")) + { + parser.printMessage(); + return 0; + } + + String modelPath = parser.get("modelPath"); + String vocPath = parser.get("vocabularyPath"); + int imreadRGB = parser.get("RGBInput"); + + if (!parser.check()) + { + parser.printErrors(); + return 1; + } + + // Load the network + CV_Assert(!modelPath.empty()); + TextRecognitionModel recognizer(modelPath); + + // Load vocabulary + CV_Assert(!vocPath.empty()); + std::ifstream vocFile; + vocFile.open(samples::findFile(vocPath)); + CV_Assert(vocFile.is_open()); + String vocLine; + std::vector vocabulary; + while (std::getline(vocFile, vocLine)) { + vocabulary.push_back(vocLine); + } + recognizer.setVocabulary(vocabulary); + recognizer.setDecodeType("CTC-greedy"); + + // Set parameters + double scale = 1.0 / 127.5; + Scalar mean = Scalar(127.5, 127.5, 127.5); + Size inputSize = Size(100, 32); + recognizer.setInputParams(scale, inputSize, mean); + + if (parser.get("evaluate")) + { + // For evaluation + String evalDataPath = parser.get("evalDataPath"); + CV_Assert(!evalDataPath.empty()); + String gtPath = evalDataPath + "/test_gts.txt"; + std::ifstream evalGts; + evalGts.open(gtPath); + CV_Assert(evalGts.is_open()); + + String gtLine; + int cntRight=0, cntAll=0; + TickMeter timer; + timer.reset(); + + while (std::getline(evalGts, gtLine)) { + size_t splitLoc = gtLine.find_first_of(' '); + String imgPath = evalDataPath + '/' + gtLine.substr(0, splitLoc); + String gt = gtLine.substr(splitLoc+1); + + // Inference + Mat frame = imread(samples::findFile(imgPath), imreadRGB); + CV_Assert(!frame.empty()); + timer.start(); + std::string recognitionResult = recognizer.recognize(frame); + timer.stop(); + + if (gt == convertForEval(recognitionResult)) + cntRight++; + + cntAll++; + } + std::cout << "Accuracy(%): " << (double)(cntRight) / (double)(cntAll) << std::endl; + std::cout << "Average Inference Time(ms): " << timer.getTimeMilli() / (double)(cntAll) << std::endl; + } + else + { + // Create a window + static const std::string winName = "Input Cropped Image"; + + // Open an image file + CV_Assert(parser.has("inputImage")); + Mat frame = imread(samples::findFile(parser.get("inputImage")), imreadRGB); + CV_Assert(!frame.empty()); + + // Recognition + std::string recognitionResult = recognizer.recognize(frame); + + imshow(winName, frame); + std::cout << "Predition: '" << recognitionResult << "'" << std::endl; + waitKey(); + } + + return 0; +} + +// Convert the predictions to lower case, and remove other characters. +// Only for Evaluation +String convertForEval(String & input) +{ + String output; + for (uint i = 0; i < input.length(); i++){ + char ch = input[i]; + if ((int)ch >= 97 && (int)ch <= 122) { + output.push_back(ch); + } else if ((int)ch >= 65 && (int)ch <= 90) { + output.push_back((char)(ch + 32)); + } else { + continue; + } + } + + return output; +} diff --git a/samples/dnn/scene_text_spotting.cpp b/samples/dnn/scene_text_spotting.cpp new file mode 100644 index 0000000000..548289d0e9 --- /dev/null +++ b/samples/dnn/scene_text_spotting.cpp @@ -0,0 +1,169 @@ +#include +#include + +#include +#include +#include + +using namespace cv; +using namespace cv::dnn; + +std::string keys = + "{ help h | | Print help message. }" + "{ inputImage i | | Path to an input image. Skip this argument to capture frames from a camera. }" + "{ detModelPath dmp | | Path to a binary .onnx model for detection. " + "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}" + "{ recModelPath rmp | | Path to a binary .onnx model for recognition. " + "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}" + "{ inputHeight ih |736| image height of the model input. It should be multiple by 32.}" + "{ inputWidth iw |736| image width of the model input. It should be multiple by 32.}" + "{ RGBInput rgb |0| 0: imread with flags=IMREAD_GRAYSCALE; 1: imread with flags=IMREAD_COLOR. }" + "{ binaryThreshold bt |0.3| Confidence threshold of the binary map. }" + "{ polygonThreshold pt |0.5| Confidence threshold of polygons. }" + "{ maxCandidate max |200| Max candidates of polygons. }" + "{ unclipRatio ratio |2.0| unclip ratio. }" + "{ vocabularyPath vp | alphabet_36.txt | Path to benchmarks for evaluation. " + "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"; + +void fourPointsTransform(const Mat& frame, const Point2f vertices[], Mat& result); +bool sortPts(const Point& p1, const Point& p2); + +int main(int argc, char** argv) +{ + // Parse arguments + CommandLineParser parser(argc, argv, keys); + parser.about("Use this script to run an end-to-end inference sample of textDetectionModel and textRecognitionModel APIs\n" + "Use -h for more information"); + if (argc == 1 || parser.has("help")) + { + parser.printMessage(); + return 0; + } + + float binThresh = parser.get("binaryThreshold"); + float polyThresh = parser.get("polygonThreshold"); + uint maxCandidates = parser.get("maxCandidate"); + String detModelPath = parser.get("detModelPath"); + String recModelPath = parser.get("recModelPath"); + String vocPath = parser.get("vocabularyPath"); + double unclipRatio = parser.get("unclipRatio"); + int height = parser.get("inputHeight"); + int width = parser.get("inputWidth"); + int imreadRGB = parser.get("RGBInput"); + + if (!parser.check()) + { + parser.printErrors(); + return 1; + } + + // Load networks + CV_Assert(!detModelPath.empty()); + TextDetectionModel_DB detector(detModelPath); + detector.setBinaryThreshold(binThresh) + .setPolygonThreshold(polyThresh) + .setUnclipRatio(unclipRatio) + .setMaxCandidates(maxCandidates); + + CV_Assert(!recModelPath.empty()); + TextRecognitionModel recognizer(recModelPath); + + // Load vocabulary + CV_Assert(!vocPath.empty()); + std::ifstream vocFile; + vocFile.open(samples::findFile(vocPath)); + CV_Assert(vocFile.is_open()); + String vocLine; + std::vector vocabulary; + while (std::getline(vocFile, vocLine)) { + vocabulary.push_back(vocLine); + } + recognizer.setVocabulary(vocabulary); + recognizer.setDecodeType("CTC-greedy"); + + // Parameters for Detection + double detScale = 1.0 / 255.0; + Size detInputSize = Size(width, height); + Scalar detMean = Scalar(122.67891434, 116.66876762, 104.00698793); + detector.setInputParams(detScale, detInputSize, detMean); + + // Parameters for Recognition + double recScale = 1.0 / 127.5; + Scalar recMean = Scalar(127.5); + Size recInputSize = Size(100, 32); + recognizer.setInputParams(recScale, recInputSize, recMean); + + // Create a window + static const std::string winName = "Text_Spotting"; + + // Input data + Mat frame = imread(samples::findFile(parser.get("inputImage"))); + std::cout << frame.size << std::endl; + + // Inference + std::vector< std::vector > detResults; + detector.detect(frame, detResults); + + if (detResults.size() > 0) { + // Text Recognition + Mat recInput; + if (!imreadRGB) { + cvtColor(frame, recInput, cv::COLOR_BGR2GRAY); + } else { + recInput = frame; + } + std::vector< std::vector > contours; + for (uint i = 0; i < detResults.size(); i++) + { + const auto& quadrangle = detResults[i]; + CV_CheckEQ(quadrangle.size(), (size_t)4, ""); + + contours.emplace_back(quadrangle); + + std::vector quadrangle_2f; + for (int j = 0; j < 4; j++) + quadrangle_2f.emplace_back(quadrangle[j]); + + // Transform and Crop + Mat cropped; + fourPointsTransform(recInput, &quadrangle_2f[0], cropped); + + std::string recognitionResult = recognizer.recognize(cropped); + std::cout << i << ": '" << recognitionResult << "'" << std::endl; + + putText(frame, recognitionResult, quadrangle[3], FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 0, 255), 2); + } + polylines(frame, contours, true, Scalar(0, 255, 0), 2); + } else { + std::cout << "No Text Detected." << std::endl; + } + imshow(winName, frame); + waitKey(); + + return 0; +} + +void fourPointsTransform(const Mat& frame, const Point2f vertices[], Mat& result) +{ + const Size outputSize = Size(100, 32); + + Point2f targetVertices[4] = { + Point(0, outputSize.height - 1), + Point(0, 0), + Point(outputSize.width - 1, 0), + Point(outputSize.width - 1, outputSize.height - 1) + }; + Mat rotationMatrix = getPerspectiveTransform(vertices, targetVertices); + + warpPerspective(frame, result, rotationMatrix, outputSize); + +#if 0 + imshow("roi", result); + waitKey(); +#endif +} + +bool sortPts(const Point& p1, const Point& p2) +{ + return p1.x < p2.x; +} diff --git a/samples/dnn/text_detection.cpp b/samples/dnn/text_detection.cpp index e1314a7de2..76989dcdc2 100644 --- a/samples/dnn/text_detection.cpp +++ b/samples/dnn/text_detection.cpp @@ -2,22 +2,23 @@ Text detection model: https://github.com/argman/EAST Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1 - CRNN Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch - How to convert from pb to onnx: - Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py - - More converted onnx text recognition models can be downloaded directly here: + Text recognition models can be downloaded directly here: Download link: https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing - And these models taken from here:https://github.com/clovaai/deep-text-recognition-benchmark + and doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown + How to convert from pb to onnx: + Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py import torch from models.crnn import CRNN - model = CRNN(32, 1, 37, 256) model.load_state_dict(torch.load('crnn.pth')) dummy_input = torch.randn(1, 1, 32, 100) torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True) + + For more information, please refer to doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown and doc/tutorials/dnn/dnn_OCR/dnn_OCR.markdown */ +#include +#include #include #include @@ -27,21 +28,20 @@ using namespace cv; using namespace cv::dnn; const char* keys = - "{ help h | | Print help message. }" - "{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}" - "{ model m | | Path to a binary .pb file contains trained detector network.}" - "{ ocr | | Path to a binary .pb or .onnx file contains trained recognition network.}" - "{ width | 320 | Preprocess input image by resizing to a specific width. It should be multiple by 32. }" - "{ height | 320 | Preprocess input image by resizing to a specific height. It should be multiple by 32. }" - "{ thr | 0.5 | Confidence threshold. }" - "{ nms | 0.4 | Non-maximum suppression threshold. }"; - -void decodeBoundingBoxes(const Mat& scores, const Mat& geometry, float scoreThresh, - std::vector& detections, std::vector& confidences); - -void fourPointsTransform(const Mat& frame, Point2f vertices[4], Mat& result); - -void decodeText(const Mat& scores, std::string& text); + "{ help h | | Print help message. }" + "{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}" + "{ detModel dmp | | Path to a binary .pb file contains trained detector network.}" + "{ width | 320 | Preprocess input image by resizing to a specific width. It should be multiple by 32. }" + "{ height | 320 | Preprocess input image by resizing to a specific height. It should be multiple by 32. }" + "{ thr | 0.5 | Confidence threshold. }" + "{ nms | 0.4 | Non-maximum suppression threshold. }" + "{ recModel rmp | | Path to a binary .onnx file contains trained CRNN text recognition model. " + "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}" + "{ RGBInput rgb |0| 0: imread with flags=IMREAD_GRAYSCALE; 1: imread with flags=IMREAD_COLOR. }" + "{ vocabularyPath vp | alphabet_36.txt | Path to benchmarks for evaluation. " + "Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"; + +void fourPointsTransform(const Mat& frame, const Point2f vertices[], Mat& result); int main(int argc, char** argv) { @@ -57,10 +57,12 @@ int main(int argc, char** argv) float confThreshold = parser.get("thr"); float nmsThreshold = parser.get("nms"); - int inpWidth = parser.get("width"); - int inpHeight = parser.get("height"); - String modelDecoder = parser.get("model"); - String modelRecognition = parser.get("ocr"); + int width = parser.get("width"); + int height = parser.get("height"); + int imreadRGB = parser.get("RGBInput"); + String detModelPath = parser.get("detModel"); + String recModelPath = parser.get("recModel"); + String vocPath = parser.get("vocabularyPath"); if (!parser.check()) { @@ -68,14 +70,39 @@ int main(int argc, char** argv) return 1; } - CV_Assert(!modelDecoder.empty()); - // Load networks. - Net detector = readNet(modelDecoder); - Net recognizer; - - if (!modelRecognition.empty()) - recognizer = readNet(modelRecognition); + CV_Assert(!detModelPath.empty() && !recModelPath.empty()); + TextDetectionModel_EAST detector(detModelPath); + detector.setConfidenceThreshold(confThreshold) + .setNMSThreshold(nmsThreshold); + + TextRecognitionModel recognizer(recModelPath); + + // Load vocabulary + CV_Assert(!vocPath.empty()); + std::ifstream vocFile; + vocFile.open(samples::findFile(vocPath)); + CV_Assert(vocFile.is_open()); + String vocLine; + std::vector vocabulary; + while (std::getline(vocFile, vocLine)) { + vocabulary.push_back(vocLine); + } + recognizer.setVocabulary(vocabulary); + recognizer.setDecodeType("CTC-greedy"); + + // Parameters for Recognition + double recScale = 1.0 / 127.5; + Scalar recMean = Scalar(127.5, 127.5, 127.5); + Size recInputSize = Size(100, 32); + recognizer.setInputParams(recScale, recInputSize, recMean); + + // Parameters for Detection + double detScale = 1.0; + Size detInputSize = Size(width, height); + Scalar detMean = Scalar(123.68, 116.78, 103.94); + bool swapRB = true; + detector.setInputParams(detScale, detInputSize, detMean, swapRB); // Open a video file or an image file or a camera stream. VideoCapture cap; @@ -83,15 +110,8 @@ int main(int argc, char** argv) CV_Assert(openSuccess); static const std::string kWinName = "EAST: An Efficient and Accurate Scene Text Detector"; - namedWindow(kWinName, WINDOW_NORMAL); - - std::vector outs; - std::vector outNames(2); - outNames[0] = "feature_fusion/Conv_7/Sigmoid"; - outNames[1] = "feature_fusion/concat_3"; - Mat frame, blob; - TickMeter tickMeter; + Mat frame; while (waitKey(1) < 0) { cap >> frame; @@ -101,162 +121,57 @@ int main(int argc, char** argv) break; } - blobFromImage(frame, blob, 1.0, Size(inpWidth, inpHeight), Scalar(123.68, 116.78, 103.94), true, false); - detector.setInput(blob); - tickMeter.start(); - detector.forward(outs, outNames); - tickMeter.stop(); + std::cout << frame.size << std::endl; - Mat scores = outs[0]; - Mat geometry = outs[1]; + // Detection + std::vector< std::vector > detResults; + detector.detect(frame, detResults); - // Decode predicted bounding boxes. - std::vector boxes; - std::vector confidences; - decodeBoundingBoxes(scores, geometry, confThreshold, boxes, confidences); - - // Apply non-maximum suppression procedure. - std::vector indices; - NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices); - - Point2f ratio((float)frame.cols / inpWidth, (float)frame.rows / inpHeight); - - // Render text. - for (size_t i = 0; i < indices.size(); ++i) - { - RotatedRect& box = boxes[indices[i]]; - - Point2f vertices[4]; - box.points(vertices); - - for (int j = 0; j < 4; ++j) - { - vertices[j].x *= ratio.x; - vertices[j].y *= ratio.y; + if (detResults.size() > 0) { + // Text Recognition + Mat recInput; + if (!imreadRGB) { + cvtColor(frame, recInput, cv::COLOR_BGR2GRAY); + } else { + recInput = frame; } - - if (!modelRecognition.empty()) + std::vector< std::vector > contours; + for (uint i = 0; i < detResults.size(); i++) { - Mat cropped; - fourPointsTransform(frame, vertices, cropped); + const auto& quadrangle = detResults[i]; + CV_CheckEQ(quadrangle.size(), (size_t)4, ""); - cvtColor(cropped, cropped, cv::COLOR_BGR2GRAY); + contours.emplace_back(quadrangle); - Mat blobCrop = blobFromImage(cropped, 1.0/127.5, Size(), Scalar::all(127.5)); - recognizer.setInput(blobCrop); + std::vector quadrangle_2f; + for (int j = 0; j < 4; j++) + quadrangle_2f.emplace_back(quadrangle[j]); - tickMeter.start(); - Mat result = recognizer.forward(); - tickMeter.stop(); + Mat cropped; + fourPointsTransform(recInput, &quadrangle_2f[0], cropped); - std::string wordRecognized = ""; - decodeText(result, wordRecognized); - putText(frame, wordRecognized, vertices[1], FONT_HERSHEY_SIMPLEX, 1.5, Scalar(0, 0, 255)); - } + std::string recognitionResult = recognizer.recognize(cropped); + std::cout << i << ": '" << recognitionResult << "'" << std::endl; - for (int j = 0; j < 4; ++j) - line(frame, vertices[j], vertices[(j + 1) % 4], Scalar(0, 255, 0), 1); + putText(frame, recognitionResult, quadrangle[3], FONT_HERSHEY_SIMPLEX, 1.5, Scalar(0, 0, 255), 2); + } + polylines(frame, contours, true, Scalar(0, 255, 0), 2); } - - // Put efficiency information. - std::string label = format("Inference time: %.2f ms", tickMeter.getTimeMilli()); - putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0)); - imshow(kWinName, frame); - - tickMeter.reset(); } return 0; } -void decodeBoundingBoxes(const Mat& scores, const Mat& geometry, float scoreThresh, - std::vector& detections, std::vector& confidences) -{ - detections.clear(); - CV_Assert(scores.dims == 4); CV_Assert(geometry.dims == 4); CV_Assert(scores.size[0] == 1); - CV_Assert(geometry.size[0] == 1); CV_Assert(scores.size[1] == 1); CV_Assert(geometry.size[1] == 5); - CV_Assert(scores.size[2] == geometry.size[2]); CV_Assert(scores.size[3] == geometry.size[3]); - - const int height = scores.size[2]; - const int width = scores.size[3]; - for (int y = 0; y < height; ++y) - { - const float* scoresData = scores.ptr(0, 0, y); - const float* x0_data = geometry.ptr(0, 0, y); - const float* x1_data = geometry.ptr(0, 1, y); - const float* x2_data = geometry.ptr(0, 2, y); - const float* x3_data = geometry.ptr(0, 3, y); - const float* anglesData = geometry.ptr(0, 4, y); - for (int x = 0; x < width; ++x) - { - float score = scoresData[x]; - if (score < scoreThresh) - continue; - - // Decode a prediction. - // Multiple by 4 because feature maps are 4 time less than input image. - float offsetX = x * 4.0f, offsetY = y * 4.0f; - float angle = anglesData[x]; - float cosA = std::cos(angle); - float sinA = std::sin(angle); - float h = x0_data[x] + x2_data[x]; - float w = x1_data[x] + x3_data[x]; - - Point2f offset(offsetX + cosA * x1_data[x] + sinA * x2_data[x], - offsetY - sinA * x1_data[x] + cosA * x2_data[x]); - Point2f p1 = Point2f(-sinA * h, -cosA * h) + offset; - Point2f p3 = Point2f(-cosA * w, sinA * w) + offset; - RotatedRect r(0.5f * (p1 + p3), Size2f(w, h), -angle * 180.0f / (float)CV_PI); - detections.push_back(r); - confidences.push_back(score); - } - } -} - -void fourPointsTransform(const Mat& frame, Point2f vertices[4], Mat& result) +void fourPointsTransform(const Mat& frame, const Point2f vertices[], Mat& result) { const Size outputSize = Size(100, 32); - Point2f targetVertices[4] = {Point(0, outputSize.height - 1), - Point(0, 0), Point(outputSize.width - 1, 0), - Point(outputSize.width - 1, outputSize.height - 1), - }; + Point2f targetVertices[4] = { + Point(0, outputSize.height - 1), + Point(0, 0), Point(outputSize.width - 1, 0), + Point(outputSize.width - 1, outputSize.height - 1) + }; Mat rotationMatrix = getPerspectiveTransform(vertices, targetVertices); warpPerspective(frame, result, rotationMatrix, outputSize); } - -void decodeText(const Mat& scores, std::string& text) -{ - static const std::string alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"; - Mat scoresMat = scores.reshape(1, scores.size[0]); - - std::vector elements; - elements.reserve(scores.size[0]); - - for (int rowIndex = 0; rowIndex < scoresMat.rows; ++rowIndex) - { - Point p; - minMaxLoc(scoresMat.row(rowIndex), 0, 0, 0, &p); - if (p.x > 0 && static_cast(p.x) <= alphabet.size()) - { - elements.push_back(alphabet[p.x - 1]); - } - else - { - elements.push_back('-'); - } - } - - if (elements.size() > 0 && elements[0] != '-') - text += elements[0]; - - for (size_t elementIndex = 1; elementIndex < elements.size(); ++elementIndex) - { - if (elementIndex > 0 && elements[elementIndex] != '-' && - elements[elementIndex - 1] != elements[elementIndex]) - { - text += elements[elementIndex]; - } - } -} \ No newline at end of file