From 660a7098402f745ccd47e355dbfcbd40c1c906b9 Mon Sep 17 00:00:00 2001 From: Liubov Batanina Date: Fri, 6 Dec 2019 11:27:59 +0300 Subject: [PATCH] Support Swish and Mish activations --- modules/dnn/src/layers/elementwise_layers.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/modules/dnn/src/layers/elementwise_layers.cpp b/modules/dnn/src/layers/elementwise_layers.cpp index 8a0ddcdd75..3459734a08 100644 --- a/modules/dnn/src/layers/elementwise_layers.cpp +++ b/modules/dnn/src/layers/elementwise_layers.cpp @@ -579,7 +579,7 @@ struct SwishFunctor bool supportBackend(int backendId, int) { return backendId == DNN_BACKEND_OPENCV || - backendId == DNN_BACKEND_HALIDE; + backendId == DNN_BACKEND_HALIDE || backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;; } void apply(const float* srcptr, float* dstptr, int len, size_t planeSize, int cn0, int cn1) const @@ -640,7 +640,8 @@ struct SwishFunctor #ifdef HAVE_DNN_NGRAPH std::shared_ptr initNgraphAPI(const std::shared_ptr& node) { - CV_Error(Error::StsNotImplemented, ""); + auto sigmoid = std::make_shared(node); + return std::make_shared(node, sigmoid); } #endif // HAVE_DNN_NGRAPH @@ -659,7 +660,7 @@ struct MishFunctor bool supportBackend(int backendId, int) { return backendId == DNN_BACKEND_OPENCV || - backendId == DNN_BACKEND_HALIDE; + backendId == DNN_BACKEND_HALIDE || backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH; } void apply(const float* srcptr, float* dstptr, int len, size_t planeSize, int cn0, int cn1) const @@ -720,7 +721,13 @@ struct MishFunctor #ifdef HAVE_DNN_NGRAPH std::shared_ptr initNgraphAPI(const std::shared_ptr& node) { - CV_Error(Error::StsNotImplemented, ""); + float one = 1.0f; + auto constant = std::make_shared(ngraph::element::f32, ngraph::Shape{1}, &one); + auto exp_node = std::make_shared(node); + auto sum = std::make_shared(constant, exp_node, ngraph::op::AutoBroadcastType::NUMPY); + auto log_node = std::make_shared(sum); + auto tanh_node = std::make_shared(log_node); + return std::make_shared(node, tanh_node); } #endif // HAVE_DNN_NGRAPH