diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 153d026e4b..01feeeb7c7 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -392,24 +392,21 @@ void ONNXImporter::populateNet(Net dstNet) layerParams.set("ave_pool_padded_area", framework_name == "pytorch"); } else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" || - layer_type == "ReduceMean" || layer_type == "ReduceSum") + layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax") { CV_Assert(node_proto.input_size() == 1); layerParams.type = "Pooling"; String pool; - if (layer_type == "GlobalMaxPool") + if (layer_type == "GlobalMaxPool" || layer_type == "ReduceMax") pool = "MAX"; else if (layer_type == "ReduceSum") pool = "SUM"; else pool = "AVE"; layerParams.set("pool", pool); - layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool"); - if (layer_type == "ReduceMean" || layer_type == "ReduceSum") + layerParams.set("global_pooling", !layerParams.has("axes")); + if (layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax")) { - if (!layerParams.has("axes")) - CV_Error(Error::StsNotImplemented, "Unsupported mode of " + layer_type + " operation."); - MatShape inpShape = outShapes[node_proto.input(0)]; DictValue axes = layerParams.get("axes"); bool keepdims = layerParams.get("keepdims"); @@ -487,6 +484,36 @@ void ONNXImporter::populateNet(Net dstNet) layerParams.type = "Reshape"; layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size())); + node_proto.set_input(0, node_proto.output(0)); + node_proto.set_output(0, layerParams.name); + } + else if (!layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax")) + { + CV_CheckEQ(layerParams.get("keepdims"), 0, (layer_type + " layer only supports keepdims = false").c_str()); + LayerParams reshapeLp; + reshapeLp.name = layerParams.name + "/reshape"; + reshapeLp.type = "Reshape"; + CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end()); + int newShape[] = {1, 1, 1, -1}; + reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 4)); + + opencv_onnx::NodeProto proto; + proto.add_input(node_proto.input(0)); + proto.add_output(reshapeLp.name); + addLayer(dstNet, reshapeLp, proto, layer_id, outShapes); + + LayerParams poolLp = layerParams; + poolLp.name = layerParams.name + "/pool"; + CV_Assert(layer_id.find(poolLp.name) == layer_id.end()); + + node_proto.set_input(0, reshapeLp.name); + node_proto.set_output(0, poolLp.name); + addLayer(dstNet, poolLp, node_proto, layer_id, outShapes); + + layerParams.type = "Reshape"; + int targetShape[] = {1}; + layerParams.set("dim", DictValue::arrayInt(&targetShape[0], 1)); + node_proto.set_input(0, node_proto.output(0)); node_proto.set_output(0, layerParams.name); } @@ -1427,8 +1454,10 @@ void ONNXImporter::populateNet(Net dstNet) case opencv_onnx::TensorProto_DataType_INT64: type = CV_32S; break; default: type = blob.type(); } - blob.convertTo(blob, type); - addConstant(layerParams.name, blob, constBlobs, outShapes); + Mat dst; + blob.convertTo(dst, type); + dst.dims = blob.dims; + addConstant(layerParams.name, dst, constBlobs, outShapes); continue; } else @@ -1477,6 +1506,8 @@ void ONNXImporter::populateNet(Net dstNet) { outShape.erase(outShape.begin() + axis); out.reshape(0, outShape); + } else { + out.dims = 1; } addConstant(layerParams.name, out, constBlobs, outShapes); continue; diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 1eb848897c..2a4555619f 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -262,6 +262,11 @@ TEST_P(Test_ONNX_layers, ReduceSum) testONNXModels("reduce_sum"); } +TEST_P(Test_ONNX_layers, ReduceMaxGlobal) +{ + testONNXModels("reduce_max"); +} + TEST_P(Test_ONNX_layers, ReduceMean3D) { if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target != DNN_TARGET_CPU)