|
|
|
@ -494,14 +494,17 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_) |
|
|
|
|
MatShape inpShape = outShapes[node_proto.input(0)]; |
|
|
|
|
DictValue axes = layerParams.get("axes"); |
|
|
|
|
bool keepdims = layerParams.get<int>("keepdims"); |
|
|
|
|
MatShape targetShape = inpShape; |
|
|
|
|
MatShape targetShape; |
|
|
|
|
std::vector<bool> shouldDelete(inpShape.size(), false); |
|
|
|
|
for (int i = 0; i < axes.size(); i++) { |
|
|
|
|
int axis = clamp(axes.get<int>(i), inpShape.size()); |
|
|
|
|
if (keepdims) { |
|
|
|
|
targetShape[axis] = 1; |
|
|
|
|
} else { |
|
|
|
|
targetShape.erase(targetShape.begin() + axis); |
|
|
|
|
} |
|
|
|
|
shouldDelete[axis] = true; |
|
|
|
|
} |
|
|
|
|
for (int axis = 0; axis < inpShape.size(); ++axis){ |
|
|
|
|
if (!shouldDelete[axis]) |
|
|
|
|
targetShape.push_back(inpShape[axis]); |
|
|
|
|
else if (keepdims) |
|
|
|
|
targetShape.push_back(1); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (inpShape.size() == 3 && axes.size() <= 2) |
|
|
|
|