mirror of https://github.com/opencv/opencv.git
parent
9ec3d76b21
commit
17a35587e1
4 changed files with 281 additions and 57 deletions
@ -0,0 +1,80 @@ |
|||||||
|
// This file is part of OpenCV project.
|
||||||
|
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||||
|
// of this distribution and at http://opencv.org/license.html.
|
||||||
|
|
||||||
|
#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_ACTIVATION_HPP |
||||||
|
#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_ACTIVATION_HPP |
||||||
|
|
||||||
|
#include <cudnn.h> |
||||||
|
|
||||||
|
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn { |
||||||
|
|
||||||
|
class ActivationDescriptor { |
||||||
|
public: |
||||||
|
enum class ActivationType { |
||||||
|
IDENTITY, |
||||||
|
RELU, |
||||||
|
CLIPPED_RELU, |
||||||
|
TANH, |
||||||
|
SIGMOID, |
||||||
|
ELU |
||||||
|
}; |
||||||
|
|
||||||
|
ActivationDescriptor() noexcept : descriptor{ nullptr } { } |
||||||
|
ActivationDescriptor(const ActivationDescriptor&) = delete; |
||||||
|
ActivationDescriptor(ActivationDescriptor&& other) noexcept |
||||||
|
: descriptor{ other.descriptor } { |
||||||
|
other.descriptor = nullptr; |
||||||
|
} |
||||||
|
|
||||||
|
/* `relu_ceiling_or_elu_alpha`:
|
||||||
|
* - `alpha` coefficient in ELU activation |
||||||
|
* - `ceiling` for CLIPPED_RELU activation |
||||||
|
*/ |
||||||
|
ActivationDescriptor(ActivationType type, double relu_ceiling_or_elu_alpha = 0.0) { |
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnCreateActivationDescriptor(&descriptor)); |
||||||
|
try { |
||||||
|
const auto mode = [type] { |
||||||
|
switch(type) { |
||||||
|
case ActivationType::IDENTITY: return CUDNN_ACTIVATION_IDENTITY; |
||||||
|
case ActivationType::RELU: return CUDNN_ACTIVATION_RELU; |
||||||
|
case ActivationType::CLIPPED_RELU: return CUDNN_ACTIVATION_CLIPPED_RELU; |
||||||
|
case ActivationType::SIGMOID: return CUDNN_ACTIVATION_SIGMOID; |
||||||
|
case ActivationType::TANH: return CUDNN_ACTIVATION_TANH; |
||||||
|
case ActivationType::ELU: return CUDNN_ACTIVATION_ELU; |
||||||
|
} |
||||||
|
CV_Assert(0); |
||||||
|
return CUDNN_ACTIVATION_IDENTITY; |
||||||
|
} (); |
||||||
|
|
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnSetActivationDescriptor(descriptor, mode, CUDNN_NOT_PROPAGATE_NAN, relu_ceiling_or_elu_alpha)); |
||||||
|
} catch(...) { |
||||||
|
/* cudnnDestroyActivationDescriptor will not fail for a valid descriptor object */ |
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnDestroyActivationDescriptor(descriptor)); |
||||||
|
throw; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
~ActivationDescriptor() noexcept { |
||||||
|
if (descriptor != nullptr) { |
||||||
|
/* cudnnDestroyActivationDescriptor will not fail */ |
||||||
|
CUDA4DNN_CHECK_CUDNN(cudnnDestroyActivationDescriptor(descriptor)); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
ActivationDescriptor& operator=(const ActivationDescriptor&) = delete; |
||||||
|
ActivationDescriptor& operator=(ActivationDescriptor&& other) noexcept { |
||||||
|
descriptor = other.descriptor; |
||||||
|
other.descriptor = nullptr; |
||||||
|
return *this; |
||||||
|
}; |
||||||
|
|
||||||
|
cudnnActivationDescriptor_t get() const noexcept { return descriptor; } |
||||||
|
|
||||||
|
private: |
||||||
|
cudnnActivationDescriptor_t descriptor; |
||||||
|
}; |
||||||
|
|
||||||
|
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */ |
||||||
|
|
||||||
|
#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_ACTIVATION_HPP */ |
Loading…
Reference in new issue