diff --git a/modules/cudaarithm/src/arithm.cpp b/modules/cudaarithm/src/arithm.cpp index 381580cff..30cf225e1 100644 --- a/modules/cudaarithm/src/arithm.cpp +++ b/modules/cudaarithm/src/arithm.cpp @@ -439,7 +439,8 @@ namespace class ConvolutionImpl : public Convolution { public: - explicit ConvolutionImpl(Size user_block_size_) : user_block_size(user_block_size_) {} + explicit ConvolutionImpl(Size user_block_size_) : user_block_size(user_block_size_), planR2C(0), planC2R(0) {} + ~ConvolutionImpl(); void convolve(InputArray image, InputArray templ, OutputArray result, bool ccorr = false, Stream& stream = Stream::Null()); @@ -452,6 +453,9 @@ namespace Size user_block_size; Size dft_size; + cufftHandle planR2C, planC2R; + Size plan_size; + GpuMat image_spect, templ_spect, result_spect; GpuMat image_block, templ_block, result_data; }; @@ -491,6 +495,27 @@ namespace // Use maximum result matrix block size for the estimated DFT block size block_size.width = std::min(dft_size.width - templ_size.width + 1, result_size.width); block_size.height = std::min(dft_size.height - templ_size.height + 1, result_size.height); + + if (dft_size != plan_size) + { + if (planR2C != 0) + cufftSafeCall( cufftDestroy(planR2C) ); + if (planC2R != 0) + cufftSafeCall( cufftDestroy(planC2R) ); + + cufftSafeCall( cufftPlan2d(&planC2R, dft_size.height, dft_size.width, CUFFT_C2R) ); + cufftSafeCall( cufftPlan2d(&planR2C, dft_size.height, dft_size.width, CUFFT_R2C) ); + + plan_size = dft_size; + } + } + + ConvolutionImpl::~ConvolutionImpl() + { + if (planR2C != 0) + cufftSafeCall( cufftDestroy(planR2C) ); + if (planC2R != 0) + cufftSafeCall( cufftDestroy(planC2R) ); } Size ConvolutionImpl::estimateBlockSize(Size result_size) @@ -516,10 +541,6 @@ namespace cudaStream_t stream = StreamAccessor::getStream(_stream); - cufftHandle planR2C, planC2R; - cufftSafeCall( cufftPlan2d(&planC2R, dft_size.height, dft_size.width, CUFFT_C2R) ); - cufftSafeCall( cufftPlan2d(&planR2C, dft_size.height, dft_size.width, CUFFT_R2C) ); - cufftSafeCall( cufftSetStream(planR2C, stream) ); cufftSafeCall( cufftSetStream(planC2R, stream) ); @@ -559,9 +580,6 @@ namespace } } - cufftSafeCall( cufftDestroy(planR2C) ); - cufftSafeCall( cufftDestroy(planC2R) ); - syncOutput(result, _result, _stream); } }