|
|
|
@ -271,7 +271,7 @@ TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap() |
|
|
|
|
dispatch["DEPTHWISE_CONV_2D"] = &TFLiteImporter::parseDWConvolution; |
|
|
|
|
dispatch["ADD"] = dispatch["MUL"] = &TFLiteImporter::parseEltwise; |
|
|
|
|
dispatch["RELU"] = dispatch["PRELU"] = dispatch["HARD_SWISH"] = |
|
|
|
|
dispatch["LOGISTIC"] = &TFLiteImporter::parseActivation; |
|
|
|
|
dispatch["LOGISTIC"] = dispatch["LEAKY_RELU"] = &TFLiteImporter::parseActivation; |
|
|
|
|
dispatch["MAX_POOL_2D"] = dispatch["AVERAGE_POOL_2D"] = &TFLiteImporter::parsePooling; |
|
|
|
|
dispatch["MaxPoolingWithArgmax2D"] = &TFLiteImporter::parsePoolingWithArgmax; |
|
|
|
|
dispatch["MaxUnpooling2D"] = &TFLiteImporter::parseUnpooling; |
|
|
|
@ -1029,6 +1029,7 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void TFLiteImporter::parseActivation(const Operator& op, const std::string& opcode, LayerParams& activParams, bool isFused) { |
|
|
|
|
float slope = 0.; |
|
|
|
|
if (opcode == "NONE") |
|
|
|
|
return; |
|
|
|
|
else if (opcode == "RELU6") |
|
|
|
@ -1041,6 +1042,13 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco |
|
|
|
|
activParams.type = "HardSwish"; |
|
|
|
|
else if (opcode == "LOGISTIC") |
|
|
|
|
activParams.type = "Sigmoid"; |
|
|
|
|
else if (opcode == "LEAKY_RELU") |
|
|
|
|
{ |
|
|
|
|
activParams.type = "ReLU"; |
|
|
|
|
auto options = reinterpret_cast<const LeakyReluOptions*>(op.builtin_options()); |
|
|
|
|
slope = options->alpha(); |
|
|
|
|
activParams.set("negative_slope", slope); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
CV_Error(Error::StsNotImplemented, "Unsupported activation " + opcode); |
|
|
|
|
|
|
|
|
@ -1072,6 +1080,8 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco |
|
|
|
|
y = 1.0f / (1.0f + std::exp(-x)); |
|
|
|
|
else if (opcode == "HARD_SWISH") |
|
|
|
|
y = x * max(0.f, min(1.f, x / 6.f + 0.5f)); |
|
|
|
|
else if (opcode == "LEAKY_RELU") |
|
|
|
|
y = x >= 0.f ? x : slope*x; |
|
|
|
|
else |
|
|
|
|
CV_Error(Error::StsNotImplemented, "Lookup table for " + opcode); |
|
|
|
|
|
|
|
|
|