|
|
|
@ -740,6 +740,9 @@ const char* const SigmoidFunctor::BaseDefaultFunctor<SigmoidFunctor>::ocl_kernel |
|
|
|
|
struct ELUFunctor : public BaseDefaultFunctor<ELUFunctor> |
|
|
|
|
{ |
|
|
|
|
typedef ELULayer Layer; |
|
|
|
|
float alpha; |
|
|
|
|
|
|
|
|
|
explicit ELUFunctor(float alpha_ = 1.f) : alpha(alpha_) {} |
|
|
|
|
|
|
|
|
|
bool supportBackend(int backendId, int) |
|
|
|
|
{ |
|
|
|
@ -749,14 +752,19 @@ struct ELUFunctor : public BaseDefaultFunctor<ELUFunctor> |
|
|
|
|
|
|
|
|
|
inline float calculate(float x) const |
|
|
|
|
{ |
|
|
|
|
return x >= 0.f ? x : exp(x) - 1.f; |
|
|
|
|
return x >= 0.f ? x : alpha * (exp(x) - 1.f); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
inline void setKernelParams(ocl::Kernel& kernel) const |
|
|
|
|
{ |
|
|
|
|
kernel.set(3, alpha); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#ifdef HAVE_HALIDE |
|
|
|
|
void attachHalide(const Halide::Expr& input, Halide::Func& top) |
|
|
|
|
{ |
|
|
|
|
Halide::Var x("x"), y("y"), c("c"), n("n"); |
|
|
|
|
top(x, y, c, n) = select(input >= 0.0f, input, exp(input) - 1); |
|
|
|
|
top(x, y, c, n) = select(input >= 0.0f, input, alpha * (exp(input) - 1)); |
|
|
|
|
} |
|
|
|
|
#endif // HAVE_HALIDE
|
|
|
|
|
|
|
|
|
@ -770,7 +778,7 @@ struct ELUFunctor : public BaseDefaultFunctor<ELUFunctor> |
|
|
|
|
#ifdef HAVE_DNN_NGRAPH |
|
|
|
|
std::shared_ptr<ngraph::Node> initNgraphAPI(const std::shared_ptr<ngraph::Node>& node) |
|
|
|
|
{ |
|
|
|
|
return std::make_shared<ngraph::op::Elu>(node, 1.0); |
|
|
|
|
return std::make_shared<ngraph::op::Elu>(node, alpha); |
|
|
|
|
} |
|
|
|
|
#endif // HAVE_DNN_NGRAPH
|
|
|
|
|
|
|
|
|
@ -1263,8 +1271,10 @@ Ptr<SigmoidLayer> SigmoidLayer::create(const LayerParams& params) |
|
|
|
|
|
|
|
|
|
Ptr<ELULayer> ELULayer::create(const LayerParams& params) |
|
|
|
|
{ |
|
|
|
|
Ptr<ELULayer> l(new ElementWiseLayer<ELUFunctor>(ELUFunctor())); |
|
|
|
|
float alpha = params.get<float>("alpha", 1.0f); |
|
|
|
|
Ptr<ELULayer> l(new ElementWiseLayer<ELUFunctor>(ELUFunctor(alpha))); |
|
|
|
|
l->setParamsFrom(params); |
|
|
|
|
l->alpha = alpha; |
|
|
|
|
|
|
|
|
|
return l; |
|
|
|
|
} |
|
|
|
|