|
|
|
@ -3,8 +3,10 @@ |
|
|
|
|
// of this distribution and at http://opencv.org/license.html.
|
|
|
|
|
|
|
|
|
|
#include "precomp.hpp" |
|
|
|
|
#include "math_utils.hpp" |
|
|
|
|
#include <algorithm> |
|
|
|
|
#include <utility> |
|
|
|
|
#include <unordered_map> |
|
|
|
|
#include <iterator> |
|
|
|
|
|
|
|
|
|
#include <opencv2/imgproc.hpp> |
|
|
|
@ -552,6 +554,9 @@ struct TextRecognitionModel_Impl : public Model::Impl |
|
|
|
|
std::string decodeType; |
|
|
|
|
std::vector<std::string> vocabulary; |
|
|
|
|
|
|
|
|
|
int beamSize = 10; |
|
|
|
|
int vocPruneSize = 0; |
|
|
|
|
|
|
|
|
|
TextRecognitionModel_Impl() |
|
|
|
|
{ |
|
|
|
|
CV_TRACE_FUNCTION(); |
|
|
|
@ -575,6 +580,13 @@ struct TextRecognitionModel_Impl : public Model::Impl |
|
|
|
|
decodeType = type; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
inline |
|
|
|
|
void setDecodeOptsCTCPrefixBeamSearch(int beam, int vocPrune) |
|
|
|
|
{ |
|
|
|
|
beamSize = beam; |
|
|
|
|
vocPruneSize = vocPrune; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual |
|
|
|
|
std::string decode(const Mat& prediction) |
|
|
|
|
{ |
|
|
|
@ -586,53 +598,213 @@ struct TextRecognitionModel_Impl : public Model::Impl |
|
|
|
|
CV_Error(Error::StsBadArg, "TextRecognitionModel: vocabulary is not specified"); |
|
|
|
|
|
|
|
|
|
std::string decodeSeq; |
|
|
|
|
if (decodeType == "CTC-greedy") |
|
|
|
|
if (decodeType == "CTC-greedy") { |
|
|
|
|
decodeSeq = ctcGreedyDecode(prediction); |
|
|
|
|
} else if (decodeType == "CTC-prefix-beam-search") { |
|
|
|
|
decodeSeq = ctcPrefixBeamSearchDecode(prediction); |
|
|
|
|
} else if (decodeType.length() == 0) { |
|
|
|
|
CV_Error(Error::StsBadArg, "Please set decodeType"); |
|
|
|
|
} else { |
|
|
|
|
CV_Error_(Error::StsBadArg, ("Unsupported decodeType: %s", decodeType.c_str())); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return decodeSeq; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual |
|
|
|
|
std::string ctcGreedyDecode(const Mat& prediction) |
|
|
|
|
{ |
|
|
|
|
std::string decodeSeq; |
|
|
|
|
CV_CheckEQ(prediction.dims, 3, ""); |
|
|
|
|
CV_CheckType(prediction.type(), CV_32FC1, ""); |
|
|
|
|
const int vocLength = (int)(vocabulary.size()); |
|
|
|
|
CV_CheckLE(prediction.size[1], vocLength, ""); |
|
|
|
|
bool ctcFlag = true; |
|
|
|
|
int lastLoc = 0; |
|
|
|
|
for (int i = 0; i < prediction.size[0]; i++) |
|
|
|
|
{ |
|
|
|
|
CV_CheckEQ(prediction.dims, 3, ""); |
|
|
|
|
CV_CheckType(prediction.type(), CV_32FC1, ""); |
|
|
|
|
const int vocLength = (int)(vocabulary.size()); |
|
|
|
|
CV_CheckLE(prediction.size[1], vocLength, ""); |
|
|
|
|
bool ctcFlag = true; |
|
|
|
|
int lastLoc = 0; |
|
|
|
|
for (int i = 0; i < prediction.size[0]; i++) |
|
|
|
|
const float* pred = prediction.ptr<float>(i); |
|
|
|
|
int maxLoc = 0; |
|
|
|
|
float maxScore = pred[0]; |
|
|
|
|
for (int j = 1; j < vocLength + 1; j++) |
|
|
|
|
{ |
|
|
|
|
const float* pred = prediction.ptr<float>(i); |
|
|
|
|
int maxLoc = 0; |
|
|
|
|
float maxScore = pred[0]; |
|
|
|
|
for (int j = 1; j < vocLength + 1; j++) |
|
|
|
|
float score = pred[j]; |
|
|
|
|
if (maxScore < score) |
|
|
|
|
{ |
|
|
|
|
float score = pred[j]; |
|
|
|
|
if (maxScore < score) |
|
|
|
|
{ |
|
|
|
|
maxScore = score; |
|
|
|
|
maxLoc = j; |
|
|
|
|
} |
|
|
|
|
maxScore = score; |
|
|
|
|
maxLoc = j; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (maxLoc > 0) |
|
|
|
|
{ |
|
|
|
|
std::string currentChar = vocabulary.at(maxLoc - 1); |
|
|
|
|
if (maxLoc != lastLoc || ctcFlag) |
|
|
|
|
{ |
|
|
|
|
lastLoc = maxLoc; |
|
|
|
|
decodeSeq += currentChar; |
|
|
|
|
ctcFlag = false; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
if (maxLoc > 0) |
|
|
|
|
{ |
|
|
|
|
std::string currentChar = vocabulary.at(maxLoc - 1); |
|
|
|
|
if (maxLoc != lastLoc || ctcFlag) |
|
|
|
|
{ |
|
|
|
|
ctcFlag = true; |
|
|
|
|
lastLoc = maxLoc; |
|
|
|
|
decodeSeq += currentChar; |
|
|
|
|
ctcFlag = false; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} else if (decodeType.length() == 0) { |
|
|
|
|
CV_Error(Error::StsBadArg, "Please set decodeType"); |
|
|
|
|
} else { |
|
|
|
|
CV_Error_(Error::StsBadArg, ("Unsupported decodeType: %s", decodeType.c_str())); |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
ctcFlag = true; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return decodeSeq; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
struct PrefixScore |
|
|
|
|
{ |
|
|
|
|
// blank ending score
|
|
|
|
|
float pB; |
|
|
|
|
// none blank ending score
|
|
|
|
|
float pNB; |
|
|
|
|
|
|
|
|
|
PrefixScore() : pB(kNegativeInfinity), pNB(kNegativeInfinity) |
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
PrefixScore(float pB, float pNB) : pB(pB), pNB(pNB) |
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
struct PrefixHash |
|
|
|
|
{ |
|
|
|
|
size_t operator()(const std::vector<int>& prefix) const |
|
|
|
|
{ |
|
|
|
|
// BKDR hash
|
|
|
|
|
unsigned int seed = 131; |
|
|
|
|
size_t hash = 0; |
|
|
|
|
for (size_t i = 0; i < prefix.size(); i++) |
|
|
|
|
{ |
|
|
|
|
hash = hash * seed + prefix[i]; |
|
|
|
|
} |
|
|
|
|
return hash; |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
static |
|
|
|
|
std::vector<std::pair<float, int>> TopK( |
|
|
|
|
const float* predictions, int length, int k) |
|
|
|
|
{ |
|
|
|
|
std::vector<std::pair<float, int>> results; |
|
|
|
|
// No prune.
|
|
|
|
|
if (k <= 0) |
|
|
|
|
{ |
|
|
|
|
for (int i = 0; i < length; ++i) |
|
|
|
|
{ |
|
|
|
|
results.emplace_back(predictions[i], i); |
|
|
|
|
} |
|
|
|
|
return results; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for (int i = 0; i < k; ++i) |
|
|
|
|
{ |
|
|
|
|
results.emplace_back(predictions[i], i); |
|
|
|
|
} |
|
|
|
|
std::make_heap(results.begin(), results.end(), std::greater<std::pair<float, int>>{}); |
|
|
|
|
|
|
|
|
|
for (int i = k; i < length; ++i) |
|
|
|
|
{ |
|
|
|
|
if (predictions[i] > results.front().first) |
|
|
|
|
{ |
|
|
|
|
std::pop_heap(results.begin(), results.end(), std::greater<std::pair<float, int>>{}); |
|
|
|
|
results.pop_back(); |
|
|
|
|
results.emplace_back(predictions[i], i); |
|
|
|
|
std::push_heap(results.begin(), results.end(), std::greater<std::pair<float, int>>{}); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return results; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static inline |
|
|
|
|
bool PrefixScoreCompare( |
|
|
|
|
const std::pair<std::vector<int>, PrefixScore>& a, |
|
|
|
|
const std::pair<std::vector<int>, PrefixScore>& b) |
|
|
|
|
{ |
|
|
|
|
float probA = LogAdd(a.second.pB, a.second.pNB); |
|
|
|
|
float probB = LogAdd(b.second.pB, b.second.pNB); |
|
|
|
|
return probA > probB; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual |
|
|
|
|
std::string ctcPrefixBeamSearchDecode(const Mat& prediction) { |
|
|
|
|
// CTC prefix beam seach decode.
|
|
|
|
|
// For more detail, refer to:
|
|
|
|
|
// https://distill.pub/2017/ctc/#inference
|
|
|
|
|
// https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0i
|
|
|
|
|
using Beam = std::vector<std::pair<std::vector<int>, PrefixScore>>; |
|
|
|
|
using BeamInDict = std::unordered_map<std::vector<int>, PrefixScore, PrefixHash>; |
|
|
|
|
|
|
|
|
|
CV_CheckType(prediction.type(), CV_32FC1, ""); |
|
|
|
|
CV_CheckEQ(prediction.dims, 3, ""); |
|
|
|
|
CV_CheckEQ(prediction.size[1], 1, ""); |
|
|
|
|
CV_CheckEQ(prediction.size[2], (int)vocabulary.size() + 1, ""); // Length add 1 for ctc blank
|
|
|
|
|
|
|
|
|
|
std::string decodeSeq; |
|
|
|
|
Beam beam = {std::make_pair(std::vector<int>(), PrefixScore(0.0, kNegativeInfinity))}; |
|
|
|
|
for (int i = 0; i < prediction.size[0]; i++) |
|
|
|
|
{ |
|
|
|
|
// Loop over time
|
|
|
|
|
BeamInDict nextBeam; |
|
|
|
|
const float* pred = prediction.ptr<float>(i); |
|
|
|
|
std::vector<std::pair<float, int>> topkPreds = |
|
|
|
|
TopK(pred, vocabulary.size() + 1, vocPruneSize); |
|
|
|
|
for (const auto& each : topkPreds) |
|
|
|
|
{ |
|
|
|
|
// Loop over vocabulary
|
|
|
|
|
float prob = each.first; |
|
|
|
|
int token = each.second; |
|
|
|
|
for (const auto& it : beam) |
|
|
|
|
{ |
|
|
|
|
const std::vector<int>& prefix = it.first; |
|
|
|
|
const PrefixScore& prefixScore = it.second; |
|
|
|
|
if (token == 0) // 0 stands for ctc blank
|
|
|
|
|
{ |
|
|
|
|
PrefixScore& nextScore = nextBeam[prefix]; |
|
|
|
|
nextScore.pB = LogAdd(nextScore.pB, |
|
|
|
|
LogAdd(prefixScore.pB + prob, prefixScore.pNB + prob)); |
|
|
|
|
continue; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
std::vector<int> nPrefix(prefix); |
|
|
|
|
nPrefix.push_back(token); |
|
|
|
|
PrefixScore& nextScore = nextBeam[nPrefix]; |
|
|
|
|
if (prefix.size() > 0 && token == prefix.back()) |
|
|
|
|
{ |
|
|
|
|
nextScore.pNB = LogAdd(nextScore.pNB, prefixScore.pB + prob); |
|
|
|
|
PrefixScore& mScore = nextBeam[prefix]; |
|
|
|
|
mScore.pNB = LogAdd(mScore.pNB, prefixScore.pNB + prob); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
nextScore.pNB = LogAdd(nextScore.pNB, |
|
|
|
|
LogAdd(prefixScore.pB + prob, prefixScore.pNB + prob)); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
// Beam prune
|
|
|
|
|
Beam newBeam(nextBeam.begin(), nextBeam.end()); |
|
|
|
|
int newBeamSize = std::min(static_cast<int>(newBeam.size()), beamSize); |
|
|
|
|
std::nth_element(newBeam.begin(), newBeam.begin() + newBeamSize, |
|
|
|
|
newBeam.end(), PrefixScoreCompare); |
|
|
|
|
newBeam.resize(newBeamSize); |
|
|
|
|
std::sort(newBeam.begin(), newBeam.end(), PrefixScoreCompare); |
|
|
|
|
beam = std::move(newBeam); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
CV_Assert(!beam.empty()); |
|
|
|
|
for (int token : beam[0].first) |
|
|
|
|
{ |
|
|
|
|
CV_Check(token, token > 0 && token <= vocabulary.size(), ""); |
|
|
|
|
decodeSeq += vocabulary.at(token - 1); |
|
|
|
|
} |
|
|
|
|
return decodeSeq; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual |
|
|
|
|
std::string recognize(InputArray frame) |
|
|
|
|
{ |
|
|
|
@ -698,6 +870,12 @@ const std::string& TextRecognitionModel::getDecodeType() const |
|
|
|
|
return TextRecognitionModel_Impl::from(impl).decodeType; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TextRecognitionModel& TextRecognitionModel::setDecodeOptsCTCPrefixBeamSearch(int beamSize, int vocPruneSize) |
|
|
|
|
{ |
|
|
|
|
TextRecognitionModel_Impl::from(impl).setDecodeOptsCTCPrefixBeamSearch(beamSize, vocPruneSize); |
|
|
|
|
return *this; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TextRecognitionModel& TextRecognitionModel::setVocabulary(const std::vector<std::string>& inputVoc) |
|
|
|
|
{ |
|
|
|
|
TextRecognitionModel_Impl::from(impl).setVocabulary(inputVoc); |
|
|
|
|