mirror of https://github.com/opencv/opencv.git
Open Source Computer Vision Library
https://opencv.org/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
179 lines
6.3 KiB
179 lines
6.3 KiB
/**M/////////////////////////////////////////////////////////////////////////////////////// |
|
// |
|
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. |
|
// |
|
// By downloading, copying, installing or using the software you agree to this license. |
|
// If you do not agree to this license, do not download, install, |
|
// copy or use the software. |
|
// |
|
// |
|
// License Agreement |
|
// For Open Source Computer Vision Library |
|
// |
|
// Copyright (C) 2013, OpenCV Foundation, all rights reserved. |
|
// Third party copyrights are property of their respective owners. |
|
// |
|
// Redistribution and use in source and binary forms, with or without modification, |
|
// are permitted provided that the following conditions are met: |
|
// |
|
// * Redistribution's of source code must retain the above copyright notice, |
|
// this list of conditions and the following disclaimer. |
|
// |
|
// * Redistribution's in binary form must reproduce the above copyright notice, |
|
// this list of conditions and the following disclaimer in the documentation |
|
// and/or other materials provided with the distribution. |
|
// |
|
// * The name of the copyright holders may not be used to endorse or promote products |
|
// derived from this software without specific prior written permission. |
|
// |
|
// This software is provided by the copyright holders and contributors "as is" and |
|
// any express or implied warranties, including, but not limited to, the implied |
|
// warranties of merchantability and fitness for a particular purpose are disclaimed. |
|
// In no event shall the Intel Corporation or contributors be liable for any direct, |
|
// indirect, incidental, special, exemplary, or consequential damages |
|
// (including, but not limited to, procurement of substitute goods or services; |
|
// loss of use, data, or profits; or business interruption) however caused |
|
// and on any theory of liability, whether in contract, strict liability, |
|
// or tort (including negligence or otherwise) arising in any way out of |
|
// the use of this software, even if advised of the possibility of such damage. |
|
// |
|
//M*/ |
|
#include <opencv2/dnn.hpp> |
|
#include <opencv2/imgproc.hpp> |
|
#include <opencv2/highgui.hpp> |
|
#include <opencv2/core/utils/trace.hpp> |
|
using namespace cv; |
|
using namespace cv::dnn; |
|
|
|
#include <fstream> |
|
#include <iostream> |
|
#include <cstdlib> |
|
using namespace std; |
|
|
|
/* Find best class for the blob (i. e. class with maximal probability) */ |
|
static void getMaxClass(const Mat &probBlob, int *classId, double *classProb) |
|
{ |
|
Mat probMat = probBlob.reshape(1, 1); //reshape the blob to 1x1000 matrix |
|
Point classNumber; |
|
|
|
minMaxLoc(probMat, NULL, classProb, NULL, &classNumber); |
|
*classId = classNumber.x; |
|
} |
|
|
|
static std::vector<String> readClassNames(const char *filename = "synset_words.txt") |
|
{ |
|
std::vector<String> classNames; |
|
|
|
std::ifstream fp(filename); |
|
if (!fp.is_open()) |
|
{ |
|
std::cerr << "File with classes labels not found: " << filename << std::endl; |
|
exit(-1); |
|
} |
|
|
|
std::string name; |
|
while (!fp.eof()) |
|
{ |
|
std::getline(fp, name); |
|
if (name.length()) |
|
classNames.push_back( name.substr(name.find(' ')+1) ); |
|
} |
|
|
|
fp.close(); |
|
return classNames; |
|
} |
|
|
|
const char* params |
|
= "{ help | false | Sample app for loading googlenet model }" |
|
"{ proto | bvlc_googlenet.prototxt | model configuration }" |
|
"{ model | bvlc_googlenet.caffemodel | model weights }" |
|
"{ image | space_shuttle.jpg | path to image file }" |
|
"{ opencl | false | enable OpenCL }" |
|
; |
|
|
|
int main(int argc, char **argv) |
|
{ |
|
CV_TRACE_FUNCTION(); |
|
|
|
CommandLineParser parser(argc, argv, params); |
|
|
|
if (parser.get<bool>("help")) |
|
{ |
|
parser.printMessage(); |
|
return 0; |
|
} |
|
|
|
String modelTxt = parser.get<string>("proto"); |
|
String modelBin = parser.get<string>("model"); |
|
String imageFile = parser.get<String>("image"); |
|
|
|
Net net; |
|
try { |
|
//! [Read and initialize network] |
|
net = dnn::readNetFromCaffe(modelTxt, modelBin); |
|
//! [Read and initialize network] |
|
} |
|
catch (cv::Exception& e) { |
|
std::cerr << "Exception: " << e.what() << std::endl; |
|
//! [Check that network was read successfully] |
|
if (net.empty()) |
|
{ |
|
std::cerr << "Can't load network by using the following files: " << std::endl; |
|
std::cerr << "prototxt: " << modelTxt << std::endl; |
|
std::cerr << "caffemodel: " << modelBin << std::endl; |
|
std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl; |
|
std::cerr << "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel" << std::endl; |
|
exit(-1); |
|
} |
|
//! [Check that network was read successfully] |
|
} |
|
|
|
if (parser.get<bool>("opencl")) |
|
{ |
|
net.setPreferableTarget(DNN_TARGET_OPENCL); |
|
} |
|
|
|
//! [Prepare blob] |
|
Mat img = imread(imageFile); |
|
if (img.empty()) |
|
{ |
|
std::cerr << "Can't read image from the file: " << imageFile << std::endl; |
|
exit(-1); |
|
} |
|
|
|
//GoogLeNet accepts only 224x224 BGR-images |
|
Mat inputBlob = blobFromImage(img, 1.0f, Size(224, 224), |
|
Scalar(104, 117, 123), false); //Convert Mat to batch of images |
|
//! [Prepare blob] |
|
net.setInput(inputBlob, "data"); //set the network input |
|
Mat prob = net.forward("prob"); //compute output |
|
|
|
cv::TickMeter t; |
|
for (int i = 0; i < 10; i++) |
|
{ |
|
CV_TRACE_REGION("forward"); |
|
//! [Set input blob] |
|
net.setInput(inputBlob, "data"); //set the network input |
|
//! [Set input blob] |
|
t.start(); |
|
//! [Make forward pass] |
|
prob = net.forward("prob"); //compute output |
|
//! [Make forward pass] |
|
t.stop(); |
|
} |
|
|
|
//! [Gather output] |
|
int classId; |
|
double classProb; |
|
getMaxClass(prob, &classId, &classProb);//find the best class |
|
//! [Gather output] |
|
|
|
//! [Print results] |
|
std::vector<String> classNames = readClassNames(); |
|
std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl; |
|
std::cout << "Probability: " << classProb * 100 << "%" << std::endl; |
|
//! [Print results] |
|
std::cout << "Time: " << (double)t.getTimeMilli() / t.getCounter() << " ms (average from " << t.getCounter() << " iterations)" << std::endl; |
|
|
|
return 0; |
|
} //main
|
|
|