diff --git a/modules/dnn/src/tflite/tflite_importer.cpp b/modules/dnn/src/tflite/tflite_importer.cpp index 1c048ad9d0..92bfeeef65 100644 --- a/modules/dnn/src/tflite/tflite_importer.cpp +++ b/modules/dnn/src/tflite/tflite_importer.cpp @@ -71,6 +71,7 @@ private: void parseSoftmax(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseCast(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseTranspose(const Operator& op, const std::string& opcode, LayerParams& layerParams); + void parseGlobalPooling(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseFusedActivation(const Operator& op, ActivationFunctionType activ); void parseActivation(const Operator& op, const std::string& opcode, LayerParams& layerParams, bool isFused); @@ -78,6 +79,8 @@ private: int addPermuteLayer(const std::vector& order, const std::string& permName, const std::pair& inpId, int dtype); int addReshapeLayer(const std::vector& shape, int axis, int num_axes, const std::string& name, const std::pair& inpId, int dtype); + int addFlattenLayer(int axis, int end_axis, const std::string& name, const std::pair& inpId, int dtype); + inline bool isInt8(const Operator& op); inline void getQuantParams(const Operator& op, float& inpScale, int& inpZero, float& outScale, int& outZero); }; @@ -286,6 +289,7 @@ TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap() dispatch["CAST"] = &TFLiteImporter::parseCast; dispatch["TFLite_Detection_PostProcess"] = &TFLiteImporter::parseDetectionPostProcess; dispatch["TRANSPOSE"] = &TFLiteImporter::parseTranspose; + dispatch["MEAN"] = dispatch["REDUCE_MAX"] = &TFLiteImporter::parseGlobalPooling; return dispatch; } @@ -764,6 +768,37 @@ void TFLiteImporter::parseTranspose(const Operator& op, const std::string& opcod addLayer(layerParams, op); } +void TFLiteImporter::parseGlobalPooling(const Operator& op, const std::string& opcode, LayerParams& layerParams) +{ + layerParams.type = "Pooling"; + if(opcode == "MEAN") { + layerParams.set("pool", "ave"); + } + else if (opcode == "REDUCE_MAX") { + layerParams.set("pool", "max"); + } + else { + CV_Error(Error::StsNotImplemented, "Unsupported pooling " + opcode); + } + layerParams.set("global_pooling", true); + auto options = op.builtin_options_as_ReducerOptions(); + bool keep_dims = options->keep_dims(); + + if (!keep_dims) { + const auto name = layerParams.name; + layerParams.name += "/global_pooling"; + addLayer(layerParams, op); + + int out = op.outputs()->Get(0); + auto outId = layerIds[out]; + int flattenId = addFlattenLayer(1, -1, name, outId, isInt8(op) ? CV_8S : CV_32F); + layerIds[out] = std::make_pair(flattenId, 0); + } + else { + addLayer(layerParams, op); + } +} + int TFLiteImporter::addPermuteLayer(const std::vector& order, const std::string& permName, const std::pair& inpId, int dtype) { @@ -786,6 +821,16 @@ int TFLiteImporter::addReshapeLayer(const std::vector& shape, int axis, int return id; } +int TFLiteImporter::addFlattenLayer(int axis, int end_axis, const std::string& name, const std::pair& inpId, int dtype) +{ + LayerParams lp; + lp.set("axis", axis); + lp.set("end_axis", end_axis); + int id = dstNet.addLayer(name, "Flatten", dtype, lp); + dstNet.connect(inpId.first, inpId.second, id, 0); + return id; +} + void TFLiteImporter::parseDeconvolution(const Operator& op, const std::string& opcode, LayerParams& layerParams) { layerParams.type = "Deconvolution"; diff --git a/modules/dnn/test/test_tflite_importer.cpp b/modules/dnn/test/test_tflite_importer.cpp index 7621b44ff5..8d374dc050 100644 --- a/modules/dnn/test/test_tflite_importer.cpp +++ b/modules/dnn/test/test_tflite_importer.cpp @@ -260,6 +260,14 @@ TEST_P(Test_TFLite, permute) { testLayer("permutation_4d_0231"); } +TEST_P(Test_TFLite, global_average_pooling_2d) { + testLayer("global_average_pooling_2d"); +} + +TEST_P(Test_TFLite, global_max_pooling_2d) { + testLayer("global_max_pooling_2d"); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets()); }} // namespace