Merge pull request #18845 from joegeisbauer:fix_reduce_mean_index_error

pull/18914/head
Alexander Alekhin 5 years ago
commit 0401d5920c
  1. 15
      modules/dnn/src/onnx/onnx_importer.cpp

@ -500,14 +500,17 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
MatShape inpShape = outShapes[node_proto.input(0)]; MatShape inpShape = outShapes[node_proto.input(0)];
DictValue axes = layerParams.get("axes"); DictValue axes = layerParams.get("axes");
bool keepdims = layerParams.get<int>("keepdims"); bool keepdims = layerParams.get<int>("keepdims");
MatShape targetShape = inpShape; MatShape targetShape;
std::vector<bool> shouldDelete(inpShape.size(), false);
for (int i = 0; i < axes.size(); i++) { for (int i = 0; i < axes.size(); i++) {
int axis = clamp(axes.get<int>(i), inpShape.size()); int axis = clamp(axes.get<int>(i), inpShape.size());
if (keepdims) { shouldDelete[axis] = true;
targetShape[axis] = 1; }
} else { for (int axis = 0; axis < inpShape.size(); ++axis){
targetShape.erase(targetShape.begin() + axis); if (!shouldDelete[axis])
} targetShape.push_back(inpShape[axis]);
else if (keepdims)
targetShape.push_back(1);
} }
if (inpShape.size() == 3 && axes.size() <= 2) if (inpShape.size() == 3 && axes.size() <= 2)

Loading…
Cancel
Save