Remove cv::dnn::Importer

pull/10347/head
Dmitry Kurtaev 7 years ago
parent d3a124c820
commit 6aabd6cc7a
  1. 12
      modules/dnn/CMakeLists.txt
  2. 2
      modules/dnn/include/opencv2/dnn/all_layers.hpp
  3. 76
      modules/dnn/include/opencv2/dnn/dnn.hpp
  4. 20
      modules/dnn/perf/perf_caffe.cpp
  5. 20
      modules/dnn/perf/perf_net.cpp
  6. 13
      modules/dnn/src/caffe/caffe_importer.cpp
  7. 8
      modules/dnn/src/darknet/darknet_importer.cpp
  8. 2
      modules/dnn/src/dnn.cpp
  9. 16
      modules/dnn/src/tensorflow/tf_importer.cpp
  10. 2
      modules/dnn/src/torch/THDiskFile.cpp
  11. 2
      modules/dnn/src/torch/THFile.cpp
  12. 2
      modules/dnn/src/torch/THFile.h
  13. 3
      modules/dnn/src/torch/THGeneral.cpp
  14. 41
      modules/dnn/src/torch/torch_importer.cpp
  15. 9
      modules/dnn/test/test_torch_importer.cpp

@ -97,15 +97,3 @@ if(BUILD_PERF_TESTS)
endif()
endif()
endif()
# ----------------------------------------------------------------------------
# Torch7 importer of blobs and models, produced by Torch.nn module
# ----------------------------------------------------------------------------
OCV_OPTION(${the_module}_BUILD_TORCH_IMPORTER "Build Torch model importer" ON)
if(${the_module}_BUILD_TORCH_IMPORTER)
message(STATUS "Torch importer has been enabled. To run the tests you have to install Torch "
"('th' executable should be available) "
"and generate testdata using opencv_extra/testdata/dnn/generate_torch_models.py script.")
add_definitions(-DENABLE_TORCH_IMPORTER=1)
ocv_warnings_disable(CMAKE_CXX_FLAGS /wd4702 /wd4127 /wd4267) #supress warnings in original torch files
endif()

@ -58,7 +58,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
You can use both API, but factory API is less convinient for native C++ programming and basically designed for use inside importers (see @ref readNetFromCaffe(), @ref readNetFromTorch(), @ref readNetFromTensorflow()).
Bult-in layers partially reproduce functionality of corresponding Caffe and Torch7 layers.
In partuclar, the following layers and Caffe @ref Importer were tested to reproduce <a href="http://caffe.berkeleyvision.org/tutorial/layers.html">Caffe</a> functionality:
In partuclar, the following layers and Caffe importer were tested to reproduce <a href="http://caffe.berkeleyvision.org/tutorial/layers.html">Caffe</a> functionality:
- Convolution
- Deconvolution
- Pooling

@ -426,15 +426,6 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
void forward(std::vector<std::vector<Mat> >& outputBlobs,
const std::vector<String>& outBlobNames);
//TODO:
/** @brief Optimized forward.
* @warning Not implemented yet.
* @details Makes forward only those layers which weren't changed after previous forward().
*/
void forwardOpt(LayerId toLayer);
/** @overload */
void forwardOpt(const std::vector<LayerId> &toLayers);
/**
* @brief Compile Halide layers.
* @param[in] scheduler Path to YAML file with scheduling directives.
@ -609,38 +600,18 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
Ptr<Impl> impl;
};
/**
* @deprecated Deprecated as external interface. Will be for internal needs only.
* @brief Small interface class for loading trained serialized models of different dnn-frameworks. */
class CV_EXPORTS_W Importer : public Algorithm
{
public:
/** @brief Adds loaded layers into the @p net and sets connections between them. */
CV_DEPRECATED CV_WRAP virtual void populateNet(Net net) = 0;
virtual ~Importer();
};
/** @brief Reads a network model stored in <a href="https://pjreddie.com/darknet/">Darknet</a> model files.
* @param cfgFile path to the .cfg file with text description of the network architecture.
* @param darknetModel path to the .weights file with learned network.
* @returns Network object that ready to do forward, throw an exception in failure cases.
* @details This is shortcut consisting from DarknetImporter and Net::populateNet calls.
* @returns Net object.
*/
CV_EXPORTS_W Net readNetFromDarknet(const String &cfgFile, const String &darknetModel = String());
/**
* @deprecated Use @ref readNetFromCaffe instead.
* @brief Creates the importer of <a href="http://caffe.berkeleyvision.org">Caffe</a> framework network.
* @param prototxt path to the .prototxt file with text description of the network architecture.
* @param caffeModel path to the .caffemodel file with learned network.
* @returns Pointer to the created importer, NULL in failure cases.
*/
CV_DEPRECATED CV_EXPORTS_W Ptr<Importer> createCaffeImporter(const String &prototxt, const String &caffeModel = String());
/** @brief Reads a network model stored in Caffe model files.
* @details This is shortcut consisting from createCaffeImporter and Net::populateNet calls.
/** @brief Reads a network model stored in <a href="http://caffe.berkeleyvision.org">Caffe</a> framework's format.
* @param prototxt path to the .prototxt file with text description of the network architecture.
* @param caffeModel path to the .caffemodel file with learned network.
* @returns Net object.
*/
CV_EXPORTS_W Net readNetFromCaffe(const String &prototxt, const String &caffeModel = String());
@ -651,16 +622,21 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
* @param lenProto length of bufferProto
* @param bufferModel buffer containing the content of the .caffemodel file
* @param lenModel length of bufferModel
* @returns Net object.
*/
CV_EXPORTS Net readNetFromCaffe(const char *bufferProto, size_t lenProto,
const char *bufferModel = NULL, size_t lenModel = 0);
/** @brief Reads a network model stored in Tensorflow model file.
* @details This is shortcut consisting from createTensorflowImporter and Net::populateNet calls.
/** @brief Reads a network model stored in <a href="https://www.tensorflow.org/">TensorFlow</a> framework's format.
* @param model path to the .pb file with binary protobuf description of the network architecture
* @param config path to the .pbtxt file that contains text graph definition in protobuf format.
* Resulting Net object is built by text graph using weights from a binary one that
* let us make it more flexible.
* @returns Net object.
*/
CV_EXPORTS_W Net readNetFromTensorflow(const String &model, const String &config = String());
/** @brief Reads a network model stored in Tensorflow model in memory.
/** @brief Reads a network model stored in <a href="https://www.tensorflow.org/">TensorFlow</a> framework's format.
* @details This is an overloaded member function, provided for convenience.
* It differs from the above function only in what argument(s) it accepts.
* @param bufferModel buffer containing the content of the pb file
@ -671,27 +647,11 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
CV_EXPORTS Net readNetFromTensorflow(const char *bufferModel, size_t lenModel,
const char *bufferConfig = NULL, size_t lenConfig = 0);
/** @brief Reads a network model stored in Torch model file.
* @details This is shortcut consisting from createTorchImporter and Net::populateNet calls.
*/
CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true);
/**
* @deprecated Use @ref readNetFromTensorflow instead.
* @brief Creates the importer of <a href="http://www.tensorflow.org">TensorFlow</a> framework network.
* @param model path to the .pb file with binary protobuf description of the network architecture.
* @returns Pointer to the created importer, NULL in failure cases.
*/
CV_DEPRECATED CV_EXPORTS_W Ptr<Importer> createTensorflowImporter(const String &model);
/**
* @deprecated Use @ref readNetFromTorch instead.
* @brief Creates the importer of <a href="http://torch.ch">Torch7</a> framework network.
* @param filename path to the file, dumped from Torch by using torch.save() function.
* @brief Reads a network model stored in <a href="http://torch.ch">Torch7</a> framework's format.
* @param model path to the file, dumped from Torch by using torch.save() function.
* @param isBinary specifies whether the network was serialized in ascii mode or binary.
* @returns Pointer to the created importer, NULL in failure cases.
*
* @warning Torch7 importer is experimental now, you need explicitly set CMake `opencv_dnn_BUILD_TORCH_IMPORTER` flag to compile its.
* @returns Net object.
*
* @note Ascii mode of Torch serializer is more preferable, because binary mode extensively use `long` type of C language,
* which has various bit-length on different systems.
@ -712,10 +672,10 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
*
* Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported.
*/
CV_DEPRECATED CV_EXPORTS_W Ptr<Importer> createTorchImporter(const String &filename, bool isBinary = true);
CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true);
/** @brief Loads blob which was serialized as torch.Tensor object of Torch7 framework.
* @warning This function has the same limitations as createTorchImporter().
* @warning This function has the same limitations as readNetFromTorch().
*/
CV_EXPORTS_W Mat readTorchBlob(const String &filename, bool isBinary = true);
/** @brief Creates 4-dimensional blob from image. Optionally resizes and crops @p image from center,

@ -68,18 +68,18 @@ static caffe::Net<float>* initNet(std::string proto, std::string weights)
return net;
}
PERF_TEST(GoogLeNet_caffe, CaffePerfTest)
PERF_TEST(AlexNet_caffe, CaffePerfTest)
{
caffe::Net<float>* net = initNet("dnn/bvlc_googlenet.prototxt",
"dnn/bvlc_googlenet.caffemodel");
caffe::Net<float>* net = initNet("dnn/bvlc_alexnet.prototxt",
"dnn/bvlc_alexnet.caffemodel");
TEST_CYCLE() net->Forward();
SANITY_CHECK_NOTHING();
}
PERF_TEST(AlexNet_caffe, CaffePerfTest)
PERF_TEST(GoogLeNet_caffe, CaffePerfTest)
{
caffe::Net<float>* net = initNet("dnn/bvlc_alexnet.prototxt",
"dnn/bvlc_alexnet.caffemodel");
caffe::Net<float>* net = initNet("dnn/bvlc_googlenet.prototxt",
"dnn/bvlc_googlenet.caffemodel");
TEST_CYCLE() net->Forward();
SANITY_CHECK_NOTHING();
}
@ -100,6 +100,14 @@ PERF_TEST(SqueezeNet_v1_1_caffe, CaffePerfTest)
SANITY_CHECK_NOTHING();
}
PERF_TEST(MobileNet_SSD, CaffePerfTest)
{
caffe::Net<float>* net = initNet("dnn/MobileNetSSD_deploy.prototxt",
"dnn/MobileNetSSD_deploy.caffemodel");
TEST_CYCLE() net->Forward();
SANITY_CHECK_NOTHING();
}
} // namespace cvtest
#endif // HAVE_CAFFE

@ -70,7 +70,7 @@ public:
}
else if (framework == "tensorflow")
{
net = cv::dnn::readNetFromTensorflow(weights);
net = cv::dnn::readNetFromTensorflow(weights, proto);
}
else
CV_Error(Error::StsNotImplemented, "Unknown framework " + framework);
@ -148,6 +148,24 @@ PERF_TEST_P_(DNNTestNetwork, SSD)
Mat(cv::Size(300, 300), CV_32FC3), "detection_out", "caffe");
}
PERF_TEST_P_(DNNTestNetwork, OpenFace)
{
processNet("dnn/openface_nn4.small2.v1.t7", "", "",
Mat(cv::Size(96, 96), CV_32FC3), "", "torch");
}
PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_Caffe)
{
processNet("dnn/MobileNetSSD_deploy.caffemodel", "dnn/MobileNetSSD_deploy.prototxt", "",
Mat(cv::Size(300, 300), CV_32FC3), "detection_out", "caffe");
}
PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_TensorFlow)
{
processNet("dnn/ssd_mobilenet_v1_coco.pb", "ssd_mobilenet_v1_coco.pbtxt", "",
Mat(cv::Size(300, 300), CV_32FC3), "", "tensorflow");
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, DNNTestNetwork,
testing::Combine(
::testing::Values(TEST_DNN_BACKEND),

@ -75,7 +75,7 @@ static cv::String toString(const T &v)
return ss.str();
}
class CaffeImporter : public Importer
class CaffeImporter
{
caffe::NetParameter net;
caffe::NetParameter netBinary;
@ -390,21 +390,10 @@ public:
dstNet.connect(addedBlobs[idx].layerId, addedBlobs[idx].outNum, layerId, inNum);
}
~CaffeImporter()
{
}
};
}
Ptr<Importer> createCaffeImporter(const String &prototxt, const String &caffeModel)
{
return Ptr<Importer>(new CaffeImporter(prototxt.c_str(), caffeModel.c_str()));
}
Net readNetFromCaffe(const String &prototxt, const String &caffeModel /*= String()*/)
{
CaffeImporter caffeImporter(prototxt.c_str(), caffeModel.c_str());

@ -58,7 +58,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
namespace
{
class DarknetImporter : public Importer
class DarknetImporter
{
darknet::NetParameter net;
@ -173,12 +173,6 @@ public:
dstNet.connect(addedBlobs[idx].layerId, addedBlobs[idx].outNum, layerId, inNum);
}
~DarknetImporter()
{
}
};
}

@ -2294,8 +2294,6 @@ int64 Net::getPerfProfile(std::vector<double>& timings)
//////////////////////////////////////////////////////////////////////////
Importer::~Importer() {}
Layer::Layer() { preferableTarget = DNN_TARGET_CPU; }
Layer::Layer(const LayerParams &params)

@ -446,14 +446,13 @@ void ExcludeLayer(tensorflow::GraphDef& net, const int layer_index, const int in
net.mutable_node()->DeleteSubrange(layer_index, 1);
}
class TFImporter : public Importer {
class TFImporter {
public:
TFImporter(const char *model, const char *config = NULL);
TFImporter(const char *dataModel, size_t lenModel,
const char *dataConfig = NULL, size_t lenConfig = 0);
void populateNet(Net dstNet);
~TFImporter() {}
private:
void kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob);
@ -1315,19 +1314,6 @@ void TFImporter::populateNet(Net dstNet)
} // namespace
Ptr<Importer> createTensorflowImporter(const String &model)
{
return Ptr<Importer>(new TFImporter(model.c_str()));
}
#else //HAVE_PROTOBUF
Ptr<Importer> createTensorflowImporter(const String&)
{
CV_Error(cv::Error::StsNotImplemented, "libprotobuf required to import data from TensorFlow models");
return Ptr<Importer>();
}
#endif //HAVE_PROTOBUF
Net readNetFromTensorflow(const String &model, const String &config)

@ -1,5 +1,4 @@
#include "../precomp.hpp"
#if defined(ENABLE_TORCH_IMPORTER) && ENABLE_TORCH_IMPORTER
#include "THGeneral.h"
#include "THDiskFile.h"
#include "THFilePrivate.h"
@ -517,4 +516,3 @@ THFile *THDiskFile_new(const std::string &name, const char *mode, int isQuiet)
}
}
#endif

@ -1,5 +1,4 @@
#include "../precomp.hpp"
#if defined(ENABLE_TORCH_IMPORTER) && ENABLE_TORCH_IMPORTER
#include "THFile.h"
#include "THFilePrivate.h"
@ -119,4 +118,3 @@ IMPLEMENT_THFILE_SCALAR(Float, float)
IMPLEMENT_THFILE_SCALAR(Double, double)
} // namespace
#endif

@ -2,7 +2,6 @@
#define TH_FILE_INC
//#include "THStorage.h"
#if defined(ENABLE_TORCH_IMPORTER) && ENABLE_TORCH_IMPORTER
#include "opencv2/core/hal/interface.h"
#include "THGeneral.h"
@ -51,5 +50,4 @@ TH_API long THFile_position(THFile *self);
TH_API void THFile_close(THFile *self);
TH_API void THFile_free(THFile *self);
} // namespace
#endif //defined(ENABLE_TORCH_IMPORTER) && ENABLE_TORCH_IMPORTER
#endif //TH_FILE_INC

@ -1,5 +1,4 @@
#include "../precomp.hpp"
#if defined(ENABLE_TORCH_IMPORTER) && ENABLE_TORCH_IMPORTER
#if defined(TH_DISABLE_HEAP_TRACKING)
#elif (defined(__unix) || defined(_WIN32))
@ -9,5 +8,3 @@
#endif
#include "THGeneral.h"
#endif

@ -47,15 +47,12 @@
#include <iostream>
#include <fstream>
#if defined(ENABLE_TORCH_IMPORTER) && ENABLE_TORCH_IMPORTER
#include "THDiskFile.h"
#endif
namespace cv {
namespace dnn {
CV__DNN_EXPERIMENTAL_NS_BEGIN
#if defined(ENABLE_TORCH_IMPORTER) && ENABLE_TORCH_IMPORTER
using namespace TH;
//#ifdef NDEBUG
@ -95,7 +92,7 @@ static inline bool endsWith(const String &str, const char *substr)
return str.rfind(substr) == str.length() - strlen(substr);
}
struct TorchImporter : public ::cv::dnn::Importer
struct TorchImporter
{
typedef std::map<String, std::pair<int, Mat> > TensorsMap;
Net net;
@ -1191,19 +1188,13 @@ struct TorchImporter : public ::cv::dnn::Importer
}
};
Ptr<Importer> createTorchImporter(const String &filename, bool isBinary)
{
return Ptr<Importer>(new TorchImporter(filename, isBinary));
}
Mat readTorchBlob(const String &filename, bool isBinary)
{
Ptr<TorchImporter> importer(new TorchImporter(filename, isBinary));
importer->readObject();
CV_Assert(importer->tensors.size() == 1);
TorchImporter importer(filename, isBinary);
importer.readObject();
CV_Assert(importer.tensors.size() == 1);
return importer->tensors.begin()->second;
return importer.tensors.begin()->second;
}
Net readNetFromTorch(const String &model, bool isBinary)
@ -1216,27 +1207,5 @@ Net readNetFromTorch(const String &model, bool isBinary)
return net;
}
#else
Ptr<Importer> createTorchImporter(const String&, bool)
{
CV_Error(Error::StsNotImplemented, "Torch importer is disabled in current build");
return Ptr<Importer>();
}
Mat readTorchBlob(const String&, bool)
{
CV_Error(Error::StsNotImplemented, "Torch importer is disabled in current build");
return Mat();
}
Net readNetFromTorch(const String &model, bool isBinary)
{
CV_Error(Error::StsNotImplemented, "Torch importer is disabled in current build");
return Net();
}
#endif //defined(ENABLE_TORCH_IMPORTER) && ENABLE_TORCH_IMPORTER
CV__DNN_EXPERIMENTAL_NS_END
}} // namespace

@ -39,8 +39,6 @@
//
//M*/
#ifdef ENABLE_TORCH_IMPORTER
#include "test_precomp.hpp"
#include "npy_blob.hpp"
#include <opencv2/dnn/shape_utils.hpp>
@ -316,9 +314,8 @@ OCL_TEST(Torch_Importer, ENet_accuracy)
Net net;
{
const string model = findDataFile("dnn/Enet-model-best.net", false);
Ptr<Importer> importer = createTorchImporter(model, true);
ASSERT_TRUE(importer != NULL);
importer->populateNet(net);
net = readNetFromTorch(model, true);
ASSERT_TRUE(!net.empty());
}
net.setPreferableBackend(DNN_BACKEND_DEFAULT);
@ -421,5 +418,3 @@ OCL_TEST(Torch_Importer, FastNeuralStyle_accuracy)
}
}
#endif

Loading…
Cancel
Save