Support global reduce ops

pull/18299/head
Liubov Batanina 5 years ago
parent 6b674709b8
commit b542a1804c
  1. 49
      modules/dnn/src/onnx/onnx_importer.cpp
  2. 5
      modules/dnn/test/test_onnx_importer.cpp

@ -392,24 +392,21 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.set("ave_pool_padded_area", framework_name == "pytorch"); layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
} }
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" || else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" ||
layer_type == "ReduceMean" || layer_type == "ReduceSum") layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax")
{ {
CV_Assert(node_proto.input_size() == 1); CV_Assert(node_proto.input_size() == 1);
layerParams.type = "Pooling"; layerParams.type = "Pooling";
String pool; String pool;
if (layer_type == "GlobalMaxPool") if (layer_type == "GlobalMaxPool" || layer_type == "ReduceMax")
pool = "MAX"; pool = "MAX";
else if (layer_type == "ReduceSum") else if (layer_type == "ReduceSum")
pool = "SUM"; pool = "SUM";
else else
pool = "AVE"; pool = "AVE";
layerParams.set("pool", pool); layerParams.set("pool", pool);
layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool"); layerParams.set("global_pooling", !layerParams.has("axes"));
if (layer_type == "ReduceMean" || layer_type == "ReduceSum") if (layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
{ {
if (!layerParams.has("axes"))
CV_Error(Error::StsNotImplemented, "Unsupported mode of " + layer_type + " operation.");
MatShape inpShape = outShapes[node_proto.input(0)]; MatShape inpShape = outShapes[node_proto.input(0)];
DictValue axes = layerParams.get("axes"); DictValue axes = layerParams.get("axes");
bool keepdims = layerParams.get<int>("keepdims"); bool keepdims = layerParams.get<int>("keepdims");
@ -487,6 +484,36 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.type = "Reshape"; layerParams.type = "Reshape";
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size())); layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
node_proto.set_input(0, node_proto.output(0));
node_proto.set_output(0, layerParams.name);
}
else if (!layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
{
CV_CheckEQ(layerParams.get<int>("keepdims"), 0, (layer_type + " layer only supports keepdims = false").c_str());
LayerParams reshapeLp;
reshapeLp.name = layerParams.name + "/reshape";
reshapeLp.type = "Reshape";
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
int newShape[] = {1, 1, 1, -1};
reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 4));
opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(0));
proto.add_output(reshapeLp.name);
addLayer(dstNet, reshapeLp, proto, layer_id, outShapes);
LayerParams poolLp = layerParams;
poolLp.name = layerParams.name + "/pool";
CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
node_proto.set_input(0, reshapeLp.name);
node_proto.set_output(0, poolLp.name);
addLayer(dstNet, poolLp, node_proto, layer_id, outShapes);
layerParams.type = "Reshape";
int targetShape[] = {1};
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], 1));
node_proto.set_input(0, node_proto.output(0)); node_proto.set_input(0, node_proto.output(0));
node_proto.set_output(0, layerParams.name); node_proto.set_output(0, layerParams.name);
} }
@ -1427,8 +1454,10 @@ void ONNXImporter::populateNet(Net dstNet)
case opencv_onnx::TensorProto_DataType_INT64: type = CV_32S; break; case opencv_onnx::TensorProto_DataType_INT64: type = CV_32S; break;
default: type = blob.type(); default: type = blob.type();
} }
blob.convertTo(blob, type); Mat dst;
addConstant(layerParams.name, blob, constBlobs, outShapes); blob.convertTo(dst, type);
dst.dims = blob.dims;
addConstant(layerParams.name, dst, constBlobs, outShapes);
continue; continue;
} }
else else
@ -1477,6 +1506,8 @@ void ONNXImporter::populateNet(Net dstNet)
{ {
outShape.erase(outShape.begin() + axis); outShape.erase(outShape.begin() + axis);
out.reshape(0, outShape); out.reshape(0, outShape);
} else {
out.dims = 1;
} }
addConstant(layerParams.name, out, constBlobs, outShapes); addConstant(layerParams.name, out, constBlobs, outShapes);
continue; continue;

@ -262,6 +262,11 @@ TEST_P(Test_ONNX_layers, ReduceSum)
testONNXModels("reduce_sum"); testONNXModels("reduce_sum");
} }
TEST_P(Test_ONNX_layers, ReduceMaxGlobal)
{
testONNXModels("reduce_max");
}
TEST_P(Test_ONNX_layers, ReduceMean3D) TEST_P(Test_ONNX_layers, ReduceMean3D)
{ {
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target != DNN_TARGET_CPU) if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target != DNN_TARGET_CPU)

Loading…
Cancel
Save