initial support of GPU LBP classifier: added new style xml format loading

pull/2/head
Marina Kolpakova 13 years ago
parent 02170a0a58
commit 1365e28a54
  1. 1
      cmake/OpenCVDetectCUDA.cmake
  2. 2
      modules/core/src/cuda/matrix_operations.cu
  3. 38
      modules/gpu/include/opencv2/gpu/gpu.hpp
  4. 8
      modules/gpu/src/brute_force_matcher.cpp
  5. 196
      modules/gpu/src/cascadeclassifier.cpp
  6. 6
      modules/gpu/src/imgproc.cpp
  7. 63
      modules/gpu/src/nvidia/NCVHaarObjectDetection.cu
  8. 61
      modules/gpu/src/nvidia/NPP_staging/NPP_staging.cu
  9. 26
      modules/gpu/src/nvidia/NPP_staging/NPP_staging.hpp
  10. 11
      modules/gpu/src/nvidia/core/NCVColorConversion.hpp
  11. 46
      modules/gpu/src/nvidia/core/NCVPixelOperations.hpp
  12. 8
      modules/gpu/src/nvidia/core/NCVPyramid.cu
  13. 20
      modules/gpu/src/nvidia/core/NCVRuntimeTemplates.hpp
  14. 14
      modules/gpu/test/nvidia/NCVAutoTestLister.hpp
  15. 10
      modules/gpu/test/nvidia/NCVTest.hpp
  16. 14
      modules/gpu/test/nvidia/TestHaarCascadeApplication.cpp
  17. 4
      modules/gpu/test/nvidia/TestHypothesesFilter.cpp
  18. 4
      modules/gpu/test/nvidia/TestResize.cpp
  19. 2
      modules/objdetect/perf/perf_cascadeclassifier.cpp
  20. 68
      modules/objdetect/src/cascadedetect.hpp
  21. 24
      modules/objdetect/src/resizeimg.cpp
  22. 12
      modules/objdetect/src/routine.cpp

@ -88,6 +88,7 @@ if(CUDA_FOUND)
if(APPLE)
set (CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -Xcompiler -fno-finite-math-only)
endif()
string(REPLACE "-Wsign-promo" "" CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS}")
# we remove -ggdb3 flag as it leads to preprocessor errors when compiling CUDA files (CUDA 4.1)
set(CMAKE_CXX_FLAGS_DEBUG_ ${CMAKE_CXX_FLAGS_DEBUG})

@ -96,7 +96,7 @@ namespace cv { namespace gpu { namespace device
__constant__ ushort scalar_16u[4];
__constant__ short scalar_16s[4];
__constant__ int scalar_32s[4];
__constant__ float scalar_32f[4];
__constant__ float scalar_32f[4];
__constant__ double scalar_64f[4];
template <typename T> __device__ __forceinline__ T readScalar(int i);

@ -1422,6 +1422,44 @@ private:
CascadeClassifierImpl* impl;
};
// The cascade classifier class for object detection.
class CV_EXPORTS CascadeClassifier_GPU_LBP
{
public:
enum stage { BOOST = 0 };
enum feature { LBP = 0 };
CascadeClassifier_GPU_LBP();
~CascadeClassifier_GPU_LBP();
bool empty() const;
bool load(const std::string& filename);
void release();
int detectMultiScale(const GpuMat& image, GpuMat& objectsBuf, double scaleFactor = 1.2, int minNeighbors = 4, Size minSize = Size());
bool findLargestObject;
bool visualizeInPlace;
Size getClassifierSize() const;
private:
bool read(const FileNode &root);
static const stage stageType = BOOST;
static const feature feature = LBP;
cv::Size NxM;
bool isStumps;
int ncategories;
struct Stage;
Stage* stages;
struct DTree;
// DTree* classifiers;
struct DTreeNode;
// DTreeNode* nodes;
};
////////////////////////////////// SURF //////////////////////////////////////////
class CV_EXPORTS SURF_GPU

@ -272,14 +272,14 @@ void cv::gpu::BFMatcher_GPU::matchConvert(const Mat& trainIdx, const Mat& distan
const float* distance_ptr = distance.ptr<float>();
for (int queryIdx = 0; queryIdx < nQuery; ++queryIdx, ++trainIdx_ptr, ++distance_ptr)
{
int trainIdx = *trainIdx_ptr;
int train_idx = *trainIdx_ptr;
if (trainIdx == -1)
if (train_idx == -1)
continue;
float distance = *distance_ptr;
float distance_local = *distance_ptr;
DMatch m(queryIdx, trainIdx, 0, distance);
DMatch m(queryIdx, train_idx, 0, distance_local);
matches.push_back(m);
}

@ -41,16 +41,40 @@
//M*/
#include "precomp.hpp"
#include <vector>
using namespace cv;
using namespace cv::gpu;
using namespace std;
#if !defined (HAVE_CUDA)
struct cv::gpu::CascadeClassifier_GPU_LBP::Stage
{
int first;
int ntrees;
float threshold;
Stage(int f = 0, int n = 0, float t = 0.f) : first(f), ntrees(n), threshold(t) {}
};
struct cv::gpu::CascadeClassifier_GPU_LBP::DTree
{
int nodeCount;
DTree(int n = 0) : nodeCount(n) {}
};
cv::gpu::CascadeClassifier_GPU::CascadeClassifier_GPU() { throw_nogpu(); }
struct cv::gpu::CascadeClassifier_GPU_LBP::DTreeNode
{
int featureIdx;
//float threshold; // for ordered features only
int left;
int right;
DTreeNode(int f = 0, int l = 0, int r = 0) : featureIdx(f), left(l), right(r) {}
};
#if !defined (HAVE_CUDA)
// ============ old fashioned haar cascade ==============================================//
cv::gpu::CascadeClassifier_GPU::CascadeClassifier_GPU() { throw_nogpu(); }
cv::gpu::CascadeClassifier_GPU::CascadeClassifier_GPU(const string&) { throw_nogpu(); }
cv::gpu::CascadeClassifier_GPU::~CascadeClassifier_GPU() { throw_nogpu(); }
cv::gpu::CascadeClassifier_GPU::~CascadeClassifier_GPU() { throw_nogpu(); }
bool cv::gpu::CascadeClassifier_GPU::empty() const { throw_nogpu(); return true; }
bool cv::gpu::CascadeClassifier_GPU::load(const string&) { throw_nogpu(); return true; }
@ -58,8 +82,174 @@ Size cv::gpu::CascadeClassifier_GPU::getClassifierSize() const { throw_nogpu();
int cv::gpu::CascadeClassifier_GPU::detectMultiScale( const GpuMat& , GpuMat& , double , int , Size) { throw_nogpu(); return 0; }
// ============ LBP cascade ==============================================//
cv::gpu::CascadeClassifier_GPU_LBP::CascadeClassifier_GPU_LBP() { throw_nogpu(); }
cv::gpu::CascadeClassifier_GPU_LBP::~CascadeClassifier_GPU_LBP() { throw_nogpu(); }
bool cv::gpu::CascadeClassifier_GPU_LBP::empty() const { throw_nogpu(); return true; }
bool cv::gpu::CascadeClassifier_GPU_LBP::load(const string&) { throw_nogpu(); return true; }
Size cv::gpu::CascadeClassifier_GPU_LBP::getClassifierSize() const { throw_nogpu(); return Size(); }
int cv::gpu::CascadeClassifier_GPU_LBP::detectMultiScale( const GpuMat& , GpuMat& , double , int , Size) { throw_nogpu(); return 0; }
#else
cv::gpu::CascadeClassifier_GPU_LBP::CascadeClassifier_GPU_LBP()
{
}
cv::gpu::CascadeClassifier_GPU_LBP::~CascadeClassifier_GPU_LBP()
{
}
bool cv::gpu::CascadeClassifier_GPU_LBP::empty() const { throw_nogpu(); return true; }
bool cv::gpu::CascadeClassifier_GPU_LBP::load(const string& classifierAsXml)
{
FileStorage fs(classifierAsXml, FileStorage::READ);
if (!fs.isOpened())
return false;
if (read(fs.getFirstTopLevelNode()))
return true;
return false;
}
#define GPU_CC_STAGE_TYPE "stageType"
#define GPU_CC_FEATURE_TYPE "featureType"
#define GPU_CC_BOOST "BOOST"
#define GPU_CC_LBP "LBP"
#define GPU_CC_MAX_CAT_COUNT "maxCatCount"
#define GPU_CC_HEIGHT "height"
#define GPU_CC_WIDTH "width"
#define GPU_CC_STAGE_PARAMS "stageParams"
#define GPU_CC_MAX_DEPTH "maxDepth"
#define GPU_CC_FEATURE_PARAMS "featureParams"
#define GPU_CC_STAGES "stages"
#define GPU_CC_STAGE_THRESHOLD "stageThreshold"
#define GPU_THRESHOLD_EPS 1e-5f
#define GPU_CC_WEAK_CLASSIFIERS "weakClassifiers"
#define GPU_CC_INTERNAL_NODES "internalNodes"
#define GPU_CC_LEAF_VALUES "leafValues"
bool CascadeClassifier_GPU_LBP::read(const FileNode &root)
{
string stageTypeStr = (string)root[GPU_CC_STAGE_TYPE];
CV_Assert(stageTypeStr == GPU_CC_BOOST);
string featureTypeStr = (string)root[GPU_CC_FEATURE_TYPE];
CV_Assert(featureTypeStr == GPU_CC_LBP);
NxM.width = (int)root[GPU_CC_WIDTH];
NxM.height = (int)root[GPU_CC_HEIGHT];
CV_Assert( NxM.height > 0 && NxM.width > 0 );
isStumps = ((int)(root[GPU_CC_STAGE_PARAMS][GPU_CC_MAX_DEPTH]) == 1) ? true : false;
// features
FileNode fn = root[GPU_CC_FEATURE_PARAMS];
if (fn.empty())
return false;
ncategories = fn[GPU_CC_MAX_CAT_COUNT];
int subsetSize = (ncategories + 31)/32, nodeStep = 3 + ( ncategories > 0 ? subsetSize : 1 );// ?
fn = root[GPU_CC_STAGES];
if (fn.empty())
return false;
delete[] stages;
// delete[] classifiers;
// delete[] nodes;
stages = new Stage[fn.size()];
std::vector<DTree> cl_trees;
std::vector<DTreeNode> cl_nodes;
std::vector<float> cl_leaves;
std::vector<int> subsets;
FileNodeIterator it = fn.begin(), it_end = fn.end();
size_t s_it = 0;
for (size_t si = 0; it != it_end; si++, ++it )
{
FileNode fns = *it;
fns = fns[GPU_CC_WEAK_CLASSIFIERS];
if (fns.empty())
return false;
stages[s_it++] = Stage((float)fns[GPU_CC_STAGE_THRESHOLD] - GPU_THRESHOLD_EPS,
(int)cl_trees.size(), (int)fns.size());
cl_trees.reserve(stages[si].first + stages[si].ntrees);
// weak trees
FileNodeIterator it1 = fns.begin(), it1_end = fns.end();
for ( ; it1 != it1_end; ++it1 )
{
FileNode fnw = *it1;
FileNode internalNodes = fnw[GPU_CC_INTERNAL_NODES];
FileNode leafValues = fnw[GPU_CC_LEAF_VALUES];
if ( internalNodes.empty() || leafValues.empty() )
return false;
DTree tree((int)internalNodes.size()/nodeStep );
cl_trees.push_back(tree);
cl_nodes.reserve(cl_nodes.size() + tree.nodeCount);
cl_leaves.reserve(cl_leaves.size() + leafValues.size());
if( subsetSize > 0 )
subsets.reserve(subsets.size() + tree.nodeCount * subsetSize);
// nodes
FileNodeIterator iIt = internalNodes.begin(), iEnd = internalNodes.end();
for( ; iIt != iEnd; )
{
DTreeNode node((int)*(iIt++), (int)*(iIt++), (int)*(iIt++));
cl_nodes.push_back(node);
if ( subsetSize > 0 )
{
for( int j = 0; j < subsetSize; j++, ++iIt )
subsets.push_back((int)*iIt); //????
}
}
iIt = leafValues.begin(), iEnd = leafValues.end();
// leaves
for( ; iIt != iEnd; ++iIt )
cl_leaves.push_back((float)*iIt);
}
}
return true;
}
#undef GPU_CC_STAGE_TYPE
#undef GPU_CC_BOOST
#undef GPU_CC_FEATURE_TYPE
#undef GPU_CC_LBP
#undef GPU_CC_MAX_CAT_COUNT
#undef GPU_CC_HEIGHT
#undef GPU_CC_WIDTH
#undef GPU_CC_STAGE_PARAMS
#undef GPU_CC_MAX_DEPTH
#undef GPU_CC_FEATURE_PARAMS
#undef GPU_CC_STAGES
#undef GPU_CC_STAGE_THRESHOLD
#undef GPU_THRESHOLD_EPS
#undef GPU_CC_WEAK_CLASSIFIERS
#undef GPU_CC_INTERNAL_NODES
#undef GPU_CC_LEAF_VALUES
Size cv::gpu::CascadeClassifier_GPU_LBP::getClassifierSize() const { throw_nogpu(); return Size(); }
int cv::gpu::CascadeClassifier_GPU_LBP::detectMultiScale( const GpuMat& , GpuMat& , double , int , Size) { throw_nogpu(); return 0; }
// ============ old fashioned haar cascade ==============================================//
struct cv::gpu::CascadeClassifier_GPU::CascadeClassifierImpl
{
CascadeClassifierImpl(const string& filename) : lastAllocatedFrameSize(-1, -1)

@ -357,6 +357,7 @@ namespace cv { namespace gpu { namespace device
void cv::gpu::buildWarpPlaneMaps(Size src_size, Rect dst_roi, const Mat &K, const Mat& R, const Mat &T,
float scale, GpuMat& map_x, GpuMat& map_y, Stream& stream)
{
(void)src_size;
using namespace ::cv::gpu::device::imgproc;
CV_Assert(K.size() == Size(3,3) && K.type() == CV_32F);
@ -390,6 +391,7 @@ namespace cv { namespace gpu { namespace device
void cv::gpu::buildWarpCylindricalMaps(Size src_size, Rect dst_roi, const Mat &K, const Mat& R, float scale,
GpuMat& map_x, GpuMat& map_y, Stream& stream)
{
(void)src_size;
using namespace ::cv::gpu::device::imgproc;
CV_Assert(K.size() == Size(3,3) && K.type() == CV_32F);
@ -422,6 +424,7 @@ namespace cv { namespace gpu { namespace device
void cv::gpu::buildWarpSphericalMaps(Size src_size, Rect dst_roi, const Mat &K, const Mat& R, float scale,
GpuMat& map_x, GpuMat& map_y, Stream& stream)
{
(void)src_size;
using namespace ::cv::gpu::device::imgproc;
CV_Assert(K.size() == Size(3,3) && K.type() == CV_32F);
@ -466,6 +469,7 @@ namespace
static void call(const GpuMat& src, GpuMat& dst, Size dsize, double angle, double xShift, double yShift, int interpolation, cudaStream_t stream)
{
(void)dsize;
static const int npp_inter[] = {NPPI_INTER_NN, NPPI_INTER_LINEAR, NPPI_INTER_CUBIC};
NppStreamHandler h(stream);
@ -1139,6 +1143,7 @@ namespace cv { namespace gpu { namespace device
void cv::gpu::mulSpectrums(const GpuMat& a, const GpuMat& b, GpuMat& c, int flags, bool conjB, Stream& stream)
{
(void)flags;
using namespace ::cv::gpu::device::imgproc;
typedef void (*Caller)(const PtrStep<cufftComplex>, const PtrStep<cufftComplex>, DevMem2D_<cufftComplex>, cudaStream_t stream);
@ -1169,6 +1174,7 @@ namespace cv { namespace gpu { namespace device
void cv::gpu::mulAndScaleSpectrums(const GpuMat& a, const GpuMat& b, GpuMat& c, int flags, float scale, bool conjB, Stream& stream)
{
(void)flags;
using namespace ::cv::gpu::device::imgproc;
typedef void (*Caller)(const PtrStep<cufftComplex>, const PtrStep<cufftComplex>, float scale, DevMem2D_<cufftComplex>, cudaStream_t stream);

@ -232,7 +232,7 @@ __device__ Ncv32u d_outMaskPosition;
__device__ void compactBlockWriteOutAnchorParallel(Ncv32u threadPassFlag, Ncv32u threadElem, Ncv32u *vectorOut)
{
#if __CUDA_ARCH__ && __CUDA_ARCH__ >= 110
__shared__ Ncv32u shmem[NUM_THREADS_ANCHORSPARALLEL * 2];
__shared__ Ncv32u numPassed;
__shared__ Ncv32u outMaskOffset;
@ -927,7 +927,7 @@ Ncv32u getStageNumWithNotLessThanNclassifiers(Ncv32u N, HaarClassifierCascadeDes
}
NCVStatus ncvApplyHaarClassifierCascade_device(NCVMatrix<Ncv32u> &d_integralImage,
NCVStatus ncvApplyHaarClassifierCascade_device(NCVMatrix<Ncv32u> &integral,
NCVMatrix<Ncv32f> &d_weights,
NCVMatrixAlloc<Ncv32u> &d_pixelMask,
Ncv32u &numDetections,
@ -945,32 +945,41 @@ NCVStatus ncvApplyHaarClassifierCascade_device(NCVMatrix<Ncv32u> &d_integralImag
cudaDeviceProp &devProp,
cudaStream_t cuStream)
{
ncvAssertReturn(d_integralImage.memType() == d_weights.memType() &&
d_integralImage.memType() == d_pixelMask.memType() &&
d_integralImage.memType() == gpuAllocator.memType() &&
(d_integralImage.memType() == NCVMemoryTypeDevice ||
d_integralImage.memType() == NCVMemoryTypeNone), NCV_MEM_RESIDENCE_ERROR);
ncvAssertReturn(integral.memType() == d_weights.memType()&&
integral.memType() == d_pixelMask.memType() &&
integral.memType() == gpuAllocator.memType() &&
(integral.memType() == NCVMemoryTypeDevice ||
integral.memType() == NCVMemoryTypeNone), NCV_MEM_RESIDENCE_ERROR);
ncvAssertReturn(d_HaarStages.memType() == d_HaarNodes.memType() &&
d_HaarStages.memType() == d_HaarFeatures.memType() &&
(d_HaarStages.memType() == NCVMemoryTypeDevice ||
d_HaarStages.memType() == NCVMemoryTypeNone), NCV_MEM_RESIDENCE_ERROR);
ncvAssertReturn(h_HaarStages.memType() != NCVMemoryTypeDevice, NCV_MEM_RESIDENCE_ERROR);
ncvAssertReturn(gpuAllocator.isInitialized() && cpuAllocator.isInitialized(), NCV_ALLOCATOR_NOT_INITIALIZED);
ncvAssertReturn((d_integralImage.ptr() != NULL && d_weights.ptr() != NULL && d_pixelMask.ptr() != NULL &&
ncvAssertReturn((integral.ptr() != NULL && d_weights.ptr() != NULL && d_pixelMask.ptr() != NULL &&
h_HaarStages.ptr() != NULL && d_HaarStages.ptr() != NULL && d_HaarNodes.ptr() != NULL &&
d_HaarFeatures.ptr() != NULL) || gpuAllocator.isCounting(), NCV_NULL_PTR);
ncvAssertReturn(anchorsRoi.width > 0 && anchorsRoi.height > 0 &&
d_pixelMask.width() >= anchorsRoi.width && d_pixelMask.height() >= anchorsRoi.height &&
d_weights.width() >= anchorsRoi.width && d_weights.height() >= anchorsRoi.height &&
d_integralImage.width() >= anchorsRoi.width + haar.ClassifierSize.width &&
d_integralImage.height() >= anchorsRoi.height + haar.ClassifierSize.height, NCV_DIMENSIONS_INVALID);
integral.width() >= anchorsRoi.width + haar.ClassifierSize.width &&
integral.height() >= anchorsRoi.height + haar.ClassifierSize.height, NCV_DIMENSIONS_INVALID);
ncvAssertReturn(scaleArea > 0, NCV_INVALID_SCALE);
ncvAssertReturn(d_HaarStages.length() >= haar.NumStages &&
d_HaarNodes.length() >= haar.NumClassifierTotalNodes &&
d_HaarFeatures.length() >= haar.NumFeatures &&
d_HaarStages.length() == h_HaarStages.length() &&
haar.NumClassifierRootNodes <= haar.NumClassifierTotalNodes, NCV_DIMENSIONS_INVALID);
ncvAssertReturn(haar.bNeedsTiltedII == false || gpuAllocator.isCounting(), NCV_NOIMPL_HAAR_TILTED_FEATURES);
ncvAssertReturn(pixelStep == 1 || pixelStep == 2, NCV_HAAR_INVALID_PIXEL_STEP);
NCV_SET_SKIP_COND(gpuAllocator.isCounting());
@ -979,7 +988,7 @@ NCVStatus ncvApplyHaarClassifierCascade_device(NCVMatrix<Ncv32u> &d_integralImag
NCVStatus ncvStat;
NCVMatrixAlloc<Ncv32u> h_integralImage(cpuAllocator, d_integralImage.width, d_integralImage.height, d_integralImage.pitch);
NCVMatrixAlloc<Ncv32u> h_integralImage(cpuAllocator, integral.width, integral.height, integral.pitch);
ncvAssertReturn(h_integralImage.isMemAllocated(), NCV_ALLOCATOR_BAD_ALLOC);
NCVMatrixAlloc<Ncv32f> h_weights(cpuAllocator, d_weights.width, d_weights.height, d_weights.pitch);
ncvAssertReturn(h_weights.isMemAllocated(), NCV_ALLOCATOR_BAD_ALLOC);
@ -997,7 +1006,7 @@ NCVStatus ncvApplyHaarClassifierCascade_device(NCVMatrix<Ncv32u> &d_integralImag
ncvStat = d_pixelMask.copySolid(h_pixelMask, 0);
ncvAssertReturnNcvStat(ncvStat);
ncvStat = d_integralImage.copySolid(h_integralImage, 0);
ncvStat = integral.copySolid(h_integralImage, 0);
ncvAssertReturnNcvStat(ncvStat);
ncvStat = d_weights.copySolid(h_weights, 0);
ncvAssertReturnNcvStat(ncvStat);
@ -1071,8 +1080,8 @@ NCVStatus ncvApplyHaarClassifierCascade_device(NCVMatrix<Ncv32u> &d_integralImag
cfdTexIImage = cudaCreateChannelDesc<Ncv32u>();
size_t alignmentOffset;
ncvAssertCUDAReturn(cudaBindTexture(&alignmentOffset, texIImage, d_integralImage.ptr(), cfdTexIImage,
(anchorsRoi.height + haar.ClassifierSize.height) * d_integralImage.pitch()), NCV_CUDA_ERROR);
ncvAssertCUDAReturn(cudaBindTexture(&alignmentOffset, texIImage, integral.ptr(), cfdTexIImage,
(anchorsRoi.height + haar.ClassifierSize.height) * integral.pitch()), NCV_CUDA_ERROR);
ncvAssertReturn(alignmentOffset==0, NCV_TEXTURE_BIND_ERROR);
}
@ -1189,7 +1198,7 @@ NCVStatus ncvApplyHaarClassifierCascade_device(NCVMatrix<Ncv32u> &d_integralImag
grid1,
block1,
cuStream,
d_integralImage.ptr(), d_integralImage.stride(),
integral.ptr(), integral.stride(),
d_weights.ptr(), d_weights.stride(),
d_HaarFeatures.ptr(), d_HaarNodes.ptr(), d_HaarStages.ptr(),
d_ptrNowData->ptr(),
@ -1259,7 +1268,7 @@ NCVStatus ncvApplyHaarClassifierCascade_device(NCVMatrix<Ncv32u> &d_integralImag
grid2,
block2,
cuStream,
d_integralImage.ptr(), d_integralImage.stride(),
integral.ptr(), integral.stride(),
d_weights.ptr(), d_weights.stride(),
d_HaarFeatures.ptr(), d_HaarNodes.ptr(), d_HaarStages.ptr(),
d_ptrNowData->ptr(),
@ -1320,7 +1329,7 @@ NCVStatus ncvApplyHaarClassifierCascade_device(NCVMatrix<Ncv32u> &d_integralImag
grid3,
block3,
cuStream,
d_integralImage.ptr(), d_integralImage.stride(),
integral.ptr(), integral.stride(),
d_weights.ptr(), d_weights.stride(),
d_HaarFeatures.ptr(), d_HaarNodes.ptr(), d_HaarStages.ptr(),
d_ptrNowData->ptr(),
@ -1455,10 +1464,14 @@ NCVStatus ncvGrowDetectionsVector_device(NCVVector<Ncv32u> &pixelMask,
cudaStream_t cuStream)
{
ncvAssertReturn(pixelMask.ptr() != NULL && hypotheses.ptr() != NULL, NCV_NULL_PTR);
ncvAssertReturn(pixelMask.memType() == hypotheses.memType() &&
pixelMask.memType() == NCVMemoryTypeDevice, NCV_MEM_RESIDENCE_ERROR);
ncvAssertReturn(rectWidth > 0 && rectHeight > 0 && curScale > 0, NCV_INVALID_ROI);
ncvAssertReturn(curScale > 0, NCV_INVALID_SCALE);
ncvAssertReturn(totalMaxDetections <= hypotheses.length() &&
numPixelMaskDetections <= pixelMask.length() &&
totalMaxDetections <= totalMaxDetections, NCV_INCONSISTENT_INPUT);
@ -1527,12 +1540,16 @@ NCVStatus ncvDetectObjectsMultiScale_device(NCVMatrix<Ncv8u> &d_srcImg,
d_srcImg.memType() == gpuAllocator.memType() &&
(d_srcImg.memType() == NCVMemoryTypeDevice ||
d_srcImg.memType() == NCVMemoryTypeNone), NCV_MEM_RESIDENCE_ERROR);
ncvAssertReturn(d_HaarStages.memType() == d_HaarNodes.memType() &&
d_HaarStages.memType() == d_HaarFeatures.memType() &&
(d_HaarStages.memType() == NCVMemoryTypeDevice ||
d_HaarStages.memType() == NCVMemoryTypeNone), NCV_MEM_RESIDENCE_ERROR);
ncvAssertReturn(h_HaarStages.memType() != NCVMemoryTypeDevice, NCV_MEM_RESIDENCE_ERROR);
ncvAssertReturn(gpuAllocator.isInitialized() && cpuAllocator.isInitialized(), NCV_ALLOCATOR_NOT_INITIALIZED);
ncvAssertReturn((d_srcImg.ptr() != NULL && d_dstRects.ptr() != NULL &&
h_HaarStages.ptr() != NULL && d_HaarStages.ptr() != NULL && d_HaarNodes.ptr() != NULL &&
d_HaarFeatures.ptr() != NULL) || gpuAllocator.isCounting(), NCV_NULL_PTR);
@ -1540,13 +1557,17 @@ NCVStatus ncvDetectObjectsMultiScale_device(NCVMatrix<Ncv8u> &d_srcImg,
d_srcImg.width() >= srcRoi.width && d_srcImg.height() >= srcRoi.height &&
srcRoi.width >= minObjSize.width && srcRoi.height >= minObjSize.height &&
d_dstRects.length() >= 1, NCV_DIMENSIONS_INVALID);
ncvAssertReturn(scaleStep > 1.0f, NCV_INVALID_SCALE);
ncvAssertReturn(d_HaarStages.length() >= haar.NumStages &&
d_HaarNodes.length() >= haar.NumClassifierTotalNodes &&
d_HaarFeatures.length() >= haar.NumFeatures &&
d_HaarStages.length() == h_HaarStages.length() &&
haar.NumClassifierRootNodes <= haar.NumClassifierTotalNodes, NCV_DIMENSIONS_INVALID);
ncvAssertReturn(haar.bNeedsTiltedII == false, NCV_NOIMPL_HAAR_TILTED_FEATURES);
ncvAssertReturn(pixelStep == 1 || pixelStep == 2, NCV_HAAR_INVALID_PIXEL_STEP);
//TODO: set NPP active stream to cuStream
@ -1557,8 +1578,8 @@ NCVStatus ncvDetectObjectsMultiScale_device(NCVMatrix<Ncv8u> &d_srcImg,
Ncv32u integralWidth = d_srcImg.width() + 1;
Ncv32u integralHeight = d_srcImg.height() + 1;
NCVMatrixAlloc<Ncv32u> d_integralImage(gpuAllocator, integralWidth, integralHeight);
ncvAssertReturn(d_integralImage.isMemAllocated(), NCV_ALLOCATOR_BAD_ALLOC);
NCVMatrixAlloc<Ncv32u> integral(gpuAllocator, integralWidth, integralHeight);
ncvAssertReturn(integral.isMemAllocated(), NCV_ALLOCATOR_BAD_ALLOC);
NCVMatrixAlloc<Ncv64u> d_sqIntegralImage(gpuAllocator, integralWidth, integralHeight);
ncvAssertReturn(d_sqIntegralImage.isMemAllocated(), NCV_ALLOCATOR_BAD_ALLOC);
@ -1589,7 +1610,7 @@ NCVStatus ncvDetectObjectsMultiScale_device(NCVMatrix<Ncv8u> &d_srcImg,
NCV_SKIP_COND_BEGIN
nppStat = nppiStIntegral_8u32u_C1R(d_srcImg.ptr(), d_srcImg.pitch(),
d_integralImage.ptr(), d_integralImage.pitch(),
integral.ptr(), integral.pitch(),
NcvSize32u(d_srcImg.width(), d_srcImg.height()),
d_tmpIIbuf.ptr(), szTmpBufIntegral, devProp);
ncvAssertReturnNcvStat(nppStat);
@ -1676,7 +1697,7 @@ NCVStatus ncvDetectObjectsMultiScale_device(NCVMatrix<Ncv8u> &d_srcImg,
NCV_SKIP_COND_BEGIN
nppStat = nppiStDecimate_32u_C1R(
d_integralImage.ptr(), d_integralImage.pitch(),
integral.ptr(), integral.pitch(),
d_scaledIntegralImage.ptr(), d_scaledIntegralImage.pitch(),
srcIIRoi, scale, true);
ncvAssertReturnNcvStat(nppStat);

@ -1,7 +1,7 @@
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
@ -95,11 +95,6 @@ inline __device__ T warpScanInclusive(T idata, volatile T *s_Data)
pos += K_WARP_SIZE;
s_Data[pos] = idata;
//for(Ncv32u offset = 1; offset < K_WARP_SIZE; offset <<= 1)
//{
// s_Data[pos] += s_Data[pos - offset];
//}
s_Data[pos] += s_Data[pos - 1];
s_Data[pos] += s_Data[pos - 2];
s_Data[pos] += s_Data[pos - 4];
@ -315,7 +310,7 @@ NCVStatus scanRowsWrapperDevice(T_in *d_src, Ncv32u srcStride,
<T_in, T_out, tbDoSqr>
<<<roi.height, NUM_SCAN_THREADS, 0, nppStGetActiveCUDAstream()>>>
(d_src, (Ncv32u)alignmentOffset, roi.width, srcStride, d_dst, dstStride);
ncvAssertCUDALastErrorReturn(NPPST_CUDA_KERNEL_EXECUTION_ERROR);
return NPPST_SUCCESS;
@ -1447,14 +1442,14 @@ NCVStatus compactVector_32u_device(Ncv32u *d_src, Ncv32u srcLen,
//adjust hierarchical partial sums
for (Ncv32s i=(Ncv32s)partSumNums.size()-3; i>=0; i--)
{
dim3 grid(partSumNums[i+1]);
if (grid.x > 65535)
dim3 grid_local(partSumNums[i+1]);
if (grid_local.x > 65535)
{
grid.y = (grid.x + 65534) / 65535;
grid.x = 65535;
grid_local.y = (grid_local.x + 65534) / 65535;
grid_local.x = 65535;
}
removePass2Adjust
<<<grid, block, 0, nppStGetActiveCUDAstream()>>>
<<<grid_local, block, 0, nppStGetActiveCUDAstream()>>>
(d_hierSums.ptr() + partSumOffsets[i], partSumNums[i],
d_hierSums.ptr() + partSumOffsets[i+1]);
@ -1463,10 +1458,10 @@ NCVStatus compactVector_32u_device(Ncv32u *d_src, Ncv32u srcLen,
}
else
{
dim3 grid(partSumNums[1]);
dim3 grid_local(partSumNums[1]);
removePass1Scan
<true, false>
<<<grid, block, 0, nppStGetActiveCUDAstream()>>>
<<<grid_local, block, 0, nppStGetActiveCUDAstream()>>>
(d_src, srcLen,
d_hierSums.ptr(),
NULL, elemRemove);
@ -1651,7 +1646,7 @@ __forceinline__ __device__ float getValueMirrorColumn(const int offset,
__global__ void FilterRowBorderMirror_32f_C1R(Ncv32u srcStep,
Ncv32f *pDst,
Ncv32f *pDst,
NcvSize32u dstSize,
Ncv32u dstStep,
NcvRect32u roi,
@ -1677,7 +1672,7 @@ __global__ void FilterRowBorderMirror_32f_C1R(Ncv32u srcStep,
float sum = 0.0f;
for (int m = 0; m < nKernelSize; ++m)
{
sum += getValueMirrorRow (rowOffset, ix + m - p, roi.width)
sum += getValueMirrorRow (rowOffset, ix + m - p, roi.width)
* tex1Dfetch (texKernel, m);
}
@ -1709,7 +1704,7 @@ __global__ void FilterColumnBorderMirror_32f_C1R(Ncv32u srcStep,
float sum = 0.0f;
for (int m = 0; m < nKernelSize; ++m)
{
sum += getValueMirrorColumn (offset, srcStep, iy + m - p, roi.height)
sum += getValueMirrorColumn (offset, srcStep, iy + m - p, roi.height)
* tex1Dfetch (texKernel, m);
}
@ -1879,7 +1874,7 @@ texture<float, 2, cudaReadModeElementType> tex_src0;
__global__ void BlendFramesKernel(const float *u, const float *v, // forward flow
const float *ur, const float *vr, // backward flow
const float *o0, const float *o1, // coverage masks
int w, int h, int s,
int w, int h, int s,
float theta, float *out)
{
const int ix = threadIdx.x + blockDim.x * blockIdx.x;
@ -1903,7 +1898,7 @@ __global__ void BlendFramesKernel(const float *u, const float *v, // forward f
if (b0 && b1)
{
// pixel is visible on both frames
out[pos] = tex2D(tex_src0, x - _u * theta, y - _v * theta) * (1.0f - theta) +
out[pos] = tex2D(tex_src0, x - _u * theta, y - _v * theta) * (1.0f - theta) +
tex2D(tex_src1, x + _u * (1.0f - theta), y + _v * (1.0f - theta)) * theta;
}
else if (b0)
@ -2004,8 +1999,8 @@ NCVStatus nppiStInterpolateFrames(const NppStInterpolationState *pState)
Ncv32f *bwdV = pState->ppBuffers[5]; // backward v
// warp flow
ncvAssertReturnNcvStat (
nppiStVectorWarp_PSF2x2_32f_C1 (pState->pFU,
pState->size,
nppiStVectorWarp_PSF2x2_32f_C1 (pState->pFU,
pState->size,
pState->nStep,
pState->pFU,
pState->pFV,
@ -2014,8 +2009,8 @@ NCVStatus nppiStInterpolateFrames(const NppStInterpolationState *pState)
pState->pos,
fwdU) );
ncvAssertReturnNcvStat (
nppiStVectorWarp_PSF2x2_32f_C1 (pState->pFV,
pState->size,
nppiStVectorWarp_PSF2x2_32f_C1 (pState->pFV,
pState->size,
pState->nStep,
pState->pFU,
pState->pFV,
@ -2025,8 +2020,8 @@ NCVStatus nppiStInterpolateFrames(const NppStInterpolationState *pState)
fwdV) );
// warp backward flow
ncvAssertReturnNcvStat (
nppiStVectorWarp_PSF2x2_32f_C1 (pState->pBU,
pState->size,
nppiStVectorWarp_PSF2x2_32f_C1 (pState->pBU,
pState->size,
pState->nStep,
pState->pBU,
pState->pBV,
@ -2035,8 +2030,8 @@ NCVStatus nppiStInterpolateFrames(const NppStInterpolationState *pState)
1.0f - pState->pos,
bwdU) );
ncvAssertReturnNcvStat (
nppiStVectorWarp_PSF2x2_32f_C1 (pState->pBV,
pState->size,
nppiStVectorWarp_PSF2x2_32f_C1 (pState->pBV,
pState->size,
pState->nStep,
pState->pBU,
pState->pBV,
@ -2252,7 +2247,7 @@ NCVStatus nppiStVectorWarp_PSF1x1_32f_C1(const Ncv32f *pSrc,
Ncv32f timeScale,
Ncv32f *pDst)
{
ncvAssertReturn (pSrc != NULL &&
ncvAssertReturn (pSrc != NULL &&
pU != NULL &&
pV != NULL &&
pDst != NULL, NPPST_NULL_POINTER_ERROR);
@ -2286,7 +2281,7 @@ NCVStatus nppiStVectorWarp_PSF2x2_32f_C1(const Ncv32f *pSrc,
Ncv32f timeScale,
Ncv32f *pDst)
{
ncvAssertReturn (pSrc != NULL &&
ncvAssertReturn (pSrc != NULL &&
pU != NULL &&
pV != NULL &&
pDst != NULL &&
@ -2375,7 +2370,7 @@ __global__ void resizeSuperSample_32f(NcvSize32u srcSize,
}
float rw = (float) srcROI.width;
float rh = (float) srcROI.height;
float rh = (float) srcROI.height;
// source position
float x = scaleX * (float) ix;
@ -2529,7 +2524,7 @@ NCVStatus nppiStResize_32f_C1R(const Ncv32f *pSrc,
ncvAssertReturn (pSrc != NULL && pDst != NULL, NPPST_NULL_POINTER_ERROR);
ncvAssertReturn (xFactor != 0.0 && yFactor != 0.0, NPPST_INVALID_SCALE);
ncvAssertReturn (nSrcStep >= sizeof (Ncv32f) * (Ncv32u) srcSize.width &&
ncvAssertReturn (nSrcStep >= sizeof (Ncv32f) * (Ncv32u) srcSize.width &&
nDstStep >= sizeof (Ncv32f) * (Ncv32f) dstSize.width,
NPPST_INVALID_STEP);
@ -2547,7 +2542,7 @@ NCVStatus nppiStResize_32f_C1R(const Ncv32f *pSrc,
dim3 gridSize ((dstROI.width + ctaSize.x - 1) / ctaSize.x,
(dstROI.height + ctaSize.y - 1) / ctaSize.y);
resizeSuperSample_32f <<<gridSize, ctaSize, 0, nppStGetActiveCUDAstream ()>>>
resizeSuperSample_32f <<<gridSize, ctaSize, 0, nppStGetActiveCUDAstream ()>>>
(srcSize, srcStep, srcROI, pDst, dstSize, dstStep, dstROI, 1.0f / xFactor, 1.0f / yFactor);
}
else if (interpolation == nppStBicubic)

@ -1,7 +1,7 @@
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
@ -132,7 +132,7 @@ enum NppStInterpMode
/** Size of a buffer required for interpolation.
*
*
* Requires several such buffers. See \see NppStInterpolationState.
*
* \param srcSize [IN] Frame size (both frames must be of the same size)
@ -177,17 +177,17 @@ NCVStatus nppiStInterpolateFrames(const NppStInterpolationState *pState);
* \return NCV status code
*/
NCV_EXPORTS
NCVStatus nppiStFilterRowBorder_32f_C1R(const Ncv32f *pSrc,
NcvSize32u srcSize,
NCVStatus nppiStFilterRowBorder_32f_C1R(const Ncv32f *pSrc,
NcvSize32u srcSize,
Ncv32u nSrcStep,
Ncv32f *pDst,
NcvSize32u dstSize,
Ncv32f *pDst,
NcvSize32u dstSize,
Ncv32u nDstStep,
NcvRect32u oROI,
NcvRect32u oROI,
NppStBorderType borderType,
const Ncv32f *pKernel,
const Ncv32f *pKernel,
Ncv32s nKernelSize,
Ncv32s nAnchor,
Ncv32s nAnchor,
Ncv32f multiplier);
@ -225,14 +225,14 @@ NCVStatus nppiStFilterColumnBorder_32f_C1R(const Ncv32f *pSrc,
/** Size of buffer required for vector image warping.
*
*
* \param srcSize [IN] Source image size
* \param nStep [IN] Source image line step
* \param hpSize [OUT] Where to store computed size (host memory)
*
* \return NCV status code
*/
NCV_EXPORTS
NCV_EXPORTS
NCVStatus nppiStVectorWarpGetBufferSize(NcvSize32u srcSize,
Ncv32u nSrcStep,
Ncv32u *hpSize);
@ -316,7 +316,7 @@ NCVStatus nppiStVectorWarp_PSF2x2_32f_C1(const Ncv32f *pSrc,
* \param xFactor [IN] Row scale factor
* \param yFactor [IN] Column scale factor
* \param interpolation [IN] Interpolation type
*
*
* \return NCV status code
*/
NCV_EXPORTS

@ -1,7 +1,7 @@
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
@ -39,11 +39,14 @@
//
//M*/
// this file does not contain any used code.
#ifndef _ncv_color_conversion_hpp_
#define _ncv_color_conversion_hpp_
#include "NCVPixelOperations.hpp"
#if 0
enum NCVColorSpace
{
NCVColorSpaceGray,
@ -71,8 +74,7 @@ static void _pixColorConv(const Tin &pixIn, Tout &pixOut)
}};
template<NCVColorSpace CSin, NCVColorSpace CSout, typename Tin, typename Tout>
static
NCVStatus _ncvColorConv_host(const NCVMatrix<Tin> &h_imgIn,
static NCVStatus _ncvColorConv_host(const NCVMatrix<Tin> &h_imgIn,
const NCVMatrix<Tout> &h_imgOut)
{
ncvAssertReturn(h_imgIn.size() == h_imgOut.size(), NCV_DIMENSIONS_INVALID);
@ -92,5 +94,6 @@ NCVStatus _ncvColorConv_host(const NCVMatrix<Tin> &h_imgIn,
NCV_SKIP_COND_END
return NCV_SUCCESS;
}
#endif
#endif //_ncv_color_conversion_hpp_

@ -1,7 +1,7 @@
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
@ -47,38 +47,38 @@
#include "NCV.hpp"
template<typename TBase> inline __host__ __device__ TBase _pixMaxVal();
template<> static inline __host__ __device__ Ncv8u _pixMaxVal<Ncv8u>() {return UCHAR_MAX;}
template<> static inline __host__ __device__ Ncv8u _pixMaxVal<Ncv8u>() {return UCHAR_MAX;}
template<> static inline __host__ __device__ Ncv16u _pixMaxVal<Ncv16u>() {return USHRT_MAX;}
template<> static inline __host__ __device__ Ncv32u _pixMaxVal<Ncv32u>() {return UINT_MAX;}
template<> static inline __host__ __device__ Ncv8s _pixMaxVal<Ncv8s>() {return CHAR_MAX;}
template<> static inline __host__ __device__ Ncv16s _pixMaxVal<Ncv16s>() {return SHRT_MAX;}
template<> static inline __host__ __device__ Ncv32s _pixMaxVal<Ncv32s>() {return INT_MAX;}
template<> static inline __host__ __device__ Ncv32f _pixMaxVal<Ncv32f>() {return FLT_MAX;}
template<> static inline __host__ __device__ Ncv64f _pixMaxVal<Ncv64f>() {return DBL_MAX;}
template<> static inline __host__ __device__ Ncv32u _pixMaxVal<Ncv32u>() {return UINT_MAX;}
template<> static inline __host__ __device__ Ncv8s _pixMaxVal<Ncv8s>() {return CHAR_MAX;}
template<> static inline __host__ __device__ Ncv16s _pixMaxVal<Ncv16s>() {return SHRT_MAX;}
template<> static inline __host__ __device__ Ncv32s _pixMaxVal<Ncv32s>() {return INT_MAX;}
template<> static inline __host__ __device__ Ncv32f _pixMaxVal<Ncv32f>() {return FLT_MAX;}
template<> static inline __host__ __device__ Ncv64f _pixMaxVal<Ncv64f>() {return DBL_MAX;}
template<typename TBase> inline __host__ __device__ TBase _pixMinVal();
template<> static inline __host__ __device__ Ncv8u _pixMinVal<Ncv8u>() {return 0;}
template<> static inline __host__ __device__ Ncv8u _pixMinVal<Ncv8u>() {return 0;}
template<> static inline __host__ __device__ Ncv16u _pixMinVal<Ncv16u>() {return 0;}
template<> static inline __host__ __device__ Ncv32u _pixMinVal<Ncv32u>() {return 0;}
template<> static inline __host__ __device__ Ncv8s _pixMinVal<Ncv8s>() {return CHAR_MIN;}
template<> static inline __host__ __device__ Ncv8s _pixMinVal<Ncv8s>() {return CHAR_MIN;}
template<> static inline __host__ __device__ Ncv16s _pixMinVal<Ncv16s>() {return SHRT_MIN;}
template<> static inline __host__ __device__ Ncv32s _pixMinVal<Ncv32s>() {return INT_MIN;}
template<> static inline __host__ __device__ Ncv32f _pixMinVal<Ncv32f>() {return FLT_MIN;}
template<> static inline __host__ __device__ Ncv64f _pixMinVal<Ncv64f>() {return DBL_MIN;}
template<typename Tvec> struct TConvVec2Base;
template<> struct TConvVec2Base<uchar1> {typedef Ncv8u TBase;};
template<> struct TConvVec2Base<uchar3> {typedef Ncv8u TBase;};
template<> struct TConvVec2Base<uchar4> {typedef Ncv8u TBase;};
template<> struct TConvVec2Base<uchar1> {typedef Ncv8u TBase;};
template<> struct TConvVec2Base<uchar3> {typedef Ncv8u TBase;};
template<> struct TConvVec2Base<uchar4> {typedef Ncv8u TBase;};
template<> struct TConvVec2Base<ushort1> {typedef Ncv16u TBase;};
template<> struct TConvVec2Base<ushort3> {typedef Ncv16u TBase;};
template<> struct TConvVec2Base<ushort4> {typedef Ncv16u TBase;};
template<> struct TConvVec2Base<uint1> {typedef Ncv32u TBase;};
template<> struct TConvVec2Base<uint3> {typedef Ncv32u TBase;};
template<> struct TConvVec2Base<uint4> {typedef Ncv32u TBase;};
template<> struct TConvVec2Base<float1> {typedef Ncv32f TBase;};
template<> struct TConvVec2Base<float3> {typedef Ncv32f TBase;};
template<> struct TConvVec2Base<float4> {typedef Ncv32f TBase;};
template<> struct TConvVec2Base<uint1> {typedef Ncv32u TBase;};
template<> struct TConvVec2Base<uint3> {typedef Ncv32u TBase;};
template<> struct TConvVec2Base<uint4> {typedef Ncv32u TBase;};
template<> struct TConvVec2Base<float1> {typedef Ncv32f TBase;};
template<> struct TConvVec2Base<float3> {typedef Ncv32f TBase;};
template<> struct TConvVec2Base<float4> {typedef Ncv32f TBase;};
template<> struct TConvVec2Base<double1> {typedef Ncv64f TBase;};
template<> struct TConvVec2Base<double3> {typedef Ncv64f TBase;};
template<> struct TConvVec2Base<double4> {typedef Ncv64f TBase;};
@ -86,9 +86,9 @@ template<> struct TConvVec2Base<double4> {typedef Ncv64f TBase;};
#define NC(T) (sizeof(T) / sizeof(TConvVec2Base<T>::TBase))
template<typename TBase, Ncv32u NC> struct TConvBase2Vec;
template<> struct TConvBase2Vec<Ncv8u, 1> {typedef uchar1 TVec;};
template<> struct TConvBase2Vec<Ncv8u, 3> {typedef uchar3 TVec;};
template<> struct TConvBase2Vec<Ncv8u, 4> {typedef uchar4 TVec;};
template<> struct TConvBase2Vec<Ncv8u, 1> {typedef uchar1 TVec;};
template<> struct TConvBase2Vec<Ncv8u, 3> {typedef uchar3 TVec;};
template<> struct TConvBase2Vec<Ncv8u, 4> {typedef uchar4 TVec;};
template<> struct TConvBase2Vec<Ncv16u, 1> {typedef ushort1 TVec;};
template<> struct TConvBase2Vec<Ncv16u, 3> {typedef ushort3 TVec;};
template<> struct TConvBase2Vec<Ncv16u, 4> {typedef ushort4 TVec;};

@ -202,7 +202,7 @@ __global__ void kernelDownsampleX2(T *d_src,
}
}
namespace cv { namespace gpu { namespace device
namespace cv { namespace gpu { namespace device
{
namespace pyramid
{
@ -211,7 +211,7 @@ namespace cv { namespace gpu { namespace device
dim3 bDim(16, 8);
dim3 gDim(divUp(src.cols, bDim.x), divUp(src.rows, bDim.y));
kernelDownsampleX2<<<gDim, bDim, 0, stream>>>((T*)src.data, static_cast<Ncv32u>(src.step),
kernelDownsampleX2<<<gDim, bDim, 0, stream>>>((T*)src.data, static_cast<Ncv32u>(src.step),
(T*)dst.data, static_cast<Ncv32u>(dst.step), NcvSize32u(dst.cols, dst.rows));
cudaSafeCall( cudaGetLastError() );
@ -277,7 +277,7 @@ __global__ void kernelInterpolateFrom1(T *d_srcTop,
d_dst_line[j] = outPix;
}
}
namespace cv { namespace gpu { namespace device
namespace cv { namespace gpu { namespace device
{
namespace pyramid
{
@ -286,7 +286,7 @@ namespace cv { namespace gpu { namespace device
dim3 bDim(16, 8);
dim3 gDim(divUp(dst.cols, bDim.x), divUp(dst.rows, bDim.y));
kernelInterpolateFrom1<<<gDim, bDim, 0, stream>>>((T*) src.data, static_cast<Ncv32u>(src.step), NcvSize32u(src.cols, src.rows),
kernelInterpolateFrom1<<<gDim, bDim, 0, stream>>>((T*) src.data, static_cast<Ncv32u>(src.step), NcvSize32u(src.cols, src.rows),
(T*) dst.data, static_cast<Ncv32u>(dst.step), NcvSize32u(dst.cols, dst.rows));
cudaSafeCall( cudaGetLastError() );

@ -1,7 +1,7 @@
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
@ -54,14 +54,14 @@
// The Loki Library
// Copyright (c) 2001 by Andrei Alexandrescu
// This code accompanies the book:
// Alexandrescu, Andrei. "Modern C++ Design: Generic Programming and Design
// Alexandrescu, Andrei. "Modern C++ Design: Generic Programming and Design
// Patterns Applied". Copyright (c) 2001. Addison-Wesley.
// Permission to use, copy, modify, distribute and sell this software for any
// purpose is hereby granted without fee, provided that the above copyright
// notice appear in all copies and that both that copyright notice and this
// Permission to use, copy, modify, distribute and sell this software for any
// purpose is hereby granted without fee, provided that the above copyright
// notice appear in all copies and that both that copyright notice and this
// permission notice appear in supporting documentation.
// The author or Addison-Welsey Longman make no representations about the
// suitability of this software for any purpose. It is provided "as is"
// The author or Addison-Welsey Longman make no representations about the
// suitability of this software for any purpose. It is provided "as is"
// without express or implied warranty.
// http://loki-lib.sourceforge.net/index.php?n=Main.License
////////////////////////////////////////////////////////////////////////////////
@ -71,7 +71,7 @@ namespace Loki
//==============================================================================
// class NullType
// Used as a placeholder for "no type here"
// Useful as an end marker in typelists
// Useful as an end marker in typelists
//==============================================================================
class NullType {};
@ -110,7 +110,7 @@ namespace Loki
//==============================================================================
// class template TypeAt
// Finds the type at a given index in a typelist
// Invocation (TList is a typelist and index is a compile-time integral
// Invocation (TList is a typelist and index is a compile-time integral
// constant):
// TypeAt<TList, index>::Result
// returns the type in position 'index' in TList

@ -1,11 +1,11 @@
/*
* Copyright 1993-2010 NVIDIA Corporation. All rights reserved.
*
* NVIDIA Corporation and its licensors retain all intellectual
* property and proprietary rights in and to this software and
* related documentation and any modifications thereto.
* Any use, reproduction, disclosure, or distribution of this
* software and related documentation without an express license
* NVIDIA Corporation and its licensors retain all intellectual
* property and proprietary rights in and to this software and
* related documentation and any modifications thereto.
* Any use, reproduction, disclosure, or distribution of this
* software and related documentation without an express license
* agreement from NVIDIA Corporation is strictly prohibited.
*/
#ifndef _ncvautotestlister_hpp_
@ -47,7 +47,7 @@ public:
if (outputLevel == OutputLevelCompact)
{
printf("Test suite '%s' with %d tests\n",
printf("Test suite '%s' with %d tests\n",
testSuiteName.c_str(),
(int)(this->tests.size()));
}
@ -109,7 +109,7 @@ public:
if (outputLevel != OutputLevelNone)
{
printf("Test suite '%s' complete: %d total, %d passed, %d memory errors, %d failed\n\n",
printf("Test suite '%s' complete: %d total, %d passed, %d memory errors, %d failed\n\n",
testSuiteName.c_str(),
(int)(this->tests.size()),
nPassed,

@ -1,11 +1,11 @@
/*
* Copyright 1993-2010 NVIDIA Corporation. All rights reserved.
*
* NVIDIA Corporation and its licensors retain all intellectual
* property and proprietary rights in and to this software and
* related documentation and any modifications thereto.
* Any use, reproduction, disclosure, or distribution of this
* software and related documentation without an express license
* NVIDIA Corporation and its licensors retain all intellectual
* property and proprietary rights in and to this software and
* related documentation and any modifications thereto.
* Any use, reproduction, disclosure, or distribution of this
* software and related documentation without an express license
* agreement from NVIDIA Corporation is strictly prohibited.
*/
#ifndef _ncvtest_hpp_

@ -1,11 +1,11 @@
/*
* Copyright 1993-2010 NVIDIA Corporation. All rights reserved.
*
* NVIDIA Corporation and its licensors retain all intellectual
* property and proprietary rights in and to this software and
* related documentation and any modifications thereto.
* Any use, reproduction, disclosure, or distribution of this
* software and related documentation without an express license
* NVIDIA Corporation and its licensors retain all intellectual
* property and proprietary rights in and to this software and
* related documentation and any modifications thereto.
* Any use, reproduction, disclosure, or distribution of this
* software and related documentation without an express license
* agreement from NVIDIA Corporation is strictly prohibited.
*/
@ -204,7 +204,7 @@ bool TestHaarCascadeApplication::process()
ncvAssertReturn(cudaSuccess == cudaStreamSynchronize(0), false);
#if !defined(__APPLE__)
#if defined(__GNUC__)
//http://www.christian-seiler.de/projekte/fpmath/
@ -239,7 +239,7 @@ bool TestHaarCascadeApplication::process()
_controlfp_s(&fpu_cw, fpu_oldcw, _MCW_PC);
#endif
#endif
#endif
NCV_SKIP_COND_END

@ -13,10 +13,10 @@
#include "NCVHaarObjectDetection.hpp"
TestHypothesesFilter::TestHypothesesFilter(std::string testName, NCVTestSourceProvider<Ncv32u> &src_,
TestHypothesesFilter::TestHypothesesFilter(std::string testName_, NCVTestSourceProvider<Ncv32u> &src_,
Ncv32u numDstRects_, Ncv32u minNeighbors_, Ncv32f eps_)
:
NCVTestProvider(testName),
NCVTestProvider(testName_),
src(src_),
numDstRects(numDstRects_),
minNeighbors(minNeighbors_),

@ -15,10 +15,10 @@
template <class T>
TestResize<T>::TestResize(std::string testName, NCVTestSourceProvider<T> &src_,
TestResize<T>::TestResize(std::string testName_, NCVTestSourceProvider<T> &src_,
Ncv32u width_, Ncv32u height_, Ncv32u scaleFactor_, NcvBool bTextureCache_)
:
NCVTestProvider(testName),
NCVTestProvider(testName_),
src(src_),
width(width_),
height(height_),

@ -34,7 +34,7 @@ PERF_TEST_P(ImageName_MinSize, CascadeClassifierLBPFrontalFace,
if (cc.empty())
FAIL() << "Can't load cascade file";
Mat img=imread(getDataPath(filename), 0);
Mat img = imread(getDataPath(filename), 0);
if (img.empty())
FAIL() << "Can't load source image";

@ -56,8 +56,8 @@ namespace cv
+ (step) * ((rect).y + (rect).width + (rect).height)
#define CALC_SUM_(p0, p1, p2, p3, offset) \
((p0)[offset] - (p1)[offset] - (p2)[offset] + (p3)[offset])
((p0)[offset] - (p1)[offset] - (p2)[offset] + (p3)[offset])
#define CALC_SUM(rect,offset) CALC_SUM_((rect)[0], (rect)[1], (rect)[2], (rect)[3], offset)
@ -68,24 +68,24 @@ public:
struct Feature
{
Feature();
float calc( int offset ) const;
void updatePtrs( const Mat& sum );
bool read( const FileNode& node );
bool tilted;
enum { RECT_NUM = 3 };
struct
{
Rect r;
float weight;
} rect[RECT_NUM];
const int* p[RECT_NUM][4];
};
HaarEvaluator();
virtual ~HaarEvaluator();
@ -109,13 +109,13 @@ protected:
Mat sum0, sqsum0, tilted0;
Mat sum, sqsum, tilted;
Rect normrect;
const int *p[4];
const double *pq[4];
int offset;
double varianceNormFactor;
double varianceNormFactor;
};
inline HaarEvaluator::Feature :: Feature()
@ -123,8 +123,8 @@ inline HaarEvaluator::Feature :: Feature()
tilted = false;
rect[0].r = rect[1].r = rect[2].r = Rect();
rect[0].weight = rect[1].weight = rect[2].weight = 0;
p[0][0] = p[0][1] = p[0][2] = p[0][3] =
p[1][0] = p[1][1] = p[1][2] = p[1][3] =
p[0][0] = p[0][1] = p[0][2] = p[0][3] =
p[1][0] = p[1][1] = p[1][2] = p[1][3] =
p[2][0] = p[2][1] = p[2][2] = p[2][3] = 0;
}
@ -134,7 +134,7 @@ inline float HaarEvaluator::Feature :: calc( int offset ) const
if( rect[2].weight != 0.0f )
ret += rect[2].weight * CALC_SUM(p[2], offset);
return ret;
}
@ -167,27 +167,27 @@ public:
struct Feature
{
Feature();
Feature( int x, int y, int _block_w, int _block_h ) :
Feature( int x, int y, int _block_w, int _block_h ) :
rect(x, y, _block_w, _block_h) {}
int calc( int offset ) const;
void updatePtrs( const Mat& sum );
bool read(const FileNode& node );
Rect rect; // weight and height for block
const int* p[16]; // fast
};
LBPEvaluator();
virtual ~LBPEvaluator();
virtual bool read( const FileNode& node );
virtual Ptr<FeatureEvaluator> clone() const;
virtual int getFeatureType() const { return FeatureEvaluator::LBP; }
virtual bool setImage(const Mat& image, Size _origWinSize);
virtual bool setWindow(Point pt);
int operator()(int featureIdx) const
{ return featuresPtr[featureIdx].calc(offset); }
virtual int calcCat(int featureIdx) const
@ -200,9 +200,9 @@ protected:
Rect normrect;
int offset;
};
};
inline LBPEvaluator::Feature :: Feature()
{
rect = Rect();
@ -213,7 +213,7 @@ inline LBPEvaluator::Feature :: Feature()
inline int LBPEvaluator::Feature :: calc( int offset ) const
{
int cval = CALC_SUM_( p[5], p[6], p[9], p[10], offset );
return (CALC_SUM_( p[0], p[1], p[4], p[5], offset ) >= cval ? 128 : 0) | // 0
(CALC_SUM_( p[1], p[2], p[5], p[6], offset ) >= cval ? 64 : 0) | // 1
(CALC_SUM_( p[2], p[3], p[6], p[7], offset ) >= cval ? 32 : 0) | // 2
@ -248,7 +248,7 @@ public:
Feature();
float calc( int offset ) const;
void updatePtrs( const vector<Mat>& _hist, const Mat &_normSum );
bool read( const FileNode& node );
bool read( const FileNode& node );
enum { CELL_NUM = 4, BIN_NUM = 9 };
@ -331,13 +331,13 @@ inline int predictOrdered( CascadeClassifier& cascade, Ptr<FeatureEvaluator> &_f
CascadeClassifier::Data::DTreeNode* cascadeNodes = &cascade.data.nodes[0];
CascadeClassifier::Data::DTree* cascadeWeaks = &cascade.data.classifiers[0];
CascadeClassifier::Data::Stage* cascadeStages = &cascade.data.stages[0];
for( int si = 0; si < nstages; si++ )
{
CascadeClassifier::Data::Stage& stage = cascadeStages[si];
int wi, ntrees = stage.ntrees;
sum = 0;
for( wi = 0; wi < ntrees; wi++ )
{
CascadeClassifier::Data::DTree& weak = cascadeWeaks[stage.first + wi];
@ -355,7 +355,7 @@ inline int predictOrdered( CascadeClassifier& cascade, Ptr<FeatureEvaluator> &_f
leafOfs += weak.nodeCount + 1;
}
if( sum < stage.threshold )
return -si;
return -si;
}
return 1;
}
@ -372,13 +372,13 @@ inline int predictCategorical( CascadeClassifier& cascade, Ptr<FeatureEvaluator>
CascadeClassifier::Data::DTreeNode* cascadeNodes = &cascade.data.nodes[0];
CascadeClassifier::Data::DTree* cascadeWeaks = &cascade.data.classifiers[0];
CascadeClassifier::Data::Stage* cascadeStages = &cascade.data.stages[0];
for(int si = 0; si < nstages; si++ )
{
CascadeClassifier::Data::Stage& stage = cascadeStages[si];
int wi, ntrees = stage.ntrees;
sum = 0;
for( wi = 0; wi < ntrees; wi++ )
{
CascadeClassifier::Data::DTree& weak = cascadeWeaks[stage.first + wi];
@ -396,7 +396,7 @@ inline int predictCategorical( CascadeClassifier& cascade, Ptr<FeatureEvaluator>
leafOfs += weak.nodeCount + 1;
}
if( sum < stage.threshold )
return -si;
return -si;
}
return 1;
}
@ -444,7 +444,7 @@ inline int predictCategoricalStump( CascadeClassifier& cascade, Ptr<FeatureEvalu
CascadeClassifier::Data::Stage* cascadeStages = &cascade.data.stages[0];
#ifdef HAVE_TEGRA_OPTIMIZATION
float tmp; // float accumulator -- float operations are quicker
float tmp; // float accumulator -- float operations are quicker
#endif
for( int si = 0; si < nstages; si++ )
{
@ -472,11 +472,11 @@ inline int predictCategoricalStump( CascadeClassifier& cascade, Ptr<FeatureEvalu
#ifdef HAVE_TEGRA_OPTIMIZATION
if( tmp < stage.threshold ) {
sum = (double)tmp;
return -si;
return -si;
}
#else
if( sum < stage.threshold )
return -si;
return -si;
#endif
}

@ -53,13 +53,13 @@ IplImage* resize_opencv(IplImage* img, float scale)
//}
//// resize along each column
//// result is transposed, so we can apply it twice for a complete resize
//void resize1dtran(float *src, int sheight, float *dst, int dheight,
//void resize1dtran(float *src, int sheight, float *dst, int dheight,
// int width, int chan) {
// alphainfo *ofs;
// float scale = (float)dheight/(float)sheight;
// float invscale = (float)sheight/(float)dheight;
//
// // we cache the interpolation values since they can be
//
// // we cache the interpolation values since they can be
// // shared among different columns
// int len = (int)ceilf(dheight*invscale) + 2*dheight;
// int k = 0;
@ -126,7 +126,7 @@ IplImage* resize_opencv(IplImage* img, float scale)
// int index;
// int widthStep;
// int tW, tH;
//
//
// W = (float)img->width;
// H = (float)img->height;
// channels = img->nChannels;
@ -149,16 +149,16 @@ IplImage* resize_opencv(IplImage* img, float scale)
// }
// }
// }
//
//
// imgTmp = cvCreateImage(cvSize(tW , tH), IPL_DEPTH_32F, channels);
//
// dst = (float *)malloc(sizeof(float) * (int)(tH * tW) * channels);
// tmp = (float *)malloc(sizeof(float) * (int)(tH * W) * channels);
//
// resize1dtran(src, (int)H, tmp, (int)tH, (int)W , 3);
//
//
// resize1dtran(tmp, (int)W, dst, (int)tW, (int)tH, 3);
//
//
// index = 0;
// //dataf = (float*)imgTmp->imageData;
// for (kk = 0; kk < channels; kk++)
@ -188,7 +188,7 @@ IplImage* resize_opencv(IplImage* img, float scale)
// int index;
// int widthStep;
// int tW, tH;
//
//
// W = (float)img->width;
// H = (float)img->height;
// channels = img->nChannels;
@ -210,16 +210,16 @@ IplImage* resize_opencv(IplImage* img, float scale)
// }
// }
// }
//
//
// imgTmp = cvCreateImage(cvSize(tW , tH), IPL_DEPTH_32F, channels);
//
// dst = (float *)malloc(sizeof(float) * (int)(tH * tW) * channels);
// tmp = (float *)malloc(sizeof(float) * (int)(tH * W) * channels);
//
// resize1dtran(src, (int)H, tmp, (int)tH, (int)W , 3);
//
//
// resize1dtran(tmp, (int)W, dst, (int)tW, (int)tH, 3);
//
//
// index = 0;
// for (kk = 0; kk < channels; kk++)
// {
@ -232,7 +232,7 @@ IplImage* resize_opencv(IplImage* img, float scale)
// }
// }
// }
//
//
// free(src);
// free(dst);
// free(tmp);

@ -2,7 +2,7 @@
#include "_lsvm_routine.h"
int allocFilterObject(CvLSVMFilterObject **obj, const int sizeX,
const int sizeY, const int numFeatures)
const int sizeY, const int numFeatures)
{
int i;
(*obj) = (CvLSVMFilterObject *)malloc(sizeof(CvLSVMFilterObject));
@ -16,7 +16,7 @@ int allocFilterObject(CvLSVMFilterObject **obj, const int sizeX,
(*obj)->V.x = 0;
(*obj)->V.y = 0;
(*obj)->V.l = 0;
(*obj)->H = (float *) malloc(sizeof (float) *
(*obj)->H = (float *) malloc(sizeof (float) *
(sizeX * sizeY * numFeatures));
for(i = 0; i < sizeX * sizeY * numFeatures; i++)
{
@ -33,7 +33,7 @@ int freeFilterObject (CvLSVMFilterObject **obj)
return LATENT_SVM_OK;
}
int allocFeatureMapObject(CvLSVMFeatureMap **obj, const int sizeX,
int allocFeatureMapObject(CvLSVMFeatureMap **obj, const int sizeX,
const int sizeY, const int numFeatures)
{
int i;
@ -41,7 +41,7 @@ int allocFeatureMapObject(CvLSVMFeatureMap **obj, const int sizeX,
(*obj)->sizeX = sizeX;
(*obj)->sizeY = sizeY;
(*obj)->numFeatures = numFeatures;
(*obj)->map = (float *) malloc(sizeof (float) *
(*obj)->map = (float *) malloc(sizeof (float) *
(sizeX * sizeY * numFeatures));
for(i = 0; i < sizeX * sizeY * numFeatures; i++)
{
@ -59,7 +59,7 @@ int freeFeatureMapObject (CvLSVMFeatureMap **obj)
}
int allocFeaturePyramidObject(CvLSVMFeaturePyramid **obj,
const int numLevels)
const int numLevels)
{
(*obj) = (CvLSVMFeaturePyramid *)malloc(sizeof(CvLSVMFeaturePyramid));
(*obj)->numLevels = numLevels;
@ -70,7 +70,7 @@ int allocFeaturePyramidObject(CvLSVMFeaturePyramid **obj,
int freeFeaturePyramidObject (CvLSVMFeaturePyramid **obj)
{
int i;
int i;
if(*obj == NULL) return LATENT_SVM_MEM_NULL;
for(i = 0; i < (*obj)->numLevels; i++)
{

Loading…
Cancel
Save