diff --git a/modules/dnn/samples/classify_with_googlenet.cpp b/modules/dnn/samples/classify_with_googlenet.cpp index c39263204..72c1db174 100644 --- a/modules/dnn/samples/classify_with_googlenet.cpp +++ b/modules/dnn/samples/classify_with_googlenet.cpp @@ -1,70 +1,156 @@ -#include +/*M/////////////////////////////////////////////////////////////////////////////////////// +// +// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. +// +// By downloading, copying, installing or using the software you agree to this license. +// If you do not agree to this license, do not download, install, +// copy or use the software. +// +// +// License Agreement +// For Open Source Computer Vision Library +// +// Copyright (C) 2013, OpenCV Foundation, all rights reserved. +// Third party copyrights are property of their respective owners. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistribution's of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistribution's in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * The name of the copyright holders may not be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// This software is provided by the copyright holders and contributors "as is" and +// any express or implied warranties, including, but not limited to, the implied +// warranties of merchantability and fitness for a particular purpose are disclaimed. +// In no event shall the Intel Corporation or contributors be liable for any direct, +// indirect, incidental, special, exemplary, or consequential damages +// (including, but not limited to, procurement of substitute goods or services; +// loss of use, data, or profits; or business interruption) however caused +// and on any theory of liability, whether in contract, strict liability, +// or tort (including negligence or otherwise) arising in any way out of +// the use of this software, even if advised of the possibility of such damage. +// +//M*/ + #include #include #include +using namespace cv; +using namespace cv::dnn; + #include #include #include -using namespace cv; -using namespace cv::dnn; +#include +using namespace std; +/* It contains class number and probability of this class */ typedef std::pair ClassProb; -ClassProb getMaxClass(Blob &probBlob, int sampleNum = 0) +/* Find best class for the blob (i. e. class with maximal probability) */ +ClassProb getMaxClass(dnn::Blob &probBlob) { - int numClasses = (int)probBlob.total(1); - Mat probMat(1, numClasses, CV_32F, probBlob.ptrf(sampleNum)); + Mat probMat = probBlob.matRefConst().reshape(1, 1); + double classProb; + Point classNumber; + minMaxLoc(probMat, NULL, &classProb, NULL, &classNumber); - double prob; - Point probLoc; - minMaxLoc(probMat, NULL, &prob, NULL, &probLoc); - - return std::make_pair(probLoc.x, prob); + return std::make_pair(classNumber.x, classProb); } -std::vector CLASSES_NAMES; - -void initClassesNames() +std::vector readClassNames(const char *filename = "synset_words.txt") { - std::ifstream fp("synset_words.txt"); - CV_Assert(fp.is_open()); + std::vector classNames; + + std::ifstream fp(filename); + if (!fp.is_open()) + { + std::cerr << "File with classes labels not found" << std::endl; + std::cerr << "Check it: " << filename << std::endl; + exit(-1); + } std::string name; while (!fp.eof()) { std::getline(fp, name); if (name.length()) - CLASSES_NAMES.push_back( name.substr(name.find(' ')+1) ); + classNames.push_back( name.substr(name.find(' ')+1) ); } fp.close(); + return classNames; } +/* Create batch from the image */ +dnn::Blob makeInputBlob(const String &imagefile) +{ + Mat img = imread(imagefile); + if (img.empty()) + { + std::cerr << "Can't read image from file:" << std::endl; + std::cerr << imagefile << std::endl; + exit(-1); + } + + cvtColor(img, img, COLOR_BGR2RGB); + resize(img, img, Size(227, 227)); + + return dnn::Blob(img); //construct 4-dim Blob (i. e. batch) +} int main(int argc, char **argv) { - Net net; + /* Initialize network */ + dnn::Net net; { - Ptr importer = createCaffeImporter("bvlc_googlenet.prototxt", "bvlc_googlenet.caffemodel"); + String modelTxt = "bvlc_googlenet.prototxt"; + String modelBin = "bvlc_googlenet.caffemodel"; + + Ptr importer; //Try to import Caffe GoogleNet model + try + { + importer = dnn::createCaffeImporter(modelTxt, modelBin); + } + catch(const cv::Exception &er) //importer can throw errors, we will catch them + { + std::cerr << er.msg << std::endl; + importer = Ptr(); //NULL + } + + if (!importer) + { + std::cerr << "Can't load network by using the following files: " << std::endl; + std::cerr << "prototxt: " << modelTxt << std::endl; + std::cerr << "caffemodel: " << modelBin << std::endl; + std::cerr << "Please, check them." << std::endl; + std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl; + std::cerr << "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel" << std::endl; + exit(-1); + } + importer->populateNet(net); } - String filename = (argc > 1) ? argv[1] : "space_shuttle.jpg"; + std::vector classNames = readClassNames(); - Mat img = imread(filename); - CV_Assert(!img.empty()); - cvtColor(img, img, COLOR_BGR2RGB); - resize(img, img, Size(227, 227)); - Blob imgBlob(img); + String filename = (argc > 1) ? argv[1] : "space_shuttle.jpg"; - net.setBlob(".data", imgBlob); - net.forward(); + Blob inputBlob = makeInputBlob(filename); //make batch + net.setBlob(".data", inputBlob); //set this blob to the network input + net.forward(); //compute output - Blob prob = net.getBlob("prob"); - ClassProb bc = getMaxClass(prob); + dnn::Blob prob = net.getBlob("prob"); //gather output of prob layer + ClassProb bc = getMaxClass(prob); //find best class - initClassesNames(); - std::string className = (bc.first < (int)CLASSES_NAMES.size()) ? CLASSES_NAMES[bc.first] : "unnamed"; + String className = classNames.at(bc.first); std::cout << "Best class:"; std::cout << " #" << bc.first;