diff --git a/modules/dnn/samples/alexnet.cpp b/modules/dnn/samples/alexnet.cpp index 45bdba1bc..cc5f89f9c 100644 --- a/modules/dnn/samples/alexnet.cpp +++ b/modules/dnn/samples/alexnet.cpp @@ -2,9 +2,45 @@ #include #include #include +#include +#include using namespace cv; using namespace cv::dnn; +typedef std::pair ClassProb; + +ClassProb getMaxClass(Blob &probBlob, int sampleNum = 0) +{ + int numClasses = (int)probBlob.total(1); + Mat probMat(1, numClasses, CV_32F, probBlob.ptr(sampleNum)); + + double prob; + Point probLoc; + minMaxLoc(probMat, NULL, &prob, NULL, &probLoc); + + return std::make_pair(probLoc.x, prob); +} + +std::vector CLASES_NAMES; + +void initClassesNames() +{ + std::ifstream fp("ILSVRC2012_synsets.txt"); + CV_Assert(fp.is_open()); + + std::string name; + while (!fp.eof()) + { + std::getline(fp, name); + CLASES_NAMES.push_back(name); + } + + CV_Assert(CLASES_NAMES.size() == 1000); + + fp.close(); +} + + int main(void) { Net net; @@ -13,16 +49,28 @@ int main(void) importer->populateNet(net); } - Mat img = imread("alexnet.png"); + Mat img = imread("zebra.jpg"); CV_Assert(!img.empty()); - img.convertTo(img, CV_32F, 1.0 / 255); + cvtColor(img, img, COLOR_BGR2RGB); + img.convertTo(img, CV_32F); + subtract(img, cv::mean(img), img); Blob imgBlob(img); net.setBlob("data", imgBlob); - net.forward(); - Blob res = net.getBlob("prob"); + Blob probBlob = net.getBlob("prob"); + ClassProb bc = getMaxClass(probBlob); + + initClassesNames(); + std::string className = (bc.first < (int)CLASES_NAMES.size()) ? CLASES_NAMES[bc.first] : "unnamed"; + + std::cout << "Best class:"; + std::cout << " #" << bc.first; + std::cout << " (from " << probBlob.total(1) << ")"; + std::cout << " \"" + className << "\""; + std::cout << std::endl; + std::cout << "Prob: " << bc.second * 100 << "%" << std::endl; return 0; } \ No newline at end of file