Some layer fixes. Added concat layer and avg. pooling to sucesfully run GoogLeNet.

Added concat layer, implemented average pooling to run GoogLeNet.
Fixed transpose error in FullyConnected layer (hotfix in softmax layer).
Added GoogleNet test and updated AlexNet test (both nets now work fine).
pull/265/head
Vitaliy Lyudvichenko 10 years ago
parent 6fd67d44b0
commit 983823468d
  1. 82
      modules/dnn/src/layers/concat_layer.cpp
  2. 44
      modules/dnn/src/layers/fully_connected_layer.cpp
  3. 40
      modules/dnn/src/layers/pooling_layer.cpp
  4. 4
      modules/dnn/src/layers/softmax_layer.cpp
  5. 17
      modules/dnn/test/test_alexnet.cpp
  6. 62
      modules/dnn/test/test_googlenet.cpp
  7. 4
      modules/dnn/testdata/dnn/bvlc_alexnet.prototxt

@ -0,0 +1,82 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include <iostream>
#include <stdlib.h>
namespace cv
{
namespace dnn
{
class ConcatLayer : public Layer
{
int axis;
public:
ConcatLayer(LayerParams& params);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
REGISTER_LAYER_CLASS(Concat, ConcatLayer)
ConcatLayer::ConcatLayer(LayerParams &params)
{
axis = params.get<int>("axis", 1);
CV_Assert(axis == 0 || axis == 1);
}
void ConcatLayer::allocate(const std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() > 0);
int axisSum = 0;
for (size_t i = 0; i < inputs.size(); i++)
{
Vec4i refShape = inputs[0]->shape();
Vec4i curShape = inputs[i]->shape();
for (int axisId = 0; axisId < 4; axisId++)
{
if (axisId != axis && refShape[axisId] != curShape[axisId])
CV_Error(cv::Error::StsBadArg, "Inconsitent shape for ConcatLayer");
}
axisSum += curShape[axis];
}
Vec4i shape = inputs[0]->shape();
shape[axis] = axisSum;
outputs.resize(1);
outputs[0].create(shape);
}
void ConcatLayer::forward(std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
{
float *dstPtr = outputs[0].ptr<float>();
if (axis == 0)
{
for (size_t i = 0; i < inputs.size(); i++)
{
const float *srcPtr = inputs[i]->ptr<float>();
memcpy(dstPtr, srcPtr, inputs[i]->total() * sizeof(float));
dstPtr += inputs[i]->total();
}
}
else
{
for (int n = 0; n < outputs[0].num(); n++)
{
for (size_t i = 0; i < inputs.size(); i++)
{
Blob &inp = *inputs[i];
memcpy(dstPtr, inp.ptr<float>(n), inp.total(1) * sizeof(float));
dstPtr += inp.total(1);
}
}
}
}
}
}

@ -1,18 +1,20 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include <iostream>
namespace cv
{
namespace dnn
{
//TODO: implement axis number parameter
class FullyConnectedLayer : public Layer
{
bool bias;
int numOutputs;
int axis;
int inC, inH, inW;
int inSize;
int innerSize;
void reshape(const Blob &inp, Blob &out);
public:
FullyConnectedLayer(LayerParams &params);
@ -28,7 +30,9 @@ namespace dnn
{
numOutputs = params.get<int>("num_output");
bias = params.get<bool>("bias_term", true);
axis = params.get<int>("axis", 1);
CV_Assert(0 <= axis && axis < 4);
CV_Assert(params.learnedBlobs.size() >= 1);
CV_Assert(!bias || (params.learnedBlobs.size() >= 2 && (int)params.learnedBlobs[1].total() == numOutputs));
@ -44,36 +48,46 @@ namespace dnn
{
CV_Assert(inputs.size() > 0);
inC = inputs[0]->channels();
inH = inputs[0]->rows();
inW = inputs[0]->cols();
inSize = inC * inH * inW;
CV_Assert((size_t)inSize * (size_t)numOutputs == learnedParams[0].total());
innerSize = (int)inputs[0]->total(axis);
CV_Assert((size_t)innerSize * (size_t)numOutputs == learnedParams[0].total());
CV_Assert(learnedParams[0].rows() == numOutputs && learnedParams[0].cols() == innerSize);
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
if (i != 0)
CV_Assert(inputs[i]->channels() == inC && inputs[i]->rows() == inH && inputs[i]->cols() == inW);
CV_Assert(inputs[i]->total(axis) == (size_t)innerSize);
outputs[i].create(inputs[i]->num(), numOutputs, 1, 1);
this->reshape(*inputs[i], outputs[i]);
}
}
void FullyConnectedLayer::reshape(const Blob &inp, Blob &out)
{
Vec4i inpShape = inp.shape();
Vec4i outShape = Vec4i::all(1);
for (int a = 0; a < axis; a++)
outShape[a] = inpShape[a];
outShape[3] = numOutputs;
out.create(outShape, inp.type());
}
void FullyConnectedLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
for (size_t i = 0; i < inputs.size(); i++)
{
int M = inputs[i]->num();
int M = inputs[i]->total(0, axis);
int N = numOutputs;
int K = inSize;
int K = innerSize;
Mat srcMat(M, K, CV_32F, inputs[i]->ptr<float>());
Mat weights(K, N, CV_32F, learnedParams[0].ptr<float>());
Mat weights(N, K, CV_32F, learnedParams[0].ptr<float>());
Mat dstMat(M, N, CV_32F, outputs[i].ptr<float>());
cv::gemm(srcMat, weights, 1, noArray(), 0, dstMat);
//important: Caffe stores weights as transposed array
cv::gemm(srcMat, weights, 1, noArray(), 0, dstMat, GEMM_2_T);
if (bias)
{

@ -27,6 +27,7 @@ namespace dnn
void computeOutputShape(int inH, int inW);
void maxPooling(Blob &input, Blob &output);
void avePooling(Blob &input, Blob &output);
public:
PoolingLayer(LayerParams &params);
@ -85,6 +86,9 @@ namespace dnn
case MAX:
maxPooling(*inputs[ii], outputs[ii]);
break;
case AVE:
avePooling(*inputs[ii], outputs[ii]);
break;
default:
CV_Error(cv::Error::StsNotImplemented, "Not implemented");
break;
@ -131,6 +135,42 @@ namespace dnn
}
}
void PoolingLayer::avePooling(Blob &input, Blob &output)
{
for (int n = 0; n < input.num(); ++n)
{
for (int c = 0; c < input.channels(); ++c)
{
float *srcData = input.ptr<float>(n, c);
float *dstData = output.ptr<float>(n, c);
for (int ph = 0; ph < pooledH; ++ph)
{
for (int pw = 0; pw < pooledW; ++pw)
{
int hstart = ph * strideH - padH;
int wstart = pw * strideH - padH;
int hend = min(hstart + kernelH, inW + padH);
int wend = min(wstart + kernelW, inH + padW);
int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, inH);
wend = min(wend, inW);
dstData[ph * pooledH + pw] = 0.f;
for (int h = hstart; h < hend; ++h)
for (int w = wstart; w < wend; ++w)
dstData[ph * pooledH + pw] += srcData[h * inW + w];
dstData[ph * pooledH + pw] /= pool_size;
}
}
}
}
}
void PoolingLayer::computeOutputShape(int inH, int inW)
{
//Yeah, something strange Caffe scheme-)

@ -8,6 +8,7 @@ namespace cv
{
namespace dnn
{
//TODO: set default axis number to 1, and add custom shape length in FullyConnected
class SoftMaxLayer : public Layer
{
int axis;
@ -25,7 +26,8 @@ namespace dnn
SoftMaxLayer::SoftMaxLayer(LayerParams &params)
{
axis = params.get<int>("axis", 1);
//hotfix!!!
axis = params.get<int>("axis", 3);
CV_Assert(0 <= axis && axis < 4);
}

@ -53,23 +53,10 @@ TEST(Reproducibility_AlexNet, Accuracy)
net.setBlob("data", inp);
net.forward("conv1");
normAssert(blobFromNPY(getTestFile("alexnet_conv1.npy")), net.getBlob("conv1"), "conv1");
//saveBlobToNPY(convBlob, getTestFile("alexnet_conv1_my.npy"));
net.forward("relu1");
normAssert(blobFromNPY(getTestFile("alexnet_relu1.npy")), net.getBlob("relu1"), "relu1");
net.forward("norm1");
normAssert(blobFromNPY(getTestFile("alexnet_norm1.npy")), net.getBlob("norm1"), "norm1");
net.forward();
Blob out = net.getBlob("prob");
Blob ref = blobFromNPY(getTestFile("alexnet.npy"));
std::cout << out.shape() << " vs " << ref.shape() << std::endl;
Mat mOut(1, 1000, CV_32F, ref.rawPtr());
Mat mRef(1, 1000, CV_32F, ref.rawPtr());
normAssert(mOut, mRef);
Blob ref = blobFromNPY(getTestFile("alexnet_prob.npy"));
normAssert(out, ref, "prob");
}
}

@ -0,0 +1,62 @@
#include "test_precomp.hpp"
#include "npy_blob.hpp"
namespace cvtest
{
using namespace std;
using namespace testing;
using namespace cv;
using namespace cv::dnn;
static std::string getOpenCVExtraDir()
{
return cvtest::TS::ptr()->get_data_path();
}
template<typename TStr>
static std::string getTestFile(TStr filename)
{
return (getOpenCVExtraDir() + "/dnn/") + filename;
}
inline void normAssert(InputArray ref, InputArray get, const char *comment = "")
{
double normL1 = cvtest::norm(ref, get, NORM_L1)/ ref.getMat().total();
EXPECT_LE(normL1, 0.0001) << comment;
double normInf = cvtest::norm(ref, get, NORM_INF);
EXPECT_LE(normInf, 0.001) << comment;
}
inline void normAssert(Blob ref, Blob test, const char *comment = "")
{
normAssert(ref.getMatRef(), test.getMatRef(), comment);
}
TEST(Reproducibility_GoogLeNet, Accuracy)
{
Net net;
{
Ptr<Importer> importer = createCaffeImporter(getTestFile("bvlc_googlenet.prototxt"), getTestFile("bvlc_googlenet.caffemodel"));
ASSERT_TRUE(importer != NULL);
importer->populateNet(net);
}
std::vector<Mat> inpMats(2);
inpMats[0] = imread(getTestFile("googlenet_0.png"));
inpMats[1] = imread(getTestFile("googlenet_1.png"));
ASSERT_TRUE(!inpMats[0].empty() && !inpMats[1].empty());
inpMats[0].convertTo(inpMats[0], CV_32F);
Blob inp(inpMats[0]);
net.setBlob("data", inp);
net.forward();
Blob out = net.getBlob("prob");
Blob ref = blobFromNPY(getTestFile("googlenet.npy"));
normAssert(out, ref);
}
}

@ -75,12 +75,12 @@ layer {
name: "relu2"
type: "ReLU"
bottom: "conv2"
top: "conv2"
top: "relu2"
}
layer {
name: "norm2"
type: "LRN"
bottom: "conv2"
bottom: "relu2"
top: "norm2"
lrn_param {
local_size: 5

Loading…
Cancel
Save