You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
1.7 KiB
60 lines
1.7 KiB
#include "../precomp.hpp" |
|
#include "layers_common.hpp" |
|
#include "concat_layer.hpp" |
|
|
|
namespace cv |
|
{ |
|
namespace dnn |
|
{ |
|
ConcatLayer::ConcatLayer(LayerParams ¶ms) : Layer(params) |
|
{ |
|
axis = params.get<int>("axis", 1); |
|
CV_Assert(axis >= 0); |
|
} |
|
|
|
void ConcatLayer::allocate(const std::vector<Blob *> &inputs, std::vector<Blob> &outputs) |
|
{ |
|
CV_Assert(inputs.size() > 0); |
|
|
|
int refType = inputs[0]->type(); |
|
BlobShape refShape = inputs[0]->shape(); |
|
CV_Assert(axis < refShape.dims()); |
|
|
|
int axisSum = 0; |
|
for (size_t i = 0; i < inputs.size(); i++) |
|
{ |
|
BlobShape curShape = inputs[i]->shape(); |
|
|
|
CV_Assert(curShape.dims() == refShape.dims() && inputs[i]->type() == refType); |
|
for (int axisId = 0; axisId < refShape.dims(); axisId++) |
|
{ |
|
if (axisId != axis && refShape[axisId] != curShape[axisId]) |
|
CV_Error(Error::StsBadSize, "Inconsitent shape for ConcatLayer"); |
|
} |
|
|
|
axisSum += curShape[axis]; |
|
} |
|
|
|
refShape[axis] = axisSum; |
|
outputs.resize(1); |
|
outputs[0].create(refShape); |
|
} |
|
|
|
void ConcatLayer::forward(std::vector<Blob *> &inputs, std::vector<Blob> &outputs) |
|
{ |
|
const Mat& outMat = outputs[0].getMatRef(); |
|
std::vector<Range> ranges(outputs[0].dims(), Range::all()); |
|
int sizeStart = 0; |
|
for (size_t i = 0; i < inputs.size(); i++) |
|
{ |
|
int sizeEnd = sizeStart + inputs[i]->size(axis); |
|
ranges[axis] = Range(sizeStart, sizeEnd); |
|
|
|
Mat outSubMat = outMat(&ranges[0]); |
|
inputs[i]->getMatRef().copyTo(outSubMat); |
|
|
|
sizeStart = sizeEnd; |
|
} |
|
} |
|
} |
|
}
|
|
|