[Bomb commit] Implemented 4 main layers. Changes in API. Added worked classification example from GTSRB into tests.

pull/265/head
Vitaliy Lyudvichenko 10 years ago
parent 194271df50
commit ed1c5691f4
  1. 60
      modules/dnn/include/opencv2/dnn/dict.hpp
  2. 21
      modules/dnn/include/opencv2/dnn/dnn.hpp
  3. 18
      modules/dnn/include/opencv2/dnn/dnn.inl.hpp
  4. 10
      modules/dnn/src/caffe_importer.cpp
  5. 154
      modules/dnn/src/dnn.cpp
  6. 266
      modules/dnn/src/layers.cpp
  7. 56
      modules/dnn/src/layers.hpp
  8. 25
      modules/dnn/test/test_caffe_importer.cpp
  9. 1
      modules/dnn/test/test_precomp.hpp

@ -2,6 +2,7 @@
#define __OPENCV_DNN_DICT_HPP__
#include <opencv2/core.hpp>
#include <map>
namespace cv
{
@ -14,8 +15,7 @@ struct DictValue
union
{
int i;
unsigned u;
int64 i;
double d;
bool b;
String *s;
@ -23,10 +23,11 @@ struct DictValue
DictValue(const DictValue &r);
DictValue(int p = 0) : type(cv::Param::INT), i(p) {}
DictValue(unsigned p) : type(cv::Param::UNSIGNED_INT), u(p) {}
DictValue(unsigned p) : type(cv::Param::INT), i(p) {}
DictValue(double p) : type(cv::Param::REAL), d(p) {}
DictValue(bool p) : type(cv::Param::BOOLEAN), b(p) {}
DictValue(const String &p) : type(cv::Param::STRING), s(new String(p)) {}
DictValue(const char *str) : type(cv::Param::STRING), s(new String(str)) {}
template<typename T>
T get() const;
@ -62,15 +63,16 @@ public:
}
template <typename T>
const T &get(const String &name) const
T get(const String &name) const
{
_Dict::const_iterator i = dict.find(name);
CV_Assert(i != dict.end());
if (i == dict.end())
CV_Error(cv::Error::StsBadArg, "Required argument \"" + name + "\" not found into dictionary");
return i->second.get<T>();
}
template <typename T>
const T &get(const String &name, const T &default_value) const
T get(const String &name, const T &default_value) const
{
_Dict::const_iterator i = dict.find(name);
@ -92,51 +94,73 @@ public:
return value;
}
inline void print()
{
for (_Dict::const_iterator i = dict.begin(); i != dict.end(); i++)
{
std::cout << i->first << std::endl;
}
}
};
template<>
inline int DictValue::get<int>() const
{
CV_Assert(type == cv::ParamType<int>::type || type == cv::ParamType<unsigned>::type && (int)u >= 0);
return i;
CV_Assert(type == cv::Param::INT);
return (int)i;
}
template<>
inline unsigned DictValue::get<unsigned>() const
{
CV_Assert(type == cv::ParamType<unsigned>::type || type == cv::ParamType<int>::type && i >= 0);
return u;
CV_Assert(type == cv::Param::INT);
return (unsigned)i;
}
template<>
inline double DictValue::get<double>() const
{
CV_Assert(type == cv::ParamType<double>::type);
return d;
if (type == cv::Param::FLOAT)
return d;
else if (type == cv::Param::INT)
return i;
else
{
CV_Assert(type == cv::Param::FLOAT || type == cv::Param::INT);
return 0;
}
}
template<>
inline float DictValue::get<float>() const
{
CV_Assert(type == cv::ParamType<double>::type);
return (float)d;
if (type == cv::Param::FLOAT)
return (float)d;
else if (type == cv::Param::INT)
return (float)i;
else
{
CV_Assert(type == cv::Param::FLOAT || type == cv::Param::INT);
return (float)0;
}
}
template<>
inline bool DictValue::get<bool>() const
{
if (type == cv::ParamType<bool>::type)
if (type == cv::Param::BOOLEAN)
{
return b;
}
else if (type == cv::ParamType<int>::type || type == cv::ParamType<unsigned>::type)
else if (type == cv::Param::INT)
{
return i;
}
else
{
CV_Assert(type == cv::ParamType<bool>::type || type == cv::ParamType<int>::type || type == cv::ParamType<unsigned>::type);
CV_Assert(type == cv::Param::BOOLEAN || type == cv::Param::INT);
return 0;
}
}
@ -144,7 +168,7 @@ inline bool DictValue::get<bool>() const
template<>
inline String DictValue::get<String>() const
{
CV_Assert(type == cv::ParamType<String>::type);
CV_Assert(type == cv::Param::STRING);
return *s;
}

@ -1,9 +1,11 @@
#ifndef __OPENCV_DNN_DNN_HPP__
#define __OPENCV_DNN_DNN_HPP__
#include <opencv2/core.hpp>
#include <map>
#include <vector>
#include <iostream>
#include <opencv2/core.hpp>
#include <opencv2/dnn/dict.hpp>
namespace cv
@ -20,12 +22,15 @@ namespace dnn
class CV_EXPORTS Blob
{
public:
Blob();
Blob(InputArray in);
explicit Blob();
explicit Blob(InputArray in);
void create(int ndims, const int *sizes, int type = CV_32F);
void create(Vec4i shape, int type = CV_32F);
void create(int num, int cn, int rows, int cols, int type = CV_32F);
void fill(InputArray in);
void fill(int ndims, const int *sizes, int type, void *data, bool deepCopy = true);
void create(int ndims, const int *sizes, int type = CV_32F);
bool empty() const;
@ -43,6 +48,8 @@ namespace dnn
Vec4i shape() const;
size_t total() const;
uchar *rawPtr(int num = 0, int cn = 0, int row = 0, int col = 0);
template<typename TFloat>
TFloat *ptr(int num = 0, int cn = 0, int row = 0, int col = 0);
@ -109,20 +116,18 @@ namespace dnn
~Net();
int addLayer(const String &name, const String &type, LayerParams &params = LayerParams());
int getLayerId(LayerId layer);
void deleteLayer(LayerId layer);
//each output of each layer can be labeled by unique string label (as in Caffe)
//if label not specified then %layer_name%.%layer_output_id% can be used
void setOutputNames(LayerId layer, const std::vector<String> &outputNames);
void setLayerInputs(const std::vector<String> &outputs, LayerId layer);
void setNetInputs(const std::vector<String> &inputBlobNames);
void connect(BlobId input, BlobId output);
void connect(const std::vector<BlobId> &outputs, const std::vector<BlobId> &inputs);
int getOutputId(LayerId layer, int outputNum);
int getInputId(LayerId layer, int inputNum);
int getLayerId(LayerId layer);
void forward();
void forward(LayerId toLayer);
void forward(LayerId startLayer, LayerId toLayer);

@ -22,11 +22,10 @@ namespace dnn
return m;
}
inline Mat Blob::getMat(int num, int channel)
{
CV_Assert(false);
return Mat();
CV_Assert(0 <= num && num < this->num() && 0 <= channel && channel < this->channels());
return Mat(rows(), cols(), m.type(), this->rawPtr(num, channel));
}
inline int Blob::cols() const
@ -64,18 +63,23 @@ namespace dnn
return Vec4i(m.size.p);
}
inline size_t Blob::total() const
{
CV_DbgAssert(m.dims == 4);
return (size_t) m.size[0] * m.size[1] * m.size[2] * m.size[3];
}
inline uchar* Blob::rawPtr(int num, int cn, int row, int col)
{
CV_DbgAssert(m.dims == 4);
return m.data + num * m.step[0] + cn * m.step[1] + row * m.step[2] + col * m.step[3];
}
template<typename TFloat>
TFloat *ptr(int num = 0, int cn = 0, int row = 0, int col = 0)
TFloat *Blob::ptr(int num, int cn, int row, int col)
{
CV_Assert(m.type() = cv::DataType<TFloat>::type && m.dims == 4);
return (TFloat*) (m.data + num * m.step[0] + cn * m.step[1] + row * m.step[2] + col * m.step[3]);
CV_Assert(m.type() == cv::DataType<TFloat>::type && m.dims == 4);
return (TFloat*) rawPtr(num, cn, row, col);
}
}

@ -188,6 +188,14 @@ namespace
void populateNet(Net dstNet)
{
//setup input layer names
{
std::vector<String> netInputs(net.input_size());
for (int ii = 0; ii < net.input_size(); ii++)
netInputs[ii] = net.input(ii);
dstNet.setNetInputs(netInputs);
}
int layersSize = net.layer_size();
std::vector<String> layersName(layersSize);
@ -210,7 +218,7 @@ namespace
extractLayerParams(layer, layerParams);
extractBinaryLayerParms(layer, layerParams);
int id = dstNet.addLayer(name, type);
int id = dstNet.addLayer(name, type, layerParams);
dstNet.setOutputNames(id, tops);
layersName[li] = name;

@ -24,8 +24,45 @@ Blob::Blob()
Blob::Blob(InputArray in)
{
CV_Assert(in.isMat());
m = in.getMat();
CV_Assert(in.isMat() || in.isUMat());
if (in.isMat())
{
Mat mat = in.getMat();
CV_Assert(mat.dims == 2);
int rows = mat.rows;
int cols = mat.cols;
int cn = mat.channels();
int type = mat.type();
int dstType = CV_MAKE_TYPE(CV_MAT_DEPTH(type), 1);
int size[3] = { cn, rows, cols };
this->create(3, size, dstType);
uchar *data = m.data;
int step = rows * cols * CV_ELEM_SIZE(dstType);
if (cn == 1)
{
Mat wrapper2D(rows, cols, dstType, m.data);
mat.copyTo(wrapper2D);
}
else
{
std::vector<Mat> wrappers(cn);
for (int i = 0; i < cn; i++)
{
wrappers[i] = Mat(rows, cols, dstType, data);
data += step;
}
cv::split(mat, wrappers);
}
}
else
{
CV_Error(cv::Error::StsNotImplemented, "Not Implemented");
}
}
static Vec4i blobNormalizeShape(int ndims, const int *sizes)
@ -62,14 +99,29 @@ void Blob::fill(int ndims, const int *sizes, int type, void *data, bool deepCopy
void Blob::fill(InputArray in)
{
CV_Assert(in.isMat() || in.isMatVector());
//TODO
*this = Blob(in);
}
void Blob::create(int ndims, const int *sizes, int type /*= CV_32F*/)
void Blob::create(int ndims, const int *sizes, int type)
{
CV_Assert(type == CV_32F || type == CV_64F);
Vec4i shape = blobNormalizeShape(ndims, sizes);
m.create(shape.channels, &shape[0], type);
}
void Blob::create(Vec4i shape, int type)
{
m.create(shape.channels, &shape[0], type);
}
void Blob::create(int num, int cn, int rows, int cols, int type)
{
Vec4i shape(num, cn, rows, cols);
create(4, &shape[0], type);
}
//////////////////////////////////////////////////////////////////////////
struct LayerOutId
@ -129,13 +181,9 @@ struct Net::Impl
{
layers.insert(make_pair(0, LayerData("_input", "_input")));
lastLayerId = 1;
netInputAliases.push_back("input");
netInputAliases.push_back("data");
netWasAllocated = false;
}
std::vector<String> netInputAliases;
std::vector<int> netOutputs;
typedef std::map<int, LayerData> MapIdToLayerData;
@ -145,6 +193,20 @@ struct Net::Impl
int lastLayerId;
bool netWasAllocated;
void setUpNet()
{
if (!netWasAllocated)
{
connectInputs();
allocateLayers();
computeNetOutputs();
netWasAllocated = true;
}
}
int getLayerId(const String &layerName)
{
std::map<String, int>::iterator it = layerNameToId.find(layerName);
@ -169,7 +231,7 @@ struct Net::Impl
return it->second;
}
int findOutputsByName(const String &name, LayerOutId *found, int maxCount)
int findOutputsByName(const String &name, LayerOutId *found, int maxCount = 1)
{
int count = 0;
@ -237,32 +299,14 @@ struct Net::Impl
}
else if (foundCount == 0)
{
if (std::find(netInputAliases.begin(), netInputAliases.end(), tgtName) == netInputAliases.end())
{
CV_Error(cv::Error::StsBadArg, "Can't find specified input blob \"" + tgtName + "\" for layer \"" + ld.name + "\"");
continue;
}
LayerData &inputLayer = layers[0];
int outIndex = std::find(inputLayer.outputNames.begin(), inputLayer.outputNames.end(), tgtName) - inputLayer.outputNames.begin();
if (outIndex < inputLayer.outputNames.size())
{
out = LayerOutId(0, outIndex);
}
else
{
inputLayer.outputNames.push_back(tgtName);
inputLayer.outputBlobs.resize(inputLayer.outputNames.size());
out = LayerOutId(0, (int)inputLayer.outputNames.size() - 1);
}
CV_Error(cv::Error::StsBadArg, "Can't find specified input blob \"" + tgtName + "\" for layer \"" + ld.name + "\"");
continue;
}
else
{
out = foundOutputs[0];
}
ld.inputBlobs[ii] = &layers[out.lid].outputBlobs[out.oid];
ld.inputBlobsId[ii] = out;
ld.inputLayersId.insert(out.lid);
layers[out.lid].requiredOutputs.insert(out.oid);
@ -312,12 +356,15 @@ struct Net::Impl
void allocateLayers()
{
allocateOutputBlobs();
MapIdToLayerData::iterator it;
for (it = layers.begin(); it != layers.end(); it++)
{
int lid = it->first;
LayerData &ld = it->second;
//create instance
if (ld.layerInstance == NULL && lid != 0)
{
ld.layerInstance = LayerRegister::createLayerInstance(ld.type, ld.params);
@ -326,6 +373,19 @@ struct Net::Impl
std::cerr << "Can't create layer \"" << ld.name << "\" of type \"" << ld.type << "\"" << std::endl;
}
}
//bind inputs
ld.inputBlobs.resize(ld.inputBlobsId.size());
for (size_t i = 0; i < ld.inputBlobsId.size(); i++)
{
int srcLId = ld.inputBlobsId[i].lid;
int srcOId = ld.inputBlobsId[i].oid;
ld.inputBlobs[i] = &layers[srcLId].outputBlobs[srcOId];
}
//allocate layer
if (ld.layerInstance)
ld.layerInstance->allocate(ld.inputBlobs, ld.outputBlobs);
}
}
@ -408,19 +468,45 @@ void Net::setLayerInputs(const std::vector<String> &outputs, LayerId layer)
void Net::forward()
{
impl->allocateOutputBlobs();
impl->connectInputs();
impl->computeNetOutputs();
impl->allocateLayers();
impl->setUpNet();
impl->forwardAll();
}
void Net::forward(LayerId toLayer)
{
impl->setUpNet();
impl->forwardLayer(impl->getLayerId(toLayer));
}
void Net::setNetInputs(const std::vector<String> &inputBlobNames)
{
setOutputNames(0, inputBlobNames);
}
void Net::setBlob(BlobId outputName, const Blob &blob)
{
String name = outputName.get<String>();
LayerOutId found;
if (!impl->findOutputsByName(name, &found, 1))
CV_Error(cv::Error::StsObjectNotFound, "Request blob \"" + name + "\" not found");
impl->allocateOutputBlobs();
impl->layers[found.lid].outputBlobs[found.oid] = blob;
}
Blob Net::getBlob(BlobId outputName)
{
String name = outputName.get<String>();
LayerOutId found;
if (!impl->findOutputsByName(name, &found, 1))
CV_Error(cv::Error::StsObjectNotFound, "Request blob \"" + name + "\" not found");
impl->allocateOutputBlobs();
return impl->layers[found.lid].outputBlobs[found.oid];
}
Importer::~Importer()
{

@ -1,6 +1,11 @@
#include "precomp.hpp"
#include "layers.hpp"
#include <math.h>
#include <float.h>
#include <iostream>
#include <algorithm>
using std::max;
using std::min;
namespace cv
{
@ -43,31 +48,211 @@ REGISTER_LAYER_CLASS(InnerProduct, FullyConnectedLayer)
//////////////////////////////////////////////////////////////////////////
static void getKernelParams(LayerParams &params, int &kernelH, int &kernelW, int &padH, int &padW, int &strideH, int &strideW)
{
if (params.has("kernel_h") && params.has("kernel_w"))
{
kernelH = params.get<int>("kernel_h");
kernelW = params.get<int>("kernel_w");
}
else if (params.has("kernel_size"))
{
kernelH = kernelW = params.get<int>("kernel_size");
}
else
{
CV_Error(cv::Error::StsBadArg, "kernel_size (or kernel_h and kernel_w) not specified");
}
if (params.has("pad_h") && params.has("pad_w"))
{
padH = params.get<int>("pad_h");
padW = params.get<int>("pad_w");
}
else
{
padH = padW = params.get<int>("pad", 0);
}
if (params.has("stride_h") && params.has("stride_w"))
{
strideH = params.get<int>("stride_h");
strideW = params.get<int>("stride_w");
}
else
{
strideH = strideW = params.get<int>("stride", 1);
}
CV_Assert(kernelH > 0 && kernelW > 0 && padH >= 0 && padW >= 0 && strideH > 0 & strideW > 0);
}
PoolingLayer::PoolingLayer(LayerParams &params)
{
if (params.has("pool"))
{
String pool = params.get<String>("pool").toLowerCase();
if (pool == "max")
type = MAX;
else if (pool == "ave")
type = AVE;
else if (pool == "stochastic")
type = STOCHASTIC;
else
CV_Error(cv::Error::StsBadArg, "Unknown pooling type \"" + pool + "\"");
}
else
{
type = MAX;
}
getKernelParams(params, kernelH, kernelW, padH, padW, strideH, strideW);
}
void PoolingLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() > 0);
inH = inputs[0]->cols();
inW = inputs[0]->rows();
computeOutputShape(inH, inW);
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
CV_Assert(inputs[i]->rows() == inH && inputs[i]->cols() == inW);
outputs[i].create(inputs[i]->num(), inputs[i]->channels(), pooledH, pooledW);
}
}
void PoolingLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
for (size_t ii = 0; ii < inputs.size(); ii++)
{
switch (type)
{
case MAX:
maxPooling(*inputs[ii], outputs[ii]);
break;
default:
CV_Error(cv::Error::StsNotImplemented, "Not implemented");
break;
}
}
}
void PoolingLayer::maxPooling(Blob &input, Blob &output)
{
CV_DbgAssert(output.rows() == pooledH && output.cols() == pooledW);
for (int n = 0; n < input.num(); ++n)
{
for (int c = 0; c < input.channels(); ++c)
{
float *srcData = input.ptr<float>(n, c);
float *dstData = output.ptr<float>(n, c);
for (int ph = 0; ph < pooledH; ++ph)
{
for (int pw = 0; pw < pooledW; ++pw)
{
int hstart = ph * strideH - padH;
int wstart = pw * strideW - padW;
int hend = min(hstart + kernelH, inH);
int wend = min(wstart + kernelW, inW);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
const int pool_index = ph * pooledW + pw;
float max_val = -FLT_MAX;
for (int h = hstart; h < hend; ++h)
for (int w = wstart; w < wend; ++w)
{
const int index = h * inW + w;
if (srcData[index] > max_val)
max_val = srcData[index];
}
dstData[pool_index] = max_val;
}
}
}
}
}
void PoolingLayer::computeOutputShape(int inH, int inW)
{
//Yeah something strange Caffe scheme-)
pooledH = static_cast<int>(ceil(static_cast<float>(inH + 2 * padH - kernelH) / strideH)) + 1;
pooledW = static_cast<int>(ceil(static_cast<float>(inW + 2 * padW - kernelW) / strideW)) + 1;
if (padH || padW)
{
// If we have padding, ensure that the last pooling starts strictly
// inside the image (instead of at the padding); otherwise clip the last.
if ((pooledH - 1) * strideH >= inH + padH)
--pooledH;
if ((pooledW - 1) * strideW >= inW + padW)
--pooledW;
CV_Assert((pooledH - 1) * strideH < inH + padH);
CV_Assert((pooledW - 1) * strideW < inW + padW);
}
}
//////////////////////////////////////////////////////////////////////////
ConvolutionLayer::ConvolutionLayer(LayerParams &params)
{
getKernelParams(params, kernelH, kernelW, padH, padW, strideH, strideW);
numOutput = params.get<int>("num_output");
bias = params.get<bool>("bias_term", true);
group = params.get<int>("group", 1);
CV_Assert(numOutput % group == 0);
CV_Assert(params.learnedBlobs.size() >= 1 && (!bias || params.learnedBlobs.size() >= 2));
learnedParams.assign(params.learnedBlobs.begin(), params.learnedBlobs.begin() + (bias ? 2 : 1));
Blob &weightBlob = learnedParams[0];
CV_Assert(weightBlob.cols() == kernelW && weightBlob.rows() == kernelH && weightBlob.num() == numOutput);
if (bias)
{
Blob &biasBlob = learnedParams[1];
CV_Assert(biasBlob.total() == numOutput);
}
}
void ConvolutionLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() > 0);
Blob &weightBlob = learnedParams[0];
inCn = inputs[0]->channels();
CV_Assert(inCn % group == 0 && weightBlob.channels() == inCn);
inH = inputs[0]->rows();
inW = inputs[0]->cols();
computeOutputShape(inH, inW);
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
CV_Assert(inputs[i]->rows() == inH && inputs[i]->cols() == inW && inputs[i]->channels() == inCn);
int num = inputs[i]->num();
outputs[i].create(num, numOutput, outH, outW);
}
colCn = kernelH * kernelW * inCn;
imColsMat.create(colCn, outH * outW, CV_32F);
if (bias)
{
biasOnesMat = Mat::ones(1, outH * outW, CV_32F);
}
}
@ -102,29 +287,104 @@ void im2col_cpu(const Dtype* data_im, const int channels,
void ConvolutionLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == outputs.size());
float *colPtr = imColsMat.ptr<float>();
float *weigtsPtr = learnedParams[0].ptr<float>();
float *biasPtr = (bias) ? learnedParams[1].ptr<float>() : NULL;
CV_Assert(group == 1);
for (size_t i = 0; i < outputs.size(); i++)
for (size_t ii = 0; ii < outputs.size(); ii++)
{
int num = inputs[ii]->num();
for (int n = 0; n < num; n++)
{
float *srcImPtr = inputs[ii]->ptr<float>(n);
float *dstImPtr = outputs[ii].ptr<float>(n);
im2col_cpu(srcImPtr, inCn, inH, inW, kernelH, kernelW, padH, padW, strideH, strideW, colPtr);
Mat weightsMat(numOutput, colCn, CV_32F, weigtsPtr);
Mat dstIm(numOutput, outH*outW, CV_32F, dstImPtr);
cv::gemm(weightsMat, imColsMat, 1, noArray(), 0, dstIm);
if (bias)
{
Mat biasMat(numOutput, 1, CV_32F, biasPtr);
cv::gemm(biasMat, biasOnesMat, 1, dstIm, 1, dstIm);
}
}
}
}
//////////////////////////////////////////////////////////////////////////
void ConvolutionLayer::computeOutputShape(int inH, int inW)
{
outH = (inH + 2 * padH - kernelH) / strideH + 1;
outW = (inW + 2 * padW - kernelW) / strideW + 1;
}
//////////////////////////////////////////////////////////////////////////
FullyConnectedLayer::FullyConnectedLayer(LayerParams &params)
{
numOutputs = params.get<int>("num_output");
bias = params.get<bool>("bias_term", true);
CV_Assert(params.learnedBlobs.size() >= 1);
CV_Assert(!bias || (params.learnedBlobs.size() >= 2 && params.learnedBlobs[1].total() == numOutputs));
learnedParams.resize(bias ? 2 : 1);
learnedParams[0] = params.learnedBlobs[0];
if (bias)
{
learnedParams[1] = params.learnedBlobs[1];
}
}
void FullyConnectedLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() > 0);
inC = inputs[0]->channels();
inH = inputs[0]->rows();
inW = inputs[0]->cols();
inSize = inC * inH * inW;
CV_Assert(inSize * numOutputs == learnedParams[0].total());
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
if (i != 0)
CV_Assert(inputs[i]->channels() == inC && inputs[i]->rows() == inH && inputs[i]->cols() == inW);
outputs[i].create(inputs[i]->num(), numOutputs, 1, 1);
}
}
void FullyConnectedLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
for (size_t i = 0; i < inputs.size(); i++)
{
int M = inputs[i]->num();
int N = numOutputs;
int K = inSize;
Mat srcMat(M, K, CV_32F, inputs[i]->ptr<float>());
Mat weights(K, N, CV_32F, learnedParams[0].ptr<float>());
Mat dstMat(M, N, CV_32F, outputs[i].ptr<float>());
cv::gemm(srcMat, weights, 1, noArray(), 0, dstMat);
if (bias)
{
Mat biasOnesMat = Mat::ones(M, 1, CV_32F);
Mat biasMat(1, N, CV_32F, learnedParams[1].ptr<float>());
cv::gemm(biasOnesMat, biasMat, 1, dstMat, 1, dstMat);
}
}
}
}

@ -17,30 +17,49 @@ namespace dnn
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == 1);
outputs[0] = *inputs[0];
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
outputs[i] = *inputs[i];
}
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == 1 && outputs.size() == 1);
CV_Assert(inputs[0]->getMatRef().data == outputs[0].getMatRef().data);
CV_Assert(inputs.size() == outputs.size());
float *data = outputs[0].getMatRef().ptr<float>();
for (size_t i = 0; i < inputs.size(); i++)
{
CV_Assert(inputs[i]->ptr<float>() == outputs[i].ptr<float>());
float *data = outputs[i].ptr<float>();
size_t size = outputs[i].total();
//Vec4i shape = outputs[0].shape();
//CV_Assert(pitch[i] == shape[i] * sizeof(float) );
//Vec4i shape = outputs[0].shape();
//CV_Assert(pitch[i] == shape[i] * sizeof(float) );
for (size_t i = 0; i < outputs[0].total(); i++)
data[i] = func(data[i]);
for (size_t j = 0; j < size; j++)
data[j] = func(data[j]);
}
}
};
class PoolingLayer : public Layer
{
enum
{
MAX,
AVE,
STOCHASTIC
};
int type;
int padH, padW;
int strideH, strideW;
int sizeH, sizeW;
int kernelH, kernelW;
int inH, inW;
int pooledH, pooledW;
void computeOutputShape(int inH, int inW);
void maxPooling(Blob &input, Blob &output);
public:
PoolingLayer(LayerParams &params);
@ -50,9 +69,18 @@ namespace dnn
class ConvolutionLayer : public Layer
{
int groups;
bool bias;
int numOutput, group;
int padH, padW;
int strideH, strideW;
int sizeH, sizeW;
int kernelH, kernelW;
int inH, inW, inCn, colCn;
int outH, outW;
Mat imColsMat, biasOnesMat;
void computeOutputShape(int inH, int inW);
public:
ConvolutionLayer(LayerParams &params);
@ -62,8 +90,12 @@ namespace dnn
class FullyConnectedLayer : public Layer
{
bool bias;
int numOutputs;
int inC, inH, inW;
size_t inSize;
public:
FullyConnectedLayer(LayerParams &params);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);

@ -14,7 +14,8 @@ static std::string getOpenCVExtraDir()
return cvtest::TS::ptr()->get_data_path();
}
static std::string getTestFile(const char *filename)
template<typename TStr>
static std::string getTestFile(TStr filename)
{
return (getOpenCVExtraDir() + "/dnn/") + filename;
}
@ -24,7 +25,29 @@ TEST(ReadCaffePrototxt_gtsrb, Accuracy)
Ptr<Importer> importer = createCaffeImporter(getTestFile("gtsrb.prototxt"), getTestFile("gtsrb_iter_36000.caffemodel"));
Net net;
importer->populateNet(net);
Mat img = imread(getTestFile("sign_50.ppm"));
CV_Assert(!img.empty());
img.convertTo(img, CV_32F, 1.0/255);
resize(img, img, cv::Size(48, 48));
Blob imgBlob(img);
net.setBlob("input", imgBlob);
net.forward();
Blob res = net.getBlob("layer8");
for (int n = 0; n < 1; n++)
{
Mat slice = Mat(res.channels() * res.rows(), res.cols(), CV_32F, res.ptr<float>(n));
double maxv;
std::vector<int> maxIdx;
minMaxLoc(slice, NULL, &maxv, NULL, &maxIdx);
std::cout << "Best class: #" << maxIdx[0] << std::endl;
//imwrite(getTestFile("vis.png"), slice*(255.0 / maxv));
}
}
//TEST(ReadCaffePrototxt_GoogleNet, Accuracy)

@ -12,6 +12,7 @@
#include "opencv2/core.hpp"
#include "opencv2/dnn.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/ts.hpp"
#include <opencv2/ts/ts_perf.hpp>
#include <opencv2/core/utility.hpp>

Loading…
Cancel
Save