Merge branch 'master' of github.com:ludv1x/opencv_contrib

pull/265/head
Vitaliy Lyudvichenko 10 years ago
commit dd15521860
  1. 11
      modules/dnn/CMakeLists.txt
  2. 142
      modules/dnn/include/opencv2/dnn/blob.hpp
  3. 284
      modules/dnn/include/opencv2/dnn/blob.inl.hpp
  4. 244
      modules/dnn/include/opencv2/dnn/dict.hpp
  5. 117
      modules/dnn/include/opencv2/dnn/dnn.hpp
  6. 122
      modules/dnn/include/opencv2/dnn/dnn.inl.hpp
  7. 25
      modules/dnn/samples/classify_with_googlenet.cpp
  8. BIN
      modules/dnn/samples/space_shuttle.jpg
  9. 1000
      modules/dnn/samples/synset_words.txt
  10. 79
      modules/dnn/scripts/download_model.py
  11. 7
      modules/dnn/scripts/test_models.json
  12. 180
      modules/dnn/src/blob.cpp
  13. 4
      modules/dnn/src/caffe/glog_emulator.hpp
  14. 224
      modules/dnn/src/caffe_importer.cpp
  15. 487
      modules/dnn/src/dnn.cpp
  16. 10
      modules/dnn/src/layers/blank_layer.cpp
  17. 73
      modules/dnn/src/layers/concat_layer.cpp
  18. 199
      modules/dnn/src/layers/convolution_layer.cpp
  19. 0
      modules/dnn/src/layers/deconvolution_layer.cpp
  20. 91
      modules/dnn/src/layers/elementwise_layers.cpp
  21. 46
      modules/dnn/src/layers/fully_connected_layer.cpp
  22. 74
      modules/dnn/src/layers/im2col.hpp
  23. 36
      modules/dnn/src/layers/lrn_layer.cpp
  24. 50
      modules/dnn/src/layers/pooling_layer.cpp
  25. 137
      modules/dnn/src/layers/reshape_layer.cpp
  26. 103
      modules/dnn/src/layers/slice_layer.cpp
  27. 18
      modules/dnn/src/layers/softmax_layer.cpp
  28. 58
      modules/dnn/src/layers/split_layer.cpp
  29. 247
      modules/dnn/test/cnpy.cpp
  30. 247
      modules/dnn/test/cnpy.h
  31. 24
      modules/dnn/test/npy_blob.hpp
  32. 40
      modules/dnn/test/test_alexnet.cpp
  33. 44
      modules/dnn/test/test_caffe_importer.cpp
  34. 24
      modules/dnn/test/test_common.hpp
  35. 41
      modules/dnn/test/test_googlenet.cpp
  36. 103
      modules/dnn/test/test_layers.cpp
  37. 1
      modules/dnn/test/test_precomp.hpp
  38. 13
      modules/dnn/testdata/dnn/bvlc_alexnet.prototxt
  39. 0
      modules/dnn/testdata/dnn/bvlc_googlenet.prototxt
  40. BIN
      modules/dnn/testdata/dnn/googlenet_0.jpg
  41. BIN
      modules/dnn/testdata/dnn/googlenet_1.jpg
  42. BIN
      modules/dnn/testdata/dnn/googlenet_prob.npy
  43. BIN
      modules/dnn/testdata/dnn/layers/blob.npy
  44. 21
      modules/dnn/testdata/dnn/layers/lrn_channels.prototxt
  45. BIN
      modules/dnn/testdata/dnn/layers/lrn_channels.prototxt.caffe.npy
  46. 22
      modules/dnn/testdata/dnn/layers/lrn_spatial.prototxt
  47. BIN
      modules/dnn/testdata/dnn/layers/lrn_spatial.prototxt.caffe.npy
  48. 77
      modules/dnn/testdata/dnn/layers/reshape_and_slice_routines.prototxt
  49. 15
      modules/dnn/testdata/dnn/layers/softmax.prototxt
  50. BIN
      modules/dnn/testdata/dnn/layers/softmax.prototxt.caffe.npy
  51. 4
      modules/dnn/testdata/dnn/sign_50.ppm

@ -33,9 +33,18 @@ ocv_module_include_directories(include src/caffe ${PROTOBUF_INCLUDE_DIR})
ocv_create_module(${PROTOBUF_LIBRARIES})
ocv_add_samples()
ocv_add_accuracy_tests()
ocv_add_perf_tests()
ocv_add_samples()
OCV_OPTION(${the_module}_DOWNLOAD_CAFFE_MODELS "Use GoogLeNet Caffe model for testing" ON IF BUILD_TESTS AND PYTHON2_EXECUTABLE AND DEFINED ENV{OPENCV_TEST_DATA_PATH})
if(BUILD_TESTS AND ${the_module}_DOWNLOAD_CAFFE_MODELS)
add_custom_command( TARGET opencv_test_${name} POST_BUILD
COMMAND ${PYTHON2_EXECUTABLE} download_model.py test_models.json
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/scripts )
else()
add_definitions(-DDISABLE_CAFFE_MODEL_TESTS=1)
endif()
else()#build as standalone module (for development purposes)

@ -0,0 +1,142 @@
#ifndef __OPENCV_DNN_DNN_BLOB_HPP__
#define __OPENCV_DNN_DNN_BLOB_HPP__
#include <opencv2/core.hpp>
#include <vector>
#include <ostream>
namespace cv
{
namespace dnn
{
struct BlobShape
{
explicit BlobShape(int ndims = 4, int fill = 1);
BlobShape(int num, int cn, int rows, int cols);
BlobShape(int ndims, const int *sizes);
BlobShape(const std::vector<int> &sizes);
template<int n>
BlobShape(const Vec<int, n> &shape);
int dims() const;
int size(int axis) const;
int &size(int axis);
//do the same as size()
int operator[](int axis) const;
int &operator[](int axis);
//same as size(), but size of non-existing dimensions equal to 1
int xsize(int axis) const;
ptrdiff_t total();
const int *ptr() const;
bool equal(const BlobShape &other) const;
private:
cv::AutoBuffer<int,4> sz;
};
bool operator== (const BlobShape &l, const BlobShape &r);
//maybe useless
CV_EXPORTS std::ostream &operator<< (std::ostream &stream, const BlobShape &shape);
/** @brief provides convenient methods for continuous n-dimensional array processing, dedicated for convolution neural networks
It's realized as wrapper over \ref cv::Mat and \ref cv::UMat and will support methods for CPU/GPU switching
*/
class CV_EXPORTS Blob
{
public:
explicit Blob();
/** @brief constucts 4-dimensional blob from input
* @param in 2-dimensional or 3-dimensional single-channel image (or vector from them)
* @param dstCn if specified force size of ouptut blob channel-dimension
*/
explicit Blob(InputArray in, int dstCn = -1);
void create(const BlobShape &shape, int type = CV_32F);
void fill(InputArray in);
void fill(const BlobShape &shape, int type, void *data, bool deepCopy = true);
Mat& getMatRef();
const Mat& getMatRef() const;
Mat getMat(int n, int cn);
//shape getters
///returns real count of blob dimensions
int dims() const;
/** @brief returns size of corresponding dimension (axis)
@param axis dimension index
Python-like indexing is supported, so \p axis can be negative, i. e. -1 is last dimension.
Supposed that size of non-existing dimensions equal to 1, so the method always finished.
*/
int xsize(int axis) const;
/** @brief returns size of corresponding dimension (axis)
@param axis dimension index
Python-like indexing is supported, so \p axis can be negative, i. e. -1 is last dimension.
@note Unlike ::xsize, if \p axis points to non-existing dimension then an error will be generated.
*/
int size(int axis) const;
/** @brief returns number of elements
@param startAxis starting axis (inverse indexing can be used)
@param endAxis ending (excluded) axis
@see ::canonicalAxis
*/
size_t total(int startAxis = 0, int endAxis = -1) const;
/** @brief converts axis index to canonical format (where 0 <= axis <= ::dims)
*/
int canonicalAxis(int axis) const;
/** @brief returns real shape of the blob
*/
BlobShape shape() const;
bool equalShape(const Blob &other) const;
//shape getters, oriented for 4-dim Blobs processing
int cols() const;
int rows() const;
int channels() const;
int num() const;
Size size2() const;
Vec4i shape4() const;
//CPU data pointer functions
int offset(int n = 0, int cn = 0, int row = 0, int col = 0) const;
uchar *ptrRaw(int n = 0, int cn = 0, int row = 0, int col = 0);
float *ptrf(int n = 0, int cn = 0, int row = 0, int col = 0);
template<typename TFloat>
TFloat *ptr(int n = 0, int cn = 0, int row = 0, int col = 0);
/** @brief share data with other blob and returns *this
@returns *this
*/
Blob &shareFrom(const Blob &blob);
/** @brief adjust blob shape to required (data reallocated if needed)
@returns *this
*/
Blob &reshape(const BlobShape &shape);
int type() const;
bool isFloat() const;
bool isDouble() const;
private:
const int *sizes() const;
Mat m;
};
}
}
#include "blob.inl.hpp"
#endif

@ -0,0 +1,284 @@
#ifndef __OPENCV_DNN_DNN_BLOB_INL_HPP__
#define __OPENCV_DNN_DNN_BLOB_INL_HPP__
#include "blob.hpp"
namespace cv
{
namespace dnn
{
inline BlobShape::BlobShape(int ndims, int fill) : sz( (size_t)std::max(ndims, 1) )
{
for (int i = 0; i < ndims; i++)
sz[i] = fill;
}
inline BlobShape::BlobShape(int ndims, const int *sizes) : sz( (size_t)std::max(ndims, 1) )
{
CV_Assert(ndims > 0);
for (int i = 0; i < ndims; i++)
sz[i] = sizes[i];
}
inline BlobShape::BlobShape(int num, int cn, int rows, int cols) : sz(4)
{
sz[0] = num;
sz[1] = cn;
sz[2] = rows;
sz[3] = cols;
}
inline BlobShape::BlobShape(const std::vector<int> &sizes) : sz( sizes.size() )
{
CV_Assert(sizes.size() > 0);
for (int i = 0; i < (int)sizes.size(); i++)
sz[i] = sizes[i];
}
template<int n>
inline BlobShape::BlobShape(const Vec<int, n> &shape) : sz(n)
{
for (int i = 0; i < n; i++)
sz[i] = shape[i];
}
inline int BlobShape::dims() const
{
return (int)sz.size();
}
inline int BlobShape::xsize(int axis) const
{
if (axis < -dims() || axis >= dims())
return 1;
return sz[(axis < 0) ? axis + dims() : axis];
}
inline int BlobShape::size(int axis) const
{
CV_Assert(-dims() <= axis && axis < dims());
return sz[(axis < 0) ? axis + dims() : axis];
}
inline int &BlobShape::size(int axis)
{
CV_Assert(-dims() <= axis && axis < dims());
return sz[(axis < 0) ? axis + dims() : axis];
}
inline int BlobShape::operator[] (int axis) const
{
CV_Assert(-dims() <= axis && axis < dims());
return sz[(axis < 0) ? axis + dims() : axis];
}
inline int &BlobShape::operator[] (int axis)
{
CV_Assert(-dims() <= axis && axis < dims());
return sz[(axis < 0) ? axis + dims() : axis];
}
inline ptrdiff_t BlobShape::total()
{
CV_Assert(dims() >= 1);
ptrdiff_t res = 1;
for (int i = 0; i < dims(); i++)
res *= sz[i];
return res;
}
inline const int *BlobShape::ptr() const
{
return sz;
}
inline bool BlobShape::equal(const BlobShape &other) const
{
if (this->dims() != other.dims())
return false;
for (int i = 0; i < other.dims(); i++)
{
if (sz[i] != other.sz[i])
return false;
}
return true;
}
inline bool operator== (const BlobShape &l, const BlobShape &r)
{
return l.equal(r);
}
inline int Blob::canonicalAxis(int axis) const
{
CV_Assert(-dims() <= axis && axis < dims());
if (axis < 0)
{
return dims() + axis;
}
return axis;
}
inline int Blob::dims() const
{
return m.dims;
}
inline int Blob::xsize(int axis) const
{
if (axis < -dims() || axis >= dims())
return 1;
return sizes()[(axis < 0) ? axis + dims() : axis];
}
inline int Blob::size(int axis) const
{
CV_Assert(-dims() <= axis && axis < dims());
return sizes()[(axis < 0) ? axis + dims() : axis];
}
inline size_t Blob::total(int startAxis, int endAxis) const
{
if (startAxis < 0)
startAxis += dims();
if (endAxis == -1)
endAxis = dims();
CV_Assert(0 <= startAxis && startAxis <= endAxis && endAxis <= dims());
size_t size = 1; //assume that blob isn't empty
for (int i = startAxis; i < endAxis; i++)
size *= (size_t)sizes()[i];
return size;
}
inline int Blob::offset(int n, int cn, int row, int col) const
{
CV_DbgAssert(0 <= n && n < num() && 0 <= cn && cn < channels() && 0 <= row && row < rows() && 0 <= col && col < cols());
return ((n*channels() + cn)*rows() + row)*cols() + col;
}
inline float *Blob::ptrf(int n, int cn, int row, int col)
{
CV_Assert(type() == CV_32F);
return (float*)m.data + offset(n, cn, row, col);
}
inline uchar *Blob::ptrRaw(int n, int cn, int row, int col)
{
return m.data + m.elemSize() * offset(n, cn, row, col);
}
template<typename TFloat>
inline TFloat* Blob::ptr(int n, int cn, int row, int col)
{
CV_Assert(type() == cv::DataDepth<TFloat>::value);
return (TFloat*) ptrRaw(n, cn, row, col);
}
inline BlobShape Blob::shape() const
{
return BlobShape(dims(), sizes());
}
inline bool Blob::equalShape(const Blob &other) const
{
if (this->dims() != other.dims())
return false;
for (int i = 0; i < dims(); i++)
{
if (this->sizes()[i] != other.sizes()[i])
return false;
}
return true;
}
inline Mat& Blob::getMatRef()
{
return m;
}
inline const Mat& Blob::getMatRef() const
{
return m;
}
inline Mat Blob::getMat(int n, int cn)
{
return Mat(rows(), cols(), m.type(), this->ptrRaw(n, cn));
}
inline int Blob::cols() const
{
return xsize(3);
}
inline int Blob::rows() const
{
return xsize(2);
}
inline int Blob::channels() const
{
return xsize(1);
}
inline int Blob::num() const
{
return xsize(0);
}
inline Size Blob::size2() const
{
return Size(cols(), rows());
}
inline int Blob::type() const
{
return m.depth();
}
inline bool Blob::isFloat() const
{
return (type() == CV_32F);
}
inline bool Blob::isDouble() const
{
return (type() == CV_32F);
}
inline const int * Blob::sizes() const
{
return &m.size[0];
}
inline Blob &Blob::shareFrom(const Blob &blob)
{
this->m = blob.m;
return *this;
}
inline Blob &Blob::reshape(const BlobShape &shape)
{
m = m.reshape(1, shape.dims(), shape.ptr());
return *this;
}
}
}
#endif

@ -1,5 +1,5 @@
#ifndef __OPENCV_DNN_DICT_HPP__
#define __OPENCV_DNN_DICT_HPP__
#ifndef __OPENCV_DNN_DNN_DICT_HPP__
#define __OPENCV_DNN_DNN_DICT_HPP__
#include <opencv2/core.hpp>
#include <map>
@ -11,41 +11,77 @@ namespace dnn
struct DictValue
{
int type;
union
{
int64 i;
double d;
bool b;
String *s;
};
DictValue(const DictValue &r);
DictValue(int p = 0) : type(cv::Param::INT), i(p) {}
DictValue(unsigned p) : type(cv::Param::INT), i(p) {}
DictValue(double p) : type(cv::Param::REAL), d(p) {}
DictValue(bool p) : type(cv::Param::BOOLEAN), b(p) {}
DictValue(const String &p) : type(cv::Param::STRING), s(new String(p)) {}
DictValue(const char *str) : type(cv::Param::STRING), s(new String(str)) {}
DictValue(int p = 0) : type(Param::INT), pi(new AutoBuffer<int64,1>) { (*pi)[0] = p; }
DictValue(unsigned p) : type(Param::INT), pi(new AutoBuffer<int64,1>) { (*pi)[0] = p; }
DictValue(double p) : type(Param::REAL), pd(new AutoBuffer<double,1>) { (*pd)[0] = p; }
DictValue(const String &p) : type(Param::STRING), ps(new AutoBuffer<String,1>) { (*ps)[0] = p; }
template<typename TypeIter>
static DictValue arrayInt(TypeIter begin, int size);
template<typename TypeIter>
static DictValue arrayReal(TypeIter begin, int size);
template<typename TypeIter>
static DictValue arrayString(TypeIter begin, int size);
template<typename T>
T get() const;
T get(int idx = -1) const;
int size() const;
bool isString() const;
bool isInt() const;
bool isString() const;
bool isReal() const;
DictValue &operator=(const DictValue &r);
~DictValue();
private:
protected:
int type;
union
{
AutoBuffer<int64, 1> *pi;
AutoBuffer<double, 1> *pd;
AutoBuffer<String, 1> *ps;
void *p;
};
DictValue(int _type, void *_p) : type(_type), p(_p) {}
void release();
};
template<typename TypeIter>
DictValue DictValue::arrayInt(TypeIter begin, int size)
{
DictValue res(Param::INT, new AutoBuffer<int64, 1>(size));
for (int j = 0; j < size; begin++, j++)
(*res.pi)[j] = *begin;
return res;
}
template<typename TypeIter>
DictValue DictValue::arrayReal(TypeIter begin, int size)
{
DictValue res(Param::REAL, new AutoBuffer<double, 1>(size));
for (int j = 0; j < size; begin++, j++)
(*res.pd)[j] = *begin;
return res;
}
template<typename TypeIter>
DictValue DictValue::arrayString(TypeIter begin, int size)
{
DictValue res(Param::STRING, new AutoBuffer<String, 1>(size));
for (int j = 0; j < size; begin++, j++)
(*res.ps)[j] = *begin;
return res;
}
class CV_EXPORTS Dict
{
//TODO: maybe this mechanism was realized somewhere in OpenCV?
typedef std::map<String, DictValue> _Dict;
_Dict dict;
@ -62,13 +98,18 @@ public:
return (i == dict.end()) ? NULL : &i->second;
}
template <typename T>
T get(const String &name) const
const DictValue &get(const String &name) const
{
_Dict::const_iterator i = dict.find(name);
if (i == dict.end())
CV_Error(cv::Error::StsBadArg, "Required argument \"" + name + "\" not found into dictionary");
return i->second.get<T>();
CV_Error(Error::StsBadArg, "Required argument \"" + name + "\" not found into dictionary");
return i->second;
}
template <typename T>
T get(const String &name) const
{
return this->get(name).get<T>();
}
template <typename T>
@ -94,92 +135,89 @@ public:
return value;
}
inline void print()
{
for (_Dict::const_iterator i = dict.begin(); i != dict.end(); i++)
{
std::cout << i->first << std::endl;
}
}
};
template<>
inline DictValue DictValue::get<DictValue>(int idx) const
{
CV_Assert(idx == -1);
return *this;
}
template<>
inline int DictValue::get<int>() const
inline int64 DictValue::get<int64>(int idx) const
{
CV_Assert(type == cv::Param::INT);
return (int)i;
CV_Assert(isInt());
CV_Assert(idx == -1 && pi->size() == 1 || idx >= 0 && idx < (int)pi->size());
return (*pi)[(idx == -1) ? 0 : idx];
}
template<>
inline unsigned DictValue::get<unsigned>() const
inline int DictValue::get<int>(int idx) const
{
CV_Assert(type == cv::Param::INT);
return (unsigned)i;
return (int)get<int64>(idx);
}
template<>
inline double DictValue::get<double>() const
inline unsigned DictValue::get<unsigned>(int idx) const
{
if (type == cv::Param::REAL)
return d;
else if (type == cv::Param::INT)
return (double)i;
else
{
CV_Assert(type == cv::Param::REAL || type == cv::Param::INT);
return 0;
}
return (unsigned)get<int64>(idx);
}
template<>
inline float DictValue::get<float>() const
inline bool DictValue::get<bool>(int idx) const
{
if (type == cv::Param::FLOAT)
return (float)d;
else if (type == cv::Param::INT)
return (float)i;
else
{
CV_Assert(type == cv::Param::FLOAT || type == cv::Param::INT);
return (float)0;
}
return (get<int64>(idx) != 0);
}
template<>
inline bool DictValue::get<bool>() const
inline double DictValue::get<double>(int idx) const
{
if (type == cv::Param::BOOLEAN)
if (type == Param::REAL)
{
return b;
CV_Assert(idx == -1 && pd->size() == 1 || idx >= 0 && idx < (int)pd->size());
return (*pd)[0];
}
else if (type == cv::Param::INT)
else if (type == Param::INT)
{
return i != 0;
CV_Assert(idx == -1 && pi->size() == 1 || idx >= 0 && idx < (int)pi->size());
return (double)(*pi)[0];;
}
else
{
CV_Assert(type == cv::Param::BOOLEAN || type == cv::Param::INT);
CV_Assert(isReal());
return 0;
}
}
template<>
inline String DictValue::get<String>() const
inline float DictValue::get<float>(int idx) const
{
CV_Assert(type == cv::Param::STRING);
return *s;
return (float)get<double>(idx);
}
template<>
inline String DictValue::get<String>(int idx) const
{
CV_Assert(isString());
CV_Assert(idx == -1 && ps->size() == 1 || idx >= 0 && idx < (int)ps->size());
return (*ps)[(idx == -1) ? 0 : idx];
}
inline void DictValue::release()
{
if (type == cv::Param::STRING && s != NULL)
switch (type)
{
delete s;
s = NULL;
case Param::INT:
delete pi;
break;
case Param::STRING:
delete ps;
break;
case Param::REAL:
delete pd;
break;
}
}
inline DictValue::~DictValue()
@ -192,33 +230,73 @@ inline DictValue & DictValue::operator=(const DictValue &r)
if (&r == this)
return *this;
if (r.type == Param::INT)
{
AutoBuffer<int64, 1> *tmp = new AutoBuffer<int64, 1>(*r.pi);
release();
//how to copy anonymous union without memcpy?
for (size_t i = 0; i < sizeof(*this); i++)
((uchar*)this)[i] = ((uchar*)&r)[i];
if (r.type == cv::Param::STRING)
pi = tmp;
}
else if (r.type == Param::STRING)
{
s = new String(*r.s);
AutoBuffer<String, 1> *tmp = new AutoBuffer<String, 1>(*r.ps);
release();
ps = tmp;
}
else if (r.type == Param::REAL)
{
AutoBuffer<double, 1> *tmp = new AutoBuffer<double, 1>(*r.pd);
release();
pd = tmp;
}
type = r.type;
return *this;
}
inline DictValue::DictValue(const DictValue &r)
{
*this = r;
type = r.type;
if (r.type == Param::INT)
pi = new AutoBuffer<int64, 1>(*r.pi);
else if (r.type == Param::STRING)
ps = new AutoBuffer<String, 1>(*r.ps);
else if (r.type == Param::REAL)
pd = new AutoBuffer<double, 1>(*r.pd);
}
inline bool DictValue::isString() const
{
return (type == cv::Param::STRING);
return (type == Param::STRING);
}
inline bool DictValue::isInt() const
{
return (type == cv::Param::INT);
return (type == Param::INT);
}
inline bool DictValue::isReal() const
{
return (type == Param::REAL || type == Param::INT);
}
inline int DictValue::size() const
{
switch (type)
{
case Param::INT:
return (int)pi->size();
break;
case Param::STRING:
return (int)ps->size();
break;
case Param::REAL:
return (int)pd->size();
break;
default:
return -1;
}
}
}

@ -7,62 +7,12 @@
#include <opencv2/core.hpp>
#include <opencv2/dnn/dict.hpp>
#include <opencv2/dnn/blob.hpp>
namespace cv
{
namespace dnn
{
class Layer;
class NetConfiguration;
class Net;
class Blob;
class LayerParams;
//wrapper over cv::Mat and cv::UMat
class CV_EXPORTS Blob
{
public:
explicit Blob();
explicit Blob(InputArray in);
void create(int ndims, const int *sizes, int type = CV_32F);
void create(Vec4i shape, int type = CV_32F);
void create(int num, int cn, int rows, int cols, int type = CV_32F);
void fill(InputArray in);
void fill(int ndims, const int *sizes, int type, void *data, bool deepCopy = true);
Mat& getMatRef();
const Mat& getMatRef() const;
Mat getMat();
Mat getMat(int num, int channel);
//shape getters
int cols() const;
int rows() const;
int channels() const;
int num() const;
Size size2() const;
Vec4i shape() const;
int size(int index) const;
size_t total(int startAxis = 0, int endAxis = -1) const;
uchar *rawPtr(int num = 0, int cn = 0, int row = 0, int col = 0);
template<typename TFloat>
TFloat *ptr(int num = 0, int cn = 0, int row = 0, int col = 0);
int type() const;
bool isFloat() const;
bool isDouble() const;
private:
const int *sizes() const;
int dims() const;
Mat m;
};
class CV_EXPORTS LayerParams : public Dict
{
public:
@ -70,30 +20,11 @@ namespace dnn
std::vector<Blob> learnedBlobs;
};
class CV_EXPORTS LayerRegister
{
public:
typedef Ptr<Layer> (*Constuctor)(LayerParams &params);
static void registerLayer(const String &type, Constuctor constructor);
static void unregisterLayer(const String &type);
static Ptr<Layer> createLayerInstance(const String &type, LayerParams& params);
private:
LayerRegister();
struct Impl;
static Ptr<Impl> impl;
};
//this class allows to build new Layers
//Interface class allows to build new Layers
class CV_EXPORTS Layer
{
public:
//TODO: this field must be declared as public if we want support possibility to change these params in runtime
//learned params of layer must be stored here to allow externally read them
std::vector<Blob> learnedParams;
virtual ~Layer();
@ -103,16 +34,13 @@ namespace dnn
virtual void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs) = 0;
virtual int getNumInputs();
virtual int getNumOutputs();
//each input/output can be labeled to easily identify their using "layer_name.output_name"
virtual String getInputName(int inputNum);
virtual String getOutputName(int outputNum);
virtual int inputNameToIndex(String inputName);
virtual int outputNameToIndex(String outputName);
};
//containers for String and int
typedef DictValue LayerId;
typedef DictValue BlobId;
class CV_EXPORTS Net
{
@ -125,14 +53,10 @@ namespace dnn
int getLayerId(LayerId layer);
void deleteLayer(LayerId layer);
//each output of each layer can be labeled by unique string label (as in Caffe)
//if label not specified then %layer_name%.%layer_output_id% can be used
void setOutputNames(LayerId layer, const std::vector<String> &outputNames);
void setLayerInputs(const std::vector<String> &outputs, LayerId layer);
void setNetInputs(const std::vector<String> &inputBlobNames);
void connect(BlobId input, BlobId output);
void connect(const std::vector<BlobId> &outputs, const std::vector<BlobId> &inputs);
void connect(String outPin, String inpPin);
void connect(int outLayerId, int outNum, int inLayerId, int inNum);
void forward();
void forward(LayerId toLayer);
@ -143,11 +67,11 @@ namespace dnn
void forwardOpt(LayerId toLayer);
void forwardOpt(const std::vector<LayerId> &toLayers);
void setBlob(BlobId outputName, const Blob &blob);
Blob getBlob(BlobId outputName);
void setBlob(String outputName, const Blob &blob);
Blob getBlob(String outputName);
void setParam(LayerId layer, int numParam, const Blob &blob);
void getParam(LayerId layer, int numParam);
Blob getParam(LayerId layer, int numParam = 0);
private:
@ -164,8 +88,27 @@ namespace dnn
virtual ~Importer();
};
CV_EXPORTS Ptr<Importer> createCaffeImporter(const String &prototxt, const String &caffeModel);
CV_EXPORTS Ptr<Importer> createCaffeImporter(const String &prototxt, const String &caffeModel = String());
//Layer factory allows to create instances of registered layers.
class CV_EXPORTS LayerRegister
{
public:
typedef Ptr<Layer>(*Constuctor)(LayerParams &params);
static void registerLayer(const String &type, Constuctor constructor);
static void unregisterLayer(const String &type);
static Ptr<Layer> createLayerInstance(const String &type, LayerParams& params);
private:
LayerRegister();
struct Impl;
static Ptr<Impl> impl;
};
//allows automatically register created layer on module load time
struct _LayerRegisterer

@ -1,5 +1,5 @@
#ifndef __OPENCV_DNN_INL_HPP__
#define __OPENCV_DNN_INL_HPP__
#ifndef __OPENCV_DNN_DNN_INL_HPP__
#define __OPENCV_DNN_DNN_INL_HPP__
#include <opencv2/dnn.hpp>
@ -7,123 +7,7 @@ namespace cv
{
namespace dnn
{
inline Mat& Blob::getMatRef()
{
return m;
}
inline const Mat& Blob::getMatRef() const
{
return m;
}
inline Mat Blob::getMat()
{
return m;
}
inline Mat Blob::getMat(int num, int channel)
{
CV_Assert(0 <= num && num < this->num() && 0 <= channel && channel < this->channels());
return Mat(rows(), cols(), m.type(), this->rawPtr(num, channel));
}
inline int Blob::cols() const
{
CV_DbgAssert(m.dims > 2);
return m.size[m.dims-1];
}
inline int Blob::rows() const
{
CV_DbgAssert(m.dims > 2);
return m.size[m.dims-2];
}
inline Size Blob::size2() const
{
return Size(cols(), rows());
}
inline int Blob::channels() const
{
CV_DbgAssert(m.dims >= 3);
return m.size[m.dims-3];
}
inline int Blob::num() const
{
CV_DbgAssert(m.dims == 4);
return m.size[0];
}
inline Vec4i Blob::shape() const
{
CV_DbgAssert(m.dims == 4);
return Vec4i(m.size.p);
}
inline int Blob::size(int index) const
{
CV_Assert(index >= 0 && index < dims());
return sizes()[index];
}
inline size_t Blob::total(int startAxis, int endAxis) const
{
if (endAxis == -1)
endAxis = dims();
CV_Assert(0 <= startAxis && startAxis <= endAxis && endAxis <= dims());
size_t size = 1; //assume that blob isn't empty
for (int i = startAxis; i < endAxis; i++)
size *= (size_t) sizes()[i];
return size;
}
inline uchar* Blob::rawPtr(int num, int cn, int row, int col)
{
CV_DbgAssert(m.dims == 4);
return m.data + num * m.step[0] + cn * m.step[1] + row * m.step[2] + col * m.step[3];
}
template<typename TFloat>
TFloat *Blob::ptr(int n, int cn, int row, int col)
{
CV_Assert(m.type() == cv::DataType<TFloat>::type);
CV_Assert(0 <= n && n < num() && 0 <= cn && cn < channels() && 0 <= row && row < rows() && 0 <= col && col < cols());
return (TFloat*) rawPtr(n, cn, row, col);
}
inline int Blob::type() const
{
return m.depth();
}
inline bool Blob::isFloat() const
{
return (type() == CV_32F);
}
inline bool Blob::isDouble() const
{
return (type() == CV_32F);
}
inline const int * Blob::sizes() const
{
return &m.size[0];
}
inline int Blob::dims() const
{
return m.dims;
}
//code is absent ... today
}
}

@ -25,50 +25,49 @@ std::vector<String> CLASES_NAMES;
void initClassesNames()
{
std::ifstream fp("ILSVRC2012_synsets.txt");
std::ifstream fp("synset_words.txt");
CV_Assert(fp.is_open());
std::string name;
while (!fp.eof())
{
std::getline(fp, name);
CLASES_NAMES.push_back(name);
if (name.length())
CLASES_NAMES.push_back( name.substr(name.find(' ')+1) );
}
CV_Assert(CLASES_NAMES.size() == 1000);
fp.close();
}
int main(void)
int main(int argc, char **argv)
{
Net net;
{
Ptr<Importer> importer = createCaffeImporter("bvlc_alexnet.prototxt", "bvlc_alexnet.caffemodel");
Ptr<Importer> importer = createCaffeImporter("bvlc_googlenet.prototxt", "bvlc_googlenet.caffemodel");
importer->populateNet(net);
}
Mat img = imread("zebra.jpg");
String filename = (argc > 1) ? argv[1] : "space_shuttle.jpg";
Mat img = imread(filename);
CV_Assert(!img.empty());
cvtColor(img, img, COLOR_BGR2RGB);
img.convertTo(img, CV_32F);
resize(img, img, Size(227, 227));
subtract(img, cv::mean(img), img);
Blob imgBlob(img);
net.setBlob("data", imgBlob);
net.setBlob(".data", imgBlob);
net.forward();
Blob probBlob = net.getBlob("prob");
ClassProb bc = getMaxClass(probBlob);
Blob prob = net.getBlob("prob");
ClassProb bc = getMaxClass(prob);
initClassesNames();
std::string className = (bc.first < (int)CLASES_NAMES.size()) ? CLASES_NAMES[bc.first] : "unnamed";
std::cout << "Best class:";
std::cout << " #" << bc.first;
std::cout << " (from " << probBlob.total(1) << ")";
std::cout << " (from " << prob.total(1) << ")";
std::cout << " \"" + className << "\"";
std::cout << std::endl;
std::cout << "Prob: " << bc.second * 100 << "%" << std::endl;

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

File diff suppressed because it is too large Load Diff

@ -0,0 +1,79 @@
#!/usr/bin/env python
import os
import sys
import time
import urllib
import hashlib
import argparse
import json
def reporthook(count, block_size, total_size):
"""
From http://blog.moleculea.com/2012/10/04/urlretrieve-progres-indicator/
"""
global start_time
global prev_duration
if count == 0:
start_time = time.time()
prev_duration = -1
return
duration = max(1, time.time() - start_time)
if int(duration) == int(prev_duration):
return
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()
prev_duration = duration
# Function for checking SHA1.
def model_checks_out(filename, sha1):
with open(filename, 'r') as f:
return hashlib.sha1(f.read()).hexdigest() == sha1
def model_download(filename, url, sha1):
# Check if model exists.
if os.path.exists(filename) and model_checks_out(filename, sha1):
print("Model {} already exists.".format(filename))
return
# Download and verify model.
urllib.urlretrieve(url, filename, reporthook)
print model_checks_out(filename, sha1)
if not model_checks_out(filename, sha1):
print("ERROR: model {} did not download correctly!".format(url))
sys.exit(1)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Downloading trained model binaries.")
parser.add_argument("download_list")
args = parser.parse_args()
test_dir = os.environ.get("OPENCV_TEST_DATA_PATH")
if not test_dir:
print "ERROR: OPENCV_TEST_DATA_PATH environment not specified"
sys.exit(1)
try:
with open(args.download_list, 'r') as f:
models_to_download = json.load(f)
except:
print "ERROR: Can't pasrse {}".format(args.download_list)
sys.exit(1)
for model_name in models_to_download:
model = models_to_download[model_name]
dst_dir = os.path.join(test_dir, os.path.dirname(model['file']))
dst_file = os.path.join(test_dir, model['file'])
if not os.path.exists(dst_dir):
print "ERROR: Can't find module testdata path '{}'".format(dst_dir)
sys.exit(1)
print "Downloading model '{}' to {} from {} ...".format(model_name, dst_file, model['url'])
model_download(dst_file, model['url'], model['sha1'])

@ -0,0 +1,7 @@
{
"googlenet": {
"file": "dnn/bvlc_googlenet.caffemodel",
"url": "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel",
"sha1": "405fc5acd08a3bb12de8ee5e23a96bec22f08204"
}
}

@ -0,0 +1,180 @@
#include "precomp.hpp"
namespace cv
{
namespace dnn
{
Blob::Blob()
{
int zeros[4] = { 0, 0, 0, 0 };
m = Mat(4, zeros, CV_32F, NULL);
}
static inline int getMatChannels(const Mat &mat)
{
return (mat.dims <= 2) ? mat.channels() : mat.size[0];
}
static BlobShape getBlobShpae(std::vector<Mat> &vmat, int requestedCn = -1)
{
BlobShape shape(4);
int cnSum = 0, matCn;
CV_Assert(vmat.size() > 0);
for (size_t i = 0; i < vmat.size(); i++)
{
Mat &mat = vmat[i];
CV_Assert(!mat.empty());
CV_Assert((mat.dims == 3 && mat.channels() == 1) || mat.dims <= 2);
matCn = getMatChannels(mat);
cnSum += getMatChannels(mat);
if (i == 0)
{
shape[-1] = mat.cols;
shape[-2] = mat.rows;
shape[-3] = (requestedCn <= 0) ? matCn : requestedCn;
}
else
{
if (mat.cols != shape[-1] || mat.rows != shape[-2])
CV_Error(Error::StsError, "Each Mat.size() must be equal");
if (requestedCn <= 0 && matCn != shape[-3])
CV_Error(Error::StsError, "Each Mat.chnannels() (or number of planes) must be equal");
}
}
if (cnSum % shape[-3] != 0)
CV_Error(Error::StsError, "Total number of channels in vector is not a multiple of requsted channel number");
shape[0] = cnSum / shape[-3];
return shape;
}
static std::vector<Mat> extractMatVector(InputArray in)
{
if (in.isMat() || in.isUMat())
{
return std::vector<Mat>(1, in.getMat());
}
else if (in.isMatVector())
{
return *static_cast<const std::vector<Mat>*>(in.getObj());
}
else if (in.isUMatVector())
{
std::vector<Mat> vmat;
in.getMatVector(vmat);
return vmat;
}
else
{
CV_Assert(in.isMat() || in.isMatVector() || in.isUMat() || in.isUMatVector());
return std::vector<Mat>();
}
}
Blob::Blob(InputArray in, int dstCn)
{
CV_Assert(dstCn == -1 || dstCn > 0);
std::vector<Mat> inMats = extractMatVector(in);
BlobShape dstShape = getBlobShpae(inMats, dstCn);
m.create(dstShape.dims(), dstShape.ptr(), CV_32F);
std::vector<Mat> wrapBuf(dstShape[-3]);
int elemSize = (int)m.elemSize();
uchar *ptr = this->ptrRaw();
for (size_t i = 0; i < inMats.size(); i++)
{
Mat inMat = inMats[i];
if (inMat.dims <= 2)
{
inMat.convertTo(inMat, m.type());
wrapBuf.resize(0);
for (int cn = 0; cn < inMat.channels(); cn++)
{
wrapBuf.push_back(Mat(inMat.rows, inMat.cols, m.type(), ptr));
ptr += elemSize * inMat.total();
}
cv::split(inMat, wrapBuf);
}
else
{
inMat.convertTo(Mat(inMat.dims, inMat.size, m.type(), ptr), m.type());
ptr += elemSize * inMat.total();
}
}
}
void Blob::fill(const BlobShape &shape, int type, void *data, bool deepCopy)
{
CV_Assert(type == CV_32F || type == CV_64F);
if (deepCopy)
{
m.create(shape.dims(), shape.ptr(), type);
memcpy(m.data, data, m.total() * m.elemSize());
}
else
{
m = Mat(shape.dims(), shape.ptr(), type, data);
}
}
void Blob::fill(InputArray in)
{
CV_Assert(in.isMat() || in.isMatVector());
//TODO
*this = Blob(in);
}
void Blob::create(const BlobShape &shape, int type)
{
CV_Assert(type == CV_32F || type == CV_64F);
m.create(shape.dims(), shape.ptr(), type);
}
inline void squeezeShape(const int srcDims, const int *srcSizes, const int dstDims, int *dstSizes)
{
const int m = std::min(dstDims, srcDims);
//copy common(last) dimensions
for (int i = 0; i < m; i++)
dstSizes[dstDims - 1 - i] = srcSizes[srcDims - 1 - i];
//either flatten extra dimensions
for (int i = m; i < srcDims; i++)
dstSizes[0] *= srcSizes[srcDims - 1 - i];
//either fill gaps
for (int i = m; i < dstDims; i++)
dstSizes[dstDims - 1 - i] = 1;
}
Vec4i Blob::shape4() const
{
return Vec4i(num(), channels(), rows(), cols());
}
std::ostream &operator<< (std::ostream &stream, const BlobShape &shape)
{
stream << "[";
for (int i = 0; i < shape.dims() - 1; i++)
stream << shape[i] << ", ";
if (shape.dims() > 0)
stream << shape[-1];
return stream << "]";
}
}
}

@ -1,4 +1,5 @@
#pragma once
#ifndef __OPENCV_DNN_CAFFE_GLOG_EMULATOR__
#define __OPENCV_DNN_CAFFE_GLOG_EMULATOR__
#include <stdlib.h>
#include <iostream>
#include <sstream>
@ -52,3 +53,4 @@ public:
};
}
#endif

@ -37,70 +37,79 @@ namespace
ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);
}
inline bool skipCaffeLayerParam(const FieldDescriptor *fd)
{
const std::string &name = fd->name();
if (fd->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
{
static const char *SKIP_FIELDS[] = { "type", "name", "top", "bottom", NULL };
for (int i = 0; SKIP_FIELDS[i]; i++)
{
if (name == SKIP_FIELDS[i])
return true;
}
return false;
}
else
{
static const std::string _param("_param");
bool endsWith_param = (name.size() >= _param.size()) && name.compare(name.size() - _param.size(), _param.size(), _param) == 0;
return !endsWith_param;
}
}
void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams &params)
{
const Reflection *msgRefl = msg.GetReflection();
const Reflection *refl = msg.GetReflection();
int type = field->cpp_type();
bool isRepeated = field->is_repeated();
const std::string &name = field->name();
std::cout << field->type_name() << " " << name << ":";
#define GET_FIRST(Type) (isRepeated ? msgRefl->GetRepeated##Type(msg, field, 0) : msgRefl->Get##Type(msg, field))
#define SET_UP_FILED(getter, arrayConstr, gtype) \
if (isRepeated) { \
const RepeatedField<gtype> &v = refl->GetRepeatedField<gtype>(msg, field); \
params.set(name, DictValue::arrayConstr(v.begin(), (int)v.size())); \
} \
else { \
params.set(name, refl->getter(msg, field)); \
}
switch (type)
{
case FieldDescriptor::CPPTYPE_INT32:
std::cout << params.set(name, GET_FIRST(Int32));
SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int32);
break;
case FieldDescriptor::CPPTYPE_UINT32:
std::cout << params.set(name, GET_FIRST(UInt32));
SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint32);
break;
case FieldDescriptor::CPPTYPE_INT64:
SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int64);
break;
case FieldDescriptor::CPPTYPE_UINT64:
SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint64);
break;
case FieldDescriptor::CPPTYPE_BOOL:
SET_UP_FILED(GetBool, arrayInt, bool);
break;
case FieldDescriptor::CPPTYPE_DOUBLE:
std::cout << params.set(name, GET_FIRST(Double));
SET_UP_FILED(GetDouble, arrayReal, double);
break;
case FieldDescriptor::CPPTYPE_FLOAT:
std::cout << params.set(name, GET_FIRST(Float));
SET_UP_FILED(GetFloat, arrayReal, float);
break;
case FieldDescriptor::CPPTYPE_ENUM:
std::cout << params.set(name, GET_FIRST(Enum)->name());
case FieldDescriptor::CPPTYPE_STRING:
if (isRepeated) {
const RepeatedPtrField<std::string> &v = refl->GetRepeatedPtrField<std::string>(msg, field);
params.set(name, DictValue::arrayString(v.begin(), (int)v.size()));
}
else {
params.set(name, refl->GetString(msg, field));
}
break;
case FieldDescriptor::CPPTYPE_BOOL:
std::cout << params.set(name, GET_FIRST(Bool));
case FieldDescriptor::CPPTYPE_ENUM:
if (isRepeated) {
int size = refl->FieldSize(msg, field);
std::vector<cv::String> buf(size);
for (int i = 0; i < size; i++)
buf[i] = refl->GetRepeatedEnum(msg, field, i)->name();
params.set(name, DictValue::arrayString(buf.begin(), size));
}
else {
params.set(name, refl->GetEnum(msg, field)->name());
}
break;
default:
std::cout << "unknown";
CV_Error(Error::StsError, "Unknown type \"" + String(field->type_name()) + "\" in prototxt");
break;
}
}
std::cout << std::endl;
inline static bool ends_with_param(const std::string &str)
{
static const std::string _param("_param");
return (str.size() >= _param.size()) && str.compare(str.size() - _param.size(), _param.size(), _param) == 0;
}
void extractLayerParams(const Message &msg, cv::dnn::LayerParams &params)
void extractLayerParams(const Message &msg, cv::dnn::LayerParams &params, bool isInternal = false)
{
const Descriptor *msgDesc = msg.GetDescriptor();
const Reflection *msgRefl = msg.GetReflection();
@ -109,19 +118,21 @@ namespace
{
const FieldDescriptor *fd = msgDesc->field(fieldId);
if (!isInternal && !ends_with_param(fd->name()))
continue;
bool hasData = fd->is_required() ||
(fd->is_optional() && (msgRefl->HasField(msg, fd) /*|| fd->has_default_value()*/)) ||
(fd->is_optional() && msgRefl->HasField(msg, fd)) ||
(fd->is_repeated() && msgRefl->FieldSize(msg, fd) > 0);
if ( !hasData || skipCaffeLayerParam(fd) )
if (!hasData)
continue;
if (fd->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE)
{
if (fd->is_repeated()) //Extract only first item!
extractLayerParams(msgRefl->GetRepeatedMessage(msg, fd, 0), params);
extractLayerParams(msgRefl->GetRepeatedMessage(msg, fd, 0), params, true);
else
extractLayerParams(msgRefl->GetMessage(msg, fd), params);
extractLayerParams(msgRefl->GetMessage(msg, fd), params, true);
}
else
{
@ -130,39 +141,41 @@ namespace
}
}
void blobFromProto(const caffe::BlobProto &protoBlob, cv::dnn::Blob &dstBlob)
BlobShape blobShapeFromProto(const caffe::BlobProto &pbBlob)
{
AutoBuffer<int, 4> shape;
if (protoBlob.has_num() || protoBlob.has_channels() || protoBlob.has_height() || protoBlob.has_width())
if (pbBlob.has_num() || pbBlob.has_channels() || pbBlob.has_height() || pbBlob.has_width())
{
shape.resize(4);
shape[0] = protoBlob.num();
shape[1] = protoBlob.channels();
shape[2] = protoBlob.height();
shape[3] = protoBlob.width();
return BlobShape(pbBlob.num(), pbBlob.channels(), pbBlob.height(), pbBlob.width());
}
else if (protoBlob.has_shape())
else if (pbBlob.has_shape())
{
const caffe::BlobShape &_shape = protoBlob.shape();
shape.resize(_shape.dim_size());
const caffe::BlobShape &_shape = pbBlob.shape();
BlobShape shape(_shape.dim_size());
for (int i = 0; i < _shape.dim_size(); i++)
shape[i] = _shape.dim(i);
shape[i] = (int)_shape.dim(i);
return shape;
}
else
{
CV_Error(cv::Error::StsAssert, "Unknown shape of input blob");
CV_Error(Error::StsError, "Unknown shape of input blob");
return BlobShape(-1);
}
}
dstBlob.create(shape.size(), shape, CV_32F);
CV_Assert(protoBlob.data_size() == (int)dstBlob.getMatRef().total());
void blobFromProto(const caffe::BlobProto &pbBlob, cv::dnn::Blob &dstBlob)
{
BlobShape shape = blobShapeFromProto(pbBlob);
dstBlob.create(shape, CV_32F);
CV_Assert(pbBlob.data_size() == (int)dstBlob.getMatRef().total());
CV_DbgAssert(protoBlob.GetDescriptor()->FindFieldByLowercaseName("data")->cpp_type() == FieldDescriptor::CPPTYPE_FLOAT);
CV_DbgAssert(pbBlob.GetDescriptor()->FindFieldByLowercaseName("data")->cpp_type() == FieldDescriptor::CPPTYPE_FLOAT);
float *dstData = dstBlob.getMatRef().ptr<float>();
for (int i = 0; i < protoBlob.data_size(); i++)
dstData[i] = protoBlob.data(i);
for (int i = 0; i < pbBlob.data_size(); i++)
dstData[i] = pbBlob.data(i);
}
void extractBinaryLayerParms(const caffe::LayerParameter& layer, LayerParams& layerParams)
@ -187,55 +200,100 @@ namespace
}
}
struct BlobNote
{
BlobNote(const std::string &_name, int _layerId, int _outNum) :
name(_name.c_str()), layerId(_layerId), outNum(_outNum) {}
const char *name;
int layerId, outNum;
};
void populateNet(Net dstNet)
{
int layersSize = net.layer_size();
std::vector<BlobNote> addedBlobs;
addedBlobs.reserve(layersSize + 1);
//setup input layer names
{
std::vector<String> netInputs(net.input_size());
for (int ii = 0; ii < net.input_size(); ii++)
netInputs[ii] = net.input(ii);
for (int inNum = 0; inNum < net.input_size(); inNum++)
{
addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum));
netInputs[inNum] = net.input(inNum);
}
dstNet.setNetInputs(netInputs);
}
int layersSize = net.layer_size();
std::vector<String> layersName(layersSize);
std::vector<int> layersId(layersSize);
std::vector<std::vector<String> > bottomsVec(layersSize);
for (int li = 0; li < layersSize; li++)
{
const caffe::LayerParameter layer = net.layer(li);
const caffe::LayerParameter &layer = net.layer(li);
String name = layer.name();
String type = layer.type();
LayerParams layerParams;
std::vector<String> tops;
tops.assign(layer.top().begin(), layer.top().end());
bottomsVec[li].assign(layer.bottom().begin(), layer.bottom().end());
std::cout << std::endl << "LAYER: " << name << std::endl;
extractLayerParams(layer, layerParams);
extractBinaryLayerParms(layer, layerParams);
int id = dstNet.addLayer(name, type, layerParams);
dstNet.setOutputNames(id, tops);
layersName[li] = name;
layersId[li] = id;
for (int inNum = 0; inNum < layer.bottom_size(); inNum++)
addInput(layer.bottom(inNum), id, inNum, dstNet, addedBlobs);
for (int outNum = 0; outNum < layer.top_size(); outNum++)
addOutput(layer, id, outNum, addedBlobs);
}
}
for (int li = 0; li < layersSize; li++)
void addOutput(const caffe::LayerParameter &layer, int layerId, int outNum, std::vector<BlobNote> &addedBlobs)
{
const std::string &name = layer.top(outNum);
bool haveDups = false;
for (int idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)
{
if (addedBlobs[idx].name == name)
{
dstNet.setLayerInputs(bottomsVec[li], layersId[li]);
haveDups = true;
break;
}
}
if (haveDups)
{
bool isInplace = layer.bottom_size() > outNum && layer.bottom(outNum) == name;
if (!isInplace)
CV_Error(Error::StsBadArg, "Duplicate blobs produced by multiple sources");
}
addedBlobs.push_back(BlobNote(name, layerId, outNum));
}
void addInput(const std::string &name, int layerId, int inNum, Net &dstNet, std::vector<BlobNote> &addedBlobs)
{
int idx;
for (idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)
{
if (addedBlobs[idx].name == name)
break;
}
if (idx < 0)
{
CV_Error(Error::StsObjectNotFound, "Can't found output blob \"" + name + "\"");
return;
}
dstNet.connect(addedBlobs[idx].layerId, addedBlobs[idx].outNum, layerId, inNum);
}
~CaffeImporter()
{
}
};
}

@ -2,6 +2,7 @@
#include <set>
#include <algorithm>
#include <iostream>
#include <sstream>
using namespace cv;
using namespace cv::dnn;
@ -16,172 +17,111 @@ namespace cv
namespace dnn
{
Blob::Blob()
template<typename T>
String toString(const T &v)
{
int zeros[4] = {0, 0, 0, 0};
m = Mat(4, zeros, CV_32F, NULL);
std::ostringstream ss;
ss << v;
return ss.str();
}
Blob::Blob(InputArray in)
struct LayerPin
{
CV_Assert(in.isMat() || in.isUMat());
if (in.isMat())
{
Mat mat = in.getMat();
CV_Assert(mat.dims == 2);
int rows = mat.rows;
int cols = mat.cols;
int cn = mat.channels();
int type = mat.type();
int dstType = CV_MAKE_TYPE(CV_MAT_DEPTH(type), 1);
int lid;
int oid;
int size[3] = { cn, rows, cols };
this->create(3, size, dstType);
uchar *data = m.data;
int step = rows * cols * CV_ELEM_SIZE(dstType);
LayerPin(int layerId = -1, int outputId = -1)
: lid(layerId), oid(outputId) {}
if (cn == 1)
bool valid() const
{
Mat wrapper2D(rows, cols, dstType, m.data);
mat.copyTo(wrapper2D);
}
else
{
std::vector<Mat> wrappers(cn);
for (int i = 0; i < cn; i++)
{
wrappers[i] = Mat(rows, cols, dstType, data);
data += step;
return (lid >= 0 && oid >= 0);
}
cv::split(mat, wrappers);
}
}
else
bool equal(const LayerPin &r) const
{
CV_Error(cv::Error::StsNotImplemented, "Not Implemented");
return (lid == r.lid && oid == r.oid);
}
}
};
static Vec4i blobNormalizeShape(int ndims, const int *sizes)
struct LayerData
{
Vec4i shape = Vec4i::all(1);
for (int i = 0; i < std::min(3, ndims); i++)
shape[3 - i] = sizes[ndims-1 - i];
LayerData() {}
LayerData(int _id, const String &_name, const String &_type, LayerParams &_params)
: id(_id), name(_name), type(_type), params(_params) {}
for (int i = 3; i < ndims; i++)
shape[0] *= sizes[ndims-1 - i];
int id;
String name;
String type;
LayerParams params;
return shape;
}
std::vector<LayerPin> inputBlobsId;
std::set<int> inputLayersId;
std::set<int> requiredOutputs;
void Blob::fill(int ndims, const int *sizes, int type, void *data, bool deepCopy)
{
CV_Assert(type == CV_32F || type == CV_64F);
Ptr<Layer> layerInstance;
std::vector<Blob> outputBlobs;
std::vector<Blob*> inputBlobs;
Vec4i shape = blobNormalizeShape(ndims, sizes);
int flag;
if (deepCopy)
Ptr<Layer> getLayerInstance()
{
m.create(3, &shape[0], type);
size_t dataSize = m.total() * m.elemSize();
memcpy(m.data, data, dataSize);
}
else
if (layerInstance)
return layerInstance;
layerInstance = LayerRegister::createLayerInstance(type, params);
if (!layerInstance)
{
m = Mat(shape.channels, &shape[0], type, data);
CV_Error(Error::StsError, "Can't create layer \"" + name + "\" of type \"" + type + "\"");
}
}
void Blob::fill(InputArray in)
{
CV_Assert(in.isMat() || in.isMatVector());
//TODO
*this = Blob(in);
}
void Blob::create(int ndims, const int *sizes, int type)
{
CV_Assert(type == CV_32F || type == CV_64F);
Vec4i shape = blobNormalizeShape(ndims, sizes);
m.create(shape.channels, &shape[0], type);
}
void Blob::create(Vec4i shape, int type)
{
m.create(shape.channels, &shape[0], type);
}
void Blob::create(int num, int cn, int rows, int cols, int type)
{
Vec4i shape(num, cn, rows, cols);
create(4, &shape[0], type);
}
//////////////////////////////////////////////////////////////////////////
struct LayerOutId
{
int lid;
int oid;
String name;
LayerOutId() {}
LayerOutId(int layerId, int outputId, const String &outputName = String())
: lid(layerId), oid(outputId), name(outputName) {}
return layerInstance;
}
};
struct LayerData
//fake layer containing network input blobs
struct NetInputLayer : public Layer
{
LayerData() {}
LayerData(const String &_name, const String &_type, LayerParams &_params)
: name(_name), type(_type), params(_params) {}
String name;
String type;
LayerParams params;
void allocate(const std::vector<Blob*>&, std::vector<Blob>&) {}
void forward(std::vector<Blob*>&, std::vector<Blob>&) {}
std::vector<String> outputNames;
std::vector<String> inputNames;
bool hasNamedOutput(const String &name)
int outputNameToIndex(String tgtName)
{
return std::find(outputNames.begin(), outputNames.end(), name) != outputNames.end();
int idx = (int)(std::find(outNames.begin(), outNames.end(), tgtName) - outNames.begin());
return (idx < (int)outNames.size()) ? idx : -1;
}
bool hasNemedInput(const String &name)
void setNames(const std::vector<String> &names)
{
return std::find(inputNames.begin(), inputNames.end(), name) != inputNames.end();
outNames.assign(names.begin(), names.end());
}
std::vector<LayerOutId> inputBlobsId;
std::set<int> inputLayersId;
std::set<int> requiredOutputs;
Ptr<Layer> layerInstance;
std::vector<Blob> outputBlobs;
std::vector<Blob*> inputBlobs;
int flag;
private:
std::vector<String> outNames;
};
struct Net::Impl
{
Impl()
{
LayerParams paramsEmpty;
layers.insert(make_pair(0, LayerData("_input", "_noType", paramsEmpty)));
//allocate fake net input layer
netInputLayer = Ptr<NetInputLayer>(new NetInputLayer());
LayerData &inpl = layers.insert( make_pair(0, LayerData()) ).first->second;
inpl.id = 0;
inpl.name = "_input";
inpl.type = "__NetInputLayer__";
inpl.layerInstance = netInputLayer;
lastLayerId = 1;
netWasAllocated = false;
}
Ptr<NetInputLayer> netInputLayer;
std::vector<int> netOutputs;
typedef std::map<int, LayerData> MapIdToLayerData;
std::map<int, LayerData> layers;
std::map<String, int> layerNameToId;
int lastLayerId;
@ -192,9 +132,8 @@ struct Net::Impl
{
if (!netWasAllocated)
{
connectInputs();
allocateLayers();
computeNetOutputs();
computeNetOutputLayers();
netWasAllocated = true;
}
@ -206,121 +145,130 @@ struct Net::Impl
return (it != layerNameToId.end()) ? it->second : -1;
}
int getLayerId(const DictValue &v)
{
if (v.isString())
return getLayerId(v.get<String>());
else if (v.isInt())
return v.get<int>();
else
int getLayerId(int id)
{
CV_Assert(v.isString() || v.isInt());
return -1;
}
MapIdToLayerData::iterator it = layers.find(id);
return (it != layers.end()) ? id : -1;
}
LayerData& getLayerData(const DictValue &v)
int getLayerId(DictValue &layerDesc)
{
int id = getLayerId(v);
std::map<int, LayerData>::iterator it = layers.find(id);
CV_Assert(id >= 0 && it != layers.end());
return it->second;
if (layerDesc.isInt())
return getLayerId(layerDesc.get<int>());
else if (layerDesc.isString())
return getLayerId(layerDesc.get<String>());
CV_Assert(layerDesc.isInt() || layerDesc.isString());
return -1;
}
int findOutputsByName(const String &name, LayerOutId *found, int maxCount = 1)
String getLayerName(int id)
{
int count = 0;
MapIdToLayerData::iterator it = layers.find(id);
return (it != layers.end()) ? it->second.name : "(unknown layer)";
}
MapIdToLayerData::iterator it;
for (it = layers.begin(); it != layers.end() && count < maxCount; it++)
LayerData& getLayerData(int id)
{
int lid = it->first;
LayerData &ld = it->second;
MapIdToLayerData::iterator it = layers.find(id);
for (size_t oi = 0; oi < ld.outputNames.size() && count < maxCount; oi++)
{
if (ld.outputNames[oi] == name)
found[count++] = LayerOutId(lid, (int)oi);
}
}
if (it == layers.end())
CV_Error(Error::StsError, "Layer with requested id=" + toString(id) + " not found");
return count;
return it->second;
}
void connectInputs()
LayerData& getLayerData(const String &layerName)
{
LayerOutId foundOutputs[3], out;
int id = getLayerId(layerName);
MapIdToLayerData::iterator it;
for (it = layers.begin(); it != layers.end(); it++)
{
LayerData &ld = it->second;
if (id < 0)
CV_Error(Error::StsError, "Requsted layer \"" + layerName + "\" not found");
ld.inputBlobs.resize(ld.inputNames.size());
ld.inputBlobsId.resize(ld.inputNames.size());
ld.inputLayersId.clear();
return getLayerData(id);
}
for (size_t ii = 0; ii < ld.inputNames.size(); ii++)
LayerData& getLayerData(const DictValue &layerDesc)
{
const String &tgtName = ld.inputNames[ii];
int foundCount = findOutputsByName(tgtName, foundOutputs, 3);
if (layerDesc.isInt())
return getLayerData(layerDesc.get<int>());
else if (layerDesc.isString())
return getLayerData(layerDesc.get<String>());
if (foundCount > 2)
{
CV_Error(cv::Error::StsNotImplemented, "Two or more non-inplace blobs have the same name \"" + tgtName + "\"");
CV_Assert(layerDesc.isInt() || layerDesc.isString());
return *((LayerData*)NULL);
}
else if (foundCount == 2)
{
bool inPlace[2];
inPlace[0] = layers[ foundOutputs[0].lid ].hasNemedInput(tgtName);
inPlace[1] = layers[ foundOutputs[1].lid ].hasNemedInput(tgtName);
if (!inPlace[0] && !inPlace[1])
static void addLayerInput(LayerData &ld, int inNum, LayerPin from)
{
CV_Error(cv::Error::StsNotImplemented, "Two or more non-inplace blobs have the same name \"" + tgtName + "\"");
}
else if (inPlace[0] && inPlace[1])
if ((int)ld.inputBlobsId.size() <= inNum)
{
CV_Error(cv::Error::StsNotImplemented, "Two or more blobs has same in-place blob \"" + tgtName + "\"");
ld.inputBlobsId.resize(inNum + 1);
}
else
{
if (ld.hasNamedOutput(tgtName))
out = (inPlace[0]) ? foundOutputs[1] : foundOutputs[0];
else
out = (inPlace[0]) ? foundOutputs[0] : foundOutputs[1];
LayerPin storedFrom = ld.inputBlobsId[inNum];
if (storedFrom.valid() && !storedFrom.equal(from))
CV_Error(Error::StsError, "Input #" + toString(inNum) + "of layer \"" + ld.name + "\" already was connected");
}
ld.inputBlobsId[inNum] = from;
}
else if (foundCount == 0)
static void splitPin(const String &pinAlias, String &layerName, String &outName)
{
CV_Error(cv::Error::StsBadArg, "Can't find specified input blob \"" + tgtName + "\" for layer \"" + ld.name + "\"");
continue;
size_t delimPos = pinAlias.find('.');
layerName = pinAlias.substr(0, delimPos);
outName = (delimPos == String::npos) ? String() : pinAlias.substr(delimPos + 1);
}
else
int resolvePinOutputName(LayerData &ld, const String &outName, bool isOutPin)
{
out = foundOutputs[0];
}
if (outName.empty())
return 0;
if (std::isdigit(outName[0]))
{
char *lastChar;
long inum = std::strtol(outName.c_str(), &lastChar, 10);
ld.inputBlobsId[ii] = out;
ld.inputLayersId.insert(out.lid);
layers[out.lid].requiredOutputs.insert(out.oid);
if (*lastChar == 0)
{
CV_Assert(inum == (int)inum);
return (int)inum;
}
}
for (it = layers.begin(); it != layers.end(); it++)
if (isOutPin)
return ld.getLayerInstance()->outputNameToIndex(outName);
else
return ld.getLayerInstance()->inputNameToIndex(outName);
}
LayerPin getPinByAlias(const String &pinAlias, bool isOutPin = true)
{
LayerData& ld = it->second;
LayerPin pin;
String layerName, outName;
splitPin(pinAlias, layerName, outName);
pin.lid = (layerName.empty()) ? 0 : getLayerId(layerName);
if (pin.lid >= 0)
pin.oid = resolvePinOutputName(getLayerData(pin.lid), outName, isOutPin);
std::cout << ld.name << std::endl;
std::cout << "Connected:" << std::endl;
for (std::set<int>::iterator j = ld.inputLayersId.begin(); j != ld.inputLayersId.end(); j++)
std::cout << layers[*j].name << std::endl;
std::cout << std::endl;
return pin;
}
void connect(int outLayerId, int outNum, int inLayerId, int inNum)
{
LayerData &ldOut = getLayerData(outLayerId);
LayerData &ldInp = getLayerData(inLayerId);
addLayerInput(ldInp, inNum, LayerPin(outLayerId, outNum));
ldOut.requiredOutputs.insert(outNum);
}
void computeNetOutputs()
void computeNetOutputLayers()
{
netOutputs.clear();
@ -351,31 +299,18 @@ struct Net::Impl
for (set<int>::iterator i = ld.inputLayersId.begin(); i != ld.inputLayersId.end(); i++)
allocateLayer(*i);
//create instance
if (ld.layerInstance == NULL && lid != 0)
{
ld.layerInstance = LayerRegister::createLayerInstance(ld.type, ld.params);
if (ld.layerInstance == NULL)
{
std::cerr << "Can't create layer \"" << ld.name << "\" of type \"" << ld.type << "\"" << std::endl;
}
}
//bind inputs
ld.inputBlobs.resize(ld.inputBlobsId.size());
for (size_t i = 0; i < ld.inputBlobsId.size(); i++)
{
int srcLId = ld.inputBlobsId[i].lid;
int srcOId = ld.inputBlobsId[i].oid;
ld.inputBlobs[i] = &layers[srcLId].outputBlobs[srcOId];
LayerPin from = ld.inputBlobsId[i];
CV_Assert(from.valid());
ld.inputBlobs[i] = &layers[from.lid].outputBlobs[from.oid];
}
//allocate layer
ld.outputBlobs.resize(ld.outputNames.size());
if (ld.layerInstance)
ld.layerInstance->allocate(ld.inputBlobs, ld.outputBlobs);
//std::cout << ld.name << " shape:" << ld.outputBlobs[0].shape() << std::endl;
ld.outputBlobs.resize(std::max((size_t)1, ld.requiredOutputs.size())); //layer produce at least one output blob
ld.getLayerInstance()->allocate(ld.inputBlobs, ld.outputBlobs);
ld.flag = 1;
}
@ -393,7 +328,7 @@ struct Net::Impl
}
}
void forwardLayer(int layerId, bool clearFlags = true)
void forwardLayer(LayerData &ld, bool clearFlags = true)
{
if (clearFlags)
{
@ -402,8 +337,6 @@ struct Net::Impl
it->second.flag = 0;
}
LayerData &ld = layers[layerId];
//already was forwarded
if (ld.flag)
return;
@ -411,15 +344,12 @@ struct Net::Impl
//forward parents
for (set<int>::iterator i = ld.inputLayersId.begin(); i != ld.inputLayersId.end(); i++)
{
forwardLayer(*i, false);
forwardLayer(layers[*i], false);
}
//forward itself
if (ld.layerInstance && layerId != 0)
ld.layerInstance->forward(ld.inputBlobs, ld.outputBlobs);
//std::cout << ld.name << " shape:" << ld.outputBlobs[0].shape() << std::endl;
ld.flag = 1;
}
@ -430,7 +360,7 @@ struct Net::Impl
it->second.flag = 0;
for (it = layers.begin(); it != layers.end(); it++)
forwardLayer(it->first, false);
forwardLayer(it->second, false);
}
};
@ -446,35 +376,38 @@ Net::~Net()
int Net::addLayer(const String &name, const String &type, LayerParams &params)
{
if (name.find('.') != String::npos)
{
CV_Error(Error::StsBadArg, "Added layer name \"" + name + "\" should not contain dot symbol");
return -1;
}
if (impl->getLayerId(name) >= 0)
{
CV_Error(cv::Error::StsBadArg, "Layer \"" + name + "\" already into net");
CV_Error(Error::StsBadArg, "Layer \"" + name + "\" already into net");
return -1;
}
int id = ++impl->lastLayerId;
impl->layerNameToId.insert(std::make_pair(name, id));
impl->layers.insert(std::make_pair(id, LayerData(name, type, params)));
impl->layers.insert(std::make_pair(id, LayerData(id, name, type, params)));
return id;
}
//void Net::connect(BlobId input, BlobId output)
//{
//}
void Net::setOutputNames(LayerId layer, const std::vector<String> &outputNames)
void Net::connect(int outLayerId, int outNum, int inLayerId, int inNum)
{
LayerData &ld = impl->getLayerData(layer);
CV_Assert(ld.outputNames.size() == 0);
ld.outputNames.assign(outputNames.begin(), outputNames.end());
impl->connect(outLayerId, outNum, inLayerId, inNum);
}
void Net::setLayerInputs(const std::vector<String> &outputs, LayerId layer)
void Net::connect(String _outPin, String _inPin)
{
LayerData &ld = impl->getLayerData(layer);
ld.inputNames.assign(outputs.begin(), outputs.end());
LayerPin outPin = impl->getPinByAlias(_outPin);
LayerPin inpPin = impl->getPinByAlias(_inPin);
CV_Assert(outPin.valid() && inpPin.valid());
impl->connect(outPin.lid, outPin.oid, inpPin.lid, inpPin.oid);
}
void Net::forward()
@ -486,74 +419,62 @@ void Net::forward()
void Net::forward(LayerId toLayer)
{
impl->setUpNet();
impl->forwardLayer(impl->getLayerId(toLayer));
impl->forwardLayer(impl->getLayerData(toLayer));
}
void Net::setNetInputs(const std::vector<String> &inputBlobNames)
{
setOutputNames(0, inputBlobNames);
impl->netInputLayer->setNames(inputBlobNames);
}
void Net::setBlob(BlobId outputName, const Blob &blob)
void Net::setBlob(String outputName, const Blob &blob)
{
String name = outputName.get<String>();
LayerOutId found;
if (!impl->findOutputsByName(name, &found, 1))
CV_Error(cv::Error::StsObjectNotFound, "Request blob \"" + name + "\" not found");
LayerPin pin = impl->getPinByAlias(outputName);
if (!pin.valid())
CV_Error(Error::StsObjectNotFound, "Requested blob \"" + outputName + "\" not found");
LayerData &ld = impl->layers[found.lid];
ld.outputBlobs.resize(ld.outputNames.size());
ld.outputBlobs[found.oid] = blob;
LayerData &ld = impl->layers[pin.lid];
ld.outputBlobs.resize( std::max(pin.oid+1, (int)ld.requiredOutputs.size()) );
ld.outputBlobs[pin.oid] = blob;
}
Blob Net::getBlob(BlobId outputName)
Blob Net::getBlob(String outputName)
{
String name = outputName.get<String>();
LayerOutId found;
LayerPin pin = impl->getPinByAlias(outputName);
if (!pin.valid())
CV_Error(Error::StsObjectNotFound, "Requested blob \"" + outputName + "\" not found");
if (!impl->findOutputsByName(name, &found, 1))
CV_Error(cv::Error::StsObjectNotFound, "Request blob \"" + name + "\" not found");
LayerData &ld = impl->layers[found.lid];
return ld.outputBlobs[found.oid];
LayerData &ld = impl->layers[pin.lid];
if ((size_t)pin.oid >= ld.outputBlobs.size())
{
CV_Error(Error::StsOutOfRange, "Layer \"" + ld.name + "\" produce only " + toString(ld.outputBlobs.size()) +
" outputs, the #" + toString(pin.oid) + " was requsted");
}
return ld.outputBlobs[pin.oid];
}
Importer::~Importer()
Blob Net::getParam(LayerId layer, int numParam)
{
LayerData &ld = impl->getLayerData(layer);
std::vector<Blob> &layerBlobs = ld.layerInstance->learnedParams;
CV_Assert(numParam < (int)layerBlobs.size());
return layerBlobs[numParam];
}
//////////////////////////////////////////////////////////////////////////
#include <sstream>
template<typename T>
String toString(const T &v)
{
std::stringstream ss;
ss << v;
return ss.str();
}
int Layer::getNumInputs()
Importer::~Importer()
{
return 1;
}
int Layer::getNumOutputs()
{
return 1;
}
cv::String Layer::getInputName(int inputNum)
int Layer::inputNameToIndex(String)
{
return "input" + toString(inputNum);
return -1;
}
cv::String Layer::getOutputName(int outputNum)
int Layer::outputNameToIndex(String)
{
return "output" + toString(outputNum);
return -1;
}
Layer::~Layer()
@ -604,5 +525,11 @@ Ptr<Layer> LayerRegister::createLayerInstance(const String &_type, LayerParams&
}
}
int Net::getLayerId(LayerId)
{
CV_Error(Error::StsNotImplemented, "");
return -1;
}
}
}

@ -18,7 +18,7 @@ namespace dnn
{
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
outputs[i] = *inputs[i];
outputs[i].shareFrom(*inputs[i]);
}
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
@ -28,12 +28,6 @@ namespace dnn
}
};
static Ptr<Layer> blankLayerRegisterer(LayerParams &params)
{
return Ptr<Layer>(new BlankLayer(params));
}
REGISTER_LAYER_FUNC(Dropout, blankLayerRegisterer)
REGISTER_LAYER_CLASS(Dropout, BlankLayer)
}
}

@ -0,0 +1,73 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
namespace cv
{
namespace dnn
{
class ConcatLayer : public Layer
{
int axis;
public:
ConcatLayer(LayerParams& params);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
REGISTER_LAYER_CLASS(Concat, ConcatLayer)
ConcatLayer::ConcatLayer(LayerParams &params)
{
axis = params.get<int>("axis", 1);
CV_Assert(axis >= 0);
}
void ConcatLayer::allocate(const std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() > 0);
int refType = inputs[0]->type();
BlobShape refShape = inputs[0]->shape();
CV_Assert(axis < refShape.dims());
int axisSum = 0;
for (size_t i = 0; i < inputs.size(); i++)
{
BlobShape curShape = inputs[i]->shape();
CV_Assert(curShape.dims() == refShape.dims() && inputs[i]->type() == refType);
for (int axisId = 0; axisId < refShape.dims(); axisId++)
{
if (axisId != axis && refShape[axisId] != curShape[axisId])
CV_Error(Error::StsBadSize, "Inconsitent shape for ConcatLayer");
}
axisSum += curShape[axis];
}
refShape[axis] = axisSum;
outputs.resize(1);
outputs[0].create(refShape);
}
void ConcatLayer::forward(std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
{
const Mat& outMat = outputs[0].getMatRef();
std::vector<Range> ranges(outputs[0].dims(), Range::all());
int sizeStart = 0;
for (size_t i = 0; i < inputs.size(); i++)
{
int sizeEnd = sizeStart + inputs[i]->size(axis);
ranges[axis] = Range(sizeStart, sizeEnd);
Mat outSubMat = outMat(&ranges[0]);
inputs[i]->getMatRef().copyTo(outSubMat);
sizeStart = sizeEnd;
}
}
}
}

@ -1,5 +1,6 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "im2col.hpp"
namespace cv
{
@ -8,19 +9,24 @@ namespace dnn
//TODO: simultaneously convolution and bias addition for cache optimization
class ConvolutionLayer : public Layer
{
protected:
bool bias;
int numOutput, group;
int padH, padW;
int kerH, kerW;
int strideH, strideW;
int kernelH, kernelW;
int inH, inW, inCn, kerSize;
int outH, outW;
int groupCn, groupCnOut;
int inpH, inpW, inpCn;
int outH, outW, outCn;
int topH, topW, topCn; //switched between inp/out on deconv/conv
int inpGroupCn, outGroupCn;
int ksize;
Mat srcColsMat, biasOnesMat;
Mat colMat, biasOnesMat;
void computeOutputShape(int inH, int inW);
inline bool is1x1() const;
virtual void computeInpOutShape(const Blob &inpBlob);
void im2col(Blob &inpBlob, int imNum, int cnGroup);
public:
ConvolutionLayer(LayerParams &params);
@ -28,13 +34,25 @@ namespace dnn
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
class DeConvolutionLayer : public ConvolutionLayer
{
protected:
void computeInpOutShape(const Blob &inpBlob);
void col2im(Mat &dstMat);
public:
DeConvolutionLayer(LayerParams &params) : ConvolutionLayer(params) {}
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
REGISTER_LAYER_CLASS(Convolution, ConvolutionLayer)
REGISTER_LAYER_CLASS(Deconvolution, DeConvolutionLayer)
ConvolutionLayer::ConvolutionLayer(LayerParams &params)
{
getKernelParams(params, kernelH, kernelW, padH, padW, strideH, strideW);
getKernelParams(params, kerH, kerW, padH, padW, strideH, strideW);
numOutput = params.get<int>("num_output");
bias = params.get<bool>("bias_term", true);
@ -44,8 +62,8 @@ namespace dnn
CV_Assert(params.learnedBlobs.size() >= 1 && (!bias || params.learnedBlobs.size() >= 2));
learnedParams.assign(params.learnedBlobs.begin(), params.learnedBlobs.begin() + (bias ? 2 : 1));
Blob &weightBlob = learnedParams[0];
CV_Assert(weightBlob.cols() == kernelW && weightBlob.rows() == kernelH && weightBlob.num() == numOutput);
const Blob &wgtBlob = learnedParams[0];
CV_Assert(wgtBlob.dims() == 4 && wgtBlob.cols() == kerW && wgtBlob.rows() == kerH && wgtBlob.num() == numOutput);
if (bias)
{
@ -58,94 +76,141 @@ namespace dnn
{
CV_Assert(inputs.size() > 0);
Blob &weightBlob = learnedParams[0];
const Blob &inpBlob = *inputs[0];
CV_Assert(inpBlob.dims() == 4 && inpBlob.type() == CV_32F);
computeInpOutShape(inpBlob);
inCn = inputs[0]->channels();
CV_Assert(inCn % group == 0 && numOutput % group == 0 && weightBlob.channels() == inCn/group);
groupCnOut = numOutput / group;
groupCn = inCn / group;
CV_Assert(inpCn % group == 0 && outCn % group == 0);
CV_Assert(learnedParams[0].channels() == inpCn / group);
CV_Assert(learnedParams[0].num() == outCn);
inH = inputs[0]->rows();
inW = inputs[0]->cols();
computeOutputShape(inH, inW);
outGroupCn = outCn / group;
inpGroupCn = inpCn / group;
ksize = inpGroupCn * kerH * kerW;
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
CV_Assert(inputs[i]->rows() == inH && inputs[i]->cols() == inW && inputs[i]->channels() == inCn);
int num = inputs[i]->num();
CV_Assert(inputs[i]->type() == inpBlob.type());
CV_Assert(inputs[i]->dims() == 4 && inputs[i]->channels() == inpBlob.channels());
CV_Assert(inputs[i]->rows() == inpBlob.rows() && inputs[i]->cols() == inpBlob.cols());
outputs[i].create(num, numOutput, outH, outW);
outputs[i].create(BlobShape(inputs[i]->num(), topCn, topH, topW));
}
kerSize = kernelH * kernelW * groupCn;
srcColsMat.create(kerSize, outH * outW, CV_32F);
if (!is1x1())
colMat.create(ksize, outH * outW, inpBlob.type());
if (bias)
{
biasOnesMat = Mat::ones(1, outH * outW, CV_32F);
biasOnesMat = Mat::ones(1, topH * topW, inpBlob.type());
}
inline bool ConvolutionLayer::is1x1() const
{
return (kerH == 1 && kerW == 1);
}
template <typename Dtype>
void im2col_cpu(const Dtype* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
Dtype* data_col)
void ConvolutionLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
Blob &wgtBlob = learnedParams[0];
for (size_t ii = 0; ii < outputs.size(); ii++)
{
Blob &inpBlob = *inputs[ii];
Blob &outBlob = outputs[ii];
for (int n = 0; n < inpBlob.num(); n++)
{
for (int g = 0; g < group; g++)
{
im2col(inpBlob, n, g);
Mat kerMat(outGroupCn, ksize, wgtBlob.type(), wgtBlob.ptrRaw(g*outGroupCn));
Mat dstMat(outGroupCn, outH*outW, outBlob.type(), outBlob.ptrRaw(n, g*outGroupCn));
cv::gemm(kerMat, colMat, 1, noArray(), 0, dstMat);
if (bias)
{
int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
int channels_col = channels * kernel_h * kernel_w;
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % kernel_w;
int h_offset = (c / kernel_w) % kernel_h;
int c_im = c / kernel_h / kernel_w;
for (int h = 0; h < height_col; ++h) {
for (int w = 0; w < width_col; ++w) {
int h_pad = h * stride_h - pad_h + h_offset;
int w_pad = w * stride_w - pad_w + w_offset;
if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
data_col[(c * height_col + h) * width_col + w] =
data_im[(c_im * height + h_pad) * width + w_pad];
else
data_col[(c * height_col + h) * width_col + w] = 0;
float *biasPtr = learnedParams[1].ptrf() + g*outGroupCn;
Mat biasMat(outGroupCn, 1, CV_32F, biasPtr);
cv::gemm(biasMat, biasOnesMat, 1, dstMat, 1, dstMat);
}
}
}
}
}
void ConvolutionLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
void ConvolutionLayer::im2col(Blob &inpBlob, int imNum, int cnGroup)
{
uchar *srcPtr = inpBlob.ptrRaw(imNum, cnGroup*inpGroupCn);
if (is1x1())
{
colMat = Mat(ksize, inpBlob.rows()*inpBlob.cols(), inpBlob.type(), srcPtr);
return;
}
if (inpBlob.type() == CV_32F)
im2col_cpu((float *)srcPtr, inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, (float *)colMat.ptr());
if (inpBlob.type() == CV_64F)
im2col_cpu((double*)srcPtr, inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, (double*)colMat.ptr());
}
void ConvolutionLayer::computeInpOutShape(const Blob &inpBlob)
{
CV_Assert(inputs.size() == outputs.size());
inpH = inpBlob.rows();
inpW = inpBlob.cols();
inpCn = inpBlob.channels();
float *srcColPtr = srcColsMat.ptr<float>();
outH = (inpH + 2 * padH - kerH) / strideH + 1;
outW = (inpW + 2 * padW - kerW) / strideW + 1;
outCn = learnedParams[0].num();
topH = outH; topW = outW; topCn = outCn;
}
void DeConvolutionLayer::computeInpOutShape(const Blob &inpBlob)
{
outH = inpBlob.rows();
outW = inpBlob.cols();
outCn = inpBlob.channels();
inpH = strideH * (outH - 1) + kerH - 2 * padH;
inpW = strideW * (outW - 1) + kerW - 2 * padW;
inpCn = learnedParams[0].channels();
topH = inpH; topW = inpW; topCn = inpCn;
}
void DeConvolutionLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
Blob &wghtBlob = learnedParams[0];
for (size_t ii = 0; ii < outputs.size(); ii++)
{
Blob &input = *inputs[ii];
Blob &output = outputs[ii];
int num = input.num();
Blob &convBlob = *inputs[ii];
Blob &decnBlob = outputs[ii];
for (int n = 0; n < num; n++)
for (int n = 0; n < convBlob.num(); n++)
{
for (int g = 0; g < group; g++)
{
float *srcPtr = input.ptr<float>(n, g*groupCn);
im2col_cpu(srcPtr, groupCn, inH, inW, kernelH, kernelW, padH, padW, strideH, strideW, srcColPtr);
Mat dstMat(inpGroupCn, inpH*inpW, decnBlob.type(), decnBlob.ptrRaw(n, g*inpGroupCn));
float *kerPtr = learnedParams[0].ptr<float>(g*groupCnOut);
float *dstPtr = output.ptr<float>(n, g*groupCnOut);
if (is1x1())
colMat = dstMat;
Mat kerMat(groupCnOut, kerSize, CV_32F, kerPtr);
Mat dstMat(groupCnOut, outH*outW, CV_32F, dstPtr);
Mat convMat(outGroupCn, outH*outW, convBlob.type(), convBlob.ptrRaw(n, g*inpGroupCn));
Mat wghtMat(outGroupCn, ksize, wghtBlob.type(), wghtBlob.ptrRaw(g*inpGroupCn));
cv::gemm(wghtMat, convMat, 1, noArray(), 0, colMat, GEMM_1_T);
cv::gemm(kerMat, srcColsMat, 1, noArray(), 0, dstMat);
col2im(dstMat);
if (bias)
{
float *biasPtr = learnedParams[1].ptr<float>() + g*groupCnOut;
Mat biasMat(groupCnOut, 1, CV_32F, biasPtr);
float *biasPtr = learnedParams[1].ptrf() + g*outGroupCn;
Mat biasMat(outGroupCn, 1, CV_32F, biasPtr);
cv::gemm(biasMat, biasOnesMat, 1, dstMat, 1, dstMat);
}
}
@ -153,10 +218,14 @@ namespace dnn
}
}
void ConvolutionLayer::computeOutputShape(int inH, int inW)
void DeConvolutionLayer::col2im(Mat &dstMat)
{
outH = (inH + 2 * padH - kernelH) / strideH + 1;
outW = (inW + 2 * padW - kernelW) / strideW + 1;
if (is1x1()) return;
if (dstMat.type() == CV_32F)
col2im_cpu((float*)colMat.ptr(), inpCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, (float*)dstMat.ptr());
if (dstMat.type() == CV_64F)
col2im_cpu((double*)colMat.ptr(), inpCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, (double*)dstMat.ptr());
}
}
}

@ -1,6 +1,10 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include <math.h>
#include <cmath>
using std::abs;
using std::exp;
using std::tanh;
using std::pow;
namespace cv
{
@ -19,22 +23,34 @@ namespace dnn
{
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
outputs[i] = *inputs[i]; //no data copy
outputs[i].shareFrom(*inputs[i]); //no data copy
}
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == outputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
CV_Assert(inputs[i]->ptr<float>() == outputs[i].ptr<float>());
float *data = outputs[i].ptr<float>();
CV_Assert(inputs[i]->ptrRaw() == outputs[i].ptrRaw() && inputs[i]->type() == outputs[i].type());
size_t size = outputs[i].total();
if (outputs[i].isFloat())
{
float *data = outputs[i].ptrf();
for (size_t j = 0; j < size; j++)
data[j] = func(data[j]);
}
else if (outputs[i].isDouble())
{
double *data = outputs[i].ptr<double>();
for (size_t j = 0; j < size; j++)
data[j] = func(data[j]);
}
else
{
CV_Error(Error::StsNotImplemented, "Only CV_32F and CV_64F blobs are supported");
}
}
}
};
@ -51,9 +67,10 @@ namespace dnn
negative_slope = 0.f;
}
inline float operator()(float x)
template<typename TFloat>
inline TFloat operator()(TFloat x)
{
return (x >= 0) ? x : negative_slope * x;
return (x >= (TFloat)0) ? x : negative_slope * x;
}
};
@ -61,14 +78,70 @@ namespace dnn
{
TanHFunctor(LayerParams&) {}
inline float operator()(float x)
template<typename TFloat>
inline TFloat operator()(TFloat x)
{
return tanh(x);
}
};
struct SigmoidFunctor
{
SigmoidFunctor(LayerParams&) {}
template<typename TFloat>
inline TFloat operator()(TFloat x)
{
return (TFloat)1 / ((TFloat)1 + exp(-x));
}
};
struct AbsValFunctor
{
AbsValFunctor(LayerParams&) {}
template<typename TFloat>
inline TFloat operator()(TFloat x)
{
return abs(x);
}
};
struct PowerFunctor
{
float power, scale, shift;
PowerFunctor(LayerParams &params)
{
power = params.get<float>("power", 1.0f);
scale = params.get<float>("scale", 1.0f);
shift = params.get<float>("shift", 0.0f);
}
template<typename TFloat>
inline TFloat operator()(TFloat x)
{
return pow((TFloat)shift + (TFloat)scale * x, (TFloat)power);
}
};
struct BNLLFunctor
{
BNLLFunctor(LayerParams&) {}
template<typename TFloat>
inline TFloat operator()(TFloat x)
{
return log((TFloat)1 + exp(-abs(x)));
}
};
REGISTER_LAYER_CLASS(ReLU, ElementWiseLayer<ReLUFunctor>)
REGISTER_LAYER_CLASS(TanH, ElementWiseLayer<TanHFunctor>)
REGISTER_LAYER_CLASS(BNLL, ElementWiseLayer<BNLLFunctor>)
REGISTER_LAYER_CLASS(Power, ElementWiseLayer<PowerFunctor>)
REGISTER_LAYER_CLASS(AbsVal, ElementWiseLayer<AbsValFunctor>)
REGISTER_LAYER_CLASS(Sigmoid, ElementWiseLayer<SigmoidFunctor>)
}
}

@ -1,18 +1,20 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include <iostream>
namespace cv
{
namespace dnn
{
//TODO: implement axis number parameter
class FullyConnectedLayer : public Layer
{
bool bias;
int numOutputs;
int axis_, axis;
int inC, inH, inW;
int inSize;
int innerSize;
void reshape(const Blob &inp, Blob &out);
public:
FullyConnectedLayer(LayerParams &params);
@ -28,6 +30,7 @@ namespace dnn
{
numOutputs = params.get<int>("num_output");
bias = params.get<bool>("bias_term", true);
axis_ = params.get<int>("axis", 1);
CV_Assert(params.learnedBlobs.size() >= 1);
CV_Assert(!bias || (params.learnedBlobs.size() >= 2 && (int)params.learnedBlobs[1].total() == numOutputs));
@ -44,41 +47,50 @@ namespace dnn
{
CV_Assert(inputs.size() > 0);
inC = inputs[0]->channels();
inH = inputs[0]->rows();
inW = inputs[0]->cols();
inSize = inC * inH * inW;
axis = inputs[0]->canonicalAxis(axis_);
innerSize = (int)inputs[0]->total(axis);
CV_Assert((size_t)inSize * (size_t)numOutputs == learnedParams[0].total());
CV_Assert((size_t)innerSize * (size_t)numOutputs == learnedParams[0].total());
CV_Assert(learnedParams[0].rows() == numOutputs && learnedParams[0].cols() == innerSize);
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
if (i != 0)
CV_Assert(inputs[i]->channels() == inC && inputs[i]->rows() == inH && inputs[i]->cols() == inW);
CV_Assert(inputs[i]->equalShape(*inputs[0]));
outputs[i].create(inputs[i]->num(), numOutputs, 1, 1);
this->reshape(*inputs[i], outputs[i]);
}
}
void FullyConnectedLayer::reshape(const Blob &inp, Blob &out)
{
BlobShape inpShape = inp.shape();
BlobShape outShape(axis+1, inpShape.ptr());
outShape[axis] = numOutputs;
out.create(outShape, inp.type());
}
void FullyConnectedLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
for (size_t i = 0; i < inputs.size(); i++)
{
int M = inputs[i]->num();
int M = (int)inputs[i]->total(0, axis);
int N = numOutputs;
int K = inSize;
int K = innerSize;
Mat srcMat(M, K, CV_32F, inputs[i]->ptr<float>());
Mat weights(K, N, CV_32F, learnedParams[0].ptr<float>());
Mat dstMat(M, N, CV_32F, outputs[i].ptr<float>());
Mat srcMat(M, K, inputs[i]->type(), inputs[i]->ptrf());
Mat weight(N, K, learnedParams[0].type(), learnedParams[0].ptrf());
Mat dstMat(M, N, outputs[i].type(), outputs[i].ptrf());
cv::gemm(srcMat, weights, 1, noArray(), 0, dstMat);
//important: Caffe stores weights as transposed array
cv::gemm(srcMat, weight, 1, noArray(), 0, dstMat, GEMM_2_T);
if (bias)
{
Mat biasOnesMat = Mat::ones(M, 1, CV_32F);
Mat biasMat(1, N, CV_32F, learnedParams[1].ptr<float>());
Mat biasMat(1, N, CV_32F, learnedParams[1].ptrf());
cv::gemm(biasOnesMat, biasMat, 1, dstMat, 1, dstMat);
}
}

@ -0,0 +1,74 @@
#ifndef __OPENCV_DNN_LAYERS_IM2COL_HPP__
#define __OPENCV_DNN_LAYERS_IM2COL_HPP__
namespace cv
{
namespace dnn
{
template <typename Dtype>
void im2col_cpu(const Dtype* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
Dtype* data_col)
{
int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
int channels_col = channels * kernel_h * kernel_w;
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % kernel_w;
int h_offset = (c / kernel_w) % kernel_h;
int c_im = c / kernel_h / kernel_w;
for (int h = 0; h < height_col; ++h) {
for (int w = 0; w < width_col; ++w) {
int h_pad = h * stride_h - pad_h + h_offset;
int w_pad = w * stride_w - pad_w + w_offset;
if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
data_col[(c * height_col + h) * width_col + w] =
data_im[(c_im * height + h_pad) * width + w_pad];
else
data_col[(c * height_col + h) * width_col + w] = 0;
}
}
}
}
template <typename Dtype>
void col2im_cpu(const Dtype* data_col, const int channels,
const int height, const int width, const int patch_h, const int patch_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
Dtype* data_im)
{
memset(data_im, 0, height * width * channels * sizeof(Dtype));
int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1;
int channels_col = channels * patch_h * patch_w;
for (int c = 0; c < channels_col; ++c)
{
int w_offset = c % patch_w;
int h_offset = (c / patch_w) % patch_h;
int c_im = c / patch_h / patch_w;
for (int h = 0; h < height_col; ++h)
{
for (int w = 0; w < width_col; ++w)
{
int h_pad = h * stride_h - pad_h + h_offset;
int w_pad = w * stride_w - pad_w + w_offset;
if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
data_im[(c_im * height + h_pad) * width + w_pad] +=
data_col[(c * height_col + h) * width_col + w];
}
}
}
}
}
}
#endif

@ -1,6 +1,7 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include <opencv2/imgproc.hpp>
#include <algorithm>
namespace cv
{
@ -45,8 +46,8 @@ namespace dnn
CV_Error(cv::Error::StsBadArg, "Unknown region type \"" + nrmType + "\"");
size = params.get<int>("local_size", 5);
if (size % 2 != 1)
CV_Error(cv::Error::StsBadArg, "LRN layer only supports odd values for local_size");
if (size % 2 != 1 || size <= 0)
CV_Error(cv::Error::StsBadArg, "LRN layer supports only positive odd values for local_size");
alpha = params.get<double>("alpha", 1);
beta = params.get<double>("beta", 0.75);
@ -57,10 +58,10 @@ namespace dnn
CV_Assert(inputs.size() == 1);
outputs.resize(1);
Vec4i shape = inputs[0]->shape();
Vec4i shape = inputs[0]->shape4();
outputs[0].create(shape);
shape[1] = 1; //maybe make shape[0] = 1 too
shape[0] = 1; //maybe make shape[0] = 1 too
bufBlob.create(shape);
}
@ -85,26 +86,37 @@ namespace dnn
void LRNLayer::channelNoramlization(Blob &srcBlob, Blob &dstBlob)
{
CV_DbgAssert(srcBlob.ptrRaw() != dstBlob.ptrRaw());
int num = srcBlob.num();
int channels = srcBlob.channels();
int ksize = (size - 1) / 2;
for (int n = 0; n < num; n++)
{
Mat buf = bufBlob.getMat(n, 0);
Mat accum = dstBlob.getMat(n, 0); //memory saving
Mat accum = dstBlob.getMat(n, channels-1); //trick for memory saving
accum.setTo(0);
for (int cn = 0; cn < std::min(ksize, channels); cn++)
cv::accumulateSquare(srcBlob.getMat(n, cn), accum);
for (int cn = 0; cn < channels; cn++)
{
cv::accumulateSquare(srcBlob.getMat(n, cn), accum);
if (cn + ksize < channels)
{
cv::accumulateSquare(srcBlob.getMat(n, cn + ksize), accum);
}
accum.convertTo(accum, accum.type(), alpha/channels, 1);
cv::pow(accum, beta, accum);
for (int cn = channels - 1; cn >= 0; cn--)
if (cn - ksize - 1 >= 0)
{
cv::divide(srcBlob.getMat(n, cn), accum, dstBlob.getMat(n, cn));
Mat left = srcBlob.getMat(n, cn - ksize - 1);
cv::subtract(accum, left.mul(left), accum); //subtractSquare
}
Mat dst = dstBlob.getMat(n, cn);
accum.convertTo(dst, dst.type(), alpha/size, 1);
cv::pow(dst, beta, dst);
cv::divide(srcBlob.getMat(n, cn), dst, dst);
}
}
}

@ -27,6 +27,7 @@ namespace dnn
void computeOutputShape(int inH, int inW);
void maxPooling(Blob &input, Blob &output);
void avePooling(Blob &input, Blob &output);
public:
PoolingLayer(LayerParams &params);
@ -64,15 +65,15 @@ namespace dnn
{
CV_Assert(inputs.size() > 0);
inH = inputs[0]->cols();
inW = inputs[0]->rows();
inW = inputs[0]->cols();
inH = inputs[0]->rows();
computeOutputShape(inH, inW);
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
CV_Assert(inputs[i]->rows() == inH && inputs[i]->cols() == inW);
outputs[i].create(inputs[i]->num(), inputs[i]->channels(), pooledH, pooledW);
outputs[i].create(BlobShape(inputs[i]->num(), inputs[i]->channels(), pooledH, pooledW));
}
}
@ -85,6 +86,9 @@ namespace dnn
case MAX:
maxPooling(*inputs[ii], outputs[ii]);
break;
case AVE:
avePooling(*inputs[ii], outputs[ii]);
break;
default:
CV_Error(cv::Error::StsNotImplemented, "Not implemented");
break;
@ -100,8 +104,8 @@ namespace dnn
{
for (int c = 0; c < input.channels(); ++c)
{
float *srcData = input.ptr<float>(n, c);
float *dstData = output.ptr<float>(n, c);
float *srcData = input.ptrf(n, c);
float *dstData = output.ptrf(n, c);
for (int ph = 0; ph < pooledH; ++ph)
{
@ -131,6 +135,42 @@ namespace dnn
}
}
void PoolingLayer::avePooling(Blob &input, Blob &output)
{
for (int n = 0; n < input.num(); ++n)
{
for (int c = 0; c < input.channels(); ++c)
{
float *srcData = input.ptrf(n, c);
float *dstData = output.ptrf(n, c);
for (int ph = 0; ph < pooledH; ++ph)
{
for (int pw = 0; pw < pooledW; ++pw)
{
int hstart = ph * strideH - padH;
int wstart = pw * strideH - padH;
int hend = min(hstart + kernelH, inW + padH);
int wend = min(wstart + kernelW, inH + padW);
int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, inH);
wend = min(wend, inW);
dstData[ph * pooledH + pw] = 0.f;
for (int h = hstart; h < hend; ++h)
for (int w = wstart; w < wend; ++w)
dstData[ph * pooledH + pw] += srcData[h * inW + w];
dstData[ph * pooledH + pw] /= pool_size;
}
}
}
}
}
void PoolingLayer::computeOutputShape(int inH, int inW)
{
//Yeah, something strange Caffe scheme-)

@ -0,0 +1,137 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
namespace cv
{
namespace dnn
{
//TODO: Extend cv::Mat::reshape method
class ReshapeLayer : public Layer
{
public:
ReshapeLayer(LayerParams &params);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*>&, std::vector<Blob>&) {}
protected:
BlobShape shapeDesc;
int inAxis, inNumAxes, autoAxisIdx;
void computeOutputShape(int startAxis, int endAxis, BlobShape &inpShape, BlobShape &outShape);
};
ReshapeLayer::ReshapeLayer(LayerParams &params)
{
DictValue paramShape = params.get("dim");
shapeDesc = BlobShape(paramShape.size());
autoAxisIdx = -1;
for (int i = 0; i < paramShape.size(); i++)
{
int dim = paramShape.get<int>(i);
CV_Assert(dim >= -1);
if (dim == -1)
{
if (autoAxisIdx != -1)
CV_Error(Error::StsBadArg, "New shape contains multiple -1 dims");
autoAxisIdx = i;
}
shapeDesc[i] = dim;
}
inAxis = params.get<int>("axis", 0);
inNumAxes = params.get<int>("num_axes", -1);
CV_Assert(inNumAxes >= -1);
}
void ReshapeLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == 1);
outputs.resize(1);
Blob &inpBlob = *inputs[0];
Blob &outBlob = outputs[0];
BlobShape inpShape = inpBlob.shape();
int startAxis = (inAxis >= 0) ? inAxis : inpShape.dims() + 1 + inAxis;
int endAxis = (inNumAxes == -1) ? inpShape.dims() : startAxis + inNumAxes;
CV_Assert(0 <= startAxis && startAxis <= inpShape.dims());
CV_Assert(0 <= endAxis && endAxis <= inpShape.dims());
int newDims = inpShape.dims() - (endAxis - startAxis) + shapeDesc.dims();
BlobShape outShape(newDims);
computeOutputShape(startAxis, endAxis, inpShape, outShape);
outBlob.shareFrom(inpBlob);
outBlob.reshape(outShape);
}
void ReshapeLayer::computeOutputShape(int startAxis, int endAxis, BlobShape &inpShape, BlobShape &outShape)
{
int idx = 0;
for (int i = 0; i < startAxis; i++)
outShape[idx++] = inpShape[i];
for (int i = 0; i < shapeDesc.dims(); i++)
{
if (shapeDesc[i] == 0)
{
int inpAxisIdx = startAxis + i;
if (inpAxisIdx < 0 || inpShape.dims() <= inpAxisIdx)
CV_Error(Error::StsOutOfRange, "new shape contains a 0, but there was no corresponding bottom axis to copy");
outShape[idx++] = inpShape[startAxis + i];
}
else
{
outShape[idx++] = (shapeDesc[i] > 0) ? shapeDesc[i] : 1;
}
}
for (int i = endAxis; i < inpShape.dims(); i++)
outShape[idx++] = inpShape[i];
if (autoAxisIdx >= 0)
{
size_t total = inpShape.total();
size_t curTotal = 1;
for (int i = 0; i < outShape.dims(); i++)
{
if (i != startAxis + autoAxisIdx)
curTotal *= outShape[i];
}
CV_DbgAssert(curTotal <= total && total % curTotal == 0);
outShape[startAxis + autoAxisIdx] = (int)(total / curTotal);
}
if (inpShape.total() != outShape.total())
{
CV_Error(Error::StsBadArg, "Mismatch between input and output blob elements count");
}
}
Ptr<Layer> createFlattenLayer(LayerParams&)
{
LayerParams params;
int shapeDesc[] = {0, -1};
params.set("dim", DictValue::arrayInt(shapeDesc, 2));
return Ptr<Layer>(new ReshapeLayer(params));
}
REGISTER_LAYER_CLASS(Reshape, ReshapeLayer)
REGISTER_LAYER_FUNC(Flatten, createFlattenLayer)
}
}

@ -0,0 +1,103 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
namespace cv
{
namespace dnn
{
class SliceLayer : public Layer
{
public:
SliceLayer(LayerParams &params);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
private:
int inAxis;
std::vector<int> slicePoints;
};
REGISTER_LAYER_CLASS(Slice, SliceLayer)
SliceLayer::SliceLayer(LayerParams &params)
{
inAxis = params.get<int>("axis", 1);
const DictValue &_slicePoints = params.get("slice_point");
slicePoints.resize(_slicePoints.size());
for (int i = 0; i < _slicePoints.size(); i++)
{
slicePoints[i] = _slicePoints.get<int>(i);
CV_Assert(slicePoints[i] > 0 && (i == 0 || slicePoints[i-1] < slicePoints[i]));
}
}
void SliceLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == 1);
const Blob inpBlob = *inputs[0];
int axis = inpBlob.canonicalAxis(inAxis);
int axisSize = inpBlob.size(axis);
BlobShape inpShape = inpBlob.shape();
if (slicePoints.size()) //divide blob with respect to passed parameters
{
std::vector<int> outAxisSize;
int prevSlice = 0;
for (size_t i = 0; i < slicePoints.size(); i++)
{
CV_Assert(prevSlice < slicePoints[i] && slicePoints[i] < axisSize);
outAxisSize.push_back(slicePoints[i] - prevSlice);
prevSlice = slicePoints[i];
}
outAxisSize.push_back(axisSize - prevSlice);
outputs.resize(outAxisSize.size());
for (size_t i = 0; i < outAxisSize.size(); i++)
{
inpShape[axis] = outAxisSize[i];
outputs[i].create(inpShape, inpBlob.type());
}
}
else //divide blob with respect to count of output blobs
{
CV_Assert(outputs.size() > 0 && axisSize % outputs.size() == 0);
int outAxisSize = axisSize / (int)outputs.size();
for (size_t i = 0; i < outputs.size(); i++)
{
inpShape[axis] = outAxisSize;
outputs[i].create(inpShape, inpBlob.type());
}
}
}
void SliceLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
Blob &inpBlob = *inputs[0];
const int axis = inpBlob.canonicalAxis(inAxis);
const Mat& inpMat = inpBlob.getMatRef();
std::vector<Range> ranges(inpBlob.dims(), Range::all());
int sizeStart = 0;
for (size_t i = 0; i < outputs.size(); i++)
{
int sizeEnd = sizeStart + outputs[i].size(axis);
ranges[axis] = Range(sizeStart, sizeEnd);
Mat inpSubMat = inpMat(&ranges[0]);
inpSubMat.copyTo(outputs[i].getMatRef());
sizeStart = sizeEnd;
}
}
}
}

@ -8,9 +8,10 @@ namespace cv
{
namespace dnn
{
//TODO: set default axis number to 1, and add custom shape length in FullyConnected
class SoftMaxLayer : public Layer
{
int axis;
int axis_, axis;
Blob maxAggregator;
public:
@ -25,15 +26,16 @@ namespace dnn
SoftMaxLayer::SoftMaxLayer(LayerParams &params)
{
axis = params.get<int>("axis", 1);
CV_Assert(0 <= axis && axis < 4);
//hotfix!!!
axis_ = params.get<int>("axis", 1);
}
void SoftMaxLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == 1);
axis = inputs[0]->canonicalAxis(axis_);
Vec4i shape = inputs[0]->shape();
BlobShape shape = inputs[0]->shape();
outputs.resize(1);
outputs[0].create(shape);
@ -46,9 +48,9 @@ namespace dnn
Blob &src = *inputs[0];
Blob &dst = outputs[0];
float *srcPtr = src.ptr<float>();
float *dstPtr = dst.ptr<float>();
float *bufPtr = maxAggregator.ptr<float>();
float *srcPtr = src.ptrf();
float *dstPtr = dst.ptrf();
float *bufPtr = maxAggregator.ptrf();
size_t outerSize = src.total(0, axis);
size_t channels = src.size(axis);
@ -85,7 +87,7 @@ namespace dnn
}
}
cv::exp(dst.getMat(), dst.getMat());
cv::exp(dst.getMatRef(), dst.getMatRef());
for (size_t outerDim = 0; outerDim < outerSize; outerDim++)
{

@ -0,0 +1,58 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
namespace cv
{
namespace dnn
{
//TODO: maybe "top_count" param is useless because it can be determined by output connections number?
class SplitLayer : public Layer
{
public:
SplitLayer(LayerParams &params);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
private:
int outputsNum;
};
REGISTER_LAYER_CLASS(Split, SplitLayer)
SplitLayer::SplitLayer(LayerParams &params)
{
if (params.has("top_count"))
{
outputsNum = params.get<int>("top_count");
CV_Assert(outputsNum >= 0);
}
else
{
outputsNum = -1;
}
}
void SplitLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == 1);
if (outputsNum >= 0)
outputs.resize(outputsNum);
for (size_t i = 0; i < outputs.size(); i++)
outputs[i].create(inputs[0]->shape(), inputs[0]->type());
}
void SplitLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
for (size_t i = 0; i < outputs.size(); i++)
inputs[0]->getMatRef().copyTo(outputs[i].getMatRef());
}
}
}

@ -0,0 +1,247 @@
//Copyright (C) 2011 Carl Rogers
//Released under MIT License
//license available in LICENSE file, or at http://www.opensource.org/licenses/mit-license.php
#include"cnpy.h"
#include<complex>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<iomanip>
char cnpy::BigEndianTest() {
unsigned char x[] = {1,0};
short y = *(short*) x;
return y == 1 ? '<' : '>';
}
char cnpy::map_type(const std::type_info& t)
{
if(t == typeid(float) ) return 'f';
if(t == typeid(double) ) return 'f';
if(t == typeid(long double) ) return 'f';
if(t == typeid(int) ) return 'i';
if(t == typeid(char) ) return 'i';
if(t == typeid(short) ) return 'i';
if(t == typeid(long) ) return 'i';
if(t == typeid(long long) ) return 'i';
if(t == typeid(unsigned char) ) return 'u';
if(t == typeid(unsigned short) ) return 'u';
if(t == typeid(unsigned long) ) return 'u';
if(t == typeid(unsigned long long) ) return 'u';
if(t == typeid(unsigned int) ) return 'u';
if(t == typeid(bool) ) return 'b';
if(t == typeid(std::complex<float>) ) return 'c';
if(t == typeid(std::complex<double>) ) return 'c';
if(t == typeid(std::complex<long double>) ) return 'c';
else return '?';
}
template<> std::vector<char>& cnpy::operator+=(std::vector<char>& lhs, const std::string rhs) {
lhs.insert(lhs.end(),rhs.begin(),rhs.end());
return lhs;
}
template<> std::vector<char>& cnpy::operator+=(std::vector<char>& lhs, const char* rhs) {
//write in little endian
size_t len = strlen(rhs);
lhs.reserve(len);
for(size_t byte = 0; byte < len; byte++) {
lhs.push_back(rhs[byte]);
}
return lhs;
}
void cnpy::parse_npy_header(FILE* fp, unsigned int& word_size, unsigned int*& shape, unsigned int& ndims, bool& fortran_order) {
char buffer[256];
size_t res = fread(buffer,sizeof(char),11,fp);
if(res != 11)
throw std::runtime_error("parse_npy_header: failed fread");
std::string header = fgets(buffer,256,fp);
assert(header[header.size()-1] == '\n');
size_t loc1, loc2;
//fortran order
loc1 = header.find("fortran_order")+16;
fortran_order = (header.substr(loc1,5) == "True" ? true : false);
//shape
loc1 = header.find("(");
loc2 = header.find(")");
std::string str_shape = header.substr(loc1+1,loc2-loc1-1);
if(str_shape[str_shape.size()-1] == ',') ndims = 1;
else ndims = (unsigned)std::count(str_shape.begin(),str_shape.end(),',')+1;
shape = new unsigned int[ndims];
for(unsigned int i = 0;i < ndims;i++) {
loc1 = str_shape.find(",");
shape[i] = atoi(str_shape.substr(0,loc1).c_str());
str_shape = str_shape.substr(loc1+1);
}
//endian, word size, data type
//byte order code | stands for not applicable.
//not sure when this applies except for byte array
loc1 = header.find("descr")+9;
bool littleEndian = (header[loc1] == '<' || header[loc1] == '|' ? true : false);
assert(littleEndian);
//char type = header[loc1+1];
//assert(type == map_type(T));
std::string str_ws = header.substr(loc1+2);
loc2 = str_ws.find("'");
word_size = atoi(str_ws.substr(0,loc2).c_str());
}
void cnpy::parse_zip_footer(FILE* fp, unsigned short& nrecs, unsigned int& global_header_size, unsigned int& global_header_offset)
{
std::vector<char> footer(22);
fseek(fp,-22,SEEK_END);
size_t res = fread(&footer[0],sizeof(char),22,fp);
if(res != 22)
throw std::runtime_error("parse_zip_footer: failed fread");
unsigned short disk_no, disk_start, nrecs_on_disk, comment_len;
disk_no = *(unsigned short*) &footer[4];
disk_start = *(unsigned short*) &footer[6];
nrecs_on_disk = *(unsigned short*) &footer[8];
nrecs = *(unsigned short*) &footer[10];
global_header_size = *(unsigned int*) &footer[12];
global_header_offset = *(unsigned int*) &footer[16];
comment_len = *(unsigned short*) &footer[20];
assert(disk_no == 0);
assert(disk_start == 0);
assert(nrecs_on_disk == nrecs);
assert(comment_len == 0);
}
cnpy::NpyArray load_the_npy_file(FILE* fp) {
unsigned int* shape;
unsigned int ndims, word_size;
bool fortran_order;
cnpy::parse_npy_header(fp,word_size,shape,ndims,fortran_order);
unsigned long long size = 1; //long long so no overflow when multiplying by word_size
for(unsigned int i = 0;i < ndims;i++) size *= shape[i];
cnpy::NpyArray arr;
arr.word_size = word_size;
arr.shape = std::vector<unsigned int>(shape,shape+ndims);
delete[] shape;
arr.data = new char[size*word_size];
arr.fortran_order = fortran_order;
size_t nread = fread(arr.data,word_size,size,fp);
if(nread != size)
throw std::runtime_error("load_the_npy_file: failed fread");
return arr;
}
cnpy::npz_t cnpy::npz_load(std::string fname) {
FILE* fp = fopen(fname.c_str(),"rb");
if(!fp) printf("npz_load: Error! Unable to open file %s!\n",fname.c_str());
assert(fp);
cnpy::npz_t arrays;
while(1) {
std::vector<char> local_header(30);
size_t headerres = fread(&local_header[0],sizeof(char),30,fp);
if(headerres != 30)
throw std::runtime_error("npz_load: failed fread");
//if we've reached the global header, stop reading
if(local_header[2] != 0x03 || local_header[3] != 0x04) break;
//read in the variable name
unsigned short name_len = *(unsigned short*) &local_header[26];
std::string varname(name_len,' ');
size_t vname_res = fread(&varname[0],sizeof(char),name_len,fp);
if(vname_res != name_len)
throw std::runtime_error("npz_load: failed fread");
//erase the lagging .npy
varname.erase(varname.end()-4,varname.end());
//read in the extra field
unsigned short extra_field_len = *(unsigned short*) &local_header[28];
if(extra_field_len > 0) {
std::vector<char> buff(extra_field_len);
size_t efield_res = fread(&buff[0],sizeof(char),extra_field_len,fp);
if(efield_res != extra_field_len)
throw std::runtime_error("npz_load: failed fread");
}
arrays[varname] = load_the_npy_file(fp);
}
fclose(fp);
return arrays;
}
cnpy::NpyArray cnpy::npz_load(std::string fname, std::string varname) {
FILE* fp = fopen(fname.c_str(),"rb");
if(!fp) {
printf("npz_load: Error! Unable to open file %s!\n",fname.c_str());
abort();
}
while(1) {
std::vector<char> local_header(30);
size_t header_res = fread(&local_header[0],sizeof(char),30,fp);
if(header_res != 30)
throw std::runtime_error("npz_load: failed fread");
//if we've reached the global header, stop reading
if(local_header[2] != 0x03 || local_header[3] != 0x04) break;
//read in the variable name
unsigned short name_len = *(unsigned short*) &local_header[26];
std::string vname(name_len,' ');
size_t vname_res = fread(&vname[0],sizeof(char),name_len,fp);
if(vname_res != name_len)
throw std::runtime_error("npz_load: failed fread");
vname.erase(vname.end()-4,vname.end()); //erase the lagging .npy
//read in the extra field
unsigned short extra_field_len = *(unsigned short*) &local_header[28];
fseek(fp,extra_field_len,SEEK_CUR); //skip past the extra field
if(vname == varname) {
NpyArray array = load_the_npy_file(fp);
fclose(fp);
return array;
}
else {
//skip past the data
unsigned int size = *(unsigned int*) &local_header[22];
fseek(fp,size,SEEK_CUR);
}
}
fclose(fp);
printf("npz_load: Error! Variable name %s not found in %s!\n",varname.c_str(),fname.c_str());
abort();
}
cnpy::NpyArray cnpy::npy_load(std::string fname) {
FILE* fp = fopen(fname.c_str(), "rb");
if(!fp) {
printf("npy_load: Error! Unable to open file %s!\n",fname.c_str());
abort();
}
NpyArray arr = load_the_npy_file(fp);
fclose(fp);
return arr;
}

@ -0,0 +1,247 @@
//Copyright (C) 2011 Carl Rogers
//Released under MIT License
//license available in LICENSE file, or at http://www.opensource.org/licenses/mit-license.php
#ifndef LIBCNPY_H_
#define LIBCNPY_H_
#include<string>
#include<stdexcept>
#include<sstream>
#include<vector>
#include<cstdio>
#include<typeinfo>
#include<iostream>
#include<cassert>
#include<map>
#if defined(HAVE_ZLIB) && HAVE_ZLIB
#include<zlib.h>
#endif
namespace cnpy {
struct NpyArray {
char* data;
std::vector<unsigned int> shape;
unsigned int word_size;
bool fortran_order;
void destruct() {delete[] data;}
};
struct npz_t : public std::map<std::string, NpyArray>
{
void destruct()
{
npz_t::iterator it = this->begin();
for(; it != this->end(); ++it) (*it).second.destruct();
}
};
char BigEndianTest();
char map_type(const std::type_info& t);
template<typename T> std::vector<char> create_npy_header(const T* data, const unsigned int* shape, const unsigned int ndims);
void parse_npy_header(FILE* fp,unsigned int& word_size, unsigned int*& shape, unsigned int& ndims, bool& fortran_order);
void parse_zip_footer(FILE* fp, unsigned short& nrecs, unsigned int& global_header_size, unsigned int& global_header_offset);
npz_t npz_load(std::string fname);
NpyArray npz_load(std::string fname, std::string varname);
NpyArray npy_load(std::string fname);
template<typename T> std::vector<char>& operator+=(std::vector<char>& lhs, const T rhs) {
//write in little endian
for(char byte = 0; byte < sizeof(T); byte++) {
char val = *((char*)&rhs+byte);
lhs.push_back(val);
}
return lhs;
}
template<> std::vector<char>& operator+=(std::vector<char>& lhs, const std::string rhs);
template<> std::vector<char>& operator+=(std::vector<char>& lhs, const char* rhs);
template<typename T> std::string tostring(T i, int = 0, char = ' ') {
std::stringstream s;
s << i;
return s.str();
}
template<typename T> void npy_save(std::string fname, const T* data, const unsigned int* shape, const unsigned int ndims, std::string mode = "w") {
FILE* fp = NULL;
if(mode == "a") fp = fopen(fname.c_str(),"r+b");
if(fp) {
//file exists. we need to append to it. read the header, modify the array size
unsigned int word_size, tmp_dims;
unsigned int* tmp_shape = 0;
bool fortran_order;
parse_npy_header(fp,word_size,tmp_shape,tmp_dims,fortran_order);
assert(!fortran_order);
if(word_size != sizeof(T)) {
std::cout<<"libnpy error: "<<fname<<" has word size "<<word_size<<" but npy_save appending data sized "<<sizeof(T)<<"\n";
assert( word_size == sizeof(T) );
}
if(tmp_dims != ndims) {
std::cout<<"libnpy error: npy_save attempting to append misdimensioned data to "<<fname<<"\n";
assert(tmp_dims == ndims);
}
for(unsigned i = 1; i < ndims; i++) {
if(shape[i] != tmp_shape[i]) {
std::cout<<"libnpy error: npy_save attempting to append misshaped data to "<<fname<<"\n";
assert(shape[i] == tmp_shape[i]);
}
}
tmp_shape[0] += shape[0];
fseek(fp,0,SEEK_SET);
std::vector<char> header = create_npy_header(data,tmp_shape,ndims);
fwrite(&header[0],sizeof(char),header.size(),fp);
fseek(fp,0,SEEK_END);
delete[] tmp_shape;
}
else {
fp = fopen(fname.c_str(),"wb");
std::vector<char> header = create_npy_header(data,shape,ndims);
fwrite(&header[0],sizeof(char),header.size(),fp);
}
unsigned int nels = 1;
for(unsigned i = 0;i < ndims;i++) nels *= shape[i];
fwrite(data,sizeof(T),nels,fp);
fclose(fp);
}
template<typename T> void npz_save(std::string zipname, std::string fname, const T* data, const unsigned int* shape, const unsigned int ndims, std::string mode = "w")
{
//first, append a .npy to the fname
fname += ".npy";
//now, on with the show
FILE* fp = NULL;
unsigned short nrecs = 0;
unsigned int global_header_offset = 0;
std::vector<char> global_header;
if(mode == "a") fp = fopen(zipname.c_str(),"r+b");
if(fp) {
//zip file exists. we need to add a new npy file to it.
//first read the footer. this gives us the offset and size of the global header
//then read and store the global header.
//below, we will write the the new data at the start of the global header then append the global header and footer below it
unsigned int global_header_size;
parse_zip_footer(fp,nrecs,global_header_size,global_header_offset);
fseek(fp,global_header_offset,SEEK_SET);
global_header.resize(global_header_size);
size_t res = fread(&global_header[0],sizeof(char),global_header_size,fp);
if(res != global_header_size){
throw std::runtime_error("npz_save: header read error while adding to existing zip");
}
fseek(fp,global_header_offset,SEEK_SET);
}
else {
fp = fopen(zipname.c_str(),"wb");
}
std::vector<char> npy_header = create_npy_header(data,shape,ndims);
unsigned long nels = 1;
for (int m=0; m<ndims; m++ ) nels *= shape[m];
int nbytes = nels*sizeof(T) + npy_header.size();
//get the CRC of the data to be added
#if defined(HAVE_ZLIB) && HAVE_ZLIB
unsigned int crc = crc32(0L,(unsigned char*)&npy_header[0],npy_header.size());
crc = crc32(crc,(unsigned char*)data,nels*sizeof(T));
#else
unsigned int crc = 0;
#endif
//build the local header
std::vector<char> local_header;
local_header += "PK"; //first part of sig
local_header += (unsigned short) 0x0403; //second part of sig
local_header += (unsigned short) 20; //min version to extract
local_header += (unsigned short) 0; //general purpose bit flag
local_header += (unsigned short) 0; //compression method
local_header += (unsigned short) 0; //file last mod time
local_header += (unsigned short) 0; //file last mod date
local_header += (unsigned int) crc; //crc
local_header += (unsigned int) nbytes; //compressed size
local_header += (unsigned int) nbytes; //uncompressed size
local_header += (unsigned short) fname.size(); //fname length
local_header += (unsigned short) 0; //extra field length
local_header += fname;
//build global header
global_header += "PK"; //first part of sig
global_header += (unsigned short) 0x0201; //second part of sig
global_header += (unsigned short) 20; //version made by
global_header.insert(global_header.end(),local_header.begin()+4,local_header.begin()+30);
global_header += (unsigned short) 0; //file comment length
global_header += (unsigned short) 0; //disk number where file starts
global_header += (unsigned short) 0; //internal file attributes
global_header += (unsigned int) 0; //external file attributes
global_header += (unsigned int) global_header_offset; //relative offset of local file header, since it begins where the global header used to begin
global_header += fname;
//build footer
std::vector<char> footer;
footer += "PK"; //first part of sig
footer += (unsigned short) 0x0605; //second part of sig
footer += (unsigned short) 0; //number of this disk
footer += (unsigned short) 0; //disk where footer starts
footer += (unsigned short) (nrecs+1); //number of records on this disk
footer += (unsigned short) (nrecs+1); //total number of records
footer += (unsigned int) global_header.size(); //nbytes of global headers
footer += (unsigned int) (global_header_offset + nbytes + local_header.size()); //offset of start of global headers, since global header now starts after newly written array
footer += (unsigned short) 0; //zip file comment length
//write everything
fwrite(&local_header[0],sizeof(char),local_header.size(),fp);
fwrite(&npy_header[0],sizeof(char),npy_header.size(),fp);
fwrite(data,sizeof(T),nels,fp);
fwrite(&global_header[0],sizeof(char),global_header.size(),fp);
fwrite(&footer[0],sizeof(char),footer.size(),fp);
fclose(fp);
}
template<typename T> std::vector<char> create_npy_header(const T*, const unsigned int* shape, const unsigned int ndims) {
std::vector<char> dict;
dict += "{'descr': '";
dict += BigEndianTest();
dict += map_type(typeid(T));
dict += tostring(sizeof(T));
dict += "', 'fortran_order': False, 'shape': (";
dict += tostring(shape[0]);
for(unsigned i = 1;i < ndims;i++) {
dict += ", ";
dict += tostring(shape[i]);
}
if(ndims == 1) dict += ",";
dict += "), }";
//pad with spaces so that preamble+dict is modulo 16 bytes. preamble is 10 bytes. dict needs to end with \n
int remainder = 16 - (10 + dict.size()) % 16;
dict.insert(dict.end(),remainder,' ');
dict.back() = '\n';
std::vector<char> header;
header += (unsigned char) 0x93;
header += "NUMPY";
header += (char) 0x01; //major version of numpy format
header += (char) 0x00; //minor version of numpy format
header += (unsigned short) dict.size();
header.insert(header.end(),dict.begin(),dict.end());
return header;
}
}
#endif

@ -0,0 +1,24 @@
#ifndef __OPENCV_DNN_TEST_NPY_BLOB_HPP__
#define __OPENCV_DNN_TEST_NPY_BLOB_HPP__
#include "test_precomp.hpp"
#include "cnpy.h"
inline cv::dnn::Blob blobFromNPY(const cv::String &path)
{
cnpy::NpyArray npyBlob = cnpy::npy_load(path.c_str());
cv::dnn::BlobShape shape((int)npyBlob.shape.size(), (int*)&npyBlob.shape[0]);
cv::dnn::Blob blob;
blob.fill(shape, CV_32F, npyBlob.data);
npyBlob.destruct();
return blob;
}
inline void saveBlobToNPY(cv::dnn::Blob &blob, const cv::String &path)
{
cv::Vec4i shape = blob.shape4();
cnpy::npy_save(path.c_str(), blob.ptr<float>(), (unsigned*)&shape[0], 4);
}
#endif

@ -0,0 +1,40 @@
#include "test_precomp.hpp"
#include "npy_blob.hpp"
namespace cvtest
{
using namespace std;
using namespace testing;
using namespace cv;
using namespace cv::dnn;
template<typename TString>
static std::string getTestFile(TString filename)
{
return (getOpenCVExtraDir() + "/dnn/") + filename;
}
TEST(Reproducibility_AlexNet, Accuracy)
{
Net net;
{
Ptr<Importer> importer = createCaffeImporter(getTestFile("bvlc_alexnet.prototxt"), getTestFile("bvlc_alexnet.caffemodel"));
ASSERT_TRUE(importer != NULL);
importer->populateNet(net);
}
std::vector<Mat> inpMats;
inpMats.push_back( imread(getTestFile("alexnet_0.png")) );
inpMats.push_back( imread(getTestFile("alexnet_1.png")) );
ASSERT_TRUE(!inpMats[0].empty() && !inpMats[1].empty());
net.setBlob(".data", Blob(inpMats));
net.forward();
Blob out = net.getBlob("prob");
Blob ref = blobFromNPY(getTestFile("alexnet.npy"));
normAssert(ref, out, "prob");
}
}

@ -4,7 +4,6 @@ namespace cvtest
{
using namespace std;
using namespace std::tr1;
using namespace testing;
using namespace cv;
using namespace cv::dnn;
@ -20,43 +19,24 @@ static std::string getTestFile(TStr filename)
return (getOpenCVExtraDir() + "/dnn/") + filename;
}
TEST(ReadCaffePrototxt_gtsrb, Accuracy)
TEST(ReadCaffe_GTSRB, Accuracy)
{
Ptr<Importer> importer = createCaffeImporter(getTestFile("gtsrb.prototxt"), getTestFile("gtsrb_iter_36000.caffemodel"));
Net net;
{
Ptr<Importer> importer = createCaffeImporter(getTestFile("gtsrb.prototxt"), "");
ASSERT_TRUE(importer != NULL);
importer->populateNet(net);
}
}
Mat img = imread(getTestFile("sign_50.ppm"));
CV_Assert(!img.empty());
img.convertTo(img, CV_32F, 1.0 / 255);
resize(img, img, cv::Size(48, 48));
Blob imgBlob(img);
net.setBlob("input", imgBlob);
net.forward();
Blob res = net.getBlob("loss");
for (int n = 0; n < 1; n++)
TEST(ReadCaffe_GoogLeNet, Accuracy)
{
Net net;
{
Mat slice = Mat(res.channels() * res.rows(), res.cols(), CV_32F, res.ptr<float>(n));
double maxv;
std::vector<int> maxIdx;
minMaxLoc(slice, NULL, &maxv, NULL, &maxIdx);
int bestClass = maxIdx[0];
std::cout << "Best class: #" << bestClass << std::endl;
//imwrite(getTestFile("vis.png"), slice*(255.0 / maxv));
Ptr<Importer> importer = createCaffeImporter(getTestFile("bvlc_googlenet.prototxt"), "");
ASSERT_TRUE(importer != NULL);
importer->populateNet(net);
}
}
//TEST(ReadCaffePrototxt_GoogleNet, Accuracy)
//{
// Ptr<Importer> importer = createCaffeImporter(getOpenCVExtraDir() + "/dnn/googlenet_deploy.prototxt", "");
// Net net;
// importer->populateNet(net);
// net.forward();
//}
}

@ -0,0 +1,24 @@
#ifndef __OPENCV_TEST_COMMON_HPP__
#define __OPENCV_TEST_COMMON_HPP__
inline const std::string &getOpenCVExtraDir()
{
return cvtest::TS::ptr()->get_data_path();
}
inline void normAssert(cv::InputArray ref, cv::InputArray get, const char *comment = "")
{
double normL1 = cvtest::norm(ref, get, cv::NORM_L1)/ ref.getMat().total();
EXPECT_NEAR(normL1, 0, 0.0001) << comment;
double normInf = cvtest::norm(ref, get, cv::NORM_INF);
EXPECT_NEAR(normInf, 0, 0.001) << comment;
}
inline void normAssert(cv::dnn::Blob &ref, cv::dnn::Blob &test, const char *comment = "")
{
EXPECT_EQ(ref.shape(), test.shape());
normAssert(ref.getMatRef(), test.getMatRef(), comment);
}
#endif

@ -0,0 +1,41 @@
#include "test_precomp.hpp"
#include "npy_blob.hpp"
namespace cvtest
{
using namespace std;
using namespace testing;
using namespace cv;
using namespace cv::dnn;
template<typename TString>
static std::string getTestFile(TString filename)
{
return (getOpenCVExtraDir() + "/dnn/") + filename;
}
TEST(Reproducibility_GoogLeNet, Accuracy)
{
Net net;
{
Ptr<Importer> importer = createCaffeImporter(getTestFile("bvlc_googlenet.prototxt"), getTestFile("bvlc_googlenet.caffemodel"));
ASSERT_TRUE(importer != NULL);
importer->populateNet(net);
}
std::vector<Mat> inpMats;
inpMats.push_back( imread(getTestFile("googlenet_0.jpg")) );
inpMats.push_back( imread(getTestFile("googlenet_1.jpg")) );
ASSERT_TRUE(!inpMats[0].empty() && !inpMats[1].empty());
Blob inp(inpMats);
net.setBlob(".data", inp);
net.forward();
Blob out = net.getBlob("prob");
Blob ref = blobFromNPY(getTestFile("googlenet_prob.npy"));
normAssert(out, ref);
}
}

@ -0,0 +1,103 @@
#include "test_precomp.hpp"
#include <iostream>
#include "npy_blob.hpp"
namespace cvtest
{
using namespace std;
using namespace testing;
using namespace cv;
using namespace cv::dnn;
static std::string getOpenCVExtraDir()
{
return cvtest::TS::ptr()->get_data_path();
}
template<typename TStr>
static std::string getTestFile(TStr filename)
{
return (getOpenCVExtraDir() + "/dnn/layers/") + filename;
}
template<typename T, int n>
bool isEqual(const cv::Vec<T, n> &l, const cv::Vec<T, n> &r)
{
for (int i = 0; i < n; i++)
{
if (l[i] != r[i])
return false;
}
return true;
}
static void testLayer(String proto, String caffemodel = String())
{
Blob inp = blobFromNPY(getTestFile("blob.npy"));
Blob ref = blobFromNPY(getTestFile(proto + ".caffe.npy"));
Net net;
{
Ptr<Importer> importer = createCaffeImporter(getTestFile(proto), caffemodel);
ASSERT_TRUE(importer != NULL);
importer->populateNet(net);
}
net.setBlob(".input", inp);
net.forward();
Blob out = net.getBlob("output");
EXPECT_EQ(ref.shape(), out.shape());
Mat &mRef = ref.getMatRef();
Mat &mOut = out.getMatRef();
double normL1 = cvtest::norm(mRef, mOut, NORM_L1) / ref.total();
EXPECT_LE(normL1, 0.0001);
double normInf = cvtest::norm(mRef, mOut, NORM_INF);
EXPECT_LE(normInf, 0.0001);
}
TEST(Layer_Softmax_Test, Accuracy)
{
testLayer("softmax.prototxt");
}
TEST(Layer_LRN_spatial_Test, Accuracy)
{
testLayer("lrn_spatial.prototxt");
}
TEST(Layer_LRN_channels_Test, Accuracy)
{
testLayer("lrn_channels.prototxt");
}
TEST(Layer_Reshape_Split_Slice_Test, Accuracy)
{
Net net;
{
Ptr<Importer> importer = createCaffeImporter(getTestFile("reshape_and_slice_routines.prototxt"));
ASSERT_TRUE(importer != NULL);
importer->populateNet(net);
}
BlobShape shape = BlobShape(Vec2i(6, 12));
Mat1f inputMat(shape[0], shape[1]);
RNG rng(0);
rng.fill(inputMat, RNG::UNIFORM, -1, 1);
Blob input(inputMat);
input.reshape(shape);
net.setBlob(".input", input);
net.forward();
Blob output = net.getBlob("output");
input.fill(shape, CV_32F, inputMat.data);
normAssert(input, output);
}
}

@ -16,5 +16,6 @@
#include "opencv2/ts.hpp"
#include <opencv2/ts/ts_perf.hpp>
#include <opencv2/core/utility.hpp>
#include "test_common.hpp"
#endif

@ -4,7 +4,6 @@ input_dim: 10
input_dim: 3
input_dim: 227
input_dim: 227
layer {
name: "conv1"
type: "Convolution"
@ -52,7 +51,6 @@ layer {
stride: 2
}
}
layer {
name: "conv2"
type: "Convolution"
@ -101,7 +99,6 @@ layer {
stride: 2
}
}
layer {
name: "conv3"
type: "Convolution"
@ -127,7 +124,6 @@ layer {
bottom: "conv3"
top: "conv3"
}
layer {
name: "conv4"
type: "Convolution"
@ -154,7 +150,6 @@ layer {
bottom: "conv4"
top: "conv4"
}
layer {
name: "conv5"
type: "Convolution"
@ -219,7 +214,7 @@ layer {
name: "drop6"
type: "Dropout"
bottom: "fc6"
top: "fc61"
top: "fc6"
dropout_param {
dropout_ratio: 0.5
}
@ -227,7 +222,7 @@ layer {
layer {
name: "fc7"
type: "InnerProduct"
bottom: "fc61"
bottom: "fc6"
top: "fc7"
param {
lr_mult: 1
@ -251,7 +246,7 @@ layer {
name: "drop7"
type: "Dropout"
bottom: "fc7"
top: "fc71"
top: "fc7"
dropout_param {
dropout_ratio: 0.5
}
@ -259,7 +254,7 @@ layer {
layer {
name: "fc8"
type: "InnerProduct"
bottom: "fc71"
bottom: "fc7"
top: "fc8"
param {
lr_mult: 1

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Binary file not shown.

@ -0,0 +1,21 @@
name: "test_LRN_channels"
input: "input"
input_dim: 2
input_dim: 6
input_dim: 75
input_dim: 113
layer {
type: "LRN"
lrn_param {
norm_region: ACROSS_CHANNELS;
local_size: 5
alpha: 1.1
beta: 0.75
}
name: "output"
bottom: "input"
top: "output"
}

@ -0,0 +1,22 @@
name: "test_LRN_spatial"
input: "input"
input_dim: 2
input_dim: 6
input_dim: 75
input_dim: 113
layer {
type: "LRN"
lrn_param {
norm_region: WITHIN_CHANNEL;
local_size: 5
alpha: 0.9
beta: 0.75
}
name: "output"
bottom: "input"
top: "output"
}

@ -0,0 +1,77 @@
name: "test_reshape_splice_split"
input: "input"
layer{
type: "Split"
name: "dummy_split"
bottom: "input"
top: "dummy_split_0"
top: "dummy_split_1"
}
layer{
type: "Slice"
name: "dummy_slice_0"
bottom: "dummy_split_0"
slice_param{
slice_point: 1
slice_point: 2
}
top: "dummy_slice_0_0"
top: "dummy_slice_0_1"
top: "dummy_slice_0_2"
}
layer{
type: "Slice"
name: "dummy_slice_1"
bottom: "dummy_split_1"
slice_param{
slice_point: 1
slice_point: 2
}
top: "dummy_slice_1_0"
top: "dummy_slice_1_1"
top: "dummy_slice_1_2"
}
layer{
type: "Sigmoid"
name: "alter_sliced_split"
bottom: "dummy_slice_1_2"
top: "dummy_slice_1_2"
}
layer{
type: "Concat"
name: "dummy_concat"
bottom: "dummy_slice_0_0"
bottom: "dummy_slice_1_1"
bottom: "dummy_slice_0_2"
top: "dummy_concat"
}
layer{
type: "Reshape"
name: "dummy_reshape"
bottom: "dummy_concat"
reshape_param{
shape{
dim: 0
dim: 1
dim: 1
dim: -1
dim: 1
}
axis: 1
num_axes: 1
}
top: "dummy_reshape"
}
layer{
type: "Flatten"
name: "dummy_reshape_undo"
bottom: "dummy_reshape"
top: "dummy_reshape_undo"
}
layer{
type: "Split"
name: "output"
bottom: "dummy_reshape_undo"
top: "output"
}

@ -0,0 +1,15 @@
name: "test_Softmax"
input: "input"
input_dim: 2
input_dim: 5
input_dim: 75
input_dim: 113
layer {
type: "Softmax"
name: "output"
bottom: "input"
top: "output"
}

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save