Implement ctc prefix beam search decode for TextRecognitionModel.

The algorithm is based on Hannun's paper: First-Pass Large Vocabulary
Continuous Speech Recognition using Bi-Directional Recurrent DNNs
pull/20524/head
JIANG Yichen 3 years ago
parent ea068dcc2c
commit 955cf35d5f
  1. 5
      doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown
  2. 13
      modules/dnn/include/opencv2/dnn/dnn.hpp
  3. 83
      modules/dnn/src/math_utils.hpp
  4. 248
      modules/dnn/src/model.cpp
  5. 19
      modules/dnn/test/test_model.cpp

@ -26,6 +26,11 @@ Before recognition, you should `setVocabulary` and `setDecodeType`.
- `T` is the sequence length
- `B` is the batch size (only support `B=1` in inference)
- and `Dim` is the length of vocabulary +1('Blank' of CTC is at the index=0 of Dim).
- "CTC-prefix-beam-search", the output of the text recognition model should be a probability matrix same with "CTC-greedy".
- The algorithm is proposed at Hannun's [paper](https://arxiv.org/abs/1408.2873).
- `setDecodeOptsCTCPrefixBeamSearch` could be used to control the beam size in search step.
- To futher optimize for big vocabulary, a new option `vocPruneSize` is introduced to avoid iterate the whole vocbulary
but only the number of `vocPruneSize` tokens with top probabilty.
@ref cv::dnn::TextRecognitionModel::recognize() is the main function for text recognition.
- The input image should be a cropped text image or an image with `roiRects`

@ -1373,7 +1373,9 @@ public:
/**
* @brief Set the decoding method of translating the network output into string
* @param[in] decodeType The decoding method of translating the network output into string: {'CTC-greedy': greedy decoding for the output of CTC-based methods}
* @param[in] decodeType The decoding method of translating the network output into string, currently supported type:
* - `"CTC-greedy"` greedy decoding for the output of CTC-based methods
* - `"CTC-prefix-beam-search"` Prefix beam search decoding for the output of CTC-based methods
*/
CV_WRAP
TextRecognitionModel& setDecodeType(const std::string& decodeType);
@ -1385,6 +1387,15 @@ public:
CV_WRAP
const std::string& getDecodeType() const;
/**
* @brief Set the decoding method options for `"CTC-prefix-beam-search"` decode usage
* @param[in] beamSize Beam size for search
* @param[in] vocPruneSize Parameter to optimize big vocabulary search,
* only take top @p vocPruneSize tokens in each search step, @p vocPruneSize <= 0 stands for disable this prune.
*/
CV_WRAP
TextRecognitionModel& setDecodeOptsCTCPrefixBeamSearch(int beamSize, int vocPruneSize = 0);
/**
* @brief Set the vocabulary for recognition.
* @param[in] vocabulary the associated vocabulary of the network.

@ -0,0 +1,83 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Code is borrowed from https://github.com/kaldi-asr/kaldi/blob/master/src/base/kaldi-math.h
// base/kaldi-math.h
// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Yanmin Qian;
// Jan Silovsky; Saarland University
//
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef __OPENCV_DNN_MATH_UTILS_HPP__
#define __OPENCV_DNN_MATH_UTILS_HPP__
#ifdef OS_QNX
#include <math.h>
#else
#include <cmath>
#endif
#include <limits>
#ifndef FLT_EPSILON
#define FLT_EPSILON 1.19209290e-7f
#endif
namespace cv { namespace dnn {
const float kNegativeInfinity = -std::numeric_limits<float>::infinity();
const float kMinLogDiffFloat = std::log(FLT_EPSILON);
#if !defined(_MSC_VER) || (_MSC_VER >= 1700)
inline float Log1p(float x) { return log1pf(x); }
#else
inline float Log1p(float x) {
const float cutoff = 1.0e-07;
if (x < cutoff)
return x - 2 * x * x;
else
return Log(1.0 + x);
}
#endif
inline float Exp(float x) { return expf(x); }
inline float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff >= kMinLogDiffFloat) {
float res;
res = x + Log1p(Exp(diff));
return res;
} else {
return x; // return the larger one.
}
}
}} // namespace
#endif // __OPENCV_DNN_MATH_UTILS_HPP__

@ -3,8 +3,10 @@
// of this distribution and at http://opencv.org/license.html.
#include "precomp.hpp"
#include "math_utils.hpp"
#include <algorithm>
#include <utility>
#include <unordered_map>
#include <iterator>
#include <opencv2/imgproc.hpp>
@ -552,6 +554,9 @@ struct TextRecognitionModel_Impl : public Model::Impl
std::string decodeType;
std::vector<std::string> vocabulary;
int beamSize = 10;
int vocPruneSize = 0;
TextRecognitionModel_Impl()
{
CV_TRACE_FUNCTION();
@ -575,6 +580,13 @@ struct TextRecognitionModel_Impl : public Model::Impl
decodeType = type;
}
inline
void setDecodeOptsCTCPrefixBeamSearch(int beam, int vocPrune)
{
beamSize = beam;
vocPruneSize = vocPrune;
}
virtual
std::string decode(const Mat& prediction)
{
@ -586,53 +598,213 @@ struct TextRecognitionModel_Impl : public Model::Impl
CV_Error(Error::StsBadArg, "TextRecognitionModel: vocabulary is not specified");
std::string decodeSeq;
if (decodeType == "CTC-greedy")
if (decodeType == "CTC-greedy") {
decodeSeq = ctcGreedyDecode(prediction);
} else if (decodeType == "CTC-prefix-beam-search") {
decodeSeq = ctcPrefixBeamSearchDecode(prediction);
} else if (decodeType.length() == 0) {
CV_Error(Error::StsBadArg, "Please set decodeType");
} else {
CV_Error_(Error::StsBadArg, ("Unsupported decodeType: %s", decodeType.c_str()));
}
return decodeSeq;
}
virtual
std::string ctcGreedyDecode(const Mat& prediction)
{
std::string decodeSeq;
CV_CheckEQ(prediction.dims, 3, "");
CV_CheckType(prediction.type(), CV_32FC1, "");
const int vocLength = (int)(vocabulary.size());
CV_CheckLE(prediction.size[1], vocLength, "");
bool ctcFlag = true;
int lastLoc = 0;
for (int i = 0; i < prediction.size[0]; i++)
{
CV_CheckEQ(prediction.dims, 3, "");
CV_CheckType(prediction.type(), CV_32FC1, "");
const int vocLength = (int)(vocabulary.size());
CV_CheckLE(prediction.size[1], vocLength, "");
bool ctcFlag = true;
int lastLoc = 0;
for (int i = 0; i < prediction.size[0]; i++)
const float* pred = prediction.ptr<float>(i);
int maxLoc = 0;
float maxScore = pred[0];
for (int j = 1; j < vocLength + 1; j++)
{
const float* pred = prediction.ptr<float>(i);
int maxLoc = 0;
float maxScore = pred[0];
for (int j = 1; j < vocLength + 1; j++)
float score = pred[j];
if (maxScore < score)
{
float score = pred[j];
if (maxScore < score)
{
maxScore = score;
maxLoc = j;
}
maxScore = score;
maxLoc = j;
}
}
if (maxLoc > 0)
{
std::string currentChar = vocabulary.at(maxLoc - 1);
if (maxLoc != lastLoc || ctcFlag)
{
lastLoc = maxLoc;
decodeSeq += currentChar;
ctcFlag = false;
}
}
else
if (maxLoc > 0)
{
std::string currentChar = vocabulary.at(maxLoc - 1);
if (maxLoc != lastLoc || ctcFlag)
{
ctcFlag = true;
lastLoc = maxLoc;
decodeSeq += currentChar;
ctcFlag = false;
}
}
} else if (decodeType.length() == 0) {
CV_Error(Error::StsBadArg, "Please set decodeType");
} else {
CV_Error_(Error::StsBadArg, ("Unsupported decodeType: %s", decodeType.c_str()));
else
{
ctcFlag = true;
}
}
return decodeSeq;
}
struct PrefixScore
{
// blank ending score
float pB;
// none blank ending score
float pNB;
PrefixScore() : pB(kNegativeInfinity), pNB(kNegativeInfinity)
{
}
PrefixScore(float pB, float pNB) : pB(pB), pNB(pNB)
{
}
};
struct PrefixHash
{
size_t operator()(const std::vector<int>& prefix) const
{
// BKDR hash
unsigned int seed = 131;
size_t hash = 0;
for (size_t i = 0; i < prefix.size(); i++)
{
hash = hash * seed + prefix[i];
}
return hash;
}
};
static
std::vector<std::pair<float, int>> TopK(
const float* predictions, int length, int k)
{
std::vector<std::pair<float, int>> results;
// No prune.
if (k <= 0)
{
for (int i = 0; i < length; ++i)
{
results.emplace_back(predictions[i], i);
}
return results;
}
for (int i = 0; i < k; ++i)
{
results.emplace_back(predictions[i], i);
}
std::make_heap(results.begin(), results.end(), std::greater<std::pair<float, int>>{});
for (int i = k; i < length; ++i)
{
if (predictions[i] > results.front().first)
{
std::pop_heap(results.begin(), results.end(), std::greater<std::pair<float, int>>{});
results.pop_back();
results.emplace_back(predictions[i], i);
std::push_heap(results.begin(), results.end(), std::greater<std::pair<float, int>>{});
}
}
return results;
}
static inline
bool PrefixScoreCompare(
const std::pair<std::vector<int>, PrefixScore>& a,
const std::pair<std::vector<int>, PrefixScore>& b)
{
float probA = LogAdd(a.second.pB, a.second.pNB);
float probB = LogAdd(b.second.pB, b.second.pNB);
return probA > probB;
}
virtual
std::string ctcPrefixBeamSearchDecode(const Mat& prediction) {
// CTC prefix beam seach decode.
// For more detail, refer to:
// https://distill.pub/2017/ctc/#inference
// https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0i
using Beam = std::vector<std::pair<std::vector<int>, PrefixScore>>;
using BeamInDict = std::unordered_map<std::vector<int>, PrefixScore, PrefixHash>;
CV_CheckType(prediction.type(), CV_32FC1, "");
CV_CheckEQ(prediction.dims, 3, "");
CV_CheckEQ(prediction.size[1], 1, "");
CV_CheckEQ(prediction.size[2], (int)vocabulary.size() + 1, ""); // Length add 1 for ctc blank
std::string decodeSeq;
Beam beam = {std::make_pair(std::vector<int>(), PrefixScore(0.0, kNegativeInfinity))};
for (int i = 0; i < prediction.size[0]; i++)
{
// Loop over time
BeamInDict nextBeam;
const float* pred = prediction.ptr<float>(i);
std::vector<std::pair<float, int>> topkPreds =
TopK(pred, vocabulary.size() + 1, vocPruneSize);
for (const auto& each : topkPreds)
{
// Loop over vocabulary
float prob = each.first;
int token = each.second;
for (const auto& it : beam)
{
const std::vector<int>& prefix = it.first;
const PrefixScore& prefixScore = it.second;
if (token == 0) // 0 stands for ctc blank
{
PrefixScore& nextScore = nextBeam[prefix];
nextScore.pB = LogAdd(nextScore.pB,
LogAdd(prefixScore.pB + prob, prefixScore.pNB + prob));
continue;
}
std::vector<int> nPrefix(prefix);
nPrefix.push_back(token);
PrefixScore& nextScore = nextBeam[nPrefix];
if (prefix.size() > 0 && token == prefix.back())
{
nextScore.pNB = LogAdd(nextScore.pNB, prefixScore.pB + prob);
PrefixScore& mScore = nextBeam[prefix];
mScore.pNB = LogAdd(mScore.pNB, prefixScore.pNB + prob);
}
else
{
nextScore.pNB = LogAdd(nextScore.pNB,
LogAdd(prefixScore.pB + prob, prefixScore.pNB + prob));
}
}
}
// Beam prune
Beam newBeam(nextBeam.begin(), nextBeam.end());
int newBeamSize = std::min(static_cast<int>(newBeam.size()), beamSize);
std::nth_element(newBeam.begin(), newBeam.begin() + newBeamSize,
newBeam.end(), PrefixScoreCompare);
newBeam.resize(newBeamSize);
std::sort(newBeam.begin(), newBeam.end(), PrefixScoreCompare);
beam = std::move(newBeam);
}
CV_Assert(!beam.empty());
for (int token : beam[0].first)
{
CV_Check(token, token > 0 && token <= vocabulary.size(), "");
decodeSeq += vocabulary.at(token - 1);
}
return decodeSeq;
}
virtual
std::string recognize(InputArray frame)
{
@ -698,6 +870,12 @@ const std::string& TextRecognitionModel::getDecodeType() const
return TextRecognitionModel_Impl::from(impl).decodeType;
}
TextRecognitionModel& TextRecognitionModel::setDecodeOptsCTCPrefixBeamSearch(int beamSize, int vocPruneSize)
{
TextRecognitionModel_Impl::from(impl).setDecodeOptsCTCPrefixBeamSearch(beamSize, vocPruneSize);
return *this;
}
TextRecognitionModel& TextRecognitionModel::setVocabulary(const std::vector<std::string>& inputVoc)
{
TextRecognitionModel_Impl::from(impl).setVocabulary(inputVoc);

@ -615,6 +615,25 @@ TEST_P(Test_Model, TextRecognition)
testTextRecognitionModel(weightPath, "", imgPath, seq, decodeType, vocabulary, size, mean, scale);
}
TEST_P(Test_Model, TextRecognitionWithCTCPrefixBeamSearch)
{
if (target == DNN_TARGET_OPENCL_FP16)
applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
std::string imgPath = _tf("text_rec_test.png");
std::string weightPath = _tf("onnx/models/crnn.onnx", false);
std::string seq = "welcome";
Size size{100, 32};
double scale = 1.0 / 127.5;
Scalar mean = Scalar(127.5);
std::string decodeType = "CTC-prefix-beam-search";
std::vector<std::string> vocabulary = {"0","1","2","3","4","5","6","7","8","9",
"a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z"};
testTextRecognitionModel(weightPath, "", imgPath, seq, decodeType, vocabulary, size, mean, scale);
}
TEST_P(Test_Model, TextDetectionByDB)
{
if (target == DNN_TARGET_OPENCL_FP16)

Loading…
Cancel
Save