@ -58,6 +58,16 @@ namespace cv { namespace dnn {
return Tensor < T > ( std : : begin ( sizes ) , std : : end ( sizes ) ) ;
}
template < class T > inline
void copyMatToTensorImpl ( const Mat & srcMat , const TensorSpan < T > destTensor , const Stream & stream ) {
CV_Assert ( srcMat . total ( ) > = destTensor . size ( ) ) ;
Mat temp = srcMat . isContinuous ( ) ? srcMat : srcMat . clone ( ) ;
CV_Assert ( temp . isContinuous ( ) ) ;
memcpy < T > ( destTensor . get ( ) , reinterpret_cast < T * > ( temp . data ) , destTensor . size ( ) , stream ) ;
}
/** @brief copies data from a cv::Mat to TensorType
*
* \ tparam T the type of the elements contained in TensorType object
@ -81,8 +91,7 @@ namespace cv { namespace dnn {
template < > inline
void copyMatToTensor ( const Mat & srcMat , const TensorSpan < half > destTensor , const Stream & stream ) {
/* should perhaps convert cv::Mat of different type to the required type and copy */
CV_Assert ( srcMat . type ( ) = = CV_32F ) ;
CV_CheckTypeEQ ( srcMat . type ( ) , CV_32F , " " ) ;
CV_Assert ( srcMat . total ( ) > = destTensor . size ( ) ) ;
Mat temp ;
@ -94,14 +103,20 @@ namespace cv { namespace dnn {
template < > inline
void copyMatToTensor ( const Mat & srcMat , const TensorSpan < float > destTensor , const Stream & stream ) {
/* should perhaps convert cv::Mat of different type to the required type and copy */
CV_Assert ( srcMat . type ( ) = = CV_32F ) ;
CV_Assert ( srcMat . total ( ) > = destTensor . size ( ) ) ;
CV_CheckTypeEQ ( srcMat . type ( ) , CV_32F , " " ) ;
copyMatToTensorImpl ( srcMat , destTensor , stream ) ;
}
Mat temp = srcMat . isContinuous ( ) ? srcMat : srcMat . clone ( ) ;
CV_Assert ( temp . isContinuous ( ) ) ;
template < > inline
void copyMatToTensor ( const Mat & srcMat , const TensorSpan < int32_t > destTensor , const Stream & stream ) {
CV_CheckTypeEQ ( srcMat . type ( ) , CV_32S , " " ) ;
copyMatToTensorImpl ( srcMat , destTensor , stream ) ;
}
memcpy < float > ( destTensor . get ( ) , reinterpret_cast < float * > ( temp . data ) , destTensor . size ( ) , stream ) ;
template < > inline
void copyMatToTensor ( const Mat & srcMat , const TensorSpan < int64_t > destTensor , const Stream & stream ) {
CV_CheckTypeEQ ( srcMat . type ( ) , CV_64S , " " ) ;
copyMatToTensorImpl ( srcMat , destTensor , stream ) ;
}
/** @brief copies data from a TensorType to a cv::Mat
@ -126,7 +141,7 @@ namespace cv { namespace dnn {
template < > inline
void copyTensorToMat ( TensorView < half > srcTensor , Mat & destMat , const Stream & stream ) {
CV_Assert ( destMat . type ( ) = = CV_32F ) ;
CV_CheckTypeEQ ( destMat . type ( ) , CV_32F , " Unsupported type " ) ;
CV_Assert ( destMat . total ( ) > = srcTensor . size ( ) ) ;
Mat temp ( shape ( destMat ) , CV_16F ) ;
@ -139,7 +154,7 @@ namespace cv { namespace dnn {
template < > inline
void copyTensorToMat ( TensorView < float > srcTensor , Mat & destMat , const Stream & stream ) {
CV_Assert ( destMat . type ( ) = = CV_32F ) ;
CV_CheckTypeEQ ( destMat . type ( ) , CV_32F , " Unsupported type " ) ;
CV_Assert ( destMat . total ( ) > = srcTensor . size ( ) ) ;
Mat temp = destMat . isContinuous ( ) ? destMat : destMat . clone ( ) ;
@ -200,6 +215,44 @@ namespace cv { namespace dnn {
return Ptr < BackendNode > ( ) ;
}
template < template < class > class NodeType , class . . . Args >
cv : : Ptr < BackendNode > make_cuda_node_with_type ( int targetId , int hostMatType , Args & & . . . args ) {
CV_CheckType ( hostMatType , hostMatType = = CV_32F | | hostMatType = = CV_32S | | hostMatType = = CV_64S , " " ) ;
if ( hostMatType = = CV_32S )
return Ptr < BackendNode > ( new NodeType < int32_t > ( std : : forward < Args > ( args ) . . . ) ) ;
else if ( hostMatType = = CV_64S )
return Ptr < BackendNode > ( new NodeType < int64_t > ( std : : forward < Args > ( args ) . . . ) ) ;
else if ( hostMatType = = CV_32F )
{
if ( targetId = = DNN_TARGET_CUDA_FP16 )
return Ptr < BackendNode > ( new NodeType < half > ( std : : forward < Args > ( args ) . . . ) ) ;
else if ( targetId = = DNN_TARGET_CUDA )
return Ptr < BackendNode > ( new NodeType < float > ( std : : forward < Args > ( args ) . . . ) ) ;
}
CV_Error ( Error : : BadDepth , " Unsupported mat type " ) ;
return Ptr < BackendNode > ( ) ;
}
template < template < class , class > class NodeType , class T_INDEX , class . . . Args >
cv : : Ptr < BackendNode > make_cuda_node_with_indices ( int targetId , int hostMatType , Args & & . . . args ) {
CV_CheckType ( hostMatType , hostMatType = = CV_32F | | hostMatType = = CV_32S | | hostMatType = = CV_64S , " " ) ;
if ( hostMatType = = CV_32S )
return Ptr < BackendNode > ( new NodeType < int32_t , T_INDEX > ( std : : forward < Args > ( args ) . . . ) ) ;
else if ( hostMatType = = CV_64S )
return Ptr < BackendNode > ( new NodeType < int64_t , T_INDEX > ( std : : forward < Args > ( args ) . . . ) ) ;
else if ( hostMatType = = CV_32F )
{
if ( targetId = = DNN_TARGET_CUDA_FP16 )
return Ptr < BackendNode > ( new NodeType < half , T_INDEX > ( std : : forward < Args > ( args ) . . . ) ) ;
else if ( targetId = = DNN_TARGET_CUDA )
return Ptr < BackendNode > ( new NodeType < float , T_INDEX > ( std : : forward < Args > ( args ) . . . ) ) ;
}
CV_Error ( Error : : BadDepth , " Unsupported mat type " ) ;
return Ptr < BackendNode > ( ) ;
}
/* base class for all CUDA backend/target wrappers */
class CUDABackendWrapper : public BackendWrapper {
public :
@ -224,11 +277,11 @@ namespace cv { namespace dnn {
namespace cuda4dnn { namespace detail {
template < class U >
void convert_D2H ( const cv : : Mat & mat , cuda4dnn : : csl : : View < U > view , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream ) ;
template < class DEVICE_T , class HOST_T >
void convert_D2H ( const cv : : Mat & mat , cuda4dnn : : csl : : View < DEVICE_T > view , cuda4dnn : : csl : : ManagedPtr < HOST_T > & device_temp , const cuda4dnn : : csl : : Stream & stream ) ;
template < > inline
void convert_D2H < half > ( const cv : : Mat & mat , cuda4dnn : : csl : : View < half > view , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream ) {
void convert_D2H < half , float > ( const cv : : Mat & mat , cuda4dnn : : csl : : View < half > view , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream ) {
if ( device_temp . size ( ) < view . size ( ) )
device_temp . reset ( view . size ( ) ) ;
auto temp_span = cuda4dnn : : csl : : Span < float > ( device_temp . get ( ) , view . size ( ) ) ;
@ -238,15 +291,25 @@ namespace cv { namespace dnn {
}
template < > inline
void convert_D2H < float > ( const cv : : Mat & mat , cuda4dnn : : csl : : View < float > view , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream ) {
void convert_D2H < float , float > ( const cv : : Mat & mat , cuda4dnn : : csl : : View < float > view , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream ) {
cuda4dnn : : csl : : memcpy < float > ( reinterpret_cast < float * > ( mat . data ) , view . data ( ) , view . size ( ) , stream ) ;
}
template < class U >
void convert_D2H_background ( const cv : : Mat & mat , cuda4dnn : : csl : : View < U > view , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream , const cuda4dnn : : csl : : Stream & d2h_stream , cuda4dnn : : csl : : Event & d2h_event ) ;
template < > inline
void convert_D2H < int32_t , int32_t > ( const cv : : Mat & mat , cuda4dnn : : csl : : View < int32_t > view , cuda4dnn : : csl : : ManagedPtr < int32_t > & device_temp , const cuda4dnn : : csl : : Stream & stream ) {
cuda4dnn : : csl : : memcpy < int32_t > ( reinterpret_cast < int32_t * > ( mat . data ) , view . data ( ) , view . size ( ) , stream ) ;
}
template < > inline
void convert_D2H < int64_t , int64_t > ( const cv : : Mat & mat , cuda4dnn : : csl : : View < int64_t > view , cuda4dnn : : csl : : ManagedPtr < int64_t > & device_temp , const cuda4dnn : : csl : : Stream & stream ) {
cuda4dnn : : csl : : memcpy < int64_t > ( reinterpret_cast < int64_t * > ( mat . data ) , view . data ( ) , view . size ( ) , stream ) ;
}
template < class DEVICE_T , class HOST_T >
void convert_D2H_background ( const cv : : Mat & mat , cuda4dnn : : csl : : View < DEVICE_T > view , cuda4dnn : : csl : : ManagedPtr < HOST_T > & device_temp , const cuda4dnn : : csl : : Stream & stream , const cuda4dnn : : csl : : Stream & d2h_stream , cuda4dnn : : csl : : Event & d2h_event ) ;
template < > inline
void convert_D2H_background < half > ( const cv : : Mat & mat , cuda4dnn : : csl : : View < half > view , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream , const cuda4dnn : : csl : : Stream & d2h_stream , cuda4dnn : : csl : : Event & d2h_event ) {
void convert_D2H_background < half , float > ( const cv : : Mat & mat , cuda4dnn : : csl : : View < half > view , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream , const cuda4dnn : : csl : : Stream & d2h_stream , cuda4dnn : : csl : : Event & d2h_event ) {
if ( device_temp . size ( ) < view . size ( ) )
device_temp . reset ( view . size ( ) ) ;
auto temp_span = cuda4dnn : : csl : : Span < float > ( device_temp . get ( ) , view . size ( ) ) ;
@ -266,17 +329,31 @@ namespace cv { namespace dnn {
}
template < > inline
void convert_D2H_background < float > ( const cv : : Mat & mat , cuda4dnn : : csl : : View < float > view , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream , const cuda4dnn : : csl : : Stream & d2h_stream , cuda4dnn : : csl : : Event & d2h_event ) {
void convert_D2H_background < float , float > ( const cv : : Mat & mat , cuda4dnn : : csl : : View < float > view , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream , const cuda4dnn : : csl : : Stream & d2h_stream , cuda4dnn : : csl : : Event & d2h_event ) {
d2h_event . record ( stream ) ;
cuda4dnn : : csl : : StreamWaitOnEvent ( d2h_stream , d2h_event ) ;
cuda4dnn : : csl : : memcpy < float > ( reinterpret_cast < float * > ( mat . data ) , view . data ( ) , view . size ( ) , d2h_stream ) ;
}
template < class U >
void convert_H2D ( cuda4dnn : : csl : : Span < U > span , const cv : : Mat & mat , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream ) ;
template < > inline
void convert_D2H_background < int32_t , int32_t > ( const cv : : Mat & mat , cuda4dnn : : csl : : View < int32_t > view , cuda4dnn : : csl : : ManagedPtr < int32_t > & device_temp , const cuda4dnn : : csl : : Stream & stream , const cuda4dnn : : csl : : Stream & d2h_stream , cuda4dnn : : csl : : Event & d2h_event ) {
d2h_event . record ( stream ) ;
cuda4dnn : : csl : : StreamWaitOnEvent ( d2h_stream , d2h_event ) ;
cuda4dnn : : csl : : memcpy < int32_t > ( reinterpret_cast < int32_t * > ( mat . data ) , view . data ( ) , view . size ( ) , d2h_stream ) ;
}
template < > inline
void convert_H2D < half > ( cuda4dnn : : csl : : Span < half > span , const cv : : Mat & mat , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream ) {
void convert_D2H_background < int64_t , int64_t > ( const cv : : Mat & mat , cuda4dnn : : csl : : View < int64_t > view , cuda4dnn : : csl : : ManagedPtr < int64_t > & device_temp , const cuda4dnn : : csl : : Stream & stream , const cuda4dnn : : csl : : Stream & d2h_stream , cuda4dnn : : csl : : Event & d2h_event ) {
d2h_event . record ( stream ) ;
cuda4dnn : : csl : : StreamWaitOnEvent ( d2h_stream , d2h_event ) ;
cuda4dnn : : csl : : memcpy < int64_t > ( reinterpret_cast < int64_t * > ( mat . data ) , view . data ( ) , view . size ( ) , d2h_stream ) ;
}
template < class DEVICE_T , class HOST_T >
void convert_H2D ( cuda4dnn : : csl : : Span < DEVICE_T > span , const cv : : Mat & mat , cuda4dnn : : csl : : ManagedPtr < HOST_T > & device_temp , const cuda4dnn : : csl : : Stream & stream ) ;
template < > inline
void convert_H2D < half , float > ( cuda4dnn : : csl : : Span < half > span , const cv : : Mat & mat , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream ) {
if ( device_temp . size ( ) < span . size ( ) )
device_temp . reset ( span . size ( ) ) ;
auto temp_span = cuda4dnn : : csl : : Span < float > ( device_temp . get ( ) , span . size ( ) ) ;
@ -286,15 +363,25 @@ namespace cv { namespace dnn {
}
template < > inline
void convert_H2D < float > ( cuda4dnn : : csl : : Span < float > span , const cv : : Mat & mat , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream ) {
void convert_H2D < float , float > ( cuda4dnn : : csl : : Span < float > span , const cv : : Mat & mat , cuda4dnn : : csl : : ManagedPtr < float > & device_temp , const cuda4dnn : : csl : : Stream & stream ) {
cuda4dnn : : csl : : memcpy < float > ( span . data ( ) , reinterpret_cast < float * > ( mat . data ) , span . size ( ) , stream ) ;
}
template < > inline
void convert_H2D < int32_t , int32_t > ( cuda4dnn : : csl : : Span < int32_t > span , const cv : : Mat & mat , cuda4dnn : : csl : : ManagedPtr < int32_t > & device_temp , const cuda4dnn : : csl : : Stream & stream ) {
cuda4dnn : : csl : : memcpy < int32_t > ( span . data ( ) , reinterpret_cast < int32_t * > ( mat . data ) , span . size ( ) , stream ) ;
}
template < > inline
void convert_H2D < int64_t , int64_t > ( cuda4dnn : : csl : : Span < int64_t > span , const cv : : Mat & mat , cuda4dnn : : csl : : ManagedPtr < int64_t > & device_temp , const cuda4dnn : : csl : : Stream & stream ) {
cuda4dnn : : csl : : memcpy < int64_t > ( span . data ( ) , reinterpret_cast < int64_t * > ( mat . data ) , span . size ( ) , stream ) ;
}
} } /* namespace cuda4dnn::detail */
template < class T , int TargetID >
template < class DEVICE_T , class HOST_ T, int TargetID >
class GenericCUDABackendWrapper final : public CUDABackendWrapper {
public :
using value_type = T ;
using value_type = DEVICE_ T;
using tensor_span_type = cuda4dnn : : csl : : TensorSpan < value_type > ;
using tensor_view_type = cuda4dnn : : csl : : TensorView < value_type > ;
@ -309,6 +396,7 @@ namespace cv { namespace dnn {
: CUDABackendWrapper ( TargetID )
{
shape = cv : : dnn : : shape ( m ) ;
hostMatDepth = m . depth ( ) ;
offset = 0 ;
shared_block = std : : make_shared < shared_block_type > ( ) ;
@ -324,7 +412,7 @@ namespace cv { namespace dnn {
/* we ignore the failure as this is just an optimization and not a requirement */
}
shared_block - > device = cuda4dnn : : csl : : ManagedPtr < T > ( m . total ( ) ) ;
shared_block - > device = cuda4dnn : : csl : : ManagedPtr < DEVICE_ T> ( m . total ( ) ) ;
}
GenericCUDABackendWrapper ( const Ptr < BackendWrapper > & base_ , const MatShape & shape_ )
@ -334,6 +422,7 @@ namespace cv { namespace dnn {
CV_Assert ( base ) ;
shape = shape_ ;
hostMatDepth = base_ - > getHostMatDepth ( ) ;
offset = 0 ;
shared_block = base - > shared_block ;
@ -377,9 +466,8 @@ namespace cv { namespace dnn {
auto & mat = shared_block - > host ;
CV_Assert ( mat . isContinuous ( ) ) ;
CV_Assert ( mat . type ( ) = = CV_32F ) ;
cuda4dnn : : detail : : convert_D2H < T > ( mat , view , shared_block - > device_temp , shared_block - > stream ) ;
cuda4dnn : : detail : : convert_D2H < DEVICE_T , HOST_ T> ( mat , view , shared_block - > device_temp , shared_block - > stream ) ;
shared_block - > stream . synchronize ( ) ;
} else if ( shared_block - > d2h_event & & shared_block - > d2h_event . busy ( ) ) {
/* wait for the background copy to finish */
@ -401,7 +489,7 @@ namespace cv { namespace dnn {
if ( ! shared_block - > d2h_event )
shared_block - > d2h_event = cuda4dnn : : csl : : Event ( true ) ;
cuda4dnn : : detail : : convert_D2H_background < T > ( mat , view , shared_block - > device_temp , shared_block - > stream , shared_block - > d2h_stream , shared_block - > d2h_event ) ;
cuda4dnn : : detail : : convert_D2H_background < DEVICE_T , HOST_ T> ( mat , view , shared_block - > device_temp , shared_block - > stream , shared_block - > d2h_stream , shared_block - > d2h_event ) ;
shared_block - > d2h_event . record ( shared_block - > d2h_stream ) ; // record position so that we can check status later
}
}
@ -422,9 +510,8 @@ namespace cv { namespace dnn {
auto & mat = shared_block - > host ;
CV_Assert ( mat . isContinuous ( ) ) ;
CV_Assert ( mat . type ( ) = = CV_32F ) ;
cuda4dnn : : detail : : convert_H2D < T > ( span , mat , shared_block - > device_temp , shared_block - > stream ) ;
cuda4dnn : : detail : : convert_H2D < DEVICE_T , HOST_ T> ( span , mat , shared_block - > device_temp , shared_block - > stream ) ;
}
}
@ -504,8 +591,8 @@ namespace cv { namespace dnn {
cv : : Mat host ;
cuda4dnn : : csl : : MemoryLockGuard memGuard ; /* keeps host memory page-locked if possible */
cuda4dnn : : csl : : ManagedPtr < T > device ;
cuda4dnn : : csl : : ManagedPtr < float > device_temp ; /* use for conversions */
cuda4dnn : : csl : : ManagedPtr < DEVICE_ T> device ;
cuda4dnn : : csl : : ManagedPtr < HOST_T > device_temp ; /* use for conversions */
cuda4dnn : : csl : : Stream stream ;
cuda4dnn : : csl : : Event d2h_event ;
@ -515,12 +602,16 @@ namespace cv { namespace dnn {
std : : shared_ptr < shared_block_type > shared_block ;
} ;
using CUDABackendWrapperFP16 = GenericCUDABackendWrapper < half , DNN_TARGET_CUDA_FP16 > ;
using CUDABackendWrapperFP32 = GenericCUDABackendWrapper < float , DNN_TARGET_CUDA > ;
using CUDABackendWrapperFP16 = GenericCUDABackendWrapper < half , float , DNN_TARGET_CUDA_FP16 > ;
using CUDABackendWrapperFP32 = GenericCUDABackendWrapper < float , float , DNN_TARGET_CUDA > ;
using CUDABackendWrapperINT32 = GenericCUDABackendWrapper < int32_t , int32_t , DNN_TARGET_CUDA > ;
using CUDABackendWrapperINT64 = GenericCUDABackendWrapper < int64_t , int64_t , DNN_TARGET_CUDA > ;
template < class T > struct GetCUDABackendWrapperType_ { } ;
template < > struct GetCUDABackendWrapperType_ < half > { typedef CUDABackendWrapperFP16 type ; } ;
template < > struct GetCUDABackendWrapperType_ < float > { typedef CUDABackendWrapperFP32 type ; } ;
template < > struct GetCUDABackendWrapperType_ < int32_t > { typedef CUDABackendWrapperINT32 type ; } ;
template < > struct GetCUDABackendWrapperType_ < int64_t > { typedef CUDABackendWrapperINT64 type ; } ;
template < class T >
using GetCUDABackendWrapperType = typename GetCUDABackendWrapperType_ < T > : : type ;