|
|
|
@ -1342,32 +1342,64 @@ void ONNXImporter::populateNet(Net dstNet) |
|
|
|
|
else if (layer_type == "Gather") |
|
|
|
|
{ |
|
|
|
|
CV_Assert(node_proto.input_size() == 2); |
|
|
|
|
Mat input = getBlob(node_proto, constBlobs, 0); |
|
|
|
|
Mat indexMat = getBlob(node_proto, constBlobs, 1); |
|
|
|
|
CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1); |
|
|
|
|
int index = indexMat.at<int>(0); |
|
|
|
|
int axis = layerParams.get<int>("axis", 0); |
|
|
|
|
|
|
|
|
|
Mat out; |
|
|
|
|
if (layerParams.has("axis")) |
|
|
|
|
if ((constBlobs.find(node_proto.input(0)) != constBlobs.end())) |
|
|
|
|
{ |
|
|
|
|
int axis = layerParams.get<int>("axis"); |
|
|
|
|
|
|
|
|
|
Mat input = getBlob(node_proto, constBlobs, 0); |
|
|
|
|
Mat out; |
|
|
|
|
std::vector<cv::Range> ranges(input.dims, Range::all()); |
|
|
|
|
ranges[axis] = Range(index, index + 1); |
|
|
|
|
|
|
|
|
|
out = input(ranges); |
|
|
|
|
MatShape outShape = shape(out); |
|
|
|
|
if (outShape.size() > 1) |
|
|
|
|
{ |
|
|
|
|
outShape.erase(outShape.begin() + axis); |
|
|
|
|
out.reshape(0, outShape); |
|
|
|
|
} |
|
|
|
|
addConstant(layerParams.name, out, constBlobs, outShapes); |
|
|
|
|
continue; |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
CV_Assert(index < input.total()); |
|
|
|
|
const int dims = input.dims; |
|
|
|
|
input = input.reshape(1, 1); |
|
|
|
|
input.dims = 2; |
|
|
|
|
out = input.reshape(1, 1).colRange(index, index + 1); |
|
|
|
|
out.dims = dims; |
|
|
|
|
shapeIt = outShapes.find(node_proto.input(0)); |
|
|
|
|
CV_Assert(shapeIt != outShapes.end()); |
|
|
|
|
MatShape inpShape = shapeIt->second; |
|
|
|
|
|
|
|
|
|
LayerParams sliceLp; |
|
|
|
|
sliceLp.type = "Slice"; |
|
|
|
|
sliceLp.name = inpShape.size() > 1 ? layerParams.name + "/slice" : layerParams.name; |
|
|
|
|
std::vector<int> begin(inpShape.size(), 0); |
|
|
|
|
std::vector<int> end(inpShape.size(), -1); |
|
|
|
|
begin[axis] = index; |
|
|
|
|
end[axis] = index + 1; |
|
|
|
|
|
|
|
|
|
cv::dnn::DictValue paramBegin = cv::dnn::DictValue::arrayInt(begin.data(), begin.size()); |
|
|
|
|
cv::dnn::DictValue paramEnd = cv::dnn::DictValue::arrayInt(end.data(), end.size()); |
|
|
|
|
sliceLp.set("begin", paramBegin); |
|
|
|
|
sliceLp.set("end", paramEnd); |
|
|
|
|
|
|
|
|
|
if (inpShape.size() > 1) |
|
|
|
|
{ |
|
|
|
|
opencv_onnx::NodeProto proto; |
|
|
|
|
proto.add_input(node_proto.input(0)); |
|
|
|
|
proto.add_output(sliceLp.name); |
|
|
|
|
addLayer(dstNet, sliceLp, proto, layer_id, outShapes); |
|
|
|
|
|
|
|
|
|
inpShape.erase(inpShape.begin() + axis); |
|
|
|
|
layerParams.type = "Reshape"; |
|
|
|
|
layerParams.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size())); |
|
|
|
|
node_proto.set_input(0, sliceLp.name); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
layerParams = sliceLp; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
addConstant(layerParams.name, out, constBlobs, outShapes); |
|
|
|
|
continue; |
|
|
|
|
} |
|
|
|
|
else if (layer_type == "Concat") |
|
|
|
|
{ |
|
|
|
|