Repository for OpenCV's extra modules
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

176 lines
4.9 KiB

/*
Sample of using OpenCV dnn module with Torch ENet model.
*/
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace cv;
using namespace cv::dnn;
#include <fstream>
#include <iostream>
#include <cstdlib>
#include <sstream>
using namespace std;
const String keys =
"{help h || Sample app for loading ENet Torch model. "
"The model and class names list can be downloaded here: "
"https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa }"
"{model m || path to Torch .net model file (model_best.net) }"
"{image i || path to image file }"
"{i_blob | .0 | input blob name) }"
"{o_blob || output blob name) }"
"{c_names c || path to file with classnames for channels (categories.txt) }"
"{result r || path to save output blob (optional, binary format, NCHW order) }"
;
std::vector<String> readClassNames(const char *filename);
int main(int argc, char **argv)
{
cv::CommandLineParser parser(argc, argv, keys);
if (parser.has("help"))
{
parser.printMessage();
return 0;
}
String modelFile = parser.get<String>("model");
String imageFile = parser.get<String>("image");
String inBlobName = parser.get<String>("i_blob");
String outBlobName = parser.get<String>("o_blob");
if (!parser.check())
{
parser.printErrors();
return 0;
}
String classNamesFile = parser.get<String>("c_names");
String resultFile = parser.get<String>("result");
//! [Create the importer of TensorFlow model]
Ptr<dnn::Importer> importer;
try //Try to import TensorFlow AlexNet model
{
importer = dnn::createTorchImporter(modelFile);
}
catch (const cv::Exception &err) //Importer can throw errors, we will catch them
{
std::cerr << err.msg << std::endl;
}
//! [Create the importer of Caffe model]
if (!importer)
{
std::cerr << "Can't load network by using the mode file: " << std::endl;
std::cerr << modelFile << std::endl;
exit(-1);
}
//! [Initialize network]
dnn::Net net;
importer->populateNet(net);
importer.release(); //We don't need importer anymore
//! [Initialize network]
//! [Prepare blob]
Mat img = imread(imageFile);
if (img.empty())
{
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
exit(-1);
}
cv::Size inputImgSize = cv::Size(512, 512);
if (inputImgSize != img.size())
resize(img, img, inputImgSize); //Resize image to input size
if(img.channels() == 3)
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
img.convertTo(img, CV_32F, 1/255.0);
dnn::Blob inputBlob = dnn::Blob::fromImages(img); //Convert Mat to dnn::Blob image batch
//! [Prepare blob]
//! [Set input blob]
net.setBlob(inBlobName, inputBlob); //set the network input
//! [Set input blob]
cv::TickMeter tm;
tm.start();
//! [Make forward pass]
net.forward(); //compute output
//! [Make forward pass]
tm.stop();
//! [Gather output]
dnn::Blob prob = net.getBlob(outBlobName); //gather output of "prob" layer
Mat& result = prob.matRef();
BlobShape shape = prob.shape();
if (!resultFile.empty()) {
CV_Assert(result.isContinuous());
ofstream fout(resultFile.c_str(), ios::out | ios::binary);
fout.write((char*)result.data, result.total() * sizeof(float));
fout.close();
}
std::cout << "Output blob shape " << shape << std::endl;
std::cout << "Inference time, ms: " << tm.getTimeMilli() << std::endl;
std::vector<String> classNames;
if(!classNamesFile.empty()) {
classNames = readClassNames(classNamesFile.c_str());
if (classNames.size() > prob.channels())
classNames = std::vector<String>(classNames.begin() + classNames.size() - prob.channels(),
classNames.end());
}
for(int i_c = 0; i_c < prob.channels(); i_c++) {
ostringstream convert;
convert << "Channel #" << i_c;
if(classNames.size() == prob.channels())
convert << ": " << classNames[i_c];
imshow(convert.str().c_str(), prob.getPlane(0, i_c));
}
waitKey();
return 0;
} //main
std::vector<String> readClassNames(const char *filename)
{
std::vector<String> classNames;
std::ifstream fp(filename);
if (!fp.is_open())
{
std::cerr << "File with classes labels not found: " << filename << std::endl;
exit(-1);
}
std::string name;
while (!fp.eof())
{
std::getline(fp, name);
if (name.length())
classNames.push_back(name);
}
fp.close();
return classNames;
}