diff --git a/modules/dnn/src/tflite/tflite_importer.cpp b/modules/dnn/src/tflite/tflite_importer.cpp index 92bfeeef65..7e7f1d0503 100644 --- a/modules/dnn/src/tflite/tflite_importer.cpp +++ b/modules/dnn/src/tflite/tflite_importer.cpp @@ -271,7 +271,7 @@ TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap() dispatch["DEPTHWISE_CONV_2D"] = &TFLiteImporter::parseDWConvolution; dispatch["ADD"] = dispatch["MUL"] = &TFLiteImporter::parseEltwise; dispatch["RELU"] = dispatch["PRELU"] = dispatch["HARD_SWISH"] = - dispatch["LOGISTIC"] = &TFLiteImporter::parseActivation; + dispatch["LOGISTIC"] = dispatch["LEAKY_RELU"] = &TFLiteImporter::parseActivation; dispatch["MAX_POOL_2D"] = dispatch["AVERAGE_POOL_2D"] = &TFLiteImporter::parsePooling; dispatch["MaxPoolingWithArgmax2D"] = &TFLiteImporter::parsePoolingWithArgmax; dispatch["MaxUnpooling2D"] = &TFLiteImporter::parseUnpooling; @@ -1029,6 +1029,7 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco } void TFLiteImporter::parseActivation(const Operator& op, const std::string& opcode, LayerParams& activParams, bool isFused) { + float slope = 0.; if (opcode == "NONE") return; else if (opcode == "RELU6") @@ -1041,6 +1042,13 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco activParams.type = "HardSwish"; else if (opcode == "LOGISTIC") activParams.type = "Sigmoid"; + else if (opcode == "LEAKY_RELU") + { + activParams.type = "ReLU"; + auto options = reinterpret_cast(op.builtin_options()); + slope = options->alpha(); + activParams.set("negative_slope", slope); + } else CV_Error(Error::StsNotImplemented, "Unsupported activation " + opcode); @@ -1072,6 +1080,8 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco y = 1.0f / (1.0f + std::exp(-x)); else if (opcode == "HARD_SWISH") y = x * max(0.f, min(1.f, x / 6.f + 0.5f)); + else if (opcode == "LEAKY_RELU") + y = x >= 0.f ? x : slope*x; else CV_Error(Error::StsNotImplemented, "Lookup table for " + opcode); diff --git a/modules/dnn/test/test_tflite_importer.cpp b/modules/dnn/test/test_tflite_importer.cpp index 8d374dc050..9773943fba 100644 --- a/modules/dnn/test/test_tflite_importer.cpp +++ b/modules/dnn/test/test_tflite_importer.cpp @@ -268,6 +268,10 @@ TEST_P(Test_TFLite, global_max_pooling_2d) { testLayer("global_max_pooling_2d"); } +TEST_P(Test_TFLite, leakyRelu) { + testLayer("leakyRelu"); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets()); }} // namespace