parent
3aa37d2971
commit
172419ea1c
9 changed files with 174 additions and 35 deletions
@ -0,0 +1,137 @@ |
|||||||
|
#include "../precomp.hpp" |
||||||
|
#include "layers_common.hpp" |
||||||
|
|
||||||
|
namespace cv |
||||||
|
{ |
||||||
|
namespace dnn |
||||||
|
{ |
||||||
|
|
||||||
|
//TODO: Extend cv::Mat::reshape method
|
||||||
|
class ReshapeLayer : public Layer |
||||||
|
{
|
||||||
|
public: |
||||||
|
ReshapeLayer(LayerParams ¶ms); |
||||||
|
|
||||||
|
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs); |
||||||
|
|
||||||
|
void forward(std::vector<Blob*>&, std::vector<Blob>&) {} |
||||||
|
|
||||||
|
protected: |
||||||
|
BlobShape shapeDesc; |
||||||
|
int inAxis, inNumAxes, autoAxisIdx; |
||||||
|
|
||||||
|
void computeOutputShape(int startAxis, int endAxis, BlobShape &inpShape, BlobShape &outShape); |
||||||
|
}; |
||||||
|
|
||||||
|
ReshapeLayer::ReshapeLayer(LayerParams ¶ms) |
||||||
|
{ |
||||||
|
DictValue paramShape = params.get("dim"); |
||||||
|
shapeDesc = BlobShape(paramShape.size()); |
||||||
|
autoAxisIdx = -1; |
||||||
|
|
||||||
|
for (int i = 0; i < paramShape.size(); i++) |
||||||
|
{ |
||||||
|
int dim = paramShape.get<int>(i); |
||||||
|
CV_Assert(dim >= -1); |
||||||
|
|
||||||
|
if (dim == -1) |
||||||
|
{ |
||||||
|
if (autoAxisIdx != -1) |
||||||
|
CV_Error(Error::StsBadArg, "New shape contains multiple -1 dims"); |
||||||
|
autoAxisIdx = i; |
||||||
|
} |
||||||
|
|
||||||
|
shapeDesc[i] = dim; |
||||||
|
} |
||||||
|
|
||||||
|
inAxis = params.get<int>("axis", 0); |
||||||
|
inNumAxes = params.get<int>("num_axes", -1); |
||||||
|
CV_Assert(inNumAxes >= -1); |
||||||
|
} |
||||||
|
|
||||||
|
void ReshapeLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs) |
||||||
|
{ |
||||||
|
CV_Assert(inputs.size() == 1); |
||||||
|
outputs.resize(1); |
||||||
|
|
||||||
|
Blob &inpBlob = *inputs[0]; |
||||||
|
Blob &outBlob = outputs[0]; |
||||||
|
BlobShape inpShape = inpBlob.shape(); |
||||||
|
|
||||||
|
int startAxis = (inAxis >= 0) ? inAxis : inpShape.dims() + 1 + inAxis; |
||||||
|
int endAxis = (inNumAxes == -1) ? inpShape.dims() : startAxis + inNumAxes; |
||||||
|
CV_Assert(0 <= startAxis && startAxis <= inpShape.dims()); |
||||||
|
CV_Assert(0 <= endAxis && endAxis <= inpShape.dims()); |
||||||
|
|
||||||
|
int newDims = inpShape.dims() - (endAxis - startAxis) + shapeDesc.dims(); |
||||||
|
BlobShape outShape(newDims); |
||||||
|
|
||||||
|
computeOutputShape(startAxis, endAxis, inpShape, outShape); |
||||||
|
|
||||||
|
outBlob.shareFrom(inpBlob); |
||||||
|
outBlob.reshape(outShape); |
||||||
|
} |
||||||
|
|
||||||
|
void ReshapeLayer::computeOutputShape(int startAxis, int endAxis, BlobShape &inpShape, BlobShape &outShape) |
||||||
|
{ |
||||||
|
int idx = 0; |
||||||
|
for (int i = 0; i < startAxis; i++) |
||||||
|
outShape[idx++] = inpShape[i]; |
||||||
|
|
||||||
|
for (int i = 0; i < shapeDesc.dims(); i++) |
||||||
|
{ |
||||||
|
if (shapeDesc[i] == 0) |
||||||
|
{ |
||||||
|
int inpAxisIdx = startAxis + i; |
||||||
|
if (inpAxisIdx < 0 || inpShape.dims() <= inpAxisIdx) |
||||||
|
CV_Error(Error::StsOutOfRange, "new shape contains a 0, but there was no corresponding bottom axis to copy"); |
||||||
|
outShape[idx++] = inpShape[startAxis + i]; |
||||||
|
} |
||||||
|
else |
||||||
|
{ |
||||||
|
outShape[idx++] = (shapeDesc[i] > 0) ? shapeDesc[i] : 1; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
for (int i = endAxis; i < inpShape.dims(); i++) |
||||||
|
outShape[idx++] = inpShape[i]; |
||||||
|
|
||||||
|
if (autoAxisIdx >= 0) |
||||||
|
{ |
||||||
|
size_t total = inpShape.total(); |
||||||
|
size_t curTotal = 1; |
||||||
|
for (int i = 0; i < outShape.dims(); i++) |
||||||
|
{ |
||||||
|
if (i != startAxis + autoAxisIdx) |
||||||
|
curTotal *= outShape[i]; |
||||||
|
} |
||||||
|
|
||||||
|
CV_DbgAssert(curTotal <= total && total % curTotal == 0); |
||||||
|
|
||||||
|
outShape[startAxis + autoAxisIdx] = (int)(total / curTotal); |
||||||
|
} |
||||||
|
|
||||||
|
if (inpShape.total() != outShape.total()) |
||||||
|
{ |
||||||
|
CV_Error(Error::StsBadArg, "Mismatch between input and output blob elements count"); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
Ptr<Layer> createFlattenLayer(LayerParams&) |
||||||
|
{ |
||||||
|
LayerParams params; |
||||||
|
|
||||||
|
int shapeDesc[] = {0, -1}; |
||||||
|
params.set("dim", DictValue::arrayInt(shapeDesc, 2)); |
||||||
|
|
||||||
|
return Ptr<Layer>(new ReshapeLayer(params)); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
REGISTER_LAYER_CLASS(Reshape, ReshapeLayer) |
||||||
|
REGISTER_LAYER_FUNC(Flatten, createFlattenLayer) |
||||||
|
|
||||||
|
|
||||||
|
} |
||||||
|
} |
Loading…
Reference in new issue