Leaky RELU support for TFLite.

pull/26132/head
Alexander Smorkalov 3 months ago
parent 79faf857d9
commit 209802c9f6
  1. 12
      modules/dnn/src/tflite/tflite_importer.cpp
  2. 4
      modules/dnn/test/test_tflite_importer.cpp

@ -271,7 +271,7 @@ TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap()
dispatch["DEPTHWISE_CONV_2D"] = &TFLiteImporter::parseDWConvolution;
dispatch["ADD"] = dispatch["MUL"] = &TFLiteImporter::parseEltwise;
dispatch["RELU"] = dispatch["PRELU"] = dispatch["HARD_SWISH"] =
dispatch["LOGISTIC"] = &TFLiteImporter::parseActivation;
dispatch["LOGISTIC"] = dispatch["LEAKY_RELU"] = &TFLiteImporter::parseActivation;
dispatch["MAX_POOL_2D"] = dispatch["AVERAGE_POOL_2D"] = &TFLiteImporter::parsePooling;
dispatch["MaxPoolingWithArgmax2D"] = &TFLiteImporter::parsePoolingWithArgmax;
dispatch["MaxUnpooling2D"] = &TFLiteImporter::parseUnpooling;
@ -1029,6 +1029,7 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
}
void TFLiteImporter::parseActivation(const Operator& op, const std::string& opcode, LayerParams& activParams, bool isFused) {
float slope = 0.;
if (opcode == "NONE")
return;
else if (opcode == "RELU6")
@ -1041,6 +1042,13 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
activParams.type = "HardSwish";
else if (opcode == "LOGISTIC")
activParams.type = "Sigmoid";
else if (opcode == "LEAKY_RELU")
{
activParams.type = "ReLU";
auto options = reinterpret_cast<const LeakyReluOptions*>(op.builtin_options());
slope = options->alpha();
activParams.set("negative_slope", slope);
}
else
CV_Error(Error::StsNotImplemented, "Unsupported activation " + opcode);
@ -1072,6 +1080,8 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
y = 1.0f / (1.0f + std::exp(-x));
else if (opcode == "HARD_SWISH")
y = x * max(0.f, min(1.f, x / 6.f + 0.5f));
else if (opcode == "LEAKY_RELU")
y = x >= 0.f ? x : slope*x;
else
CV_Error(Error::StsNotImplemented, "Lookup table for " + opcode);

@ -268,6 +268,10 @@ TEST_P(Test_TFLite, global_max_pooling_2d) {
testLayer("global_max_pooling_2d");
}
TEST_P(Test_TFLite, leakyRelu) {
testLayer("leakyRelu");
}
INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets());
}} // namespace

Loading…
Cancel
Save