Fixed Deconvolution layer. Added more wide layers test coverage.

pull/265/head
Vitaliy Lyudvichenko 10 years ago
parent d0875b1c4c
commit f8119ea058
  1. 3
      modules/dnn/include/opencv2/dnn/dnn.hpp
  2. 22
      modules/dnn/src/layers/convolution_layer.cpp
  3. 8
      modules/dnn/src/layers/pooling_layer.cpp
  4. 4
      modules/dnn/test/npy_blob.hpp
  5. 74
      modules/dnn/test/test_layers.cpp
  6. 39
      modules/dnn/testdata/dnn/layers/convolution.prototxt
  7. 39
      modules/dnn/testdata/dnn/layers/deconvolution.prototxt
  8. 32
      modules/dnn/testdata/dnn/layers/inner_product.prototxt
  9. 26
      modules/dnn/testdata/dnn/layers/pooling_ave.prototxt
  10. 26
      modules/dnn/testdata/dnn/layers/pooling_max.prototxt
  11. 48
      modules/dnn/testdata/dnn/layers/run.py

@ -1,10 +1,7 @@
#ifndef __OPENCV_DNN_DNN_HPP__
#define __OPENCV_DNN_DNN_HPP__
#include <map>
#include <vector>
#include <iostream>
#include <opencv2/core.hpp>
#include <opencv2/dnn/dict.hpp>
#include <opencv2/dnn/blob.hpp>

@ -29,6 +29,7 @@ namespace dnn
void im2col(Blob &inpBlob, int imNum, int cnGroup);
public:
ConvolutionLayer() {}
ConvolutionLayer(LayerParams &params);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
@ -63,7 +64,7 @@ namespace dnn
learnedParams.assign(params.learnedBlobs.begin(), params.learnedBlobs.begin() + (bias ? 2 : 1));
const Blob &wgtBlob = learnedParams[0];
CV_Assert(wgtBlob.dims() == 4 && wgtBlob.cols() == kerW && wgtBlob.rows() == kerH && wgtBlob.num() == numOutput);
CV_Assert(wgtBlob.dims() == 4 && wgtBlob.cols() == kerW && wgtBlob.rows() == kerH);
if (bias)
{
@ -81,8 +82,7 @@ namespace dnn
computeInpOutShape(inpBlob);
CV_Assert(inpCn % group == 0 && outCn % group == 0);
CV_Assert(learnedParams[0].channels() == inpCn / group);
CV_Assert(learnedParams[0].num() == outCn);
CV_Assert(learnedParams[0].num() == outCn && learnedParams[0].channels() == inpCn / group);
outGroupCn = outCn / group;
inpGroupCn = inpCn / group;
@ -165,7 +165,7 @@ namespace dnn
outH = (inpH + 2 * padH - kerH) / strideH + 1;
outW = (inpW + 2 * padW - kerW) / strideW + 1;
outCn = learnedParams[0].num();
outCn = numOutput;
topH = outH; topW = outW; topCn = outCn;
}
@ -178,7 +178,7 @@ namespace dnn
inpH = strideH * (outH - 1) + kerH - 2 * padH;
inpW = strideW * (outW - 1) + kerW - 2 * padW;
inpCn = learnedParams[0].channels();
inpCn = numOutput;
topH = inpH; topW = inpW; topCn = inpCn;
}
@ -201,16 +201,16 @@ namespace dnn
if (is1x1())
colMat = dstMat;
Mat convMat(outGroupCn, outH*outW, convBlob.type(), convBlob.ptrRaw(n, g*inpGroupCn));
Mat wghtMat(outGroupCn, ksize, wghtBlob.type(), wghtBlob.ptrRaw(g*inpGroupCn));
Mat convMat(outGroupCn, outH*outW, convBlob.type(), convBlob.ptrRaw(n, g*outGroupCn));
Mat wghtMat(outGroupCn, ksize, wghtBlob.type(), wghtBlob.ptrRaw(g*outGroupCn));
cv::gemm(wghtMat, convMat, 1, noArray(), 0, colMat, GEMM_1_T);
col2im(dstMat);
if (bias)
{
float *biasPtr = learnedParams[1].ptrf() + g*outGroupCn;
Mat biasMat(outGroupCn, 1, CV_32F, biasPtr);
float *biasPtr = learnedParams[1].ptrf() + g*inpGroupCn;
Mat biasMat(inpGroupCn, 1, CV_32F, biasPtr);
cv::gemm(biasMat, biasOnesMat, 1, dstMat, 1, dstMat);
}
}
@ -223,9 +223,9 @@ namespace dnn
if (is1x1()) return;
if (dstMat.type() == CV_32F)
col2im_cpu((float*)colMat.ptr(), inpCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, (float*)dstMat.ptr());
col2im_cpu((float*)colMat.ptr(), inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, (float*)dstMat.ptr());
if (dstMat.type() == CV_64F)
col2im_cpu((double*)colMat.ptr(), inpCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, (double*)dstMat.ptr());
col2im_cpu((double*)colMat.ptr(), inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, (double*)dstMat.ptr());
}
}
}

@ -119,7 +119,7 @@ namespace dnn
int wend = min(wstart + kernelW, inpW);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
const int pool_index = ph * outW + pw;
const int poolIndex = ph * outW + pw;
float max_val = -FLT_MAX;
for (int h = hstart; h < hend; ++h)
@ -130,7 +130,7 @@ namespace dnn
max_val = srcData[index];
}
dstData[pool_index] = max_val;
dstData[poolIndex] = max_val;
}
}
}
@ -154,7 +154,7 @@ namespace dnn
int wstart = pw * strideW - padW;
int hend = min(hstart + kernelH, inpH + padH);
int wend = min(wstart + kernelW, inpW + padW);
int pool_size = (hend - hstart) * (wend - wstart);
int poolSize = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, inpH);
@ -166,7 +166,7 @@ namespace dnn
for (int w = wstart; w < wend; ++w)
dstData[ph * outW + pw] += srcData[h * inpW + w];
dstData[ph * outW + pw] /= pool_size;
dstData[ph * outW + pw] /= poolSize;
}
}
}

@ -17,8 +17,8 @@ inline cv::dnn::Blob blobFromNPY(const cv::String &path)
inline void saveBlobToNPY(cv::dnn::Blob &blob, const cv::String &path)
{
cv::Vec4i shape = blob.shape4();
cnpy::npy_save(path.c_str(), blob.ptr<float>(), (unsigned*)&shape[0], 4);
cv::dnn::BlobShape shape = blob.shape();
cnpy::npy_save(path.c_str(), blob.ptr<float>(), (unsigned*)&shape[0], shape.dims());
}
#endif

@ -16,30 +16,22 @@ static std::string getOpenCVExtraDir()
}
template<typename TStr>
static std::string getTestFile(TStr filename)
static String _tf(TStr filename)
{
return (getOpenCVExtraDir() + "/dnn/layers/") + filename;
}
template<typename T, int n>
bool isEqual(const cv::Vec<T, n> &l, const cv::Vec<T, n> &r)
static void testLayer(String basename, bool useCaffeModel = false)
{
for (int i = 0; i < n; i++)
{
if (l[i] != r[i])
return false;
}
return true;
}
Blob inp = blobFromNPY(_tf("blob.npy"));
Blob ref = blobFromNPY(_tf(basename + ".npy"));
static void testLayer(String proto, String caffemodel = String())
{
Blob inp = blobFromNPY(getTestFile("blob.npy"));
Blob ref = blobFromNPY(getTestFile(proto + ".caffe.npy"));
String prototxt = basename + ".prototxt";
String caffemodel = basename + ".caffemodel";
Net net;
{
Ptr<Importer> importer = createCaffeImporter(getTestFile(proto), caffemodel);
Ptr<Importer> importer = createCaffeImporter(_tf(prototxt), (useCaffeModel) ? _tf(caffemodel) : String());
ASSERT_TRUE(importer != NULL);
importer->populateNet(net);
}
@ -60,22 +52,47 @@ static void testLayer(String proto, String caffemodel = String())
EXPECT_LE(normInf, 0.0001);
}
TEST(Layer_Softmax_Test, Accuracy)
TEST(Layer_Test_Softmax, Accuracy)
{
testLayer("softmax");
}
TEST(Layer_Test_LRN_spatial, Accuracy)
{
testLayer("lrn_spatial");
}
TEST(Layer_Test_LRN_channels, Accuracy)
{
testLayer("softmax.prototxt");
testLayer("lrn_channels");
}
TEST(Layer_LRN_spatial_Test, Accuracy)
TEST(Layer_Test_Convolution, Accuracy)
{
testLayer("lrn_spatial.prototxt");
testLayer("convolution", true);
}
TEST(Layer_LRN_channels_Test, Accuracy)
TEST(Layer_Test_InnerProduct, Accuracy)
{
testLayer("lrn_channels.prototxt");
testLayer("inner_product", true);
}
TEST(Layer_Reshape_squeeze, Accuracy)
TEST(Layer_Test_Pooling_max, Accuracy)
{
testLayer("pooling_max");
}
TEST(Layer_Test_Pooling_ave, Accuracy)
{
testLayer("pooling_ave");
}
TEST(Layer_Test_DeConvolution, Accuracy)
{
testLayer("deconvolution", true);
}
TEST(Layer_Test_Reshape, squeeze)
{
LayerParams params;
params.set("axis", 2);
@ -92,28 +109,23 @@ TEST(Layer_Reshape_squeeze, Accuracy)
EXPECT_EQ(outVec[0].shape(), BlobShape(Vec3i(4, 3, 2)));
}
TEST(Layer_Reshape_Split_Slice_Test, Accuracy)
TEST(Layer_Test_Reshape_Split_Slice, Accuracy)
{
Net net;
{
Ptr<Importer> importer = createCaffeImporter(getTestFile("reshape_and_slice_routines.prototxt"));
Ptr<Importer> importer = createCaffeImporter(_tf("reshape_and_slice_routines.prototxt"));
ASSERT_TRUE(importer != NULL);
importer->populateNet(net);
}
BlobShape shape = BlobShape(Vec2i(6, 12));
Mat1f inputMat(shape[0], shape[1]);
Blob input(BlobShape(Vec2i(6, 12)));
RNG rng(0);
rng.fill(inputMat, RNG::UNIFORM, -1, 1);
rng.fill(input.getMatRef(), RNG::UNIFORM, -1, 1);
Blob input(inputMat);
input.reshape(shape);
net.setBlob(".input", input);
net.forward();
Blob output = net.getBlob("output");
input.fill(shape, CV_32F, inputMat.data);
normAssert(input, output);
}

@ -0,0 +1,39 @@
name: "test_Convolution"
input: "input"
input_dim: 2
input_dim: 6
input_dim: 75
input_dim: 113
layer {
type: "Convolution"
convolution_param
{
group: 3
num_output: 12
pad_h: 0
pad_w: 1
kernel_h: 4
kernel_w: 5
stride_h: 2
stride_w: 3
weight_filler{
type: 'uniform'
min: -1
max: 1
}
bias_filler {
type: 'uniform'
min: -1
max: 1
}
}
name: "output"
bottom: "input"
top: "output"
}

@ -0,0 +1,39 @@
name: "test_Convolution"
input: "input"
input_dim: 2
input_dim: 6
input_dim: 75
input_dim: 113
layer {
type: "Deconvolution"
convolution_param
{
group: 3
num_output: 12
pad_h: 0
pad_w: 1
kernel_h: 4
kernel_w: 5
stride_h: 2
stride_w: 3
weight_filler{
type: 'uniform'
min: -1
max: 1
}
bias_filler {
type: 'uniform'
min: -1
max: 1
}
}
name: "output"
bottom: "input"
top: "output"
}

@ -0,0 +1,32 @@
name: "test_InnerProduct"
input: "input"
input_dim: 2
input_dim: 6
input_dim: 75
input_dim: 113
layer {
type: "InnerProduct"
inner_product_param
{
axis: 3
num_output: 2
weight_filler{
type: 'uniform'
min: -1
max: 1
}
bias_filler {
type: 'uniform'
min: -1
max: 1
}
}
name: "output"
bottom: "input"
top: "output"
}

@ -0,0 +1,26 @@
name: "test_Pooling_max"
input: "input"
input_dim: 2
input_dim: 6
input_dim: 75
input_dim: 113
layer {
type: "Pooling"
pooling_param
{
pool: AVE
pad_h: 2
pad_w: 1
kernel_h: 3
kernel_w: 5
stride_h: 2
stride_w: 1
}
name: "output"
bottom: "input"
top: "output"
}

@ -0,0 +1,26 @@
name: "test_Pooling_max"
input: "input"
input_dim: 2
input_dim: 6
input_dim: 75
input_dim: 113
layer {
type: "Pooling"
pooling_param
{
pool: MAX
pad_h: 2
pad_w: 1
kernel_h: 3
kernel_w: 5
stride_h: 2
stride_w: 1
}
name: "output"
bottom: "input"
top: "output"
}

@ -0,0 +1,48 @@
# coding: utf-8
import sys, os, glob
CAFFE_ROOT = "/home/vitaliy/opencv/caffe/"
sys.path.insert(0, CAFFE_ROOT + 'python')
CV2_DIR = "/home/vitaliy/opencv/build-opencv-qt/lib"
sys.path.insert(0, CV2_DIR)
import numpy as np
import caffe
import cv2
def get_cafe_output(inp_blob, proto_name, caffemodel_name):
caffe.set_mode_cpu()
net = caffe.Net(proto_name, caffe.TEST)
net.blobs['input'].reshape(*inp_blob.shape)
net.blobs['input'].data[...] = inp_blob
net.forward()
out_blob = net.blobs['output'].data[...];
if net.params.get('output'):
print "Params count:", len(net.params['output'])
net.save(caffemodel_name)
return out_blob
if __name__ == '__main__':
proto_filenames = glob.glob("*.prototxt")
inp_blob = np.load('blob.npy')
print inp_blob.shape
for proto_filename in proto_filenames:
proto_filename = os.path.basename(proto_filename)
proto_basename = os.path.splitext(proto_filename)[0]
cfmod_basename = proto_basename + ".caffemodel"
npy_filename = proto_basename + ".npy"
print cfmod_basename
out_blob = get_cafe_output(inp_blob, proto_filename, cfmod_basename)
print out_blob.shape
np.save(npy_filename, out_blob)
Loading…
Cancel
Save