Fixed average pooling error

pull/265/head
Vitaliy Lyudvichenko 10 years ago
parent 5b0805322e
commit d0875b1c4c
  1. 67
      modules/dnn/src/layers/pooling_layer.cpp

@ -3,6 +3,7 @@
#include <float.h>
#include <algorithm>
using std::max;
using std::min;
namespace cv
{
@ -23,10 +24,10 @@ namespace dnn
int strideH, strideW;
int kernelH, kernelW;
int inH, inW;
int pooledH, pooledW;
int inpH, inpW;
int outH, outW;
void computeOutputShape(int inH, int inW);
void computeOutputShape(int inpH, int inpW);
void maxPooling(Blob &input, Blob &output);
void avePooling(Blob &input, Blob &output);
@ -66,15 +67,15 @@ namespace dnn
{
CV_Assert(inputs.size() > 0);
inW = inputs[0]->cols();
inH = inputs[0]->rows();
computeOutputShape(inH, inW);
inpW = inputs[0]->cols();
inpH = inputs[0]->rows();
computeOutputShape(inpH, inpW);
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
CV_Assert(inputs[i]->rows() == inH && inputs[i]->cols() == inW);
outputs[i].create(BlobShape(inputs[i]->num(), inputs[i]->channels(), pooledH, pooledW));
CV_Assert(inputs[i]->rows() == inpH && inputs[i]->cols() == inpW);
outputs[i].create(BlobShape(inputs[i]->num(), inputs[i]->channels(), outH, outW));
}
}
@ -99,7 +100,7 @@ namespace dnn
void PoolingLayer::maxPooling(Blob &input, Blob &output)
{
CV_DbgAssert(output.rows() == pooledH && output.cols() == pooledW);
CV_DbgAssert(output.rows() == outH && output.cols() == outW);
for (int n = 0; n < input.num(); ++n)
{
@ -108,23 +109,23 @@ namespace dnn
float *srcData = input.ptrf(n, c);
float *dstData = output.ptrf(n, c);
for (int ph = 0; ph < pooledH; ++ph)
for (int ph = 0; ph < outH; ++ph)
{
for (int pw = 0; pw < pooledW; ++pw)
for (int pw = 0; pw < outW; ++pw)
{
int hstart = ph * strideH - padH;
int wstart = pw * strideW - padW;
int hend = min(hstart + kernelH, inH);
int wend = min(wstart + kernelW, inW);
int hend = min(hstart + kernelH, inpH);
int wend = min(wstart + kernelW, inpW);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
const int pool_index = ph * pooledW + pw;
const int pool_index = ph * outW + pw;
float max_val = -FLT_MAX;
for (int h = hstart; h < hend; ++h)
for (int w = wstart; w < wend; ++w)
{
const int index = h * inW + w;
const int index = h * inpW + w;
if (srcData[index] > max_val)
max_val = srcData[index];
}
@ -145,27 +146,27 @@ namespace dnn
float *srcData = input.ptrf(n, c);
float *dstData = output.ptrf(n, c);
for (int ph = 0; ph < pooledH; ++ph)
for (int ph = 0; ph < outH; ++ph)
{
for (int pw = 0; pw < pooledW; ++pw)
for (int pw = 0; pw < outW; ++pw)
{
int hstart = ph * strideH - padH;
int wstart = pw * strideH - padH;
int hend = min(hstart + kernelH, inW + padH);
int wend = min(wstart + kernelW, inH + padW);
int wstart = pw * strideW - padW;
int hend = min(hstart + kernelH, inpH + padH);
int wend = min(wstart + kernelW, inpW + padW);
int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, inH);
wend = min(wend, inW);
hend = min(hend, inpH);
wend = min(wend, inpW);
dstData[ph * pooledH + pw] = 0.f;
dstData[ph * outW + pw] = 0.f;
for (int h = hstart; h < hend; ++h)
for (int w = wstart; w < wend; ++w)
dstData[ph * pooledH + pw] += srcData[h * inW + w];
dstData[ph * outW + pw] += srcData[h * inpW + w];
dstData[ph * pooledH + pw] /= pool_size;
dstData[ph * outW + pw] /= pool_size;
}
}
}
@ -175,19 +176,19 @@ namespace dnn
void PoolingLayer::computeOutputShape(int inH, int inW)
{
//Yeah, something strange Caffe scheme-)
pooledH = static_cast<int>(ceil(static_cast<float>(inH + 2 * padH - kernelH) / strideH)) + 1;
pooledW = static_cast<int>(ceil(static_cast<float>(inW + 2 * padW - kernelW) / strideW)) + 1;
outH = static_cast<int>(ceil(static_cast<float>(inH + 2 * padH - kernelH) / strideH)) + 1;
outW = static_cast<int>(ceil(static_cast<float>(inW + 2 * padW - kernelW) / strideW)) + 1;
if (padH || padW)
{
// If we have padding, ensure that the last pooling starts strictly
// inside the image (instead of at the padding); otherwise clip the last.
if ((pooledH - 1) * strideH >= inH + padH)
--pooledH;
if ((pooledW - 1) * strideW >= inW + padW)
--pooledW;
CV_Assert((pooledH - 1) * strideH < inH + padH);
CV_Assert((pooledW - 1) * strideW < inW + padW);
if ((outH - 1) * strideH >= inH + padH)
--outH;
if ((outW - 1) * strideW >= inW + padW)
--outW;
CV_Assert((outH - 1) * strideH < inH + padH);
CV_Assert((outW - 1) * strideW < inW + padW);
}
}
}

Loading…
Cancel
Save