Implemented allocation of DAG and it's forward pass.

Added wrappers for basic layers.
pull/265/head
Vitaliy Lyudvichenko 10 years ago
parent ee837c1132
commit 194271df50
  1. 33
      modules/dnn/include/opencv2/dnn/dict.hpp
  2. 54
      modules/dnn/include/opencv2/dnn/dnn.hpp
  3. 33
      modules/dnn/include/opencv2/dnn/dnn.inl.hpp
  4. 66
      modules/dnn/src/caffe_importer.cpp
  5. 396
      modules/dnn/src/dnn.cpp
  6. 131
      modules/dnn/src/layers.cpp
  7. 76
      modules/dnn/src/layers.hpp
  8. 16
      modules/dnn/test/test_caffe_importer.cpp
  9. 5
      modules/dnn/testdata/dnn/gtsrb.prototxt

@ -1,4 +1,6 @@
#pragma once
#ifndef __OPENCV_DNN_DICT_HPP__
#define __OPENCV_DNN_DICT_HPP__
#include <opencv2/core.hpp>
namespace cv
@ -29,8 +31,8 @@ struct DictValue
template<typename T>
T get() const;
template<typename T>
const T &get() const;
bool isString() const;
bool isInt() const;
DictValue &operator=(const DictValue &r);
@ -48,6 +50,17 @@ class Dict
public:
bool has(const String &name)
{
return dict.count(name);
}
DictValue *ptr(const String &name)
{
_Dict::iterator i = dict.find(name);
return (i == dict.end()) ? NULL : &i->second;
}
template <typename T>
const T &get(const String &name) const
{
@ -129,7 +142,7 @@ inline bool DictValue::get<bool>() const
}
template<>
inline const String &DictValue::get<String>() const
inline String DictValue::get<String>() const
{
CV_Assert(type == cv::ParamType<String>::type);
return *s;
@ -174,5 +187,17 @@ inline DictValue::DictValue(const DictValue &r)
*this = r;
}
inline bool DictValue::isString() const
{
return (type == cv::Param::STRING);
}
inline bool DictValue::isInt() const
{
return (type == cv::Param::INT);
}
}
}
#endif

@ -34,14 +34,17 @@ namespace dnn
Mat getMat();
Mat getMat(int num, int channel);
//shape getters
int cols() const;
int rows() const;
Size size() const;
int channels() const;
int num() const;
Vec4i shape() const;
size_t total() const;
template<typename TFloat>
TFloat *ptr(int num = 0, int cn = 0, int row = 0, int col = 0);
private:
Mat m;
@ -58,19 +61,19 @@ namespace dnn
{
public:
typedef Layer* (*Constuctor)();
typedef Ptr<Layer> (*Constuctor)(LayerParams &params);
static void registerLayer(const String &type, Constuctor constructor);
static void unregisterLayer(const String &type);
static Ptr<Layer> createLayerInstance(const String &type);
static Ptr<Layer> createLayerInstance(const String &type, LayerParams& params = LayerParams());
private:
LayerRegister();
LayerRegister(const LayerRegister &lr);
static std::map<String, Constuctor> registeredLayers;
struct Impl;
static Ptr<Impl> impl;
};
//this class allows to build new Layers
@ -82,17 +85,10 @@ namespace dnn
virtual ~Layer();
//type of Layer
virtual String type() const = 0;
//setUp calls once (think that it's constructor)
virtual void setUp(LayerParams &params) = 0;
//maybe useless function
//shape of output blobs must be adjusted with respect to shape of input blobs
virtual void adjustShape(const std::vector<Blob> &inputs, std::vector<Blob> &outputs) = 0;
virtual void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs) = 0;
virtual void forward(std::vector<Blob> &inputs, std::vector<Blob> &outputs) = 0;
virtual void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs) = 0;
virtual int getNumInputs();
virtual int getNumOutputs();
@ -118,10 +114,10 @@ namespace dnn
//each output of each layer can be labeled by unique string label (as in Caffe)
//if label not specified then %layer_name%.%layer_output_id% can be used
void setOutputNames(LayerId layer, const std::vector<String> &outputNames);
void setLayerInputs(const std::vector<String> &outputs, LayerId layer);
void connect(BlobId input, BlobId output);
void connect(const std::vector<BlobId> &outputs, const std::vector<BlobId> &inputs);
void connect(const std::vector<BlobId> &outputs, LayerId layer);
int getOutputId(LayerId layer, int outputNum);
int getInputId(LayerId layer, int inputNum);
@ -159,6 +155,32 @@ namespace dnn
CV_EXPORTS Ptr<Importer> createCaffeImporter(const String &prototxt, const String &caffeModel);
//allows automatically register created layer on module load time
struct _LayerRegisterer
{
String type;
_LayerRegisterer(const String &type, LayerRegister::Constuctor constuctor)
{
this->type = type;
LayerRegister::registerLayer(type, constuctor);
}
~_LayerRegisterer()
{
LayerRegister::unregisterLayer(type);
}
};
//registers layer on module load time
#define REGISTER_LAYER(type, constuctorFunc) \
static _LayerRegisterer __layerRegisterer_##type(#type, func);
#define REGISTER_LAYER_CLASS(type, class) \
Ptr<Layer> __layerRegisterer_func_##type(LayerParams &params) \
{ return Ptr<Layer>(new class(params)); } \
static _LayerRegisterer __layerRegisterer_##type(#type, __layerRegisterer_func_##type);
}
}

@ -29,46 +29,55 @@ namespace dnn
return Mat();
}
inline
int Blob::cols() const
inline int Blob::cols() const
{
CV_DbgAssert(m.dims > 2);
return m.size[m.dims-1];
}
inline
int Blob::rows() const
inline int Blob::rows() const
{
CV_DbgAssert(m.dims > 2);
return m.size[m.dims-2];
}
inline
Size Blob::size() const
inline Size Blob::size() const
{
return Size(cols(), rows());
}
inline
int Blob::channels() const
inline int Blob::channels() const
{
CV_DbgAssert(m.dims >= 3);
return m.size[m.dims-3];
}
inline
int Blob::num() const
inline int Blob::num() const
{
CV_DbgAssert(m.dims == 4);
return m.size[0];
}
inline
Vec4i Blob::shape() const
inline Vec4i Blob::shape() const
{
CV_DbgAssert(m.dims == 4);
return Vec4i(m.size.p);
}
inline size_t Blob::total() const
{
CV_DbgAssert(m.dims == 4);
return (size_t) m.size[0] * m.size[1] * m.size[2] * m.size[3];
}
template<typename TFloat>
TFloat *ptr(int num = 0, int cn = 0, int row = 0, int col = 0)
{
CV_Assert(m.type() = cv::DataType<TFloat>::type && m.dims == 4);
return (TFloat*) (m.data + num * m.step[0] + cn * m.step[1] + row * m.step[2] + col * m.step[3]);
}
}
}

@ -164,56 +164,62 @@ namespace
dstData[i] = protoBlob.data(i);
}
void extractBinaryLayerParms(const caffe::LayerParameter& layer, LayerParams& layerParams)
{
const std::string &name = layer.name();
int li;
for (li = 0; li != netBinary.layer_size(); li++)
{
if (netBinary.layer(li).name() == name)
break;
}
if (li == netBinary.layer_size() || netBinary.layer(li).blobs_size() == 0)
return;
const caffe::LayerParameter &binLayer = netBinary.layer(li);
layerParams.learnedBlobs.resize(binLayer.blobs_size());
for (int bi = 0; bi < binLayer.blobs_size(); bi++)
{
blobFromProto(binLayer.blobs(bi), layerParams.learnedBlobs[bi]);
}
}
void populateNet(Net dstNet)
{
int layersSize = net.layer_size();
std::vector<String> layersName(layersSize);
std::vector<LayerParams> layersParam(layersSize);
std::vector<int> layersId(layersSize);
std::vector<std::vector<String>> bottomsVec(layersSize);
for (int li = 0; li < layersSize; li++)
{
const caffe::LayerParameter layer = net.layer(li);
String name = layer.name();
String type = layer.type();
LayerParams layerParams;
std::vector<String> bottoms, tops;
bottoms.assign(layer.bottom().begin(), layer.bottom().end());
std::vector<String> tops;
tops.assign(layer.top().begin(), layer.top().end());
bottomsVec[li].assign(layer.bottom().begin(), layer.bottom().end());
std::cout << std::endl << "LAYER: " << name << std::endl;
extractLayerParams(layer, layersParam[li]);
layersName[li] = name;
extractLayerParams(layer, layerParams);
extractBinaryLayerParms(layer, layerParams);
int id = dstNet.addLayer(name, type);
dstNet.setOutputNames(id, tops);
//SetUp
//int id = config->addLayer(name, type);
//config->setLayerOutputLabels(id, bottoms);
layersName[li] = name;
layersId[li] = id;
}
for (int li = 0; li < netBinary.layer_size(); li++)
for (int li = 0; li < layersSize; li++)
{
const caffe::LayerParameter layer = netBinary.layer(li);
if (layer.blobs_size() == 0)
continue;
String name = layer.name();
int index = std::find(layersName.begin(), layersName.end(), name) - layersName.begin();
if (index < layersName.size())
{
std::vector<Blob> &layerBlobs = layersParam[index].learnedBlobs;
layerBlobs.resize(layer.blobs_size());
for (int bi = 0; bi < layer.blobs_size(); bi++)
{
blobFromProto(layer.blobs(bi), layerBlobs[bi]);
}
}
else
{
std::cerr << "Unknown layer name " << name << " into" << std::endl;
}
dstNet.setLayerInputs(bottomsVec[li], layersId[li]);
}
}

@ -1,7 +1,15 @@
#include "opencv2/dnn.hpp"
#include "precomp.hpp"
#include <set>
#include <algorithm>
#include <iostream>
using namespace cv;
using namespace cv::dnn;
#include <algorithm>
using std::vector;
using std::map;
using std::make_pair;
using std::set;
namespace cv
{
@ -10,7 +18,8 @@ namespace dnn
Blob::Blob()
{
int zeros[4] = {0, 0, 0, 0};
m = Mat(4, zeros, CV_32F, NULL);
}
Blob::Blob(InputArray in)
@ -61,9 +70,297 @@ void Blob::create(int ndims, const int *sizes, int type /*= CV_32F*/)
m.create(shape.channels, &shape[0], type);
}
//////////////////////////////////////////////////////////////////////////
struct LayerOutId
{
int lid;
int oid;
String name;
LayerOutId() {}
LayerOutId(int layerId, int outputId, const String &outputName = String())
: lid(layerId), oid(outputId), name(outputName) {}
struct UnaryMatchName
{
const String &name;
UnaryMatchName(const String &_name) : name(_name) {}
bool operator()(const String &other) { return name == other; }
};
};
struct LayerData
{
LayerData() {}
LayerData(const String &_name, const String &_type, LayerParams &_params = LayerParams())
: name(_name), type(_type), params(_params)
{}
String name;
String type;
LayerParams params;
std::vector<String> outputNames;
std::vector<String> inputNames;
bool hasNamedOutput(const String &name)
{
return std::find(outputNames.begin(), outputNames.end(), name) != outputNames.end();
}
bool hasNemedInput(const String &name)
{
return std::find(inputNames.begin(), inputNames.end(), name) != inputNames.end();
}
std::vector<LayerOutId> inputBlobsId;
std::set<int> inputLayersId;
std::set<int> requiredOutputs;
Ptr<Layer> layerInstance;
std::vector<Blob> outputBlobs;
std::vector<Blob*> inputBlobs;
int flag;
};
struct Net::Impl
{
Impl()
{
layers.insert(make_pair(0, LayerData("_input", "_input")));
lastLayerId = 1;
netInputAliases.push_back("input");
netInputAliases.push_back("data");
}
std::vector<String> netInputAliases;
std::vector<int> netOutputs;
typedef std::map<int, LayerData> MapIdToLayerData;
std::map<int, LayerData> layers;
std::map<String, int> layerNameToId;
int lastLayerId;
int getLayerId(const String &layerName)
{
std::map<String, int>::iterator it = layerNameToId.find(layerName);
return (it != layerNameToId.end()) ? it->second : -1;
}
int getLayerId(const DictValue &v)
{
if (v.isString())
return getLayerId(v.get<String>());
else if (v.isInt())
int id = v.get<int>();
else
CV_Assert(v.isString() || v.isInt());
}
LayerData& getLayerData(const DictValue &v)
{
int id = getLayerId(v);
std::map<int, LayerData>::iterator it = layers.find(id);
CV_Assert(id >= 0 && it != layers.end());
return it->second;
}
int findOutputsByName(const String &name, LayerOutId *found, int maxCount)
{
int count = 0;
MapIdToLayerData::iterator it;
for (it = layers.begin(); it != layers.end() && count < maxCount; it++)
{
int lid = it->first;
LayerData &ld = it->second;
for (size_t oi = 0; oi < ld.outputNames.size() && count < maxCount; oi++)
{
if (ld.outputNames[oi] == name)
found[count++] = LayerOutId(lid, oi);
}
}
return count;
}
void connectInputs()
{
LayerOutId foundOutputs[3], out;
MapIdToLayerData::iterator it;
for (it = layers.begin(); it != layers.end(); it++)
{
int lid = it->first;
LayerData &ld = it->second;
ld.inputBlobs.resize(ld.inputNames.size());
ld.inputBlobsId.resize(ld.inputNames.size());
ld.inputLayersId.clear();
for (size_t ii = 0; ii < ld.inputNames.size(); ii++)
{
const String &tgtName = ld.inputNames[ii];
int foundCount = findOutputsByName(tgtName, foundOutputs, 3);
if (foundCount > 2)
{
CV_Error(cv::Error::StsNotImplemented, "Two or more non-inplace blobs have the same name \"" + tgtName + "\"");
}
else if (foundCount == 2)
{
bool inPlace[2];
inPlace[0] = layers[ foundOutputs[0].lid ].hasNemedInput(tgtName);
inPlace[1] = layers[ foundOutputs[1].lid ].hasNemedInput(tgtName);
if (!inPlace[0] && !inPlace[1])
{
CV_Error(cv::Error::StsNotImplemented, "Two or more non-inplace blobs have the same name \"" + tgtName + "\"");
}
else if (inPlace[0] && inPlace[1])
{
CV_Error(cv::Error::StsNotImplemented, "Two or more blobs has same in-place blob \"" + tgtName + "\"");
}
else
{
if (ld.hasNamedOutput(tgtName))
out = (inPlace[0]) ? foundOutputs[1] : foundOutputs[0];
else
out = (inPlace[0]) ? foundOutputs[0] : foundOutputs[1];
}
}
else if (foundCount == 0)
{
if (std::find(netInputAliases.begin(), netInputAliases.end(), tgtName) == netInputAliases.end())
{
CV_Error(cv::Error::StsBadArg, "Can't find specified input blob \"" + tgtName + "\" for layer \"" + ld.name + "\"");
continue;
}
LayerData &inputLayer = layers[0];
int outIndex = std::find(inputLayer.outputNames.begin(), inputLayer.outputNames.end(), tgtName) - inputLayer.outputNames.begin();
if (outIndex < inputLayer.outputNames.size())
{
out = LayerOutId(0, outIndex);
}
else
{
inputLayer.outputNames.push_back(tgtName);
inputLayer.outputBlobs.resize(inputLayer.outputNames.size());
out = LayerOutId(0, (int)inputLayer.outputNames.size() - 1);
}
}
else
{
out = foundOutputs[0];
}
ld.inputBlobs[ii] = &layers[out.lid].outputBlobs[out.oid];
ld.inputBlobsId[ii] = out;
ld.inputLayersId.insert(out.lid);
layers[out.lid].requiredOutputs.insert(out.oid);
}
}
for (it = layers.begin(); it != layers.end(); it++)
{
LayerData& ld = it->second;
std::cout << ld.name << std::endl;
std::cout << "Connected:" << std::endl;
for (std::set<int>::iterator j = ld.inputLayersId.begin(); j != ld.inputLayersId.end(); j++)
std::cout << layers[*j].name << std::endl;
std::cout << std::endl;
}
}
void computeNetOutputs()
{
netOutputs.clear();
MapIdToLayerData::iterator it;
for (it = layers.begin(); it != layers.end(); it++)
{
int lid = it->first;
LayerData &ld = it->second;
if (ld.requiredOutputs.size() == 0)
netOutputs.push_back(lid);
}
std::cout << "\nNet Outputs(" << netOutputs.size() << "):\n";
for (int i = 0; i < netOutputs.size(); i++)
std::cout << layers[netOutputs[i]].name << std::endl;
}
void allocateOutputBlobs()
{
MapIdToLayerData::iterator it;
for (it = layers.begin(); it != layers.end(); it++)
{
LayerData &ld = it->second;
ld.outputBlobs.resize(ld.outputNames.size());
}
}
void allocateLayers()
{
MapIdToLayerData::iterator it;
for (it = layers.begin(); it != layers.end(); it++)
{
int lid = it->first;
LayerData &ld = it->second;
if (ld.layerInstance == NULL && lid != 0)
{
ld.layerInstance = LayerRegister::createLayerInstance(ld.type, ld.params);
if (ld.layerInstance == NULL)
{
std::cerr << "Can't create layer \"" << ld.name << "\" of type \"" << ld.type << "\"" << std::endl;
}
}
}
}
void forwardLayer(int layerId, bool clearFlags = true)
{
if (clearFlags)
{
MapIdToLayerData::iterator it;
for (it = layers.begin(); it != layers.end(); it++)
it->second.flag = 0;
}
LayerData &ld = layers[layerId];
for (set<int>::iterator i = ld.inputLayersId.begin(); i != ld.inputLayersId.end(); i++)
{
LayerData &ild = layers[*i];
if (!ild.flag)
{
if (ild.layerInstance)
ild.layerInstance->forward(ild.inputBlobs, ild.outputBlobs);
ild.flag = true;
}
}
}
void forwardAll()
{
MapIdToLayerData::iterator it;
for (it = layers.begin(); it != layers.end(); it++)
it->second.flag = 0;
for (it = layers.begin(); it != layers.end(); it++)
forwardLayer(it->first, false);
}
};
Net::Net() : impl(new Net::Impl)
@ -76,11 +373,61 @@ Net::~Net()
}
int Net::addLayer(const String &name, const String &type, LayerParams &params)
{
if (impl->getLayerId(name) >= 0)
{
CV_Error(cv::Error::StsBadArg, "Layer \"" + name + "\" already into net");
return -1;
}
int id = ++impl->lastLayerId;
impl->layerNameToId.insert(std::make_pair(name, id));
impl->layers.insert(std::make_pair(id, LayerData(name, type, params)));
return id;
}
void Net::connect(BlobId input, BlobId output)
{
}
void Net::setOutputNames(LayerId layer, const std::vector<String> &outputNames)
{
LayerData &ld = impl->getLayerData(layer);
CV_Assert(ld.outputNames.size() == 0);
ld.outputNames.assign(outputNames.begin(), outputNames.end());
}
void Net::setLayerInputs(const std::vector<String> &outputs, LayerId layer)
{
LayerData &ld = impl->getLayerData(layer);
ld.inputNames.assign(outputs.begin(), outputs.end());
}
void Net::forward()
{
impl->allocateOutputBlobs();
impl->connectInputs();
impl->computeNetOutputs();
impl->allocateLayers();
impl->forwardAll();
}
void Net::forward(LayerId toLayer)
{
impl->forwardLayer(impl->getLayerId(toLayer));
}
Importer::~Importer()
{
}
//////////////////////////////////////////////////////////////////////////
#include <sstream>
template<typename T>
String toString(const T &v)
@ -116,5 +463,48 @@ Layer::~Layer()
}
//////////////////////////////////////////////////////////////////////////
struct LayerRegister::Impl : public std::map<String, LayerRegister::Constuctor>
{
};
//allocates on load and cleans on exit
Ptr<LayerRegister::Impl> LayerRegister::impl(new LayerRegister::Impl());
void LayerRegister::registerLayer(const String &_type, Constuctor constructor)
{
String type = _type.toLowerCase();
Impl::iterator it = impl->find(type);
if (it != impl->end() && it->second != constructor)
{
CV_Error(cv::Error::StsBadArg, "Layer \"" + type + "\" already was registered");
}
impl->insert(std::make_pair(type, constructor));
}
void LayerRegister::unregisterLayer(const String &_type)
{
String type = _type.toLowerCase();
impl->erase(type);
}
Ptr<Layer> LayerRegister::createLayerInstance(const String &_type, LayerParams& params)
{
String type = _type.toLowerCase();
Impl::const_iterator it = LayerRegister::impl->find(type);
if (it != impl->end())
{
return it->second(params);
}
else
{
return Ptr<Layer>(); //NULL
}
}
}
}

@ -0,0 +1,131 @@
#include "precomp.hpp"
#include "layers.hpp"
#include <math.h>
namespace cv
{
namespace dnn
{
struct ReLUFunctor
{
float negative_slope;
ReLUFunctor(LayerParams &params)
{
if (params.has("negative_slope"))
negative_slope = params.get<float>("negative_slope");
else
negative_slope = 0.f;
}
inline float operator()(float x)
{
return (x >= 0) ? x : negative_slope * x;
}
};
struct TanHFunctor
{
TanHFunctor(LayerParams &params) {}
inline float operator()(float x)
{
return tanh(x);
}
};
REGISTER_LAYER_CLASS(ReLU, ElementWiseLayer<ReLUFunctor>)
REGISTER_LAYER_CLASS(TanH, ElementWiseLayer<TanHFunctor>)
REGISTER_LAYER_CLASS(Convolution, ConvolutionLayer)
REGISTER_LAYER_CLASS(Pooling, PoolingLayer)
REGISTER_LAYER_CLASS(InnerProduct, FullyConnectedLayer)
//////////////////////////////////////////////////////////////////////////
PoolingLayer::PoolingLayer(LayerParams &params)
{
}
void PoolingLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
}
void PoolingLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
}
//////////////////////////////////////////////////////////////////////////
ConvolutionLayer::ConvolutionLayer(LayerParams &params)
{
}
void ConvolutionLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
}
template <typename Dtype>
void im2col_cpu(const Dtype* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
Dtype* data_col)
{
int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
int channels_col = channels * kernel_h * kernel_w;
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % kernel_w;
int h_offset = (c / kernel_w) % kernel_h;
int c_im = c / kernel_h / kernel_w;
for (int h = 0; h < height_col; ++h) {
for (int w = 0; w < width_col; ++w) {
int h_pad = h * stride_h - pad_h + h_offset;
int w_pad = w * stride_w - pad_w + w_offset;
if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
data_col[(c * height_col + h) * width_col + w] =
data_im[(c_im * height + h_pad) * width + w_pad];
else
data_col[(c * height_col + h) * width_col + w] = 0;
}
}
}
}
void ConvolutionLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == outputs.size());
for (size_t i = 0; i < outputs.size(); i++)
{
}
}
//////////////////////////////////////////////////////////////////////////
FullyConnectedLayer::FullyConnectedLayer(LayerParams &params)
{
}
void FullyConnectedLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
}
void FullyConnectedLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
}
}
}

@ -0,0 +1,76 @@
#ifndef __OPENCV_DNN_LAYERS_HPP__
#define __OPENCV_DNN_LAYERS_HPP__
#include <opencv2/dnn.hpp>
namespace cv
{
namespace dnn
{
template<typename Func>
class ElementWiseLayer : public Layer
{
Func func;
public:
ElementWiseLayer(LayerParams &_params) : func(_params) {}
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == 1);
outputs[0] = *inputs[0];
}
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == 1 && outputs.size() == 1);
CV_Assert(inputs[0]->getMatRef().data == outputs[0].getMatRef().data);
float *data = outputs[0].getMatRef().ptr<float>();
//Vec4i shape = outputs[0].shape();
//CV_Assert(pitch[i] == shape[i] * sizeof(float) );
for (size_t i = 0; i < outputs[0].total(); i++)
data[i] = func(data[i]);
}
};
class PoolingLayer : public Layer
{
int type;
int strideH, strideW;
int sizeH, sizeW;
public:
PoolingLayer(LayerParams &params);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
class ConvolutionLayer : public Layer
{
int groups;
int strideH, strideW;
int sizeH, sizeW;
public:
ConvolutionLayer(LayerParams &params);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
class FullyConnectedLayer : public Layer
{
int numOutputs;
public:
FullyConnectedLayer(LayerParams &params);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
}
}
#endif

@ -21,16 +21,18 @@ static std::string getTestFile(const char *filename)
TEST(ReadCaffePrototxt_gtsrb, Accuracy)
{
Ptr<Importer> importer = createCaffeImporter(getTestFile("gtsrb.prototxt"), getTestFile("gtsrb_iter_36000.caffemodel") );
Ptr<Importer> importer = createCaffeImporter(getTestFile("gtsrb.prototxt"), getTestFile("gtsrb_iter_36000.caffemodel"));
Net net;
importer->populateNet(net);
net.forward();
}
TEST(ReadCaffePrototxt_GoogleNet, Accuracy)
{
Ptr<Importer> importer = createCaffeImporter(getOpenCVExtraDir() + "/dnn/googlenet_deploy.prototxt", "");
Net net;
importer->populateNet(net);
}
//TEST(ReadCaffePrototxt_GoogleNet, Accuracy)
//{
// Ptr<Importer> importer = createCaffeImporter(getOpenCVExtraDir() + "/dnn/googlenet_deploy.prototxt", "");
// Net net;
// importer->populateNet(net);
// net.forward();
//}
}

@ -1,13 +1,12 @@
name: "gtsrb"
input: "data"
input: "input"
input_dim: 1
input_dim: 3
input_dim: 48
input_dim: 48
layers {
bottom: "data"
bottom: "input"
top: "layer1"
name: "layer1"
type: CONVOLUTION

Loading…
Cancel
Save